Skip to content

Commit 5dc36ca

Browse files
authored
Fix graph building with SQdata (#152)
During graph construction, both the query and dataset vectors are compressed. However, this compression was not handled correctly in the computation pipeline, leading to significantly reduced recall in some cases—especially when using normalized data with MIP (Maximum Inner Product) distance.
1 parent 9d1314c commit 5dc36ca

File tree

14 files changed

+264
-65
lines changed

14 files changed

+264
-65
lines changed

include/svs/core/distance/distance_core.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ namespace svs::distance {
2828

2929
enum class AVX_AVAILABILITY { NONE, AVX2, AVX512 };
3030

31-
constexpr std::array<size_t, 9> supported_dim_list{64, 96, 100, 128, 160, 200, 512, 768, svs::Dynamic};
31+
constexpr std::array<size_t, 9> supported_dim_list{
32+
64, 96, 100, 128, 160, 200, 512, 768, svs::Dynamic};
3233

3334
template <size_t N> constexpr bool is_dim_supported() {
3435
for (auto i : supported_dim_list) {

include/svs/extensions/flat/scalar.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "svs/index/flat/flat.h"
18+
#include "svs/quantization/scalar/scalar.h"
19+
20+
namespace svs::quantization::scalar {
21+
22+
template <IsSQData Data, typename Distance>
23+
auto svs_invoke(
24+
svs::tag_t<svs::index::flat::extensions::distance>,
25+
const Data& data,
26+
const Distance& SVS_UNUSED(distance)
27+
) {
28+
return compressed_distance_t<Distance, typename Data::element_type>(
29+
data.get_scale(), data.get_bias(), data.dimensions()
30+
);
31+
}
32+
33+
} // namespace svs::quantization::scalar

include/svs/extensions/vamana/scalar.h

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
namespace svs::quantization::scalar {
2121

2222
template <IsSQData Data>
23-
SVS_FORCE_INLINE data::GetDatumAccessor svs_invoke(
23+
SVS_FORCE_INLINE scalar::DecompressionAccessor svs_invoke(
2424
svs::tag_t<svs::index::vamana::extensions::reconstruct_accessor> SVS_UNUSED(cpo),
25-
const Data& SVS_UNUSED(data)
25+
const Data& data
2626
) {
27-
return data::GetDatumAccessor();
27+
return scalar::DecompressionAccessor{data};
2828
}
2929

3030
template <IsSQData Data, typename Distance>
@@ -38,4 +38,58 @@ auto svs_invoke(
3838
);
3939
}
4040

41+
/////
42+
///// Vamana Build
43+
/////
44+
45+
template <IsSQData Data, typename Distance> struct VamanaBuildAdaptor {
46+
public:
47+
using distance_type =
48+
DecompressionAdaptor<compressed_distance_t<Distance, typename Data::element_type>>;
49+
using search_distance_type = distance_type;
50+
using general_distance_type = distance_type;
51+
52+
auto access_query_for_graph_search(const Data& data, size_t i) const {
53+
return data.get_datum(i);
54+
}
55+
56+
template <typename Query>
57+
SVS_FORCE_INLINE const Query& modify_post_search_query(
58+
const Data& SVS_UNUSED(data), size_t SVS_UNUSED(i), const Query& query
59+
) const {
60+
return query;
61+
}
62+
63+
static constexpr bool refix_argument_after_search = false;
64+
65+
data::GetDatumAccessor graph_search_accessor() const {
66+
return data::GetDatumAccessor{};
67+
}
68+
search_distance_type& graph_search_distance() { return distance_; }
69+
general_distance_type& general_distance() { return distance_; }
70+
data::GetDatumAccessor general_accessor() const { return data::GetDatumAccessor{}; }
71+
72+
template <typename Query, NeighborLike N>
73+
SVS_FORCE_INLINE Neighbor<typename N::index_type> post_search_modify(
74+
const Data& SVS_UNUSED(data),
75+
general_distance_type& SVS_UNUSED(distance),
76+
const Query& SVS_UNUSED(query),
77+
const N& n
78+
) const {
79+
return n;
80+
}
81+
82+
public:
83+
distance_type distance_{};
84+
};
85+
86+
template <IsSQData Data, typename Distance>
87+
VamanaBuildAdaptor<Data, Distance> svs_invoke(
88+
svs::tag_t<svs::index::vamana::extensions::build_adaptor>,
89+
const Data& data,
90+
const Distance& distance
91+
) {
92+
return VamanaBuildAdaptor<Data, Distance>{adapt_for_self(data, distance)};
93+
}
94+
4195
} // namespace svs::quantization::scalar

include/svs/index/vamana/dynamic_index.h

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,7 @@ class MutableVamanaIndex {
304304
sp.search_buffer_visited_set_
305305
),
306306
extensions::single_search_setup(data_, distance_),
307-
{sp.prefetch_lookahead_, sp.prefetch_step_}
308-
};
307+
{sp.prefetch_lookahead_, sp.prefetch_step_}};
309308
}
310309

311310
scratchspace_type scratchspace() const { return scratchspace(get_search_parameters()); }
@@ -512,8 +511,7 @@ class MutableVamanaIndex {
512511
search_buffer_type{sp.buffer_config_, distance::comparator(distance_)};
513512

514513
auto prefetch_parameters = GreedySearchPrefetchParameters{
515-
sp.prefetch_lookahead_, sp.prefetch_step_
516-
};
514+
sp.prefetch_lookahead_, sp.prefetch_step_};
517515

518516
// Legalize search buffer for this search.
519517
if (buffer.target() < num_neighbors) {
@@ -696,8 +694,7 @@ class MutableVamanaIndex {
696694
construction_window_size_,
697695
max_candidates_,
698696
prune_to_,
699-
use_full_search_history_
700-
};
697+
use_full_search_history_};
701698

702699
auto sp = get_search_parameters();
703700
auto prefetch_parameters =
@@ -710,8 +707,7 @@ class MutableVamanaIndex {
710707
threadpool_,
711708
prefetch_parameters,
712709
logger_,
713-
logging::Level::Trace
714-
};
710+
logging::Level::Trace};
715711
builder.construct(alpha_, entry_point(), slots, logging::Level::Trace, logger_);
716712
// Mark all added entries as valid.
717713
for (const auto& i : slots) {
@@ -1013,8 +1009,7 @@ class MutableVamanaIndex {
10131009
get_max_candidates(),
10141010
prune_to_,
10151011
get_full_search_history()},
1016-
get_search_parameters()
1017-
};
1012+
get_search_parameters()};
10181013

