Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f2380c1

Browse files
yuejiaointelmihaic
andauthoredMar 28, 2025··
feature!: verify and set default VamanaBuildParameters (#96)
BREAKING CHANGE: Removed the deprecated `num_threads` argument and added a `use_full_search_history` argument in `VamanaBuildParameters` in Python bindings. This change ensures consistency between the Python API and the C++ implementation in `include/svs/index/vamana/build_params.h`. Added default value setting and checking based on doc requirement, and added additional tests. Pin CMake < 4 as a workaround until our dependencies require CMake >= 3.5. --------- Co-authored-by: Mihai Capotă <mihai@mihaic.ro>
1 parent 1e59bf6 commit f2380c1

File tree

10 files changed

+365
-59
lines changed

10 files changed

+365
-59
lines changed
 

‎bindings/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
requires = [
1717
"setuptools>=42",
1818
"scikit-build",
19-
"cmake>=3.21", # Keep in-sync with `CMakeLists.txt`
19+
"cmake>=3.21, <4", # Keep in-sync with `CMakeLists.txt`
2020
"numpy>=1.10.0, <2", # Keep in-sync with `setup.py`
2121
"archspec>=0.2.0", # Keep in-sync with `setup.py`
2222
"toml>=0.10.2", # Keep in-sync with `setup.py` required for the tests

‎bindings/python/src/vamana.cpp

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "svs/lib/dispatcher.h"
3131
#include "svs/lib/float16.h"
3232
#include "svs/lib/meta.h"
33+
#include "svs/lib/preprocessor.h"
3334
#include "svs/orchestrators/vamana.h"
3435

3536
// pybind
@@ -420,40 +421,22 @@ void wrap(py::module& m) {
420421
size_t window_size,
421422
size_t max_candidate_pool_size,
422423
size_t prune_to,
423-
size_t num_threads) {
424-
if (num_threads != std::numeric_limits<size_t>::max()) {
425-
PyErr_WarnEx(
426-
PyExc_DeprecationWarning,
427-
"Constructing VamanaBuildParameters with the \"num_threads\" "
428-
"keyword "
429-
"argument is deprecated, no longer has any effect, and will be "
430-
"removed "
431-
"from future versions of the library. Use the \"num_threads\" "
432-
"keyword "
433-
"argument of \"svs.Vamana.build\" instead!",
434-
1
435-
);
436-
}
437-
438-
// Default the `prune_to` argument appropriately.
439-
if (prune_to == std::numeric_limits<size_t>::max()) {
440-
prune_to = graph_max_degree;
441-
}
442-
424+
bool use_full_search_history) {
443425
return svs::index::vamana::VamanaBuildParameters{
444426
alpha,
445427
graph_max_degree,
446428
window_size,
447429
max_candidate_pool_size,
448430
prune_to,
449-
true};
431+
use_full_search_history};
450432
}),
451-
py::arg("alpha") = 1.2,
452-
py::arg("graph_max_degree") = 32,
453-
py::arg("window_size") = 64,
454-
py::arg("max_candidate_pool_size") = 80,
455-
py::arg("prune_to") = std::numeric_limits<size_t>::max(),
456-
py::arg("num_threads") = std::numeric_limits<size_t>::max(),
433+
py::arg("alpha") = svs::FLOAT_PLACEHOLDER,
434+
py::arg("graph_max_degree") = svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT,
435+
py::arg("window_size") = svs::VAMANA_WINDOW_SIZE_DEFAULT,
436+
py::arg("max_candidate_pool_size") = svs::UNSIGNED_INTEGER_PLACEHOLDER,
437+
py::arg("prune_to") = svs::UNSIGNED_INTEGER_PLACEHOLDER,
438+
py::arg("use_full_search_history") =
439+
svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT,
457440
R"(
458441
Construct a new instance from keyword arguments.
459442
@@ -462,6 +445,7 @@ void wrap(py::module& m) {
462445
For distance types favoring minimization, set this to a number
463446
greater than 1.0 (typically, 1.2 is sufficient). For distance types
464447
preferring maximization, set to a value less than 1.0 (such as 0.95).
448+
The default value is 1.2 for L2 distance type and 0.95 for MIP/Cosine.
465449
graph_max_degree: The maximum out-degree in the final graph. Graphs with
466450
a higher degree tend to yield better accuracy and performance at the cost
467451
of a larger memory footprint.
@@ -470,10 +454,15 @@ void wrap(py::module& m) {
470454
longer construction time. Should be larger than `graph_max_degree`.
471455
max_candidate_pool_size: Limit on the number of candidates to consider
472456
for neighbor updates. Should be larger than `window_size`.
457+
The default value is ``graph_max_degree`` * 2.
473458
prune_to: Amount candidate lists will be pruned to when exceeding the
474459
target max degree. In general, setting this to slightly less than
475-
`graph_max_degree` will yield faster index building times. Default:
476-
`graph_max_degree`.
460+
``graph_max_degree`` will yield faster index building times. Default:
461+
` `graph_max_degree`` - 4 if
462+
``graph_max_degree`` is at least 16, otherwise ``graph_max_degree``.
463+
use_full_search_history: When true, uses the full search history during
464+
graph construction, which can improve graph quality at the expense of
465+
additional memory and potentially longer build times.
477466
)"
478467
)
479468
.def_readwrite("alpha", &svs::index::vamana::VamanaBuildParameters::alpha)
@@ -557,4 +546,4 @@ overwritten when saving the index to this directory.
557546
)"
558547
);
559548
}
560-
} // namespace svs::python::vamana
549+
} // namespace svs::python::vamana

