Skip to content

Commit cc9c60e

Browse files
MrPresent-HanMrPresent-Hanclaude
authored
fix: HashTable dynamic rehash and group count limit for GROUP BY aggregation (#48174)
## Summary - HashTable now dynamically rehashes (doubles capacity) when load factor exceeds 7/8, fixing crash when GROUP BY cardinality > ~1792 - Added configurable `queryNode.segcore.maxGroupByGroups` (default 100K) to cap total groups and prevent OOM on both C++ (per-segment HashTable) and Go (cross-segment agg reducer) layers - Added 4 C++ unit tests covering rehash basic/correctness, max groups limit, and multiple rehash rounds issue: #47569 ## Test plan - [ ] C++ unit tests: `--gtest_filter="*HashTableRehash*:*MaxGroups*"` - [ ] E2E: GROUP BY aggregation with >2K unique values should succeed - [ ] E2E: Set `queryNode.segcore.maxGroupByGroups` to small value, verify clear error message 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 786bb24 commit cc9c60e

File tree

12 files changed

+405
-12
lines changed

12 files changed

+405
-12
lines changed

configs/milvus.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,8 @@ common:
10551055
mode:
10561056
queryNode: sync # File resource mode for query node, options: [sync, close]. Default is sync.
10571057
dataNode: sync # File resource mode for data node, options: [sync, ref, close]. Default is sync.
1058+
groupBy:
1059+
maxGroups: 100000 # Maximum number of groups allowed in GROUP BY aggregation, enforced both per segment and during cross-segment merge. Exceeding this limit fails the query.
10581060

10591061
# QuotaConfig, configurations of Milvus quota and limits.
10601062
# By default, we enable:

internal/agg/aggregate_reducer.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
1010
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
1111
"github.com/milvus-io/milvus/pkg/v2/proto/segcorepb"
12+
"github.com/milvus-io/milvus/pkg/v2/util/merr"
13+
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
1214
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
1315
)
1416

@@ -345,11 +347,12 @@ func (reducer *GroupAggReducer) Reduce(ctx context.Context, results []*Aggregati
345347
}
346348