10191014
return lib::SaveTable(
10201015
"vamana_dynamic_auxiliary_parameters",
@@ -1328,8 +1323,7 @@ struct VamanaStateLoader {
13281323
if (debug_load_from_static) {
13291324
return VamanaStateLoader{
13301325
lib::load<VamanaIndexParameters>(table),
1331-
IDTranslator::Identity(assume_datasize)
1332-
};
1326+
IDTranslator::Identity(assume_datasize)};
13331327
}
13341328

13351329
return VamanaStateLoader{
@@ -1430,8 +1424,7 @@ auto auto_dynamic_assemble(
14301424
std::move(distance),
14311425
std::move(translator),
14321426
std::move(threadpool),
1433-
std::move(logger)
1434-
};
1427+
std::move(logger)};
14351428
}
14361429

14371430
} // namespace svs::index::vamana

include/svs/index/vamana/index.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,13 @@ struct VamanaIndexParameters {
143143
lib::load_at<size_t>(table, "construction_window_size"),
144144
lib::load_at<size_t>(table, "max_candidates"),
145145
prune_to,
146-
use_full_search_history
147-
},
146+
use_full_search_history},
148147
VamanaSearchParameters{
149148
SearchBufferConfig{
150-
lib::load_at<size_t>(table, "default_search_window_size")
151-
},
149+
lib::load_at<size_t>(table, "default_search_window_size")},
152150
lib::load_at<bool>(table, "visited_set"),
153151
4,
154-
1
155-
}
156-
};
152+
1}};
157153
}
158154

159155
static VamanaIndexParameters load(const lib::ContextFreeLoadTable& table) {
@@ -410,8 +406,7 @@ class VamanaIndex {
410406
entry_point,
411407
std::move(distance_function),
412408
std::move(threadpool),
413-
logger
414-
} {
409+
logger} {
415410
if (graph_.n_nodes() != data_.size()) {
416411
throw ANNEXCEPTION("Wrong sizes!");
417412
}
@@ -455,8 +450,7 @@ class VamanaIndex {
455450
sp.search_buffer_visited_set_
456451
),
457452
extensions::single_search_setup(data_, distance_),
458-
{sp.prefetch_lookahead_, sp.prefetch_step_}
459-
};
453+
{sp.prefetch_lookahead_, sp.prefetch_step_}};
460454
}
461455

462456
/// @brief Return scratch-space resources for external threading with default parameters
@@ -574,12 +568,11 @@ class VamanaIndex {
574568
auto search_buffer = search_buffer_type{
575569
SearchBufferConfig(search_parameters.buffer_config_),
576570
distance::comparator(distance_),
577-
search_parameters.search_buffer_visited_set_
578-
};
571+
search_parameters.search_buffer_visited_set_};
579572

580573
auto prefetch_parameters = GreedySearchPrefetchParameters{
581-
search_parameters.prefetch_lookahead_, search_parameters.prefetch_step_
582-
};
574+
search_parameters.prefetch_lookahead_,
575+
search_parameters.prefetch_step_};
583576

