Skip to content

Commit 24bf60c

Browse files
fix: add logger everywhere we need to print (#122)
Co-authored-by: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 235fea2 commit 24bf60c

File tree

12 files changed

+156
-42
lines changed

12 files changed

+156
-42
lines changed

include/svs/core/kmeans.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ data::SimpleData<float> train_impl(
158158
const KMeansParameters& parameters,
159159
const Data& data,
160160
Pool& threadpool,
161-
Callback&& post_epoch_callback = lib::donothing()
161+
Callback&& post_epoch_callback = lib::donothing(),
162+
svs::logging::logger_ptr logger = svs::logging::get()
162163
) {
163164
size_t ndims = data.dimensions();
164165
auto num_clusters = parameters.clusters;
@@ -211,7 +212,7 @@ data::SimpleData<float> train_impl(
211212
old_counts[i] = 0;
212213
}
213214
}
214-
svs::logging::debug("{}", timer);
215+
svs::logging::debug(logger, "{}", timer);
215216
return centroids;
216217
}
217218

@@ -223,11 +224,12 @@ data::SimpleData<float> train(
223224
const KMeansParameters& parameters,
224225
const Data& data,
225226
ThreadPoolProto threadpool_proto,
226-
Callback&& post_epoch_callback = lib::donothing()
227+
Callback&& post_epoch_callback = lib::donothing(),
228+
svs::logging::logger_ptr logger = svs::logging::get()
227229
) {
228230
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
229231
return train_impl(
230-
parameters, data, threadpool, std::forward<Callback>(post_epoch_callback)
232+
parameters, data, threadpool, std::forward<Callback>(post_epoch_callback), logger
231233
);
232234
}
233235

@@ -239,10 +241,11 @@ data::SimpleData<float> train(
239241
const KMeansParameters& parameters,
240242
const Data& data,
241243
Pool& threadpool,
242-
Callback&& post_epoch_callback = lib::donothing()
244+
Callback&& post_epoch_callback = lib::donothing(),
245+
svs::logging::logger_ptr logger = svs::logging::get()
243246
) {
244247
return train_impl(
245-
parameters, data, threadpool, std::forward<Callback>(post_epoch_callback)
248+
parameters, data, threadpool, std::forward<Callback>(post_epoch_callback), logger
246249
);
247250
}
248251
} // namespace svs

include/svs/index/inverted/clustering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,8 @@ auto build_primary_index(
778778
std::span<const I> ids,
779779
const vamana::VamanaBuildParameters& vamana_parameters,
780780
const Distance& distance,
781-
Pool threadpool
781+
Pool threadpool,
782+
svs::logging::logger_ptr logger = svs::logging::get()
782783
) {
783784
return vamana::auto_build(
784785
vamana_parameters,
@@ -792,7 +793,8 @@ auto build_primary_index(
792793
}),
793794
distance,
794795
std::move(threadpool),
795-
HugepageAllocator<I>()
796+
HugepageAllocator<I>(),
797+
logger
796798
);
797799
}
798800

include/svs/index/inverted/memory_based.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ auto auto_build(
577577
lib::as_const_span(centroids),
578578
parameters.primary_parameters_,
579579
distance,
580-
std::move(threadpool)
580+
std::move(threadpool),
581+
logger
581582
);
582583

583584
// Cluster the dataset with the help of the primary index.

