@@ -35,10 +35,12 @@ import (
3535 "github.com/cloudwego/kitex/pkg/event"
3636 "github.com/cloudwego/kitex/pkg/kerrors"
3737 "github.com/cloudwego/kitex/pkg/proxy"
38+ "github.com/cloudwego/kitex/pkg/remote"
3839 "github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
3940 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
4041 "github.com/cloudwego/kitex/pkg/rpcinfo"
4142 "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
43+ "github.com/cloudwego/kitex/transport"
4244)
4345
4446var (
@@ -84,7 +86,7 @@ func TestResolverMW(t *testing.T) {
8486
8587 var invoked bool
8688 cli := newMockClient(t, ctrl).(*kcFinalizerClient)
87- mw := newResolveMWBuilder(cli.lbf)(ctx)
89+ mw := newResolveMWBuilder(cli.lbf, nil )(ctx)
8890 ep := func(ctx context.Context, request, response interface{}) error {
8991 invoked = true
9092 return nil
@@ -114,14 +116,14 @@ func TestResolverMWOutOfInstance(t *testing.T) {
114116 }
115117 var invoked bool
116118 cli := newMockClient(t, ctrl, WithResolver(resolver)).(*kcFinalizerClient)
117- mw := newResolveMWBuilder(cli.lbf)(ctx)
119+ mw := newResolveMWBuilder(cli.lbf, nil )(ctx)
118120 ep := func(ctx context.Context, request, response interface{}) error {
119121 invoked = true
120122 return nil
121123 }
122124
123125 to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
124- ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), nil , rpcinfo.NewRPCStats())
126+ ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig() , rpcinfo.NewRPCStats())
125127
126128 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
127129 req := new(MockTStruct)
@@ -222,7 +224,7 @@ func BenchmarkResolverMW(b *testing.B) {
222224 defer ctrl.Finish()
223225
224226 cli := newMockClient(b, ctrl).(*kcFinalizerClient)
225- mw := newResolveMWBuilder(cli.lbf)(ctx)
227+ mw := newResolveMWBuilder(cli.lbf, nil )(ctx)
226228 ep := func(ctx context.Context, request, response interface{}) error { return nil }
227229 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
228230
@@ -241,7 +243,7 @@ func BenchmarkResolverMWParallel(b *testing.B) {
241243 defer ctrl.Finish()
242244
243245 cli := newMockClient(b, ctrl).(*kcFinalizerClient)
244- mw := newResolveMWBuilder(cli.lbf)(ctx)
246+ mw := newResolveMWBuilder(cli.lbf, nil )(ctx)
245247 ep := func(ctx context.Context, request, response interface{}) error { return nil }
246248 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
247249
@@ -279,3 +281,170 @@ func TestDiscoveryEventHandler(t *testing.T) {
279281 added := extra["Added"].([]*instInfo)
280282 test.Assert(t, len(added) == 1)
281283}
284+
285+ // mockConnStatistics implements remote.ConnStatistics for testing
286+ type mockConnStatistics struct {
287+ activeStreams map[string]int
288+ }
289+
290+ func (m *mockConnStatistics) ActiveStreams(addr string) int {
291+ if m.activeStreams == nil {
292+ return 0
293+ }
294+ return m.activeStreams[addr]
295+ }
296+
297+ // TestResolverMW_WithConnStatistics_StreamingMode tests that ConnStatistics is passed to context
298+ // when in gRPC streaming mode
299+ func TestResolverMW_WithConnStatistics_StreamingMode(t *testing.T) {
300+ ctrl := gomock.NewController(t)
301+ defer ctrl.Finish()
302+
303+ mockStats := &mockConnStatistics{
304+ activeStreams: map[string]int{
305+ "localhost:404": 5,
306+ },
307+ }
308+
309+ var contextPassedToEndpoint context.Context
310+ cli := newMockClient(t, ctrl).(*kcFinalizerClient)
311+ mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
312+ ep := func(ctx context.Context, request, response interface{}) error {
313+ contextPassedToEndpoint = ctx
314+ return nil
315+ }
316+
317+ to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
318+
319+ // Create RPC config with streaming mode and gRPC protocol
320+ cfg := rpcinfo.NewRPCConfig()
321+ rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
322+ rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
323+
324+ ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
325+
326+ ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
327+ req := new(MockTStruct)
328+ res := new(MockTStruct)
329+ err := mw(ep)(ctx, req, res)
330+ test.Assert(t, err == nil)
331+ test.Assert(t, to.GetInstance() == instance404[0])
332+
333+ // Verify ConnStatistics was passed to context
334+ cs := remote.GetConnStatistics(contextPassedToEndpoint)
335+ test.Assert(t, cs != nil, "ConnStatistics should be in context for streaming mode")
336+ test.Assert(t, cs.ActiveStreams("localhost:404") == 5)
337+ }
338+
339+ // TestResolverMW_WithConnStatistics_NonStreamingMode tests that ConnStatistics is NOT passed
340+ // to context when not in streaming mode
341+ func TestResolverMW_WithConnStatistics_NonStreamingMode(t *testing.T) {
342+ ctrl := gomock.NewController(t)
343+ defer ctrl.Finish()
344+
345+ mockStats := &mockConnStatistics{
346+ activeStreams: map[string]int{
347+ "localhost:404": 5,
348+ },
349+ }
350+
351+ var contextPassedToEndpoint context.Context
352+ cli := newMockClient(t, ctrl).(*kcFinalizerClient)
353+ mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
354+ ep := func(ctx context.Context, request, response interface{}) error {
355+ contextPassedToEndpoint = ctx
356+ return nil
357+ }
358+
359+ to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
360+
361+ // Create RPC config with PingPong mode (not streaming)
362+ cfg := rpcinfo.NewRPCConfig()
363+ rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.PingPong)
364+ rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
365+
366+ ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
367+
368+ ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
369+ req := new(MockTStruct)
370+ res := new(MockTStruct)
371+ err := mw(ep)(ctx, req, res)
372+ test.Assert(t, err == nil)
373+
374+ // Verify ConnStatistics was NOT passed to context for non-streaming mode
375+ cs := remote.GetConnStatistics(contextPassedToEndpoint)
376+ test.Assert(t, cs == nil, "ConnStatistics should not be in context for non-streaming mode")
377+ }
378+
379+ // TestResolverMW_WithConnStatistics_NonGRPC tests that ConnStatistics is NOT passed
380+ // for non-gRPC protocols
381+ func TestResolverMW_WithConnStatistics_NonGRPC(t *testing.T) {
382+ ctrl := gomock.NewController(t)
383+ defer ctrl.Finish()
384+
385+ mockStats := &mockConnStatistics{
386+ activeStreams: map[string]int{
387+ "localhost:404": 5,
388+ },
389+ }
390+
391+ var contextPassedToEndpoint context.Context
392+ cli := newMockClient(t, ctrl).(*kcFinalizerClient)
393+ mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
394+ ep := func(ctx context.Context, request, response interface{}) error {
395+ contextPassedToEndpoint = ctx
396+ return nil
397+ }
398+
399+ to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
400+
401+ // Create RPC config with streaming mode but non-gRPC protocol
402+ cfg := rpcinfo.NewRPCConfig()
403+ rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
404+ rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) // Not GRPC
405+
406+ ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
407+
408+ ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
409+ req := new(MockTStruct)
410+ res := new(MockTStruct)
411+ err := mw(ep)(ctx, req, res)
412+ test.Assert(t, err == nil)
413+
414+ // Verify ConnStatistics was NOT passed for non-gRPC protocol
415+ cs := remote.GetConnStatistics(contextPassedToEndpoint)
416+ test.Assert(t, cs == nil, "ConnStatistics should not be in context for non-gRPC protocol")
417+ }
418+
419+ // TestResolverMW_WithoutConnStatistics tests behavior when ConnStatistics is nil
420+ func TestResolverMW_WithoutConnStatistics(t *testing.T) {
421+ ctrl := gomock.NewController(t)
422+ defer ctrl.Finish()
423+
424+ var contextPassedToEndpoint context.Context
425+ cli := newMockClient(t, ctrl).(*kcFinalizerClient)
426+ mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
427+ ep := func(ctx context.Context, request, response interface{}) error {
428+ contextPassedToEndpoint = ctx
429+ return nil
430+ }
431+
432+ to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
433+
434+ // Create RPC config with streaming mode and gRPC protocol
435+ cfg := rpcinfo.NewRPCConfig()
436+ rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
437+ rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
438+
439+ ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
440+
441+ ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
442+ req := new(MockTStruct)
443+ res := new(MockTStruct)
444+ err := mw(ep)(ctx, req, res)
445+ test.Assert(t, err == nil)
446+
447+ // Verify ConnStatistics is nil when not provided
448+ cs := remote.GetConnStatistics(contextPassedToEndpoint)
449+ test.Assert(t, cs == nil, "ConnStatistics should be nil when not provided")
450+ }
0 commit comments