Skip to content

Commit d47f238

Browse files
MrPresent-Hanclaude
andcommitted
fix: HashTable dynamic rehash and group count limit for GROUP BY aggregation (#47569)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: MrPresent-Han <chun.han@gmail.com>
1 parent 77829de commit d47f238

11 files changed

Lines changed: 260 additions & 4 deletions

File tree

configs/milvus.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ queryNode:
485485
buildParallelRate: 0.5 # the ratio of building interim index parallel matched with cpu num
486486
multipleChunkedEnable: true # Deprecated. Enable multiple chunked search
487487
enableGeometryCache: false # Enable geometry cache for geometry data
488+
maxGroupByGroups: 100000 # Maximum number of groups allowed in GROUP BY aggregation per segment. Exceeding this limit fails the query.
488489
tieredStorage:
489490
warmup:
490491
# options: sync, async, disable.

internal/agg/aggregate_reducer.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
1010
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
1111
"github.com/milvus-io/milvus/pkg/v2/proto/segcorepb"
12+
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
1213
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
1314
)
1415

@@ -426,6 +427,12 @@ processResults:
426427
bucket.Accumulate(newRow, rowIdx, numGroupingKeys, aggs)
427428
}
428429
}
430+
maxGroupByGroups := paramtable.Get().QueryNodeCfg.MaxGroupByGroups.GetAsInt64()
431+
if maxGroupByGroups > 0 && totalRowCount > maxGroupByGroups {
432+
return nil, fmt.Errorf("GROUP BY produced too many groups (%d). "+
433+
"Add filters or increase queryNode.segcore.maxGroupByGroups (current: %d)",
434+
totalRowCount, maxGroupByGroups)
435+
}
429436
// Don't guarantee specific groups to be returned before milvus support order by
430437
}
431438
}

internal/core/src/exec/HashTable.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "common/SimdUtil.h"
2424
#include "exec/VectorHasher.h"
25+
#include "fmt/format.h"
2526

