@@ -2,15 +2,22 @@ package agg
22
33import (
44 "context"
5+ "fmt"
6+ "strings"
57 "testing"
68
79 "github.com/stretchr/testify/assert"
810 "github.com/stretchr/testify/require"
911
1012 "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
1113 "github.com/milvus-io/milvus/pkg/v2/proto/planpb"
14+ "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
1215)
1316
17+ func init () {
18+ paramtable .Init ()
19+ }
20+
1421func makeTestSchema () * schemapb.CollectionSchema {
1522 return & schemapb.CollectionSchema {
1623 Fields : []* schemapb.FieldSchema {
@@ -177,3 +184,116 @@ func TestReduceSingleResult(t *testing.T) {
177184 require .NoError (t , err )
178185 assert .Equal (t , singleResult , out )
179186}
187+
188+ // buildTestSchema creates a simple schema with an INT64 groupBy field and an INT64 agg field.
189+ func buildTestSchema () * schemapb.CollectionSchema {
190+ return & schemapb.CollectionSchema {
191+ Fields : []* schemapb.FieldSchema {
192+ {FieldID : 100 , Name : "group_field" , DataType : schemapb .DataType_Int64 },
193+ {FieldID : 101 , Name : "agg_field" , DataType : schemapb .DataType_Int64 },
194+ },
195+ }
196+ }
197+
198+ // buildAggResult creates an AggregationResult with N distinct groups.
199+ // Each group has group key = startKey+i and count = 1.
200+ func buildAggResult (startKey int64 , numGroups int ) * AggregationResult {
201+ groupKeys := make ([]int64 , numGroups )
202+ counts := make ([]int64 , numGroups )
203+ for i := 0 ; i < numGroups ; i ++ {
204+ groupKeys [i ] = startKey + int64 (i )
205+ counts [i ] = 1
206+ }
207+ return NewAggregationResult ([]* schemapb.FieldData {
208+ {
209+ Type : schemapb .DataType_Int64 ,
210+ FieldName : "group_field" ,
211+ Field : & schemapb.FieldData_Scalars {
212+ Scalars : & schemapb.ScalarField {
213+ Data : & schemapb.ScalarField_LongData {
214+ LongData : & schemapb.LongArray {Data : groupKeys },
215+ },
216+ },
217+ },
218+ },
219+ {
220+ Type : schemapb .DataType_Int64 ,
221+ FieldName : "agg_field" ,
222+ Field : & schemapb.FieldData_Scalars {
223+ Scalars : & schemapb.ScalarField {
224+ Data : & schemapb.ScalarField_LongData {
225+ LongData : & schemapb.LongArray {Data : counts },
226+ },
227+ },
228+ },
229+ },
230+ }, int64 (numGroups ))
231+ }
232+
233+ func TestGroupAggReducer_MaxGroupByGroupsExceeded (t * testing.T ) {
234+ maxGroups := int64 (10 )
235+ paramtable .Get ().Save (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key , fmt .Sprintf ("%d" , maxGroups ))
236+ defer paramtable .Get ().Reset (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key )
237+
238+ schema := buildTestSchema ()
239+ aggregates := []* planpb.Aggregate {
240+ {Op : planpb .AggregateOp_count , FieldId : 101 },
241+ }
242+ reducer := NewGroupAggReducer ([]int64 {100 }, aggregates , - 1 , schema )
243+
244+ // Two results each with 10 distinct groups (20 total > 10 limit)
245+ results := []* AggregationResult {
246+ buildAggResult (0 , 10 ),
247+ buildAggResult (10 , 10 ),
248+ }
249+
250+ _ , err := reducer .Reduce (context .Background (), results )
251+ require .Error (t , err )
252+ assert .True (t , strings .Contains (err .Error (), "too many groups" ))
253+ }
254+
255+ func TestGroupAggReducer_MaxGroupByGroupsExactlyAtLimit (t * testing.T ) {
256+ maxGroups := int64 (10 )
257+ paramtable .Get ().Save (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key , fmt .Sprintf ("%d" , maxGroups ))
258+ defer paramtable .Get ().Reset (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key )
259+
260+ schema := buildTestSchema ()
261+ aggregates := []* planpb.Aggregate {
262+ {Op : planpb .AggregateOp_count , FieldId : 101 },
263+ }
264+ reducer := NewGroupAggReducer ([]int64 {100 }, aggregates , - 1 , schema )
265+
266+ // Exactly 10 groups = limit, should succeed
267+ // Use 2 results to force cross-segment merge path (single result fast-returns)
268+ results := []* AggregationResult {
269+ buildAggResult (0 , 5 ),
270+ buildAggResult (5 , 5 ),
271+ }
272+
273+ result , err := reducer .Reduce (context .Background (), results )
274+ require .NoError (t , err )
275+ assert .NotNil (t , result )
276+ }
277+
278+ func TestGroupAggReducer_MaxGroupByGroupsJustOverLimit (t * testing.T ) {
279+ maxGroups := int64 (10 )
280+ paramtable .Get ().Save (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key , fmt .Sprintf ("%d" , maxGroups ))
281+ defer paramtable .Get ().Reset (paramtable .Get ().CommonCfg .GroupByMaxGroups .Key )
282+
283+ schema := buildTestSchema ()
284+ aggregates := []* planpb.Aggregate {
285+ {Op : planpb .AggregateOp_count , FieldId : 101 },
286+ }
287+ reducer := NewGroupAggReducer ([]int64 {100 }, aggregates , - 1 , schema )
288+
289+ // 6 + 5 = 11 distinct groups > 10 limit, should fail
290+ // Need 2 results to trigger cross-segment merge path (single result fast-returns)
291+ results := []* AggregationResult {
292+ buildAggResult (0 , 6 ),
293+ buildAggResult (6 , 5 ),
294+ }
295+
296+ _ , err := reducer .Reduce (context .Background (), results )
297+ require .Error (t , err )
298+ assert .True (t , strings .Contains (err .Error (), "too many groups" ))
299+ }
0 commit comments