347349
// 2. compute hash values for all rows in the result retrieved
348-
var totalRowCount int64 = 0
350+
var totalGroupCount int64 = 0
351+
maxGroupByGroups := paramtable.Get().CommonCfg.GroupByMaxGroups.GetAsInt64()
349352
processResults:
350353
for _, result := range results {
351354
// Check limit before processing each shard to avoid unnecessary work
352-
if reducer.groupLimit != -1 && totalRowCount >= reducer.groupLimit {
355+
if reducer.groupLimit != -1 && totalGroupCount >= reducer.groupLimit {
353356
break processResults
354357
}
355358

@@ -387,7 +390,7 @@ processResults:
387390

388391
for row := 0; row < rowCount; row++ {
389392
// Check limit before processing each row to avoid unnecessary hashing and copying
390-
if reducer.groupLimit != -1 && totalRowCount >= reducer.groupLimit {
393+
if reducer.groupLimit != -1 && totalGroupCount >= reducer.groupLimit {
391394
break processResults
392395
}
393396
rowFieldValues := make([]*FieldValue, outputColumnCount)
@@ -416,24 +419,29 @@ processResults:
416419
if bucket := reducer.hashValsMap[hashVal]; bucket == nil {
417420
newBucket := NewBucket()
418421
newBucket.AddRow(newRow)
419-
totalRowCount++
422+
totalGroupCount++
420423
reducer.hashValsMap[hashVal] = newBucket
421424
} else {
422425
if rowIdx := bucket.Find(newRow, numGroupingKeys); rowIdx == NONE {
423426
bucket.AddRow(newRow)
424-
totalRowCount++
427+
totalGroupCount++
425428
} else {
426429
if err := bucket.Accumulate(newRow, rowIdx, numGroupingKeys, aggs); err != nil {
427430
return nil, err
428431
}
429432
}
430433
}
434+
if totalGroupCount > maxGroupByGroups {
435+
return nil, merr.WrapErrServiceInternal(fmt.Sprintf("GROUP BY produced too many groups (%d). "+
436+
"Add filters or increase common.groupBy.maxGroups (current: %d)",
437+
totalGroupCount, maxGroupByGroups))
438+
}
431439
// Don't guarantee specific groups to be returned before milvus support order by
432440
}
433441
}
434442

435443
// 3. assemble reduced buckets into retrievedResult
436-
reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, totalRowCount)
444+
reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, totalGroupCount)
437445
for _, bucket := range reducer.hashValsMap {
438446
err := AssembleBucket(bucket, reducedResult.GetFieldDatas())
439447
if err != nil {

internal/agg/aggregate_reducer_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@ package agg
22

33
import (
44
"context"
5+
"fmt"
6+
"strings"
57
"testing"
68

79
"github.com/stretchr/testify/assert"
810
"github.com/stretchr/testify/require"
911

1012
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
1113
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
14+
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
1215
)
1316

17+
func init() {
18+
paramtable.Init()
19+
}
20+
1421
func makeTestSchema() *schemapb.CollectionSchema {
1522
return &schemapb.CollectionSchema{
1623
Fields: []*schemapb.FieldSchema{
@@ -177,3 +184,116 @@ func TestReduceSingleResult(t *testing.T) {
177184
require.NoError(t, err)
178185
assert.Equal(t, singleResult, out)
179186
}
187+
188+
// buildTestSchema creates a simple schema with an INT64 groupBy field and an INT64 agg field.
189+
func buildTestSchema() *schemapb.CollectionSchema {
190+
return &schemapb.CollectionSchema{
191+
Fields: []*schemapb.FieldSchema{
192+
{FieldID: 100, Name: "group_field", DataType: schemapb.DataType_Int64},
193+
{FieldID: 101, Name: "agg_field", DataType: schemapb.DataType_Int64},
194+
},
195+
}
196+
}
197+
198+
// buildAggResult creates an AggregationResult with N distinct groups.
199+
// Each group has group key = startKey+i and count = 1.
200+
func buildAggResult(startKey int64, numGroups int) *AggregationResult {
201+
groupKeys := make([]int64, numGroups)
202+
counts := make([]int64, numGroups)
203+
for i := 0; i < numGroups; i++ {
204+
groupKeys[i] = startKey + int64(i)
205+
counts[i] = 1
206+
}
207+
return NewAggregationResult([]*schemapb.FieldData{
208+
{
209+
Type: schemapb.DataType_Int64,
210+
FieldName: "group_field",
211+
Field: &schemapb.FieldData_Scalars{
212+
Scalars: &schemapb.ScalarField{
213+
Data: &schemapb.ScalarField_LongData{
214+
LongData: &schemapb.LongArray{Data: groupKeys},
215+
},
216+
},
217+
},
218+
},
219+
{
220+
Type: schemapb.DataType_Int64,
221+
FieldName: "agg_field",
222+
Field: &schemapb.FieldData_Scalars{
223+
Scalars: &schemapb.ScalarField{
224+
Data: &schemapb.ScalarField_LongData{
225+
LongData: &schemapb.LongArray{Data: counts},
226+
},
227+
},
228+
},
229+
},
230+
}, int64(numGroups))
231+
}
232+
233+
func TestGroupAggReducer_MaxGroupByGroupsExceeded(t *testing.T) {
234+
maxGroups := int64(10)
235+
paramtable.Get().Save(paramtable.Get().CommonCfg.GroupByMaxGroups.Key, fmt.Sprintf("%d", maxGroups))
236+
defer paramtable.Get().Reset(paramtable.Get().CommonCfg.GroupByMaxGroups.Key)
237+
238+
schema := buildTestSchema()
239+
aggregates := []*planpb.Aggregate{
240+
{Op: planpb.AggregateOp_count, FieldId: 101},
241+
}
242+
reducer := NewGroupAggReducer([]int64{100}, aggregates, -1, schema)
243+
244+
// Two results each with 10 distinct groups (20 total > 10 limit)
245+
results := []*AggregationResult{
246+
buildAggResult(0, 10),
247+
buildAggResult(10, 10),
248+
}
249+
250+
_, err := reducer.Reduce(context.Background(), results)
251+
require.Error(t, err)
252+
assert.True(t, strings.Contains(err.Error(), "too many groups"))
253+
}
254+
255+
func TestGroupAggReducer_MaxGroupByGroupsExactlyAtLimit(t *testing.T) {
256+
maxGroups := int64(10)
257+
paramtable.Get().Save(paramtable.Get().CommonCfg.GroupByMaxGroups.Key, fmt.Sprintf("%d", maxGroups))
258+
defer paramtable.Get().Reset(paramtable.Get().CommonCfg.GroupByMaxGroups.Key)
259+
260+
schema := buildTestSchema()
261+
aggregates := []*planpb.Aggregate{
262+
{Op: planpb.AggregateOp_count, FieldId: 101},
263+
}
264+
reducer := NewGroupAggReducer([]int64{100}, aggregates, -1, schema)
265+
266+
// Exactly 10 groups = limit, should succeed
267+
// Use 2 results to force cross-segment merge path (single result fast-returns)
268+
results := []*AggregationResult{
269+
buildAggResult(0, 5),
270+
buildAggResult(5, 5),
271+
}
272+
273+
result, err := reducer.Reduce(context.Background(), results)
274+
require.NoError(t, err)
275+
assert.NotNil(t, result)
276+
}
277+
278+
func TestGroupAggReducer_MaxGroupByGroupsJustOverLimit(t *testing.T) {
279+
maxGroups := int64(10)
280+
paramtable.Get().Save(paramtable.Get().CommonCfg.GroupByMaxGroups.Key, fmt.Sprintf("%d", maxGroups))
281+
defer paramtable.Get().Reset(paramtable.Get().CommonCfg.GroupByMaxGroups.Key)
282+
283+
schema := buildTestSchema()
284+
aggregates := []*planpb.Aggregate{
285+
{Op: planpb.AggregateOp_count, FieldId: 101},
286+
}
287+
reducer := NewGroupAggReducer([]int64{100}, aggregates, -1, schema)
288+
289+
// 6 + 5 = 11 distinct groups > 10 limit, should fail
290+
// Need 2 results to trigger cross-segment merge path (single result fast-returns)
291+
results := []*AggregationResult{
292+
buildAggResult(0, 6),
293+
buildAggResult(6, 5),
294+
}
295+
296+
_, err := reducer.Reduce(context.Background(), results)
297+
require.Error(t, err)
298+
assert.True(t, strings.Contains(err.Error(), "too many groups"))
299+
}

internal/core/src/exec/HashTable.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "common/SimdUtil.h"
2424
#include "exec/VectorHasher.h"
25+
#include "fmt/format.h"
2526

2627
namespace milvus {
2728
namespace exec {
@@ -222,10 +223,20 @@ char*
222223
HashTable::insertEntry(milvus::exec::HashLookup& lookup,
223224
uint64_t index,
224225
milvus::vector_size_t row) {
226+
if (numDistinct_ >= maxNumGroups_) {
227+
ThrowInfo(
228+
UnexpectedError,
229+
fmt::format("GROUP BY produced too many groups ({}). "
230+
"Add filters or increase common.groupBy.maxGroups "
231+
"(current: {})",
232+
numDistinct_ + 1,
233+
maxNumGroups_));
234+
}
225235
char* group = rows_->newRow();
226236
lookup.hits_[row] = group;
227237
storeKeys(lookup, row);
228238
storeRowPointer(index, lookup.hashes_[row], group);
239+
rowHashes_.push_back(lookup.hashes_[row]);
229240
numDistinct_++;
230241
lookup.newGroups_.push_back(row);
231242
return group;
@@ -250,6 +261,9 @@ HashTable::groupProbe(milvus::exec::HashLookup& lookup) {
250261
checkSizeAndAllocateTable(0);
251262
ProbeState state;
252263
for (int32_t idx = 0; idx < lookup.hashes_.size(); idx++) {
264+
if (numDistinct_ >= rehashSize()) {
265+
rehash();
266+
}
253267
state.preProbe(*this, lookup.hashes_[idx], idx);
254268
state.firstProbe<ProbeState::Operation::kInsert>(*this);
255269
fullProbe(lookup, state);
@@ -272,6 +286,40 @@ HashTable::clear(bool freeTable) {
272286
numBuckets_ = 0;
273287
sizeMask_ = 0;
274288
bucketOffsetMask_ = 0;
289+
rowHashes_.clear();
290+
}
291+
292+
void
293+
HashTable::insertForRehash(char* row, uint64_t hash) {
294+
const auto tag = hashTag(hash);
295+
const auto kEmptyGroup = TagVector::broadcast(0);
296+
int64_t bktOffset = bucketOffset(hash);
297+
for (int64_t i = 0; i < numBuckets_; i++) {
298+
auto tags = loadTags(bktOffset);
299+
uint16_t empty = toBitMask(tags == kEmptyGroup) & 0xffff;
300+
if (empty > 0) {
301+
auto pos = bits::getAndClearLastSetBit(empty);
302+
auto* bucket = bucketAt(bktOffset);
303+
bucket->setTag(pos, tag);
304+
bucket->setPointer(pos, row);
305+
return;
306+
}
307+
bktOffset = nextBucketOffset(bktOffset);
308+
}
309+
AssertInfo(false, "Failed to insert during rehash");
310+
}
311+
312+
void
313+
HashTable::rehash() {
314+
// allRows is safe to reference across allocateTables() because
315+
// allocateTables() only rebuilds the hash bucket array (table_),
316+
// it does not touch the RowContainer (rows_) that owns the row data.
317+
const auto& allRows = rows_->allRows();
318+
allocateTables(capacity_ * 2);
319+
for (size_t i = 0; i < allRows.size(); i++) {
320+
insertForRehash(allRows[i], rowHashes_[i]);
321+
}
322+
numRehashes_++;
275323
}
276324

277325
} // namespace exec

internal/core/src/exec/HashTable.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "common/Vector.h"
3535
#include "exec/operator/query-agg/RowContainer.h"
3636
#include "folly/CPortability.h"
37+
#include "segcore/SegcoreConfig.h"
3738
#include "xsimd/xsimd.hpp"
3839

3940
namespace milvus {
@@ -174,9 +175,11 @@ class ProbeState;
174175

175176
class HashTable : public BaseHashTable {
176177
public:
177-
HashTable(std::vector<std::unique_ptr<VectorHasher>>&& hashers,
178-
const std::vector<Accumulator>& accumulators)
179-
: BaseHashTable(std::move(hashers)) {
178+
HashTable(
179+
std::vector<std::unique_ptr<VectorHasher>>&& hashers,
180+
const std::vector<Accumulator>& accumulators,
181+
int64_t maxNumGroups = segcore::SegcoreConfig::kDefaultMaxGroupByGroups)
182+
: BaseHashTable(std::move(hashers)), maxNumGroups_(maxNumGroups) {
180183
std::vector<DataType> keyTypes;
181184
for (auto& hasher : hashers_) {
182185
keyTypes.push_back(hasher->ChannelDataType());
@@ -305,6 +308,12 @@ class HashTable : public BaseHashTable {
305308
void
306309
checkSizeAndAllocateTable(int32_t numNew);
307310

311+
void
312+
rehash();
313+
314+
void
315+
insertForRehash(char* row, uint64_t hash);
316+
308317
// Returns the number of entries after which the table gets rehashed.
309318
static uint64_t
310319
rehashSize(int64_t size) {
@@ -344,6 +353,8 @@ class HashTable : public BaseHashTable {
344353

345354
[[maybe_unused]] int64_t numRehashes_{0};
346355
char* table_ = nullptr;
356+
std::vector<uint64_t> rowHashes_;
357+
int64_t maxNumGroups_;
347358

348359
HashMode
349360
hashMode() const override {

internal/core/src/exec/operator/query-agg/GroupingSet.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "exec/operator/query-agg/AggregateInfo.h"
2828
#include "exec/operator/query-agg/RowContainer.h"
2929
#include "folly/Range.h"
30+
#include "segcore/SegcoreConfig.h"
3031

3132
namespace milvus {
3233
namespace exec {
@@ -248,8 +249,10 @@ initializeAggregates(const std::vector<AggregateInfo>& aggregates,
248249

249250
void
250251
GroupingSet::createHashTable() {
251-
hash_table_ =
252-
std::make_unique<HashTable>(std::move(hashers_), accumulators());
252+
auto maxGroups =
253+
segcore::SegcoreConfig::default_config().get_max_group_by_groups();
254+
hash_table_ = std::make_unique<HashTable>(
255+
std::move(hashers_), accumulators(), maxGroups);
253256
auto& rows = *(hash_table_->rows());
254257
initializeAggregates(aggregates_, rows);
255258
lookup_ = std::make_unique<HashLookup>(hash_table_->hashers());

internal/core/src/segcore/SegcoreConfig.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@ class SegcoreConfig {
156156
return enable_geometry_cache_;
157157
}
158158

159+
static constexpr int64_t kDefaultMaxGroupByGroups = 100000;
160+
161+
int64_t
162+
get_max_group_by_groups() const {
163+
return max_group_by_groups_;
164+
}
165+
166+
void
167+
set_max_group_by_groups(int64_t v) {
168+
max_group_by_groups_ = v;
169+
}
170+
159171
void
160172
set_interim_index_mem_expansion_rate(float rate) {
161173
interim_index_mem_expansion_rate_ = rate;
@@ -186,6 +198,7 @@ class SegcoreConfig {
186198
inline static bool refine_with_quant_flag_ = false;
187199
inline static bool enable_geometry_cache_ = false;
188200
inline static float interim_index_mem_expansion_rate_ = 1.15f;
201+
inline static int64_t max_group_by_groups_ = kDefaultMaxGroupByGroups;
189202
};
190203

191204
} // namespace milvus::segcore

0 commit comments

Comments
 (0)