1818#include " test_utils/storage_test_utils.h"
1919#include " exec/expression/function/FunctionFactory.h"
2020#include " exec/operator/query-agg/CountAggregateBase.h"
21+ #include " exec/HashTable.h"
22+ #include " exec/VectorHasher.h"
2123#include " query/PlanImpl.h"
2224#include " query/PlanNode.h"
2325
@@ -1212,4 +1214,158 @@ TEST_P(QueryAggTest, GroupByEmptyResultMultipleAggs) {
12121214 EXPECT_EQ (
12131215 retrieve_results->fields_data (3 ).scalars ().double_data ().data_size (),
12141216 0 );
1217+ }
1218+
1219+ // ============================================================
1220+ // HashTable rehash and group limit tests
1221+ // ============================================================
1222+
1223+ namespace {
1224+ // Helper to create a HashTable with a single INT64 key column, no accumulators.
1225+ // Inserts 'numGroups' distinct INT64 values via groupProbe in batches.
1226+ // Returns the HashTable.
1227+ std::unique_ptr<milvus::exec::HashTable>
1228+ createAndInsertGroups (int64_t numGroups,
1229+ int64_t maxNumGroups,
1230+ int batchSize = 1024 ) {
1231+ std::vector<milvus::exec::Accumulator> accumulators;
1232+ std::vector<std::unique_ptr<milvus::exec::VectorHasher>> hashers;
1233+ hashers.push_back (
1234+ milvus::exec::VectorHasher::create (milvus::DataType::INT64, 0 ));
1235+ auto table = std::make_unique<milvus::exec::HashTable>(
1236+ std::move (hashers), accumulators, maxNumGroups);
1237+
1238+ int64_t inserted = 0 ;
1239+ while (inserted < numGroups) {
1240+ int64_t thisCount = std::min ((int64_t )batchSize, numGroups - inserted);
1241+ auto col = std::make_shared<milvus::ColumnVector>(
1242+ milvus::DataType::INT64, thisCount);
1243+ auto * data = reinterpret_cast <int64_t *>(col->GetRawData ());
1244+ for (int64_t i = 0 ; i < thisCount; i++) {
1245+ data[i] = inserted + i;
1246+ }
1247+
1248+ std::vector<milvus::DataType> dtypes = {milvus::DataType::INT64};
1249+ std::vector<milvus::VectorPtr> children = {col};
1250+ auto input = std::make_shared<milvus::RowVector>(children);
1251+
1252+ milvus::exec::HashLookup lookup (table->hashers ());
1253+ table->prepareForGroupProbe (lookup, input);
1254+ table->groupProbe (lookup);
1255+
1256+ inserted += thisCount;
1257+ }
1258+ return table;
1259+ }
1260+ } // namespace
1261+
1262+ TEST (HashTableRehashTest, TestHashTableRehashBasic) {
1263+ // Insert 5000 distinct groups (exceeds initial 2048 capacity),
1264+ // verify no crash and numDistinct == 5000
1265+ const int64_t numGroups = 5000 ;
1266+ auto table = createAndInsertGroups (numGroups, 1000000 );
1267+ ASSERT_NE (table->rows (), nullptr );
1268+ EXPECT_EQ (table->rows ()->allRows ().size (), numGroups);
1269+ }
1270+
1271+ TEST (HashTableRehashTest, TestHashTableRehashCorrectness) {
1272+ // Insert N groups, record row pointers.
1273+ // Then re-probe the same keys and verify we get back the same row pointers.
1274+ const int64_t numGroups = 3000 ;
1275+ std::vector<milvus::exec::Accumulator> accumulators;
1276+ std::vector<std::unique_ptr<milvus::exec::VectorHasher>> hashers;
1277+ hashers.push_back (
1278+ milvus::exec::VectorHasher::create (milvus::DataType::INT64, 0 ));
1279+ auto table = std::make_unique<milvus::exec::HashTable>(
1280+ std::move (hashers), accumulators, 1000000 );
1281+
1282+ // Insert numGroups distinct values
1283+ auto col = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1284+ numGroups);
1285+ auto * data = reinterpret_cast <int64_t *>(col->GetRawData ());
1286+ for (int64_t i = 0 ; i < numGroups; i++) {
1287+ data[i] = i;
1288+ }
1289+ std::vector<milvus::VectorPtr> children = {col};
1290+ auto input = std::make_shared<milvus::RowVector>(children);
1291+
1292+ milvus::exec::HashLookup lookup1 (table->hashers ());
1293+ table->prepareForGroupProbe (lookup1, input);
1294+ table->groupProbe (lookup1);
1295+
1296+ // Record row pointers
1297+ std::vector<char *> firstHits (lookup1.hits_ .begin (), lookup1.hits_ .end ());
1298+
1299+ // Re-probe with the same keys — should get the same row pointers
1300+ auto col2 = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1301+ numGroups);
1302+ auto * data2 = reinterpret_cast <int64_t *>(col2->GetRawData ());
1303+ for (int64_t i = 0 ; i < numGroups; i++) {
1304+ data2[i] = i;
1305+ }
1306+ std::vector<milvus::VectorPtr> children2 = {col2};
1307+ auto input2 = std::make_shared<milvus::RowVector>(children2);
1308+
1309+ milvus::exec::HashLookup lookup2 (table->hashers ());
1310+ table->prepareForGroupProbe (lookup2, input2);
1311+ table->groupProbe (lookup2);
1312+
1313+ ASSERT_EQ (lookup2.hits_ .size (), numGroups);
1314+ for (int64_t i = 0 ; i < numGroups; i++) {
1315+ EXPECT_EQ (lookup2.hits_ [i], firstHits[i])
1316+ << " Row pointer mismatch for group " << i;
1317+ }
1318+ // No new groups should have been created on re-probe
1319+ EXPECT_TRUE (lookup2.newGroups_ .empty ());
1320+ }
1321+
1322+ TEST (HashTableRehashTest, TestHashTableMaxGroupsLimit) {
1323+ // Set maxNumGroups=100, insert >100 groups, verify exception
1324+ const int64_t maxGroups = 100 ;
1325+ EXPECT_THROW (
1326+ {
1327+ try {
1328+ createAndInsertGroups (200 , maxGroups);
1329+ } catch (const std::exception& e) {
1330+ // Verify the error message mentions the limit
1331+ std::string msg = e.what ();
1332+ EXPECT_NE (msg.find (" too many groups" ), std::string::npos)
1333+ << " Expected 'too many groups' in: " << msg;
1334+ EXPECT_NE (msg.find (" maxGroupByGroups" ), std::string::npos)
1335+ << " Expected 'maxGroupByGroups' in: " << msg;
1336+ throw ;
1337+ }
1338+ },
1339+ std::exception);
1340+ }
1341+
1342+ TEST (HashTableRehashTest, TestHashTableRehashMultipleRounds) {
1343+ // Insert 50K groups forcing ~5 rehash rounds (2048->4096->...->65536)
1344+ // Verify all groups found correctly
1345+ const int64_t numGroups = 50000 ;
1346+ auto table = createAndInsertGroups (numGroups, 1000000 , 2048 );
1347+ ASSERT_NE (table->rows (), nullptr );
1348+ EXPECT_EQ (table->rows ()->allRows ().size (), numGroups);
1349+
1350+ // Verify we can look up all groups
1351+ auto col = std::make_shared<milvus::ColumnVector>(milvus::DataType::INT64,
1352+ numGroups);
1353+ auto * data = reinterpret_cast <int64_t *>(col->GetRawData ());
1354+ for (int64_t i = 0 ; i < numGroups; i++) {
1355+ data[i] = i;
1356+ }
1357+ std::vector<milvus::VectorPtr> children = {col};
1358+ auto input = std::make_shared<milvus::RowVector>(children);
1359+
1360+ milvus::exec::HashLookup lookup (table->hashers ());
1361+ table->prepareForGroupProbe (lookup, input);
1362+ table->groupProbe (lookup);
1363+
1364+ // All should be existing groups, no new groups created
1365+ EXPECT_TRUE (lookup.newGroups_ .empty ());
1366+ // All hits should be non-null
1367+ for (int64_t i = 0 ; i < numGroups; i++) {
1368+ EXPECT_NE (lookup.hits_ [i], nullptr )
1369+ << " Group " << i << " not found after multiple rehashes" ;
1370+ }
12151371}
0 commit comments