584577
// Increase the search window size if the defaults are not suitable for the
585578
// requested number of neighbors.
@@ -809,8 +802,7 @@ class VamanaIndex {
809802
) const {
810803
// Construct and save runtime parameters.
811804
auto parameters = VamanaIndexParameters{
812-
entry_point_.front(), build_parameters_, get_search_parameters()
813-
};
805+
entry_point_.front(), build_parameters_, get_search_parameters()};
814806

815807
// Config
816808
lib::save_to_disk(parameters, config_directory);
@@ -957,8 +949,7 @@ auto auto_build(
957949
lib::narrow<I>(entry_point),
958950
std::move(distance),
959951
std::move(threadpool),
960-
logger
961-
};
952+
logger};
962953
}
963954

964955
///
@@ -1007,8 +998,7 @@ auto auto_assemble(
1007998
I{},
1008999
std::move(distance),
10091000
std::move(threadpool),
1010-
std::move(logger)
1011-
};
1001+
std::move(logger)};
10121002
auto config = lib::load_from_disk<VamanaIndexParameters>(config_path);
10131003
index.apply(config);
10141004
return index;

include/svs/index/vamana/vamana_build.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,8 @@ template <typename Idx> class BackedgeBuffer {
124124
, bucket_locks_{parameters.num_buckets_} {}
125125

126126
BackedgeBuffer(size_t num_elements, size_t bucket_size)
127-
: BackedgeBuffer(
128-
BackedgeBufferParameters{
129-
bucket_size, lib::div_round_up(num_elements, bucket_size)
130-
}
131-
) {}
127+
: BackedgeBuffer(BackedgeBufferParameters{
128+
bucket_size, lib::div_round_up(num_elements, bucket_size)}) {}
132129

133130
// Add a point.
134131
void add_edge(Idx src, Idx dst) {
@@ -339,7 +336,9 @@ class VamanaBuilder {
339336
update_type updates{threadpool_.size()};
340337
auto main = timer.push_back("main");
341338
threads::parallel_for(
342-
threadpool_, range, [&](const auto& local_indices, uint64_t tid) {
339+
threadpool_,
340+
range,
341+
[&](const auto& local_indices, uint64_t tid) {
343342
// Thread local variables
344343
auto& thread_local_updates = updates.at(tid);
345344

@@ -490,7 +489,9 @@ class VamanaBuilder {
490489
auto range = threads::StaticPartition{indices};
491490
backedge_buffer_.reset();
492491
threads::parallel_for(
493-
threadpool_, range, [&](const auto& is, uint64_t SVS_UNUSED(tid)) {
492+
threadpool_,
493+
range,
494+
[&](const auto& is, uint64_t SVS_UNUSED(tid)) {
494495
for (auto node_id : is) {
495496
for (auto other_id : graph_.get_node(node_id)) {
496497
std::lock_guard lock{vertex_locks_[other_id]};
@@ -539,8 +540,7 @@ class VamanaBuilder {
539540
i,
540541
distance::compute(
541542
general_distance, src_data, general_accessor(data_, i)
542-
)
543-
};
543+
)};
544544
};
545545

546546
candidates.clear();

include/svs/lib/saveload/save.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ concept HasZeroArgSaveTo = requires(const T& x) {
198198
///
199199
/// The expected return type is either ``svs::lib::SaveTable`` or ``svs::lib::SaveNode``.
200200
///
201-
/// This class is automatically defined for classes ``T`` with appropriate ``save()`` methods.
201+
/// This class is automatically defined for classes ``T`` with appropriate ``save()``
202+
/// methods.
202203
///
203204
template <typename T> struct Saver {
204205
static SaveTable save(const T& x)

include/svs/lib/threads/types.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,11 +435,11 @@ DynamicPartition(const R&, size_t) -> DynamicPartition<typename R::const_iterato
435435
// Comment out the code until the issue is resolved in an upcoming fmt update
436436
// Related test: tests/svs/lib/threads/types.cpp::printing
437437
// Formatting
438-
//template <typename T>
439-
//struct fmt::formatter<svs::threads::UnitRange<T>> : svs::format_empty {
440-
//auto format(const svs::threads::UnitRange<T>& x, auto& ctx) const {
441-
//return fmt::format_to(
442-
//ctx.out(), "UnitRange<{}>({}, {})", svs::datatype_v<T>, x.start(), x.stop()
443-
//);
444-
//}
438+
// template <typename T>
439+
// struct fmt::formatter<svs::threads::UnitRange<T>> : svs::format_empty {
440+
// auto format(const svs::threads::UnitRange<T>& x, auto& ctx) const {
441+
// return fmt::format_to(
442+
// ctx.out(), "UnitRange<{}>({}, {})", svs::datatype_v<T>, x.start(), x.stop()
443+
//);
444+
//}
445445
//};

include/svs/multi-arch/x86/preprocessor.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,3 @@
8787

8888
#define DISTANCE_CS_EXTERN_TEMPLATE(N, AVX) \
8989
DISTANCE_CS_TEMPLATE_HELPER(extern template, N, AVX);
90-

0 commit comments

Comments
 (0)