Skip to content

Commit 99a2276

Browse files
committed
feat: streaming least-conn loadbalance
1 parent c5acf5e commit 99a2276

15 files changed

+3093
-12
lines changed

client/client.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,14 @@ func (kc *kClient) initMiddlewares(ctx context.Context) {
283283
kc.mws = append(kc.mws, builderMWs...)
284284
kc.mws = append(kc.mws, acl.NewACLMiddleware(kc.opt.ACLRules))
285285
if kc.opt.Proxy == nil {
286-
kc.mws = append(kc.mws, newResolveMWBuilder(kc.lbf)(ctx))
286+
cs, _ := kc.opt.RemoteOpt.ConnPool.(remote.ConnStatistics)
287+
kc.mws = append(kc.mws, newResolveMWBuilder(kc.lbf, cs)(ctx))
287288
kc.mws = append(kc.mws, kc.opt.CBSuite.InstanceCBMW())
288289
kc.mws = append(kc.mws, richMWsWithBuilder(ctx, kc.opt.IMWBs)...)
289290
} else {
291+
cs, _ := kc.opt.RemoteOpt.ConnPool.(remote.ConnStatistics)
290292
if kc.opt.Resolver != nil { // customized service discovery
291-
kc.mws = append(kc.mws, newResolveMWBuilder(kc.lbf)(ctx))
293+
kc.mws = append(kc.mws, newResolveMWBuilder(kc.lbf, cs)(ctx))
292294
}
293295
kc.mws = append(kc.mws, newProxyMW(kc.opt.Proxy))
294296
}

client/client_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,9 @@ func Test_kClient_initStreamMiddlewares(t *testing.T) {
11381138
opt: &client.Options{
11391139
TracerCtl: ctl,
11401140
Streaming: internal_stream.StreamingConfig{},
1141+
RemoteOpt: &remote.ClientOption{
1142+
ConnPool: nil,
1143+
},
11411144
},
11421145
}
11431146
c.initStreamMiddlewares(context.Background())

client/middlewares.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
3737
"github.com/cloudwego/kitex/pkg/rpcinfo"
3838
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
39+
"github.com/cloudwego/kitex/transport"
3940
)
4041