2627
namespace milvus {
2728
namespace exec {
@@ -222,10 +223,21 @@ char*
222223
HashTable::insertEntry(milvus::exec::HashLookup& lookup,
223224
uint64_t index,
224225
milvus::vector_size_t row) {
226+
if (numDistinct_ >= maxNumGroups_) {
227+
ThrowInfo(
228+
UnexpectedError,
229+
fmt::format(
230+
"GROUP BY produced too many groups ({}). "
231+
"Add filters or increase queryNode.segcore.maxGroupByGroups "
232+
"(current: {})",
233+
numDistinct_ + 1,
234+
maxNumGroups_));
235+
}
225236
char* group = rows_->newRow();
226237
lookup.hits_[row] = group;
227238
storeKeys(lookup, row);
228239
storeRowPointer(index, lookup.hashes_[row], group);
240+
rowHashes_.push_back(lookup.hashes_[row]);
229241
numDistinct_++;
230242
lookup.newGroups_.push_back(row);
231243
return group;
@@ -250,6 +262,9 @@ HashTable::groupProbe(milvus::exec::HashLookup& lookup) {
250262
checkSizeAndAllocateTable(0);
251263
ProbeState state;
252264
for (int32_t idx = 0; idx < lookup.hashes_.size(); idx++) {
265+
if (numDistinct_ >= rehashSize()) {
266+
rehash();
267+
}
253268
state.preProbe(*this, lookup.hashes_[idx], idx);
254269
state.firstProbe<ProbeState::Operation::kInsert>(*this);
255270
fullProbe(lookup, state);
@@ -272,6 +287,37 @@ HashTable::clear(bool freeTable) {
272287
numBuckets_ = 0;
273288
sizeMask_ = 0;
274289
bucketOffsetMask_ = 0;
290+
rowHashes_.clear();
291+
}
292+
293+
void
294+
HashTable::insertForRehash(char* row, uint64_t hash) {
295+
const auto tag = hashTag(hash);
296+
const auto kEmptyGroup = TagVector::broadcast(0);
297+
int64_t bktOffset = bucketOffset(hash);
298+
for (int64_t i = 0; i < numBuckets_; i++) {
299+
auto tags = loadTags(bktOffset);
300+
uint16_t empty = toBitMask(tags == kEmptyGroup) & 0xffff;
301+
if (empty > 0) {
302+
auto pos = bits::getAndClearLastSetBit(empty);
303+
auto* bucket = bucketAt(bktOffset);
304+
bucket->setTag(pos, tag);
305+
bucket->setPointer(pos, row);
306+
return;
307+
}
308+
bktOffset = nextBucketOffset(bktOffset);
309+
}
310+
AssertInfo(false, "Failed to insert during rehash");
311+
}
312+
313+
void
314+
HashTable::rehash() {
315+
const auto& allRows = rows_->allRows();
316+
allocateTables(capacity_ * 2);
317+
for (size_t i = 0; i < allRows.size(); i++) {
318+
insertForRehash(allRows[i], rowHashes_[i]);
319+
}
320+
numRehashes_++;
275321
}
276322

277323
} // namespace exec

internal/core/src/exec/HashTable.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ class ProbeState;
175175
class HashTable : public BaseHashTable {
176176
public:
177177
HashTable(std::vector<std::unique_ptr<VectorHasher>>&& hashers,
178-
const std::vector<Accumulator>& accumulators)
179-
: BaseHashTable(std::move(hashers)) {
178+
const std::vector<Accumulator>& accumulators,
179+
int64_t maxNumGroups = 100000)
180+
: BaseHashTable(std::move(hashers)), maxNumGroups_(maxNumGroups) {
180181
std::vector<DataType> keyTypes;
181182
for (auto& hasher : hashers_) {
182183
keyTypes.push_back(hasher->ChannelDataType());
@@ -305,6 +306,12 @@ class HashTable : public BaseHashTable {
305306
void
306307
checkSizeAndAllocateTable(int32_t numNew);
307308

309+
void
310+
rehash();
311+
312+
void
313+
insertForRehash(char* row, uint64_t hash);
314+
308315
// Returns the number of entries after which the table gets rehashed.
309316
static uint64_t
310317
rehashSize(int64_t size) {
@@ -344,6 +351,8 @@ class HashTable : public BaseHashTable {
344351

345352
int64_t numRehashes_{0};
346353
char* table_ = nullptr;
354+
std::vector<uint64_t> rowHashes_;
355+
int64_t maxNumGroups_;
347356

348357
HashMode
349358
hashMode() const override {

internal/core/src/exec/operator/query-agg/GroupingSet.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "exec/operator/query-agg/AggregateInfo.h"
2828
#include "exec/operator/query-agg/RowContainer.h"
2929
#include "folly/Range.h"
30+
#include "segcore/SegcoreConfig.h"
3031

3132
namespace milvus {
3233
namespace exec {
@@ -248,8 +249,10 @@ initializeAggregates(const std::vector<AggregateInfo>& aggregates,
248249

249250
void
250251
GroupingSet::createHashTable() {
251-
hash_table_ =
252-
std::make_unique<HashTable>(std::move(hashers_), accumulators());
252+
auto maxGroups =
253+
segcore::SegcoreConfig::default_config().get_max_group_by_groups();
254+
hash_table_ = std::make_unique<HashTable>(
255+
std::move(hashers_), accumulators(), maxGroups);
253256
auto& rows = *(hash_table_->rows());
254257
initializeAggregates(aggregates_, rows);
255258
lookup_ = std::make_unique<HashLookup>(hash_table_->hashers());

internal/core/src/segcore/SegcoreConfig.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ class SegcoreConfig {
156156
return enable_geometry_cache_;
157157
}
158158

159+
int64_t
160+
get_max_group_by_groups() const {
161+
return max_group_by_groups_;
162+
}
163+
164+
void
165+
set_max_group_by_groups(int64_t v) {
166+
max_group_by_groups_ = v;
167+
}
168+
159169
void
160170
set_interim_index_mem_expansion_rate(float rate) {
161171
interim_index_mem_expansion_rate_ = rate;
@@ -186,6 +196,7 @@ class SegcoreConfig {
186196
inline static bool refine_with_quant_flag_ = false;
187197
inline static bool enable_geometry_cache_ = false;
188198
inline static float interim_index_mem_expansion_rate_ = 1.15f;
199+
inline static int64_t max_group_by_groups_ = 100000;
189200
};
190201

191202
} // namespace milvus::segcore

internal/core/src/segcore/segcore_init_c.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ SegcoreSetInterimIndexMemExpansionRate(const float value) {
114114
config.set_interim_index_mem_expansion_rate(value);
115115
}
116116

117+
extern "C" void
118+
SegcoreSetMaxGroupByGroups(const int64_t value) {
119+
milvus::segcore::SegcoreConfig& config =
120+
milvus::segcore::SegcoreConfig::default_config();
121+
config.set_max_group_by_groups(value);
122+
}
123+
117124
extern "C" void
118125
SegcoreSetSubDim(const int64_t value) {
119126
milvus::segcore::SegcoreConfig& config =

internal/core/src/segcore/segcore_init_c.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ SegcoreSetDenseVectorInterminIndexRefineWithQuantFlag(const bool);
6262
void
6363
SegcoreSetInterimIndexMemExpansionRate(const float);
6464

65+
void
66+
SegcoreSetMaxGroupByGroups(const int64_t);
67+
6568
// return value must be freed by the caller
6669
char*
6770
SegcoreSetSimdType(const char*);

internal/core/unittest/test_query_group_by.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "test_utils/storage_test_utils.h"
1919
#include "exec/expression/function/FunctionFactory.h"
2020
#include "exec/operator/query-agg/CountAggregateBase.h"
21+
#include "exec/HashTable.h"
22+
#include "exec/VectorHasher.h"
2123
#include "query/PlanImpl.h"
2224
#include "query/PlanNode.h"
2325

@@ -1212,4 +1214,158 @@ TEST_P(QueryAggTest, GroupByEmptyResultMultipleAggs) {
12121214
EXPECT_EQ(
12131215
retrieve_results->fields_data(3).scalars().double_data().data_size(),
12141216
0);
1217+
}
1218+
1219+
// ============================================================
1220+
// HashTable rehash and group limit tests
1221+
// ============================================================
1222+
1223+
namespace {
1224+
// Helper to create a HashTable with a single INT64 key column, no accumulators.
1225+
// Inserts 'numGroups' distinct INT64 values via groupProbe in batches.
1226+
// Returns the HashTable.
1227+
std::unique_ptr<milvus::exec::HashTable>
1228+
createAndInsertGroups(int64_t numGroups,
1229+
int64_t maxNumGroups,
1230+
int batchSize = 1024) {
1231+
std::vector<milvus::exec::Accumulator> accumulators;
1232+
std::vector<std::unique_ptr<milvus::exec::VectorHasher>> hashers;
1233+
hashers.push_back(
1234+
milvus::exec::VectorHasher::create(milvus::DataType::INT64, 0));
1235+
auto table = std::make_unique<milvus::exec::HashTable>(
1236+
std::move(hashers), accumulators, maxNumGroups);
1237+
1238+
int64_t inserted = 0;
1239+
while (inserted < numGroups) {
1240+
int64_t thisCount = std::min((int64_t)batchSize, numGroups - inserted);
1241+
auto col = std::make_shared<milvus::ColumnVector>(
1242+
milvus::DataType::INT64, thisCount);
1243+
auto* data = reinterpret_cast<int64_t*>(col->GetRawData());
1244+
for (int64_t i = 0; i < thisCount; i++) {
1245+
data[i] = inserted + i;
1246+
}
1247+
1248+
std::vector<milvus::DataType> dtypes = {milvus::DataType::INT64};
1249+
std::vector<milvus::VectorPtr> children = {col};
1250+
auto input = std::make_shared<milvus::RowVector>(children);
1251+
1252+
milvus::exec::HashLookup lookup(table->hashers());
1253+
table->prepareForGroupProbe(lookup, input);
1254+
table->groupProbe(lookup);
1255+
1256+
inserted += thisCount;
1257+
}
1258+
return table;
1259+
}
1260+
} // namespace
1261+
1262+
TEST(HashTableRehashTest, TestHashTableRehashBasic) {
1263+
// Insert 5000 distinct groups (exceeds initial 2048 capacity),
1264+
// verify no crash and numDistinct == 5000
1265+
const int64_t numGroups = 5000;
1266+
auto table = createAndInsertGroups(numGroups, 1000000);
1267+
ASSERT_NE(table->rows(), nullptr);
1268+
EXPECT_EQ(table->rows()->allRows().size(), numGroups);
1269+
}
1270+
1271+
TEST(HashTableRehashTest, TestHashTableRehashCorrectness) {
1272+
// Insert N groups, record row pointers.
1273+
// Then re-probe the same keys and verify we get back the same row pointers.
1274+
const int64_t numGroups = 3000;
1275+
std::vector<milvus::exec::Accumulator> accumulators;
1276+
std::vector<std::unique_ptr<milvus::exec::VectorHasher>> hashers;
1277+
hashers.push_back(
1278+
milvus::exec::VectorHasher::create(milvus::DataType::INT64, 0));
1279+
auto table = std::make_unique<milvus::exec::HashTable>(
1280+
std::move(hashers), accumulators, 1000000);
1281+
1282+
// Insert numGroups distinct values
1283+
auto col = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1284+
numGroups);
1285+
auto* data = reinterpret_cast<int64_t*>(col->GetRawData());
1286+
for (int64_t i = 0; i < numGroups; i++) {
1287+
data[i] = i;
1288+
}
1289+
std::vector<milvus::VectorPtr> children = {col};
1290+
auto input = std::make_shared<milvus::RowVector>(children);
1291+
1292+
milvus::exec::HashLookup lookup1(table->hashers());
1293+
table->prepareForGroupProbe(lookup1, input);
1294+
table->groupProbe(lookup1);
1295+
1296+
// Record row pointers
1297+
std::vector<char*> firstHits(lookup1.hits_.begin(), lookup1.hits_.end());
1298+
1299+
// Re-probe with the same keys — should get the same row pointers
1300+
auto col2 = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1301+
numGroups);
1302+
auto* data2 = reinterpret_cast<int64_t*>(col2->GetRawData());
1303+
for (int64_t i = 0; i < numGroups; i++) {
1304+
data2[i] = i;
1305+
}
1306+
std::vector<milvus::VectorPtr> children2 = {col2};
1307+
auto input2 = std::make_shared<milvus::RowVector>(children2);
1308+
1309+
milvus::exec::HashLookup lookup2(table->hashers());
1310+
table->prepareForGroupProbe(lookup2, input2);
1311+
table->groupProbe(lookup2);
1312+
1313+
ASSERT_EQ(lookup2.hits_.size(), numGroups);
1314+
for (int64_t i = 0; i < numGroups; i++) {
1315+
EXPECT_EQ(lookup2.hits_[i], firstHits[i])
1316+
<< "Row pointer mismatch for group " << i;
1317+
}
1318+
// No new groups should have been created on re-probe
1319+
EXPECT_TRUE(lookup2.newGroups_.empty());
1320+
}
1321+
1322+
TEST(HashTableRehashTest, TestHashTableMaxGroupsLimit) {
1323+
// Set maxNumGroups=100, insert >100 groups, verify exception
1324+
const int64_t maxGroups = 100;
1325+
EXPECT_THROW(
1326+
{
1327+
try {
1328+
createAndInsertGroups(200, maxGroups);
1329+
} catch (const std::exception& e) {
1330+
// Verify the error message mentions the limit
1331+
std::string msg = e.what();
1332+
EXPECT_NE(msg.find("too many groups"), std::string::npos)
1333+
<< "Expected 'too many groups' in: " << msg;
1334+
EXPECT_NE(msg.find("maxGroupByGroups"), std::string::npos)
1335+
<< "Expected 'maxGroupByGroups' in: " << msg;
1336+
throw;
1337+
}
1338+
},
1339+
std::exception);
1340+
}
1341+
1342+
TEST(HashTableRehashTest, TestHashTableRehashMultipleRounds) {
1343+
// Insert 50K groups forcing ~5 rehash rounds (2048->4096->...->65536)
1344+
// Verify all groups found correctly
1345+
const int64_t numGroups = 50000;
1346+
auto table = createAndInsertGroups(numGroups, 1000000, 2048);
1347+
ASSERT_NE(table->rows(), nullptr);
1348+
EXPECT_EQ(table->rows()->allRows().size(), numGroups);
1349+
1350+
// Verify we can look up all groups
1351+
auto col = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1352+
numGroups);
1353+
auto* data = reinterpret_cast<int64_t*>(col->GetRawData());
1354+
for (int64_t i = 0; i < numGroups; i++) {
1355+
data[i] = i;
1356+
}
1357+
std::vector<milvus::VectorPtr> children = {col};
1358+
auto input = std::make_shared<milvus::RowVector>(children);
1359+
1360+
milvus::exec::HashLookup lookup(table->hashers());
1361+
table->prepareForGroupProbe(lookup, input);
1362+
table->groupProbe(lookup);
1363+
1364+
// All should be existing groups, no new groups created
1365+
EXPECT_TRUE(lookup.newGroups_.empty());
1366+
// All hits should be non-null
1367+
for (int64_t i = 0; i < numGroups; i++) {
1368+
EXPECT_NE(lookup.hits_[i], nullptr)
1369+
<< "Group " << i << " not found after multiple rehashes";
1370+
}
12151371
}

internal/util/initcore/query_node.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ func doInitQueryNodeOnce(ctx context.Context) error {
7272
cChunkRows := C.int64_t(paramtable.Get().QueryNodeCfg.ChunkRows.GetAsInt64())
7373
C.SegcoreSetChunkRows(cChunkRows)
7474

75+
cMaxGroupByGroups := C.int64_t(paramtable.Get().QueryNodeCfg.MaxGroupByGroups.GetAsInt64())
76+
C.SegcoreSetMaxGroupByGroups(cMaxGroupByGroups)
77+
7578
cKnowhereThreadPoolSize := C.uint32_t(paramtable.Get().QueryNodeCfg.KnowhereThreadPoolSize.GetAsUint32())
7679
C.SegcoreSetKnowhereSearchThreadPoolNum(cKnowhereThreadPoolSize)
7780

0 commit comments

Comments
 (0)