‎bindings/python/tests/test_dynamic_vamana.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_loop(self):
9898
# here, we set an expected mid-point for the recall and allow it to wander up and
9999
# down by a little.
100100
expected_recall = 0.845
101-
expected_recall_delta = 0.03
101+
expected_recall_delta = 0.05
102102

103103
reference = ReferenceDataset(num_threads = num_threads)
104104
data, ids = reference.new_ids(5000)

‎bindings/python/tests/test_vamana.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,6 @@ def test_basic(self):
281281
self._test_basic(loader, matcher, first_iter = first_iter)
282282
first_iter = False
283283

284-
def test_deprecation(self):
285-
with warnings.catch_warnings(record = True) as w:
286-
p = svs.VamanaBuildParameters(num_threads = 1)
287-
self.assertTrue(len(w) == 1)
288-
self.assertTrue(issubclass(w[0].category, DeprecationWarning))
289-
self.assertTrue("VamanaBuildParameters" in str(w[0].message))
290-
291284
def _groundtruth_map(self):
292285
return {
293286
svs.DistanceType.L2: test_groundtruth_l2,

‎include/svs/index/vamana/build_params.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
// svs
20+
#include "svs/lib/preprocessor.h"
2021
#include "svs/lib/saveload.h"
2122

2223
// stl
@@ -44,33 +45,33 @@ struct VamanaBuildParameters {
4445
, use_full_search_history{use_full_search_history_} {}
4546

4647
/// The pruning parameter.
47-
float alpha;
48+
float alpha = svs::FLOAT_PLACEHOLDER;
4849

4950
/// The maximum degree in the graph. A higher max degree may yield a higher quality
5051
/// graph in terms of recall for performance, but the memory footprint of the graph is
5152
/// directly proportional to the maximum degree.
52-
size_t graph_max_degree;
53+
size_t graph_max_degree = svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT;
5354

5455
/// The search window size to use during graph construction. A higher search window
5556
/// size will yield a higher quality graph since more overall vertices are considered,
5657
/// but will increase construction time.
57-
size_t window_size;
58+
size_t window_size = svs::VAMANA_WINDOW_SIZE_DEFAULT;
5859

5960
/// Set a limit on the number of neighbors considered during pruning. In practice, set
6061
/// this to a high number (at least 5 times greater than the window_size) and forget
6162
/// about it.
62-
size_t max_candidate_pool_size;
63+
size_t max_candidate_pool_size = svs::UNSIGNED_INTEGER_PLACEHOLDER;
6364

6465
/// This is the amount that candidates will be pruned to after certain pruning
6566
/// procedures. Setting this to less than ``graph_max_degree`` can result in significant
6667
/// speedups in index building.
67-
size_t prune_to;
68+
size_t prune_to = svs::UNSIGNED_INTEGER_PLACEHOLDER;
6869

6970
/// When building, either the contents of the search buffer can be used or the entire
7071
/// search history can be used.
7172
///
7273
/// The latter case may yield a slightly better graph as the cost of more search time.
73-
bool use_full_search_history = true;
74+
bool use_full_search_history = svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT;
7475

7576
///// Comparison
7677
friend bool
@@ -129,4 +130,4 @@ struct VamanaBuildParameters {
129130
);
130131
}
131132
};
132-
} // namespace svs::index::vamana
133+
} // namespace svs::index::vamana