include/svs/index/vamana/calibrate.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ VamanaSearchParameters optimize_split_buffer(
177177
VamanaSearchParameters current,
178178
const F& compute_recall,
179179
const DoSearch& do_search,
180-
svs::logging::logger_ptr logger = svs::logging::get()
180+
svs::logging::logger_ptr logger
181181
) {
182182
svs::logging::trace(logger, "Entering split buffer optimization routine");
183183
assert(
@@ -334,7 +334,8 @@ std::pair<VamanaSearchParameters, bool> optimize_search_buffer(
334334
target_recall,
335335
current,
336336
compute_recall,
337-
do_search
337+
do_search,
338+
logger
338339
);
339340
}
340341
return std::make_pair(current, converged);
@@ -346,7 +347,7 @@ VamanaSearchParameters tune_prefetch(
346347
Index& index,
347348
VamanaSearchParameters search_parameters,
348349
const DoSearch& do_search,
349-
svs::logging::logger_ptr logger = svs::logging::get()
350+
svs::logging::logger_ptr logger
350351
) {
351352
svs::logging::trace(logger, "Tuning prefetch parameters");
352353
const auto& prefetch_steps = calibration_parameters.prefetch_steps_;
@@ -480,7 +481,8 @@ VamanaSearchParameters calibrate(
480481
size_t num_neighbors,
481482
double target_recall,
482483
F&& compute_recall,
483-
DoSearch&& do_search
484+
DoSearch&& do_search,
485+
svs::logging::logger_ptr logger = svs::logging::get()
484486
) {
485487
// Get the existing parameters and the default values decide which to use as the seed.
486488
auto default_parameters = VamanaSearchParameters();
@@ -492,30 +494,32 @@ VamanaSearchParameters calibrate(
492494

493495
// Step 1: Optimize aspects of the search buffer if desired.
494496
if (calibration_parameters.should_optimize_search_buffer()) {
495-
svs::logging::trace("Optimizing search buffer.");
497+
svs::logging::trace(logger, "Optimizing search buffer.");
496498
auto [best, converged] = calibration::optimize_search_buffer<Index>(
497499
calibration_parameters,
498500
current,
499501
num_neighbors,
500502
target_recall,
501503
compute_recall,
502-
do_search
504+
do_search,
505+
logger
503506
);
504507
current = best;
505508

506509
if (!converged) {
507510
svs::logging::warn(
508-
"Target recall could not be achieved. Exiting optimization early."
511+
logger, "Target recall could not be achieved. Exiting optimization early."
509512
);
510513
return current;
511514
}
512515
}
513516

514517
// Step 2: Optimize prefetch parameters.
515518
if (calibration_parameters.train_prefetchers_) {
516-
svs::logging::trace("Training Prefetchers.");
517-
current =
518-
calibration::tune_prefetch(calibration_parameters, index, current, do_search);
519+
svs::logging::trace(logger, "Training Prefetchers.");
520+
current = calibration::tune_prefetch(
521+
calibration_parameters, index, current, do_search, logger
522+
);
519523
}
520524

521525
// Finish up.

include/svs/index/vamana/dynamic_index.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ class MutableVamanaIndex {
688688
GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_};
689689
VamanaBuilder builder{
690690
graph_, data_, distance_, parameters, threadpool_, prefetch_parameters};
691-
builder.construct(alpha_, entry_point(), slots, logging::Level::Trace);
691+
builder.construct(alpha_, entry_point(), slots, logging::Level::Trace, logger_);
692692
// Mark all added entries as valid.
693693
for (const auto& i : slots) {
694694
status_[i] = SlotMetadata::Valid;
@@ -935,11 +935,10 @@ class MutableVamanaIndex {
935935
assert(entry_point_.size() == 1);
936936
auto entry_point = entry_point_[0];
937937
if (status_.at(entry_point) == SlotMetadata::Deleted) {
938-
auto logger = svs::logging::get();
939-
svs::logging::debug(logger, "Replacing entry point.");
938+
svs::logging::debug(logger_, "Replacing entry point.");
940939
auto new_entry_point =
941940
extensions::compute_entry_point(data_, threadpool_, valid);
942-
svs::logging::debug(logger, "New point: {}", new_entry_point);
941+
svs::logging::debug(logger_, "New point: {}", new_entry_point);
943942
assert(!is_deleted(new_entry_point));
944943
entry_point_[0] = new_entry_point;
945944
}
@@ -1051,7 +1050,8 @@ class MutableVamanaIndex {
10511050
num_neighbors,
10521051
target_recall,
10531052
compute_recall,
1054-
do_search
1053+
do_search,
1054+
logger_
10551055
);
10561056

10571057
set_search_parameters(p);
@@ -1267,6 +1267,15 @@ MutableVamanaIndex(
12671267
svs::logging::logger_ptr
12681268
) -> MutableVamanaIndex<graphs::SimpleBlockedGraph<uint32_t>, Data, Dist>;
12691269

1270+
template <typename Data, typename Dist, typename ExternalIds>
1271+
MutableVamanaIndex(
1272+
const VamanaBuildParameters&,
1273+
Data,
1274+
const ExternalIds&,
1275+
Dist,
1276+
size_t,
1277+
svs::logging::logger_ptr
1278+
) -> MutableVamanaIndex<graphs::SimpleBlockedGraph<uint32_t>, Data, Dist>;
12701279
namespace detail {
12711280

12721281
struct VamanaStateLoader {

include/svs/index/vamana/index.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,12 @@ struct VamanaIndexParameters {
9999
return schema == serialization_schema && version <= save_version;
100100
}
101101

102-
static VamanaIndexParameters load_legacy(const lib::ContextFreeLoadTable& table) {
102+
static VamanaIndexParameters load_legacy(
103+
const lib::ContextFreeLoadTable& table,
104+
svs::logging::logger_ptr logger = svs::logging::get()
105+
) {
103106
svs::logging::warn(
107+
logger,
104108
"Loading a legacy IndexParameters class. Please consider resaving this "
105109
"index to update the save version and prevent future breaking!\n"
106110
);

include/svs/misc/dynamic_helper.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ template <typename Idx, typename Eltype, size_t N, typename Dist> class Referenc
155155
size_t bucket_size,
156156
size_t num_neighbors,
157157
const Queries& queries,
158-
uint64_t rng_seed
158+
uint64_t rng_seed,
159+
svs::logging::logger_ptr logger = svs::logging::get()
159160
)
160161
: data_{std::move(data)}
161162
, num_queries_{queries.size()}
@@ -214,7 +215,7 @@ template <typename Idx, typename Eltype, size_t N, typename Dist> class Referenc
214215
reserve_buckets_.emplace_back(ids, std::move(bucket_groundtruth));
215216
start = stop;
216217
}
217-
svs::logging::debug("{}", timer);
218+
svs::logging::debug(logger, "{}", timer);
218219
}
219220

220221
/// @brief Return the total number of elements in the dataset.

tests/svs/index/flat/flat.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") {
2727
// Vector to store captured log messages
2828
std::vector<std::string> captured_logs;
29+
std::vector<std::string> global_captured_logs;
2930

3031
// Create a callback sink to capture log messages
3132
auto callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
@@ -39,6 +40,15 @@ CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") {
3940
auto test_logger = std::make_shared<spdlog::logger>("test_logger", callback_sink);
4041
test_logger->set_level(spdlog::level::trace);
4142

43+
auto global_callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
44+
[&global_captured_logs](const spdlog::details::log_msg& msg) {
45+
global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size());
46+
}
47+
);
48+
global_callback_sink->set_level(spdlog::level::trace);
49+
auto original_logger = svs::logging::get();
50+
original_logger->sinks().push_back(global_callback_sink);
51+
4252
std::vector<float> data{1.0f, 2.0f};
4353
auto dataView = svs::data::SimpleDataView<float>(data.data(), 2, 1);
4454
svs::distance::DistanceL2 dist;
@@ -52,6 +62,7 @@ CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") {
5262
test_logger->info("Test FlatIndex Logging");
5363

5464
// Verify the log output
65+
CATCH_REQUIRE(global_captured_logs.empty());
5566
CATCH_REQUIRE(captured_logs.size() == 1);
5667
CATCH_REQUIRE(captured_logs[0] == "Test FlatIndex Logging");
5768
}

tests/svs/index/inverted/clustering.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ CATCH_TEST_CASE("Random Clustering - End to End", "[inverted][random_clustering]
396396
CATCH_TEST_CASE("Clustering with Logger", "[logging]") {
397397
// Setup logger
398398
std::vector<std::string> captured_logs;
399+
std::vector<std::string> global_captured_logs;
400+
399401
auto callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
400402
[&captured_logs](const spdlog::details::log_msg& msg) {
401403
captured_logs.emplace_back(msg.payload.data(), msg.payload.size());
@@ -405,6 +407,15 @@ CATCH_TEST_CASE("Clustering with Logger", "[logging]") {
405407
auto test_logger = std::make_shared<spdlog::logger>("test_logger", callback_sink);
406408
test_logger->set_level(spdlog::level::trace);
407409

410+
auto global_callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
411+
[&global_captured_logs](const spdlog::details::log_msg& msg) {
412+
global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size());
413+
}
414+
);
415+
global_callback_sink->set_level(spdlog::level::trace);
416+
auto original_logger = svs::logging::get();
417+
original_logger->sinks().push_back(global_callback_sink);
418+
408419
// Setup cluster
409420
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
410421
auto vamana_parameters =
@@ -427,12 +438,14 @@ CATCH_TEST_CASE("Clustering with Logger", "[logging]") {
427438
svs::lib::as_const_span(centroids),
428439
vamana_parameters,
429440
svs::DistanceL2(),
430-
std::move(threadpool)
441+
std::move(threadpool),
442+
test_logger
431443
);
432444
auto clustering = svs::index::inverted::cluster_with(
433445
data, svs::lib::as_const_span(centroids), clustering_parameters, index, test_logger
434446
);
435447

436448
// Verify the internal log messages
437-
CATCH_REQUIRE(captured_logs[0].find("Processing batch") != std::string::npos);
449+
CATCH_REQUIRE(global_captured_logs.empty());
450+
CATCH_REQUIRE(captured_logs[0].find("Number of syncs") != std::string::npos);
438451
}

tests/svs/index/inverted/memory_based.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {
2727
// Vector to store captured log messages
2828
std::vector<std::string> captured_logs;
29+
std::vector<std::string> global_captured_logs;
2930

3031
// Create a callback sink to capture log messages
3132
auto callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
@@ -39,6 +40,15 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {
3940
auto test_logger = std::make_shared<spdlog::logger>("test_logger", callback_sink);
4041
test_logger->set_level(spdlog::level::trace);
4142

43+
auto global_callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
44+
[&global_captured_logs](const spdlog::details::log_msg& msg) {
45+
global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size());
46+
}
47+
);
48+
global_callback_sink->set_level(spdlog::level::trace);
49+
auto original_logger = svs::logging::get();
50+
original_logger->sinks().push_back(global_callback_sink);
51+
4252
// Setup index
4353
auto distance = svs::DistanceL2();
4454
constexpr auto distance_type = svs::distance_type_v<svs::DistanceL2>;
@@ -59,5 +69,6 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {
5969
);
6070

6171
// Verify the internal log messages
62-
CATCH_REQUIRE(captured_logs[0].find("Processing batch") != std::string::npos);
72+
CATCH_REQUIRE(global_captured_logs.empty());
73+
CATCH_REQUIRE(captured_logs[0].find("Number of syncs") != std::string::npos);
6374
}

0 commit comments

Comments
 (0)