4142
const maxRetry = 6
@@ -83,7 +84,7 @@ func discoveryEventHandler(name string, bus event.Bus, queue event.Queue) func(d
8384
// newResolveMWBuilder creates a middleware for service discovery.
8485
// This middleware selects an appropriate instance based on the resolver and loadbalancer given.
8586
// If retryable error is encountered, it will retry until timeout or an unretryable error is returned.
86-
func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilder {
87+
func newResolveMWBuilder(lbf *lbcache.BalancerFactory, cs remote.ConnStatistics) endpoint.MiddlewareBuilder {
8788
return func(ctx context.Context) endpoint.Middleware {
8889
return func(next endpoint.Endpoint) endpoint.Endpoint {
8990
return func(ctx context.Context, request, response interface{}) error {
@@ -94,12 +95,12 @@ func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilde
9495
return kerrors.ErrNoDestService
9596
}
9697

97-
remote := remoteinfo.AsRemoteInfo(dest)
98-
if remote == nil {
98+
ri := remoteinfo.AsRemoteInfo(dest)
99+
if ri == nil {
99100
err := fmt.Errorf("unsupported target EndpointInfo type: %T", dest)
100101
return kerrors.ErrInternalException.WithCause(err)
101102
}
102-
if remote.GetInstance() != nil {
103+
if ri.GetInstance() != nil {
103104
return next(ctx, request, response)
104105
}
105106
lb, err := lbf.Get(ctx, dest)
@@ -118,11 +119,16 @@ func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilde
118119
// we always need to get a new picker every time, because when downstream update deployment,
119120
// we may get an old picker that include all outdated instances which will cause connect always failed.
120121
picker := lb.GetPicker()
122+
cfg := rpcInfo.Config()
123+
// gRPC streaming
124+
if cs != nil && cfg.InteractionMode() == rpcinfo.Streaming && (cfg.TransportProtocol()&transport.GRPC != 0) {
125+
ctx = remote.NewCtxWithConnStatistics(ctx, cs)
126+
}
121127
ins := picker.Next(ctx, request)
122128
if ins == nil {
123129
err = kerrors.ErrNoMoreInstance.WithCause(fmt.Errorf("last error: %w", lastErr))
124130
} else {
125-
remote.SetInstance(ins)
131+
ri.SetInstance(ins)
126132
// TODO: generalize retry strategy
127133
err = next(ctx, request, response)
128134
}

client/middlewares_test.go

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ import (
3535
"github.com/cloudwego/kitex/pkg/event"
3636
"github.com/cloudwego/kitex/pkg/kerrors"
3737
"github.com/cloudwego/kitex/pkg/proxy"
38+
"github.com/cloudwego/kitex/pkg/remote"
3839
"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
3940
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
4041
"github.com/cloudwego/kitex/pkg/rpcinfo"
4142
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
43+
"github.com/cloudwego/kitex/transport"
4244
)
4345

4446
var (
@@ -84,7 +86,7 @@ func TestResolverMW(t *testing.T) {
8486

8587
var invoked bool
8688
cli := newMockClient(t, ctrl).(*kcFinalizerClient)
87-
mw := newResolveMWBuilder(cli.lbf)(ctx)
89+
mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
8890
ep := func(ctx context.Context, request, response interface{}) error {
8991
invoked = true
9092
return nil
@@ -114,14 +116,14 @@ func TestResolverMWOutOfInstance(t *testing.T) {
114116
}
115117
var invoked bool
116118
cli := newMockClient(t, ctrl, WithResolver(resolver)).(*kcFinalizerClient)
117-
mw := newResolveMWBuilder(cli.lbf)(ctx)
119+
mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
118120
ep := func(ctx context.Context, request, response interface{}) error {
119121
invoked = true
120122
return nil
121123
}
122124

123125
to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
124-
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
126+
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
125127

126128
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
127129
req := new(MockTStruct)
@@ -222,7 +224,7 @@ func BenchmarkResolverMW(b *testing.B) {
222224
defer ctrl.Finish()
223225

224226
cli := newMockClient(b, ctrl).(*kcFinalizerClient)
225-
mw := newResolveMWBuilder(cli.lbf)(ctx)
227+
mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
226228
ep := func(ctx context.Context, request, response interface{}) error { return nil }
227229
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
228230

@@ -241,7 +243,7 @@ func BenchmarkResolverMWParallel(b *testing.B) {
241243
defer ctrl.Finish()
242244

243245
cli := newMockClient(b, ctrl).(*kcFinalizerClient)
244-
mw := newResolveMWBuilder(cli.lbf)(ctx)
246+
mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
245247
ep := func(ctx context.Context, request, response interface{}) error { return nil }
246248
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
247249

@@ -279,3 +281,170 @@ func TestDiscoveryEventHandler(t *testing.T) {
279281
added := extra["Added"].([]*instInfo)
280282
test.Assert(t, len(added) == 1)
281283
}
284+
285+
// mockConnStatistics implements remote.ConnStatistics for testing
286+
type mockConnStatistics struct {
287+
activeStreams map[string]int
288+
}
289+
290+
func (m *mockConnStatistics) ActiveStreams(addr string) int {
291+
if m.activeStreams == nil {
292+
return 0
293+
}
294+
return m.activeStreams[addr]
295+
}
296+
297+
// TestResolverMW_WithConnStatistics_StreamingMode tests that ConnStatistics is passed to context
298+
// when in gRPC streaming mode
299+
func TestResolverMW_WithConnStatistics_StreamingMode(t *testing.T) {
300+
ctrl := gomock.NewController(t)
301+
defer ctrl.Finish()
302+
303+
mockStats := &mockConnStatistics{
304+
activeStreams: map[string]int{
305+
"localhost:404": 5,
306+
},
307+
}
308+
309+
var contextPassedToEndpoint context.Context
310+
cli := newMockClient(t, ctrl).(*kcFinalizerClient)
311+
mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
312+
ep := func(ctx context.Context, request, response interface{}) error {
313+
contextPassedToEndpoint = ctx
314+
return nil
315+
}
316+
317+
to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
318+
319+
// Create RPC config with streaming mode and gRPC protocol
320+
cfg := rpcinfo.NewRPCConfig()
321+
rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
322+
rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
323+
324+
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
325+
326+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
327+
req := new(MockTStruct)
328+
res := new(MockTStruct)
329+
err := mw(ep)(ctx, req, res)
330+
test.Assert(t, err == nil)
331+
test.Assert(t, to.GetInstance() == instance404[0])
332+
333+
// Verify ConnStatistics was passed to context
334+
cs := remote.GetConnStatistics(contextPassedToEndpoint)
335+
test.Assert(t, cs != nil, "ConnStatistics should be in context for streaming mode")
336+
test.Assert(t, cs.ActiveStreams("localhost:404") == 5)
337+
}
338+
339+
// TestResolverMW_WithConnStatistics_NonStreamingMode tests that ConnStatistics is NOT passed
340+
// to context when not in streaming mode
341+
func TestResolverMW_WithConnStatistics_NonStreamingMode(t *testing.T) {
342+
ctrl := gomock.NewController(t)
343+
defer ctrl.Finish()
344+
345+
mockStats := &mockConnStatistics{
346+
activeStreams: map[string]int{
347+
"localhost:404": 5,
348+
},
349+
}
350+
351+
var contextPassedToEndpoint context.Context
352+
cli := newMockClient(t, ctrl).(*kcFinalizerClient)
353+
mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
354+
ep := func(ctx context.Context, request, response interface{}) error {
355+
contextPassedToEndpoint = ctx
356+
return nil
357+
}
358+
359+
to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
360+
361+
// Create RPC config with PingPong mode (not streaming)
362+
cfg := rpcinfo.NewRPCConfig()
363+
rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.PingPong)
364+
rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
365+
366+
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
367+
368+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
369+
req := new(MockTStruct)
370+
res := new(MockTStruct)
371+
err := mw(ep)(ctx, req, res)
372+
test.Assert(t, err == nil)
373+
374+
// Verify ConnStatistics was NOT passed to context for non-streaming mode
375+
cs := remote.GetConnStatistics(contextPassedToEndpoint)
376+
test.Assert(t, cs == nil, "ConnStatistics should not be in context for non-streaming mode")
377+
}
378+
379+
// TestResolverMW_WithConnStatistics_NonGRPC tests that ConnStatistics is NOT passed
380+
// for non-gRPC protocols
381+
func TestResolverMW_WithConnStatistics_NonGRPC(t *testing.T) {
382+
ctrl := gomock.NewController(t)
383+
defer ctrl.Finish()
384+
385+
mockStats := &mockConnStatistics{
386+
activeStreams: map[string]int{
387+
"localhost:404": 5,
388+
},
389+
}
390+
391+
var contextPassedToEndpoint context.Context
392+
cli := newMockClient(t, ctrl).(*kcFinalizerClient)
393+
mw := newResolveMWBuilder(cli.lbf, mockStats)(ctx)
394+
ep := func(ctx context.Context, request, response interface{}) error {
395+
contextPassedToEndpoint = ctx
396+
return nil
397+
}
398+
399+
to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
400+
401+
// Create RPC config with streaming mode but non-gRPC protocol
402+
cfg := rpcinfo.NewRPCConfig()
403+
rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
404+
rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) // Not GRPC
405+
406+
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
407+
408+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
409+
req := new(MockTStruct)
410+
res := new(MockTStruct)
411+
err := mw(ep)(ctx, req, res)
412+
test.Assert(t, err == nil)
413+
414+
// Verify ConnStatistics was NOT passed for non-gRPC protocol
415+
cs := remote.GetConnStatistics(contextPassedToEndpoint)
416+
test.Assert(t, cs == nil, "ConnStatistics should not be in context for non-gRPC protocol")
417+
}
418+
419+
// TestResolverMW_WithoutConnStatistics tests behavior when ConnStatistics is nil
420+
func TestResolverMW_WithoutConnStatistics(t *testing.T) {
421+
ctrl := gomock.NewController(t)
422+
defer ctrl.Finish()
423+
424+
var contextPassedToEndpoint context.Context
425+
cli := newMockClient(t, ctrl).(*kcFinalizerClient)
426+
mw := newResolveMWBuilder(cli.lbf, nil)(ctx)
427+
ep := func(ctx context.Context, request, response interface{}) error {
428+
contextPassedToEndpoint = ctx
429+
return nil
430+
}
431+
432+
to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "")
433+
434+
// Create RPC config with streaming mode and gRPC protocol
435+
cfg := rpcinfo.NewRPCConfig()
436+
rpcinfo.AsMutableRPCConfig(cfg).SetInteractionMode(rpcinfo.Streaming)
437+
rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC)
438+
439+
ri := rpcinfo.NewRPCInfo(nil, to, rpcinfo.NewInvocation("", ""), cfg, rpcinfo.NewRPCStats())
440+
441+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
442+
req := new(MockTStruct)
443+
res := new(MockTStruct)
444+
err := mw(ep)(ctx, req, res)
445+
test.Assert(t, err == nil)
446+
447+
// Verify ConnStatistics is nil when not provided
448+
cs := remote.GetConnStatistics(contextPassedToEndpoint)
449+
test.Assert(t, cs == nil, "ConnStatistics should be nil when not provided")
450+
}

pkg/loadbalance/weighted_balancer.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const (
3030
lbKindInterleaved
3131
lbKindRandom
3232
lbKindRandomWithAliasMethod
33+
lbKindWeightedLeastConn
3334
)
3435

3536
type weightedBalancer struct {
@@ -67,6 +68,11 @@ func NewWeightedRandomWithAliasMethodBalancer() Loadbalancer {
6768
return lb
6869
}
6970

71+
func NewWeightedLeastConnBalancer() Loadbalancer {
72+
lb := &weightedBalancer{kind: lbKindWeightedLeastConn}
73+
return lb
74+
}
75+
7076
// GetPicker implements the Loadbalancer interface.
7177
func (wb *weightedBalancer) GetPicker(e discovery.Result) Picker {
7278
if !e.Cacheable {
@@ -127,6 +133,8 @@ func (wb *weightedBalancer) createPicker(e discovery.Result) (picker Picker) {
127133
} else {
128134
picker = newAliasMethodPicker(instances, weightSum)
129135
}
136+
case lbKindWeightedLeastConn:
137+
picker = newWeightedLeastConnPicker(instances, balance)
130138
default: // random
131139
if balance {
132140
picker = newRandomPicker(instances)
@@ -161,6 +169,8 @@ func (wb *weightedBalancer) Name() string {
161169
return "interleaved_weighted_round_robin"
162170
case lbKindRandomWithAliasMethod:
163171
return "weight_random_with_alias_method"
172+
case lbKindWeightedLeastConn:
173+
return "weight_least_conn"
164174
default:
165175
return "weight_random"
166176
}

0 commit comments

Comments
 (0)