Skip to content

Commit 82ee0fd

Browse files
samthanawallagopherbot
authored andcommitted
internal/mcp: change paginateList to a generic helper
This CL simplifies the paginateList function in server.go to use a generic helper for tools, resources, and prompts. Change-Id: Ide0d2a90d715374280067e094d8870882e6bddfa Reviewed-on: https://go-review.googlesource.com/c/tools/+/678055 Auto-Submit: Sam Thanawalla <[email protected]> Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 64bfecc commit 82ee0fd

File tree

4 files changed

+96
-81
lines changed

4 files changed

+96
-81
lines changed

internal/mcp/client.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -391,23 +391,13 @@ func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams)
391391
})
392392
}
393393

394-
type ListParams interface {
395-
// Returns a pointer to the param's Cursor field.
396-
cursorPtr() *string
397-
}
398-
399-
type ListResult[T any] interface {
400-
// Returns a pointer to the param's NextCursor field.
401-
nextCursorPtr() *string
402-
}
403-
404394
// paginate is a generic helper function to provide a paginated iterator.
405-
func paginate[P ListParams, R ListResult[E], E any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*E) iter.Seq2[E, error] {
406-
return func(yield func(E, error) bool) {
395+
func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[T, error] {
396+
return func(yield func(T, error) bool) {
407397
for {
408398
res, err := listFunc(ctx, params)
409399
if err != nil {
410-
var zero E
400+
var zero T
411401
yield(zero, err)
412402
return
413403
}

internal/mcp/server.go

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,15 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] {
182182
func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) {
183183
s.mu.Lock()
184184
defer s.mu.Unlock()
185-
var cursor string
186-
if params != nil {
187-
cursor = params.Cursor
185+
if params == nil {
186+
params = &ListPromptsParams{}
188187
}
189-
prompts, nextCursor, err := paginateList(s.prompts, cursor, s.opts.PageSize)
190-
if err != nil {
191-
return nil, err
192-
}
193-
res := new(ListPromptsResult)
194-
res.NextCursor = nextCursor
195-
res.Prompts = []*Prompt{} // avoid JSON null
196-
for _, p := range prompts {
197-
res.Prompts = append(res.Prompts, p.Prompt)
198-
}
199-
return res, nil
188+
return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*ServerPrompt) {
189+
res.Prompts = []*Prompt{} // avoid JSON null
190+
for _, p := range prompts {
191+
res.Prompts = append(res.Prompts, p.Prompt)
192+
}
193+
})
200194
}
201195

202196
func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) {
@@ -213,21 +207,15 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr
213207
func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) {
214208
s.mu.Lock()
215209
defer s.mu.Unlock()
216-
var cursor string
217-
if params != nil {
218-
cursor = params.Cursor
219-
}
220-
tools, nextCursor, err := paginateList(s.tools, cursor, s.opts.PageSize)
221-
if err != nil {
222-
return nil, err
223-
}
224-
res := new(ListToolsResult)
225-
res.NextCursor = nextCursor
226-
res.Tools = []*Tool{} // avoid JSON null
227-
for _, t := range tools {
228-
res.Tools = append(res.Tools, t.Tool)
210+
if params == nil {
211+
params = &ListToolsParams{}
229212
}
230-
return res, nil
213+
return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) {
214+
res.Tools = []*Tool{} // avoid JSON null
215+
for _, t := range tools {
216+
res.Tools = append(res.Tools, t.Tool)
217+
}
218+
})
231219
}
232220

233221
func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) {
@@ -243,21 +231,15 @@ func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallTo
243231
func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) {
244232
s.mu.Lock()
245233
defer s.mu.Unlock()
246-
var cursor string
247-
if params != nil {
248-
cursor = params.Cursor
249-
}
250-
resources, nextCursor, err := paginateList(s.resources, cursor, s.opts.PageSize)
251-
if err != nil {
252-
return nil, err
234+
if params == nil {
235+
params = &ListResourcesParams{}
253236
}
254-
res := new(ListResourcesResult)
255-
res.NextCursor = nextCursor
256-
res.Resources = []*Resource{} // avoid JSON null
257-
for _, r := range resources {
258-
res.Resources = append(res.Resources, r.Resource)
259-
}
260-
return res, nil
237+
return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*ServerResource) {
238+
res.Resources = []*Resource{} // avoid JSON null
239+
for _, r := range resources {
240+
res.Resources = append(res.Resources, r.Resource)
241+
}
242+
})
261243
}
262244