‎include/svs/index/vamana/dynamic_index.h

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "svs/index/vamana/index.h"
3939
#include "svs/index/vamana/vamana_build.h"
4040
#include "svs/lib/boundscheck.h"
41+
#include "svs/lib/preprocessor.h"
4142
#include "svs/lib/threads.h"
4243

4344
namespace svs::index::vamana {
@@ -157,6 +158,9 @@ class MutableVamanaIndex {
157158
float alpha_ = 1.2;
158159
bool use_full_search_history_ = true;
159160

161+
// Construction parameters
162+
VamanaBuildParameters build_parameters_{};
163+
160164
// SVS logger for per index logging
161165
svs::logging::logger_ptr logger_;
162166

@@ -210,12 +214,19 @@ class MutableVamanaIndex {
210214
, distance_(std::move(distance_function))
211215
, threadpool_(threads::as_threadpool(std::move(threadpool_proto)))
212216
, search_parameters_(vamana::construct_default_search_parameters(data_))
213-
, construction_window_size_(parameters.window_size)
214-
, max_candidates_(parameters.max_candidate_pool_size)
215-
, prune_to_(parameters.prune_to)
216-
, alpha_(parameters.alpha)
217-
, use_full_search_history_{parameters.use_full_search_history}
217+
, build_parameters_(parameters)
218218
, logger_{std::move(logger)} {
219+
// Verify and set defaults directly on the input parameters
220+
verify_and_set_default_index_parameters(build_parameters_, distance_function);
221+
222+
// Set graph again as verify function might change graph_max_degree parameter
223+
graph_ = Graph{data_.size(), build_parameters_.graph_max_degree};
224+
construction_window_size_ = build_parameters_.window_size;
225+
max_candidates_ = build_parameters_.max_candidate_pool_size;
226+
prune_to_ = build_parameters_.prune_to;
227+
alpha_ = build_parameters_.alpha;
228+
use_full_search_history_ = build_parameters_.use_full_search_history;
229+
219230
// Setup the initial translation of external to internal ids.
220231
translator_.insert(external_ids, threads::UnitRange<Idx>(0, external_ids.size()));
221232

@@ -227,10 +238,12 @@ class MutableVamanaIndex {
227238
auto prefetch_parameters =
228239
GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_};
229240
auto builder = VamanaBuilder(
230-
graph_, data_, distance_, parameters, threadpool_, prefetch_parameters
241+
graph_, data_, distance_, build_parameters_, threadpool_, prefetch_parameters
231242
);
232243
builder.construct(1.0f, entry_point_[0], logging::Level::Info, logger_);
233-
builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger_);
244+
builder.construct(
245+
build_parameters_.alpha, entry_point_[0], logging::Level::Info, logger_
246+
);
234247
}
235248

236249
/// @brief Post re-load constructor.
@@ -1346,4 +1359,4 @@ auto auto_dynamic_assemble(
13461359
std::move(logger)};
13471360
}
13481361

1349-
} // namespace svs::index::vamana
1362+
} // namespace svs::index::vamana

