Skip to content

Commit e1f1b47

Browse files
authored
optimize listByPagination (#246)
1 parent 46bfb6f commit e1f1b47

File tree

5 files changed

+94
-4
lines changed

5 files changed

+94
-4
lines changed

mcp/prompts.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ type Prompt struct {
5050
Arguments []PromptArgument `json:"arguments,omitempty"`
5151
}
5252

53+
// GetName returns the name of the prompt.
54+
func (p Prompt) GetName() string {
55+
return p.Name
56+
}
57+
5358
// PromptArgument describes an argument that a prompt template can accept.
5459
// When a prompt includes arguments, clients must provide values for all
5560
// required arguments when making a prompts/get request.

mcp/tools.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ type Tool struct {
7979
Annotations ToolAnnotation `json:"annotations"`
8080
}
8181

82+
// GetName returns the name of the tool.
83+
func (t Tool) GetName() string {
84+
return t.Name
85+
}
86+
8287
// MarshalJSON implements the json.Marshaler interface for Tool.
8388
// It handles marshaling either InputSchema or RawInputSchema based on which is set.
8489
func (t Tool) MarshalJSON() ([]byte, error) {

mcp/types.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,11 @@ type Resource struct {
523523
MIMEType string `json:"mimeType,omitempty"`
524524
}
525525

526+
// GetName returns the name of the resource.
527+
func (r Resource) GetName() string {
528+
return r.Name
529+
}
530+
526531
// ResourceTemplate represents a template description for resources available
527532
// on the server.
528533
type ResourceTemplate struct {
@@ -544,6 +549,11 @@ type ResourceTemplate struct {
544549
MIMEType string `json:"mimeType,omitempty"`
545550
}
546551

552+
// GetName returns the name of the resourceTemplate.
553+
func (rt ResourceTemplate) GetName() string {
554+
return rt.Name
555+
}
556+
547557
// ResourceContents represents the contents of a specific resource or sub-
548558
// resource.
549559
type ResourceContents interface {
@@ -893,3 +903,7 @@ type ServerNotification any
893903

894904
// ServerResult represents any result that can be sent from server to client.
895905
type ServerResult any
906+
907+
type Named interface {
908+
GetName() string
909+
}

server/server.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/base64"
77
"encoding/json"
88
"fmt"
9-
"reflect"
109
"sort"
1110
"sync"
1211

@@ -541,7 +540,7 @@ func (s *MCPServer) handlePing(
541540
return &mcp.EmptyResult{}, nil
542541
}
543542

544-
func listByPagination[T any](
543+
func listByPagination[T mcp.Named](
545544
ctx context.Context,
546545
s *MCPServer,
547546
cursor mcp.Cursor,
@@ -555,7 +554,7 @@ func listByPagination[T any](
555554
}
556555
cString := string(c)
557556
startPos = sort.Search(len(allElements), func(i int) bool {
558-
return reflect.ValueOf(allElements[i]).FieldByName("Name").String() > cString
557+
return allElements[i].GetName() > cString
559558
})
560559
}
561560
endPos := len(allElements)
@@ -568,7 +567,7 @@ func listByPagination[T any](
568567
// set the next cursor
569568
nextCursor := func() mcp.Cursor {
570569
if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit {
571-
nc := reflect.ValueOf(elementsToReturn[len(elementsToReturn)-1]).FieldByName("Name").String()
570+
nc := elementsToReturn[len(elementsToReturn)-1].GetName()
572571
toString := base64.StdEncoding.EncodeToString([]byte(nc))
573572
return mcp.Cursor(toString)
574573
}

server/server_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"reflect"
10+
"sort"
911
"testing"
1012
"time"
1113

@@ -1557,3 +1559,68 @@ func TestMCPServer_WithRecover(t *testing.T) {
15571559
assert.Equal(t, "panic recovered in panic-tool tool handler: test panic", errorResponse.Error.Message)
15581560
assert.Nil(t, errorResponse.Error.Data)
15591561
}
1562+
1563+
func getTools(length int) []mcp.Tool {
1564+
list := make([]mcp.Tool, 0, 10000)
1565+
for i := 0; i < length; i++ {
1566+
list = append(list, mcp.Tool{
1567+
Name: fmt.Sprintf("tool%d", i),
1568+
Description: fmt.Sprintf("tool%d", i),
1569+
})
1570+
}
1571+
return list
1572+
}
1573+
1574+
func listByPaginationForReflect[T any](
1575+
ctx context.Context,
1576+
s *MCPServer,
1577+
cursor mcp.Cursor,
1578+
allElements []T,
1579+
) ([]T, mcp.Cursor, error) {
1580+
startPos := 0
1581+
if cursor != "" {
1582+
c, err := base64.StdEncoding.DecodeString(string(cursor))
1583+
if err != nil {
1584+
return nil, "", err
1585+
}
1586+
cString := string(c)
1587+
startPos = sort.Search(len(allElements), func(i int) bool {
1588+
return reflect.ValueOf(allElements[i]).FieldByName("Name").String() > cString
1589+
})
1590+
}
1591+
endPos := len(allElements)
1592+
if s.paginationLimit != nil {
1593+
if len(allElements) > startPos+*s.paginationLimit {
1594+
endPos = startPos + *s.paginationLimit
1595+
}
1596+
}
1597+
elementsToReturn := allElements[startPos:endPos]
1598+
// set the next cursor
1599+
nextCursor := func() mcp.Cursor {
1600+
if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit {
1601+
nc := reflect.ValueOf(elementsToReturn[len(elementsToReturn)-1]).FieldByName("Name").String()
1602+
toString := base64.StdEncoding.EncodeToString([]byte(nc))
1603+
return mcp.Cursor(toString)
1604+
}
1605+
return ""
1606+
}()
1607+
return elementsToReturn, nextCursor, nil
1608+
}
1609+
1610+
func BenchmarkMCPServer_Pagination(b *testing.B) {
1611+
list := getTools(10000)
1612+
ctx := context.Background()
1613+
server := createTestServer()
1614+
for i := 0; i < b.N; i++ {
1615+
_, _, _ = listByPagination[mcp.Tool](ctx, server, "dG9vbDY1NA==", list)
1616+
}
1617+
}
1618+
1619+
func BenchmarkMCPServer_PaginationForReflect(b *testing.B) {
1620+
list := getTools(10000)
1621+
ctx := context.Background()
1622+
server := createTestServer()
1623+
for i := 0; i < b.N; i++ {
1624+
_, _, _ = listByPaginationForReflect[mcp.Tool](ctx, server, "dG9vbDY1NA==", list)
1625+
}
1626+
}

0 commit comments

Comments
 (0)