Skip to content

Commit e43ca0c

Browse files
committed
internal/mcp: validate tool input schemas
A tool now validates its input with its InputSchema. Change schema inference to allow explicit nulls for fields of pointer type. We assume the schema has no external references. For example, the following schema cannot be handled: { "$ref": "https://example.com/other.json" } Schemas with internal references, like to a "$defs", are fine. Change-Id: I6ee7c18c2c5cb609df0b22a66da986f7ea64bbe4 Reviewed-on: https://go-review.googlesource.com/c/tools/+/670676 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]>
1 parent 61f37dc commit e43ca0c

File tree

9 files changed

+160
-17
lines changed

9 files changed

+160
-17
lines changed

internal/mcp/jsonschema/infer.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,21 @@ func For[T any]() (*Schema, error) {
4242
// - complex numbers
4343
// - unsafe pointers
4444
//
45+
// The cannot be any cycles in the types.
4546
// TODO(rfindley): we could perhaps just skip these incompatible fields.
4647
func ForType(t reflect.Type) (*Schema, error) {
4748
return typeSchema(t)
4849
}
4950

5051
func typeSchema(t reflect.Type) (*Schema, error) {
51-
if t.Kind() == reflect.Pointer {
52+
// Follow pointers: the schema for *T is almost the same as for T, except that
53+
// an explicit JSON "null" is allowed for the pointer.
54+
allowNull := false
55+
for t.Kind() == reflect.Pointer {
56+
allowNull = true
5257
t = t.Elem()
5358
}
59+
5460
var (
5561
s = new(Schema)
5662
err error
@@ -121,6 +127,10 @@ func typeSchema(t reflect.Type) (*Schema, error) {
121127
default:
122128
return nil, fmt.Errorf("type %v is unsupported by jsonschema", t)
123129
}
130+
if allowNull && s.Type != "" {
131+
s.Types = []string{"null", s.Type}
132+
s.Type = ""
133+
}
124134
return s, nil
125135
}
126136

internal/mcp/jsonschema/infer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestForType(t *testing.T) {
5757
Properties: map[string]*schema{
5858
"f": {Type: "integer"},
5959
"G": {Type: "array", Items: &schema{Type: "number"}},
60-
"P": {Type: "boolean"},
60+
"P": {Types: []string{"null", "boolean"}},
6161
"NoSkip": {Type: "string"},
6262
},
6363
Required: []string{"f", "G", "P"},

internal/mcp/jsonschema/resolve.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ type Resolved struct {
2727
resolvedURIs map[string]*Schema
2828
}
2929

30+
// Schema returns the schema that was resolved.
31+
// It must not be modified.
32+
func (r *Resolved) Schema() *Schema { return r.root }
33+
3034
// A Loader reads and unmarshals the schema at uri, if any.
3135
type Loader func(uri *url.URL) (*Schema, error)
3236

internal/mcp/jsonschema/util.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func jsonType(v reflect.Value) (string, bool) {
270270
return "string", true
271271
case reflect.Slice, reflect.Array:
272272
return "array", true
273-
case reflect.Map:
273+
case reflect.Map, reflect.Struct:
274274
return "object", true
275275
default:
276276
return "", false

internal/mcp/jsonschema/validate.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,8 @@ func property(v reflect.Value, name string) reflect.Value {
662662
case reflect.Struct:
663663
props := structPropertiesOf(v.Type())
664664
// Ignore nonexistent properties.
665-
if index, ok := props[name]; ok {
666-
return v.FieldByIndex(index)
665+
if sf, ok := props[name]; ok {
666+
return v.FieldByIndex(sf.Index)
667667
}
668668
return reflect.Value{}
669669
default:
@@ -673,6 +673,8 @@ func property(v reflect.Value, name string) reflect.Value {
673673

674674
// properties returns an iterator over the names and values of all properties
675675
// in v, which must be a map or a struct.
676+
// If a struct, zero-valued properties that are marked omitempty or omitzero
677+
// are excluded.
676678
func properties(v reflect.Value) iter.Seq2[string, reflect.Value] {
677679
return func(yield func(string, reflect.Value) bool) {
678680
switch v.Kind() {
@@ -683,8 +685,14 @@ func properties(v reflect.Value) iter.Seq2[string, reflect.Value] {
683685
}
684686
}
685687
case reflect.Struct:
686-
for name, index := range structPropertiesOf(v.Type()) {
687-
if !yield(name, v.FieldByIndex(index)) {
688+
for name, sf := range structPropertiesOf(v.Type()) {
689+
val := v.FieldByIndex(sf.Index)
690+
if val.IsZero() {
691+
if tag, ok := sf.Tag.Lookup("json"); ok && (strings.Contains(tag, "omitempty") || strings.Contains(tag, "omitzero")) {
692+
continue
693+
}
694+
}
695+
if !yield(name, val) {
688696
return
689697
}
690698
}
@@ -707,8 +715,8 @@ func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int)
707715
case reflect.Struct:
708716
sp := structPropertiesOf(v.Type())
709717
min := 0
710-
for prop, index := range sp {
711-
if !v.FieldByIndex(index).IsZero() || isRequired[prop] {
718+
for prop, sf := range sp {
719+
if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] {
712720
min++
713721
}
714722
}
@@ -719,7 +727,7 @@ func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int)
719727
}
720728

721729
// A propertyMap is a map from property name to struct field index.
722-
type propertyMap = map[string][]int
730+
type propertyMap = map[string]reflect.StructField
723731

724732
var structProperties sync.Map // from reflect.Type to propertyMap
725733

@@ -730,10 +738,10 @@ func structPropertiesOf(t reflect.Type) propertyMap {
730738
if props, ok := structProperties.Load(t); ok {
731739
return props.(propertyMap)
732740
}
733-
props := map[string][]int{}
741+
props := map[string]reflect.StructField{}
734742
for _, sf := range reflect.VisibleFields(t) {
735743
if name, ok := jsonName(sf); ok {
736-
props[name] = sf.Index
744+
props[name] = sf
737745
}
738746
}
739747
structProperties.Store(t, props)

internal/mcp/mcp_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,4 +709,5 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] {
709709
}
710710
}
711711

712+
// A function, because schemas must form a tree (they have hidden state).
712713
func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} }

internal/mcp/prompt.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context,
4242
if schema.Type != "object" || !reflect.DeepEqual(schema.AdditionalProperties, &jsonschema.Schema{Not: &jsonschema.Schema{}}) {
4343
panic(fmt.Sprintf("handler request type must be a struct"))
4444
}
45+
resolved, err := schema.Resolve(nil)
46+
if err != nil {
47+
panic(err)
48+
}
4549
prompt := &ServerPrompt{
4650
Prompt: &Prompt{
4751
Name: name,
@@ -70,7 +74,7 @@ func NewPrompt[TReq any](name, description string, handler func(context.Context,
7074
return nil, err
7175
}
7276
var v TReq
73-
if err := unmarshalSchema(rawArgs, schema, &v); err != nil {
77+
if err := unmarshalSchema(rawArgs, resolved, &v); err != nil {
7478
return nil, err
7579
}
7680
return handler(ctx, ss, v, params)

internal/mcp/tool.go

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
package mcp
66

77
import (
8+
"bytes"
89
"context"
910
"encoding/json"
11+
"fmt"
1012
"slices"
1113

1214
"golang.org/x/tools/internal/mcp/jsonschema"
@@ -39,10 +41,17 @@ func NewTool[TReq any](name, description string, handler ToolHandler[TReq], opts
3941
if err != nil {
4042
panic(err)
4143
}
44+
// We must resolve the schema after the ToolOptions have had a chance to update it.
45+
// But the handler needs access to the resolved schema, and the options may change
46+
// the handler too.
47+
// The best we can do is use the resolved schema in our own wrapped handler,
48+
// and hope that no ToolOption replaces it.
49+
// TODO(jba): at a minimum, document this.
50+
var resolved *jsonschema.Resolved
4251
wrapped := func(ctx context.Context, cc *ServerSession, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) {
4352
var params2 CallToolParams[TReq]
4453
if params.Arguments != nil {
45-
if err := unmarshalSchema(params.Arguments, schema, &params2.Arguments); err != nil {
54+
if err := unmarshalSchema(params.Arguments, resolved, &params2.Arguments); err != nil {
4655
return nil, err
4756
}
4857
}
@@ -68,15 +77,38 @@ func NewTool[TReq any](name, description string, handler ToolHandler[TReq], opts
6877
for _, opt := range opts {
6978
opt.set(t)
7079
}
80+
if schema := t.Tool.InputSchema; schema != nil {
81+
// Resolve the schema, with no base URI. We don't expect tool schemas to
82+
// refer outside of themselves.
83+
resolved, err = schema.Resolve(nil)
84+
if err != nil {
85+
panic(fmt.Errorf("resolving input schema %s: %w", schemaJSON(schema), err))
86+
}
87+
}
7188
return t
7289
}
7390

7491
// unmarshalSchema unmarshals data into v and validates the result according to
75-
// the given schema.
76-
func unmarshalSchema(data json.RawMessage, _ *jsonschema.Schema, v any) error {
92+
// the given resolved schema.
93+
func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) error {
7794
// TODO: use reflection to create the struct type to unmarshal into.
7895
// Separate validation from assignment.
79-
return json.Unmarshal(data, v)
96+
97+
// Disallow unknown fields.
98+
// Otherwise, if the tool was built with a struct, the client could send extra
99+
// fields and json.Unmarshal would ignore them, so the schema would never get
100+
// a chance to declare the extra args invalid.
101+
dec := json.NewDecoder(bytes.NewReader(data))
102+
dec.DisallowUnknownFields()
103+
if err := dec.Decode(v); err != nil {
104+
return fmt.Errorf("unmarshaling: %w", err)
105+
}
106+
if resolved != nil {
107+
if err := resolved.Validate(v); err != nil {
108+
return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err)
109+
}
110+
}
111+
return nil
80112
}
81113

82114
// A ToolOption configures the behavior of a Tool.
@@ -177,3 +209,12 @@ func Schema(schema *jsonschema.Schema) SchemaOption {
177209
*s = *schema
178210
})
179211
}
212+
213+
// schemaJSON returns the JSON value for s as a string, or a string indicating an error.
214+
func schemaJSON(s *jsonschema.Schema) string {
215+
m, err := json.Marshal(s)
216+
if err != nil {
217+
return fmt.Sprintf("<!%s>", err)
218+
}
219+
return string(m)
220+
}

internal/mcp/tool_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package mcp_test
66

77
import (
88
"context"
9+
"encoding/json"
10+
"strings"
911
"testing"
1012

1113
"github.com/google/go-cmp/cmp"
@@ -88,3 +90,76 @@ func TestNewTool(t *testing.T) {
8890
}
8991
}
9092
}
93+
94+
func TestNewToolValidate(t *testing.T) {
95+
// Check that the tool returned from NewTool properly validates its input schema.
96+
97+
type req struct {
98+
I int
99+
B bool
100+
S string `json:",omitempty"`
101+
P *int `json:",omitempty"`
102+
}
103+
104+
dummyHandler := func(context.Context, *mcp.ServerSession, *mcp.CallToolParams[req]) (*mcp.CallToolResult, error) {
105+
return nil, nil
106+
}
107+
108+
tool := mcp.NewTool("test", "test", dummyHandler)
109+
for _, tt := range []struct {
110+
desc string
111+
args map[string]any
112+
want string // error should contain this string; empty for success
113+
}{
114+
{
115+
"both required",
116+
map[string]any{"I": 1, "B": true},
117+
"",
118+
},
119+
{
120+
"optional",
121+
map[string]any{"I": 1, "B": true, "S": "foo"},
122+
"",
123+
},
124+
{
125+
"wrong type",
126+
map[string]any{"I": 1.5, "B": true},
127+
"cannot unmarshal",
128+
},
129+
{
130+
"extra property",
131+
map[string]any{"I": 1, "B": true, "C": 2},
132+
"unknown field",
133+
},
134+
{
135+
"value for pointer",
136+
map[string]any{"I": 1, "B": true, "P": 3},
137+
"",
138+
},
139+
{
140+
"null for pointer",
141+
map[string]any{"I": 1, "B": true, "P": nil},
142+
"",
143+
},
144+
} {
145+
t.Run(tt.desc, func(t *testing.T) {
146+
raw, err := json.Marshal(tt.args)
147+
if err != nil {
148+
t.Fatal(err)
149+
}
150+
_, err = tool.Handler(context.Background(), nil,
151+
&mcp.CallToolParams[json.RawMessage]{Arguments: json.RawMessage(raw)})
152+
if err == nil && tt.want != "" {
153+
t.Error("got success, wanted failure")
154+
}
155+
if err != nil {
156+
if tt.want == "" {
157+
t.Fatalf("failed with:\n%s\nwanted success", err)
158+
}
159+
if !strings.Contains(err.Error(), tt.want) {
160+
t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want)
161+
}
162+
}
163+
})
164+
}
165+
}

0 commit comments

Comments
 (0)