‎include/svs/index/vamana/index.h

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,19 +404,22 @@ class VamanaIndex {
404404
if (graph_.n_nodes() != data_.size()) {
405405
throw ANNEXCEPTION("Wrong sizes!");
406406
}
407-
408407
build_parameters_ = parameters;
408+
// verify the parameters before set local var
409+
verify_and_set_default_index_parameters(build_parameters_, distance_function);
409410
auto builder = VamanaBuilder(
410411
graph_,
411412
data_,
412413
distance_,
413-
parameters,
414+
build_parameters_,
414415
threadpool_,
415416
extensions::estimate_prefetch_parameters(data_)
416417
);
417418

418419
builder.construct(1.0F, entry_point_[0], logging::Level::Info, logger);
419-
builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger);
420+
builder.construct(
421+
build_parameters_.alpha, entry_point_[0], logging::Level::Info, logger
422+
);
420423
}
421424

422425
/// @brief Getter method for logger
@@ -896,10 +899,13 @@ auto auto_build(
896899
auto entry_point = extensions::compute_entry_point(data, threadpool);
897900

898901
// Default graph.
899-
auto graph = default_graph(data.size(), parameters.graph_max_degree, graph_allocator);
902+
auto verified_parameters = parameters;
903+
verify_and_set_default_index_parameters(verified_parameters, distance);
904+
auto graph =
905+
default_graph(data.size(), verified_parameters.graph_max_degree, graph_allocator);
900906
using I = typename decltype(graph)::index_type;
901907
return VamanaIndex{
902-
parameters,
908+
verified_parameters,
903909
std::move(graph),
904910
std::move(data),
905911
lib::narrow<I>(entry_point),
@@ -959,4 +965,57 @@ auto auto_assemble(
959965
index.apply(config);
960966
return index;
961967
}
968+
969+
/// @brief Verify parameters and set defaults if needed
970+
template <typename Dist>
971+
void verify_and_set_default_index_parameters(
972+
VamanaBuildParameters& parameters, Dist distance_function
973+
) {
974+
// Set default values
975+
if (parameters.max_candidate_pool_size == svs::UNSIGNED_INTEGER_PLACEHOLDER) {
976+
parameters.max_candidate_pool_size = 2 * parameters.graph_max_degree;
977+
}
978+
979+
if (parameters.prune_to == svs::UNSIGNED_INTEGER_PLACEHOLDER) {
980+
if (parameters.graph_max_degree >= 16) {
981+
parameters.prune_to = parameters.graph_max_degree - 4;
982+
} else {
983+
parameters.prune_to = parameters.graph_max_degree;
984+
}
985+
}
986+
987+
// Check supported distance type using std::is_same type trait
988+
using dist_type = std::decay_t<decltype(distance_function)>;
989+
// Create type flags for each distance type
990+
constexpr bool is_L2 = std::is_same_v<dist_type, svs::distance::DistanceL2>;
991+
constexpr bool is_IP = std::is_same_v<dist_type, svs::distance::DistanceIP>;
992+
constexpr bool is_Cosine =
993+
std::is_same_v<dist_type, svs::distance::DistanceCosineSimilarity>;
994+
995+
// Handle alpha based on distance type
996+
if constexpr (is_L2) {
997+
if (parameters.alpha == svs::FLOAT_PLACEHOLDER) {
998+
parameters.alpha = svs::VAMANA_ALPHA_MINIMIZE_DEFAULT;
999+
} else if (parameters.alpha < 1.0f) {
1000+
// Check User set values
1001+
throw std::invalid_argument("For L2 distance, alpha must be >= 1.0");
1002+
}
1003+
} else if constexpr (is_IP || is_Cosine) {
1004+
if (parameters.alpha == svs::FLOAT_PLACEHOLDER) {
1005+
parameters.alpha = svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT;
1006+
} else if (parameters.alpha > 1.0f) {
1007+
// Check User set values
1008+
throw std::invalid_argument("For MIP/Cosine distance, alpha must be <= 1.0");
1009+
} else if (parameters.alpha <= 0.0f) {
1010+
throw std::invalid_argument("alpha must be > 0");
1011+
}
1012+
} else {
1013+
throw std::invalid_argument("Unsupported distance type");
1014+
}
1015+
1016+
// Check prune_to <= graph_max_degree
1017+
if (parameters.prune_to > parameters.graph_max_degree) {
1018+
throw std::invalid_argument("prune_to must be <= graph_max_degree");
1019+
}
1020+
}
9621021
} // namespace svs::index::vamana

‎include/svs/lib/preprocessor.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
#pragma once
1818

19+
#include <cstddef>
20+
#include <limits>
21+
1922
namespace svs::preprocessor::detail {
2023

2124
// consteval functions for working with preprocessor defines.
@@ -159,3 +162,14 @@ inline constexpr bool have_avx512_avx2 = true;
159162
#endif
160163

161164
} // namespace svs::arch
165+
166+
namespace svs {
167+
// Maximum values used as default initializers
168+
inline constexpr size_t UNSIGNED_INTEGER_PLACEHOLDER = std::numeric_limits<size_t>::max();
169+
inline constexpr float FLOAT_PLACEHOLDER = std::numeric_limits<float>::max();
170+
inline constexpr float VAMANA_GRAPH_MAX_DEGREE_DEFAULT = 32;
171+
inline constexpr float VAMANA_WINDOW_SIZE_DEFAULT = 64;
172+
inline constexpr bool VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT = true;
173+
inline constexpr float VAMANA_ALPHA_MINIMIZE_DEFAULT = 1.2;
174+
inline constexpr float VAMANA_ALPHA_MAXIMIZE_DEFAULT = 0.95;
175+
} // namespace svs

‎tests/svs/index/vamana/dynamic_index_2.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "svs/core/recall.h"
2020
#include "svs/index/flat/flat.h"
2121
#include "svs/index/vamana/dynamic_index.h"
22+
#include "svs/lib/preprocessor.h"
2223
#include "svs/lib/timing.h"
2324

2425
#include "svs/misc/dynamic_helper.h"
@@ -476,4 +477,144 @@ CATCH_TEST_CASE("Dynamic MutableVamanaIndex Default Logger Test", "[logging]") {
476477
// Verify that the default logger is used
477478
auto default_logger = svs::logging::get();
478479
CATCH_REQUIRE(index.get_logger() == default_logger);
480+
}
481+
482+
CATCH_TEST_CASE("Dynamic Vamana Index Default Parameters", "[parameter][vamana]") {
483+
using Catch::Approx;
484+
std::filesystem::path data_path = test_dataset::data_svs_file();
485+
486+
CATCH_SECTION("L2 Distance Defaults") {
487+
auto expected_result = test_dataset::vamana::expected_build_results(
488+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
489+
);
490+
auto build_params = expected_result.build_parameters_.value();
491+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
492+
493+
// Get IDs for all points in the dataset
494+
std::vector<size_t> indices(data_loader.size());
495+
std::iota(indices.begin(), indices.end(), 0);
496+
497+
// Build dynamic index with L2 distance
498+
auto index = svs::index::vamana::MutableVamanaIndex(
499+
build_params, std::move(data_loader), indices, svs::distance::DistanceL2(), 2
500+
);
501+
502+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT));
503+
}
504+
505+
CATCH_SECTION("MIP Distance Defaults") {
506+
auto expected_result = test_dataset::vamana::expected_build_results(
507+
svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32)
508+
);
509+
auto build_params = expected_result.build_parameters_.value();
510+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
511+
512+
// Get IDs for all points in the dataset
513+
std::vector<size_t> indices(data_loader.size());
514+
std::iota(indices.begin(), indices.end(), 0);
515+
516+
// Build dynamic index with MIP distance
517+
auto index = svs::index::vamana::MutableVamanaIndex(
518+
build_params, std::move(data_loader), indices, svs::distance::DistanceIP(), 2
519+
);
520+
521+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT));
522+
}
523+
524+
CATCH_SECTION("Invalid Alpha for L2") {
525+
auto expected_result = test_dataset::vamana::expected_build_results(
526+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
527+
);
528+
auto build_params = expected_result.build_parameters_.value();
529+
build_params.alpha = 0.8f;
530+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
531+
532+
// Get IDs for all points in the dataset
533+
std::vector<size_t> indices(data_loader.size());
534+
std::iota(indices.begin(), indices.end(), 0);
535+
536+
CATCH_REQUIRE_THROWS_WITH(
537+
svs::index::vamana::MutableVamanaIndex(
538+
build_params,
539+
std::move(data_loader),
540+
indices,
541+
svs::distance::DistanceL2(),
542+
2
543+
),
544+
"For L2 distance, alpha must be >= 1.0"
545+
);
546+
}
547+
548+
CATCH_SECTION("Invalid Alpha for MIP") {
549+
auto expected_result = test_dataset::vamana::expected_build_results(
550+
svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32)
551+
);
552+
auto build_params = expected_result.build_parameters_.value();
553+
build_params.alpha = 1.2f;
554+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
555+
556+
// Get IDs for all points in the dataset
557+
std::vector<size_t> indices(data_loader.size());
558+
std::iota(indices.begin(), indices.end(), 0);
559+
560+
CATCH_REQUIRE_THROWS_WITH(
561+
svs::index::vamana::MutableVamanaIndex(
562+
build_params,
563+
std::move(data_loader),
564+
indices,
565+
svs::distance::DistanceIP(),
566+
2
567+
),
568+
"For MIP/Cosine distance, alpha must be <= 1.0"
569+
);
570+
}
571+
572+
CATCH_SECTION("Invalid prune_to > graph_max_degree") {
573+
auto expected_result = test_dataset::vamana::expected_build_results(
574+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
575+
);
576+
auto build_params = expected_result.build_parameters_.value();
577+
build_params.prune_to = build_params.graph_max_degree + 10;
578+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
579+
580+
// Get IDs for all points in the dataset
581+
std::vector<size_t> indices(data_loader.size());
582+
std::iota(indices.begin(), indices.end(), 0);
583+
584+
CATCH_REQUIRE_THROWS_WITH(
585+
svs::index::vamana::MutableVamanaIndex(
586+
build_params,
587+
std::move(data_loader),
588+
indices,
589+
svs::distance::DistanceL2(),
590+
2
591+
),
592+
"prune_to must be <= graph_max_degree"
593+
);
594+
}
595+
596+
CATCH_SECTION("L2 Distance Empty Params") {
597+
svs::index::vamana::VamanaBuildParameters params;
598+
std::vector<float> data(32);
599+
for (size_t i = 0; i < data.size(); i++) {
600+
data[i] = static_cast<float>(i + 1);
601+
}
602+
auto data_view = svs::data::SimpleDataView<float>(data.data(), 8, 4);
603+
std::vector<size_t> indices = {0, 1, 2, 3, 4, 5, 6, 7};
604+
auto index = svs::index::vamana::MutableVamanaIndex(
605+
params, std::move(data_view), indices, svs::distance::DistanceL2(), 1
606+
);
607+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT));
608+
CATCH_REQUIRE(index.get_graph_max_degree() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT);
609+
CATCH_REQUIRE(index.get_prune_to() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT - 4);
610+
CATCH_REQUIRE(
611+
index.get_construction_window_size() == svs::VAMANA_WINDOW_SIZE_DEFAULT
612+
);
613+
CATCH_REQUIRE(
614+
index.get_max_candidates() == 2 * svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT
615+
);
616+
CATCH_REQUIRE(
617+
index.get_full_search_history() == svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT
618+
);
619+
}
479620
}