263245
func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) {
@@ -618,22 +600,25 @@ func decodeCursor(cursor string) (*pageToken, error) {
618600
return &token, nil
619601
}
620602

621-
// paginateList returns a slice of features from the given featureSet, based on
622-
// the provided cursor and page size. It also returns a new cursor for the next
623-
// page, or an empty string if there are no more pages.
624-
func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (features []T, nextCursor string, err error) {
603+
// paginateList is a generic helper that returns a paginated slice of items
604+
// from a featureSet. It populates the provided result res with the items
605+
// and sets its next cursor for subsequent pages.
606+
// If there are no more pages, the next cursor within the result will be an empty string.
607+
func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageSize int, params P, res R, setFunc func(R, []T)) (R, error) {
625608
var seq iter.Seq[T]
626-
if cursor == "" {
609+
if params.cursorPtr() == nil || *params.cursorPtr() == "" {
627610
seq = fs.all()
628611
} else {
629-
pageToken, err := decodeCursor(cursor)
612+
pageToken, err := decodeCursor(*params.cursorPtr())
630613
// According to the spec, invalid cursors should return Invalid params.
631614
if err != nil {
632-
return nil, "", jsonrpc2.ErrInvalidParams
615+
var zero R
616+
return zero, jsonrpc2.ErrInvalidParams
633617
}
634618
seq = fs.above(pageToken.LastUID)
635619
}
636620
var count int
621+
var features []T
637622
for f := range seq {
638623
count++
639624
// If we've seen pageSize + 1 elements, we've gathered enough info to determine
@@ -643,13 +628,16 @@ func paginateList[T any](fs *featureSet[T], cursor string, pageSize int) (featur
643628
}
644629
features = append(features, f)
645630
}
631+
setFunc(res, features)
646632
// No remaining pages.
647633
if count < pageSize+1 {
648-
return features, "", nil
634+
return res, nil
649635
}
650-
nextCursor, err = encodeCursor(fs.uniqueID(features[len(features)-1]))
636+
nextCursor, err := encodeCursor(fs.uniqueID(features[len(features)-1]))
651637
if err != nil {
652-
return nil, "", err
638+
var zero R
639+
return zero, err
653640
}
654-
return features, nextCursor, nil
641+
*res.nextCursorPtr() = nextCursor
642+
return res, nil
655643
}

internal/mcp/server_test.go

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,29 @@ import (
1212
"github.com/google/go-cmp/cmp"
1313
)
1414

15-
type TestItem struct {
15+
type testItem struct {
1616
Name string
1717
Value string
1818
}
1919

20-
var allTestItems = []*TestItem{
20+
type testListParams struct {
21+
Cursor string
22+
}
23+
24+
func (p *testListParams) cursorPtr() *string {
25+
return &p.Cursor
26+
}
27+
28+
type testListResult struct {
29+
Items []*testItem
30+
NextCursor string
31+
}
32+
33+
func (r *testListResult) nextCursorPtr() *string {
34+
return &r.NextCursor
35+
}
36+
37+
var allTestItems = []*testItem{
2138
{"alpha", "val-A"},
2239
{"bravo", "val-B"},
2340
{"charlie", "val-C"},
@@ -44,10 +61,10 @@ func getCursor(input string) string {
4461
func TestServerPaginateBasic(t *testing.T) {
4562
testCases := []struct {
4663
name string
47-
initialItems []*TestItem
64+
initialItems []*testItem
4865
inputCursor string
4966
inputPageSize int
50-
wantFeatures []*TestItem
67+
wantFeatures []*testItem
5168
wantNextCursor string
5269
wantErr bool
5370
}{
@@ -154,41 +171,51 @@ func TestServerPaginateBasic(t *testing.T) {
154171

155172
for _, tc := range testCases {
156173
t.Run(tc.name, func(t *testing.T) {
157-
fs := newFeatureSet(func(t *TestItem) string { return t.Name })
174+
fs := newFeatureSet(func(t *testItem) string { return t.Name })
158175
fs.add(tc.initialItems...)
159-
gotFeatures, gotNextCursor, err := paginateList(fs, tc.inputCursor, tc.inputPageSize)
176+
params := &testListParams{Cursor: tc.inputCursor}
177+
gotResult, err := paginateList(fs, tc.inputPageSize, params, &testListResult{}, func(res *testListResult, items []*testItem) {
178+
res.Items = items
179+
})
160180
if (err != nil) != tc.wantErr {
161181
t.Errorf("paginateList(%s) error, got %v, wantErr %v", tc.name, err, tc.wantErr)
162182
}
163-
if diff := cmp.Diff(tc.wantFeatures, gotFeatures); diff != "" {
183+
if tc.wantErr {
184+
return
185+
}
186+
if diff := cmp.Diff(tc.wantFeatures, gotResult.Items); diff != "" {
164187
t.Errorf("paginateList(%s) mismatch (-want +got):\n%s", tc.name, diff)
165188
}
166-
if tc.wantNextCursor != gotNextCursor {
167-
t.Errorf("paginateList(%s) nextCursor, got %v, want %v", tc.name, gotNextCursor, tc.wantNextCursor)
189+
if tc.wantNextCursor != gotResult.NextCursor {
190+
t.Errorf("paginateList(%s) nextCursor, got %v, want %v", tc.name, gotResult.NextCursor, tc.wantNextCursor)
168191
}
169192
})
170193
}
171194
}
172195

173196
func TestServerPaginateVariousPageSizes(t *testing.T) {
174-
fs := newFeatureSet(func(t *TestItem) string { return t.Name })
197+
fs := newFeatureSet(func(t *testItem) string { return t.Name })
175198
fs.add(allTestItems...)
176199
// Try all possible page sizes, ensuring we get the correct list of items.
177200
for pageSize := 1; pageSize < len(allTestItems)+1; pageSize++ {
178-
var gotItems []*TestItem
201+
var gotItems []*testItem
179202
var nextCursor string
180203
wantChunks := slices.Collect(slices.Chunk(allTestItems, pageSize))
181204
index := 0
182205
// Iterate through all pages, comparing sub-slices to the paginated list.
183206
for {
184-
gotFeatures, gotNextCursor, err := paginateList(fs, nextCursor, pageSize)
207+
params := &testListParams{Cursor: nextCursor}
208+
gotResult, err := paginateList(fs, pageSize, params, &testListResult{}, func(res *testListResult, items []*testItem) {
209+
res.Items = items
210+
})
185211
if err != nil {
212+
t.Fatalf("paginateList() unexpected error for pageSize %d, cursor %q: %v", pageSize, nextCursor, err)
186213
}
187-
if diff := cmp.Diff(wantChunks[index], gotFeatures); diff != "" {
214+
if diff := cmp.Diff(wantChunks[index], gotResult.Items); diff != "" {
188215
t.Errorf("paginateList mismatch (-want +got):\n%s", diff)
189216
}
190-
gotItems = append(gotItems, gotFeatures...)
191-
nextCursor = gotNextCursor
217+
gotItems = append(gotItems, gotResult.Items...)
218+
nextCursor = gotResult.NextCursor
192219
if nextCursor == "" {
193220
break
194221
}

internal/mcp/shared.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,13 @@ type Result interface {
271271
type emptyResult struct{}
272272

273273
func (*emptyResult) GetMeta() *Meta { panic("should never be called") }
274+
275+
type listParams interface {
276+
// Returns a pointer to the param's Cursor field.
277+
cursorPtr() *string
278+
}
279+
280+
type listResult[T any] interface {
281+
// Returns a pointer to the param's NextCursor field.
282+
nextCursorPtr() *string
283+
}

0 commit comments

Comments
 (0)