‎tests/svs/index/vamana/index.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,26 @@
1616

1717
// Header under test
1818
#include "svs/index/vamana/index.h"
19+
20+
// Logging
1921
#include "spdlog/sinks/callback_sink.h"
2022
#include "svs/core/logging.h"
2123

24+
// svs
25+
#include "svs/index/vamana/build_params.h"
26+
#include "svs/lib/preprocessor.h"
27+
2228
// catch2
2329
#include "catch2/catch_test_macros.hpp"
30+
#include <catch2/catch_approx.hpp>
2431

32+
// tests
33+
#include "tests/utils/test_dataset.h"
34+
#include "tests/utils/utils.h"
35+
#include "tests/utils/vamana_reference.h"
36+
37+
// svsbenchmark
38+
#include "svs-benchmark/benchmark.h"
2539
// stl
2640
#include <string_view>
2741

@@ -150,4 +164,86 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") {
150164
// Verify the internal log messages
151165
CATCH_REQUIRE(captured_logs[0].find("Number of syncs:") != std::string::npos);
152166
CATCH_REQUIRE(captured_logs[1].find("Batch Size:") != std::string::npos);
167+
}
168+
169+
CATCH_TEST_CASE("Vamana Index Default Parameters", "[parameter][vamana]") {
170+
using Catch::Approx;
171+
std::filesystem::path data_path = test_dataset::data_svs_file();
172+
173+
CATCH_SECTION("L2 Distance Defaults") {
174+
auto expected_result = test_dataset::vamana::expected_build_results(
175+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
176+
);
177+
auto build_params = expected_result.build_parameters_.value();
178+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
179+
svs::Vamana index = svs::Vamana::build<float>(build_params, data_loader, svs::L2);
180+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT));
181+
}
182+
183+
CATCH_SECTION("MIP Distance Defaults") {
184+
auto expected_result = test_dataset::vamana::expected_build_results(
185+
svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32)
186+
);
187+
auto build_params = expected_result.build_parameters_.value();
188+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
189+
svs::Vamana index = svs::Vamana::build<float>(build_params, data_loader, svs::MIP);
190+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT));
191+
}
192+
193+
CATCH_SECTION("Invalid Alpha for L2") {
194+
auto expected_result = test_dataset::vamana::expected_build_results(
195+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
196+
);
197+
auto build_params = expected_result.build_parameters_.value();
198+
build_params.alpha = 0.8f;
199+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
200+
CATCH_REQUIRE_THROWS_WITH(
201+
svs::Vamana::build<float>(build_params, data_loader, svs::L2),
202+
"For L2 distance, alpha must be >= 1.0"
203+
);
204+
}
205+
206+
CATCH_SECTION("Invalid Alpha for MIP") {
207+
auto expected_result = test_dataset::vamana::expected_build_results(
208+
svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32)
209+
);
210+
auto build_params = expected_result.build_parameters_.value();
211+
build_params.alpha = 1.2f;
212+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
213+
CATCH_REQUIRE_THROWS_WITH(
214+
svs::Vamana::build<float>(build_params, data_loader, svs::MIP),
215+
"For MIP/Cosine distance, alpha must be <= 1.0"
216+
);
217+
}
218+
219+
CATCH_SECTION("Invalid prune_to > graph_max_degree") {
220+
auto expected_result = test_dataset::vamana::expected_build_results(
221+
svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32)
222+
);
223+
auto build_params = expected_result.build_parameters_.value();
224+
build_params.prune_to = build_params.graph_max_degree + 10;
225+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
226+
CATCH_REQUIRE_THROWS_WITH(
227+
svs::Vamana::build<float>(build_params, data_loader, svs::L2),
228+
"prune_to must be <= graph_max_degree"
229+
);
230+
}
231+
232+
CATCH_SECTION("L2 Distance Empty Params") {
233+
svs::index::vamana::VamanaBuildParameters empty_params;
234+
auto data_loader = svs::data::SimpleData<float>::load(data_path);
235+
svs::Vamana index = svs::Vamana::build<float>(empty_params, data_loader, svs::L2);
236+
CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT));
237+
CATCH_REQUIRE(index.get_graph_max_degree() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT);
238+
CATCH_REQUIRE(index.get_prune_to() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT - 4);
239+
CATCH_REQUIRE(
240+
index.get_construction_window_size() == svs::VAMANA_WINDOW_SIZE_DEFAULT
241+
);
242+
CATCH_REQUIRE(
243+
index.get_max_candidates() == 2 * svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT
244+
);
245+
CATCH_REQUIRE(
246+
index.get_full_search_history() == svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT
247+
);
248+
}
153249
}

0 commit comments

Comments
 (0)
Please sign in to comment.