Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions src/indexes/text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,28 @@ absl::StatusOr<bool> Text::AddRecord(const InternedStringPtr& key,
absl::string_view data) {
valkey_search::indexes::text::Lexer lexer;

auto tokens =
std::vector<std::string> stemmed_words;
auto original_words =
lexer.Tokenize(data, text_index_schema_->GetPunctuationBitmap(),
text_index_schema_->GetStemmer(), !no_stem_,
min_stem_size_, text_index_schema_->GetStopWordsSet());
min_stem_size_, text_index_schema_->GetStopWordsSet(),
!no_stem_ ? &stemmed_words : nullptr,
text_index_schema_->GetStemmerMutex());

if (!tokens.ok()) {
if (tokens.status().code() == absl::StatusCode::kInvalidArgument) {
if (!original_words.ok()) {
if (original_words.status().code() == absl::StatusCode::kInvalidArgument) {
return false; // UTF-8 errors → hash_indexing_failures
}
return tokens.status();
return original_words.status();
}

for (uint32_t position = 0; position < tokens->size(); ++position) {
const auto& token = (*tokens)[position];
size_t stemmed_index = 0;
for (uint32_t position = 0; position < original_words->size(); ++position) {
const auto& original_word = (*original_words)[position];

// Index the original word in the prefix tree
text_index_schema_->GetTextIndex()->prefix_.Mutate(
token,
original_word,
[&](std::optional<std::shared_ptr<text::Postings>> existing)
-> std::optional<std::shared_ptr<text::Postings>> {
std::shared_ptr<text::Postings> postings;
Expand All @@ -62,6 +68,25 @@ absl::StatusOr<bool> Text::AddRecord(const InternedStringPtr& key,
postings->InsertPosting(key, text_field_number_, position);
return postings;
});

// If stemming occurred, update the stem tree
if (!no_stem_ && stemmed_index < stemmed_words.size()) {
const auto& stemmed_word = stemmed_words[stemmed_index];
text_index_schema_->GetTextIndex()->stem_.Mutate(
stemmed_word,
[&](std::optional<std::shared_ptr<text::StemTarget>> existing)
-> std::optional<std::shared_ptr<text::StemTarget>> {
std::shared_ptr<text::StemTarget> stem_target;
if (existing.has_value()) {
stem_target = existing.value();
} else {
stem_target = std::make_shared<text::StemTarget>();
}
stem_target->insert(original_word);
return stem_target;
});
stemmed_index++;
}
}

return true;
Expand Down Expand Up @@ -129,6 +154,9 @@ std::unique_ptr<Text::EntriesFetcher> Text::Search(
CalculateSize(predicate), text_index_schema_->GetTextIndex(),
negate ? &untracked_keys_ : nullptr);
fetcher->predicate_ = &predicate;
fetcher->text_index_schema_ = text_index_schema_;
fetcher->no_stem_ = no_stem_;
fetcher->min_stem_size_ = min_stem_size_;
// TODO : We only support single field queries for now. Change below when we
// support multiple and all fields.
fetcher->field_mask_ = 1ULL << text_field_number_;
Expand All @@ -141,7 +169,8 @@ std::unique_ptr<EntriesFetcherIteratorBase> Text::EntriesFetcher::Begin() {
if (auto term = dynamic_cast<const query::TermPredicate*>(predicate_)) {
auto iter = text_index_->prefix_.GetWordIterator(term->GetTextString());
auto itr = std::make_unique<text::TermIterator>(
iter, term->GetTextString(), field_mask_, untracked_keys_);
iter, term->GetTextString(), field_mask_, untracked_keys_,
!no_stem_, text_index_schema_, min_stem_size_, text_index_);
itr->Next();
return itr;
} else if (auto prefix =
Expand Down
5 changes: 5 additions & 0 deletions src/indexes/text.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class Text : public IndexBase {
absl::string_view data_;
bool no_field_{false};
text::FieldMaskPredicate field_mask_;

// Stemming configuration
std::shared_ptr<text::TextIndexSchema> text_index_schema_;
bool no_stem_;
int32_t min_stem_size_;
};

// Calculate size based on the predicate.
Expand Down
39 changes: 31 additions & 8 deletions src/indexes/text/lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "src/indexes/text/lexer.h"

#include <memory>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/ascii.h"
Expand All @@ -17,7 +18,9 @@ namespace valkey_search::indexes::text {
absl::StatusOr<std::vector<std::string>> Lexer::Tokenize(
absl::string_view text, const std::bitset<256>& punct_bitmap,
sb_stemmer* stemmer, bool stemming_enabled, uint32_t min_stem_size,
const absl::flat_hash_set<std::string>& stop_words_set) const {
const absl::flat_hash_set<std::string>& stop_words_set,
std::vector<std::string>* stemmed_words,
std::mutex* stemmer_mutex) const {
if (!IsValidUtf8(text)) {
return absl::InvalidArgumentError("Invalid UTF-8");
}
Expand All @@ -39,15 +42,18 @@ absl::StatusOr<std::vector<std::string>> Lexer::Tokenize(
if (pos > word_start) {
absl::string_view word_view(text.data() + word_start, pos - word_start);

std::string word = absl::AsciiStrToLower(word_view);
std::string original_word = absl::AsciiStrToLower(word_view);

if (Lexer::IsStopWord(word, stop_words_set)) {
if (Lexer::IsStopWord(original_word, stop_words_set)) {
continue; // Skip stop words
}

word = StemWord(word, stemmer, stemming_enabled, min_stem_size);
std::string stemmed_word = StemWord(original_word, stemmer, stemming_enabled, min_stem_size, stemmer_mutex);

tokens.push_back(std::move(word));
if (stemmed_words && original_word != stemmed_word) {
stemmed_words->push_back(stemmed_word);
}
tokens.push_back(std::move(original_word));
}
}

Expand All @@ -56,19 +62,36 @@ absl::StatusOr<std::vector<std::string>> Lexer::Tokenize(

std::string Lexer::StemWord(const std::string& word, sb_stemmer* stemmer,
bool stemming_enabled,
uint32_t min_stem_size) const {
uint32_t min_stem_size,
std::mutex* stemmer_mutex) const {
if (word.empty() || !stemming_enabled || word.length() < min_stem_size) {
return word;
}

CHECK(stemmer) << "Stemmer not initialized";
// If stemmer is not initialized, return the original word
if (!stemmer) {
return word;
}

// Lock the stemmer mutex to ensure thread-safe access (if provided)
std::unique_ptr<std::lock_guard<std::mutex>> lock;
if (stemmer_mutex) {
lock = std::make_unique<std::lock_guard<std::mutex>>(*stemmer_mutex);
}

const sb_symbol* stemmed = sb_stemmer_stem(
stemmer, reinterpret_cast<const sb_symbol*>(word.c_str()), word.length());

DCHECK(stemmed) << "Stemming failed for word: " + word;
// If stemming fails (e.g., for non-English words), return the original word
if (!stemmed) {
return word;
}

int stemmed_length = sb_stemmer_length(stemmer);
if (stemmed_length <= 0) {
return word;
}

return std::string(reinterpret_cast<const char*>(stemmed), stemmed_length);
}

Expand Down
12 changes: 9 additions & 3 deletions src/indexes/text/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Tokenization Pipeline:
*/

#include <bitset>
#include <mutex>
#include <string>
#include <vector>

Expand All @@ -40,7 +41,9 @@ struct Lexer {
absl::StatusOr<std::vector<std::string>> Tokenize(
absl::string_view text, const std::bitset<256>& punct_bitmap,
sb_stemmer* stemmer, bool stemming_enabled, uint32_t min_stem_size,
const absl::flat_hash_set<std::string>& stop_words_set) const;
const absl::flat_hash_set<std::string>& stop_words_set,
std::vector<std::string>* stemmed_words = nullptr,
std::mutex* stemmer_mutex = nullptr) const;

// Punctuation checking API
static bool IsPunctuation(char c, const std::bitset<256>& punct_bitmap) {
Expand All @@ -54,9 +57,12 @@ struct Lexer {
return stop_words_set.contains(lowercase_word);
}

private:
// Stemming API
std::string StemWord(const std::string& word, sb_stemmer* stemmer,
bool stemming_enabled, uint32_t min_stem_size) const;
bool stemming_enabled, uint32_t min_stem_size,
std::mutex* stemmer_mutex) const;

private:

// UTF-8 processing helpers
bool IsValidUtf8(absl::string_view text) const;
Expand Down
100 changes: 90 additions & 10 deletions src/indexes/text/term.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,127 @@
*/

#include "src/indexes/text/term.h"
#include "src/indexes/text/lexer.h"
#include "src/indexes/text/text_index.h"

namespace valkey_search::indexes::text {

TermIterator::TermIterator(const WordIterator& word,
const absl::string_view data,
const FieldMaskPredicate field_mask,
const InternedStringSet* untracked_keys)
const InternedStringSet* untracked_keys,
bool stemming_enabled,
std::shared_ptr<TextIndexSchema> text_index_schema,
uint32_t min_stem_size,
std::shared_ptr<TextIndex> text_index)
: word_(word),
data_(data),
field_mask_(field_mask),
untracked_keys_(untracked_keys) {}
untracked_keys_(untracked_keys),
stemming_enabled_(stemming_enabled),
text_index_schema_(text_index_schema),
min_stem_size_(min_stem_size),
text_index_(text_index) {}

bool TermIterator::Done() const {
if (nomatch_ || word_.GetWord() != data_) {
if (nomatch_) {
return true;
}

if (stemming_enabled_) {
if (key_iter_.IsValid()) {
return false;
}
if (stem_target_ && stem_word_iter_ != stem_target_->end()) {
return false;
}
if (word_.GetWord() != data_) {
return false;
}
return true;
}

if (word_.GetWord() != data_) {
return true;
}
// Check if key iterator is valid
return !key_iter_.IsValid();
}

void TermIterator::Next() {
// On a Begin() call, we initialize the target_posting_ and key_iter_.
if (begin_) {
if (stemming_enabled_ && text_index_schema_ && text_index_) {
Lexer lexer;
std::string search_term = std::string(data_);
std::string stemmed_term = lexer.StemWord(
search_term, text_index_schema_->GetStemmer(), true, min_stem_size_,
text_index_schema_->GetStemmerMutex());

auto stem_iter = text_index_->stem_.GetWordIterator(stemmed_term);
if (!stem_iter.Done() && stem_iter.GetWord() == stemmed_term) {
stem_target_ = stem_iter.GetTarget();
if (stem_target_ && !stem_target_->empty()) {
stem_word_iter_ = stem_target_->begin();
word_ = text_index_->prefix_.GetWordIterator(*stem_word_iter_);
}
}
}

if (word_.Done()) {
nomatch_ = true;
return;
}
target_posting_ = word_.GetTarget();
key_iter_ = target_posting_->GetKeyIterator();
begin_ = false; // Set to false after the first call to Next.
} else {
begin_ = false;
} else if (key_iter_.IsValid()) {
key_iter_.NextKey();
}
// Advance until we find a valid key or reach the end
while (!Done() && !key_iter_.ContainsFields(field_mask_)) {

if (stemming_enabled_) {
while (stem_target_ && stem_word_iter_ != stem_target_->end()) {
Copy link
Member

@KarthikSubbarao KarthikSubbarao Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For stemming, it would be good to consider a separate word iterator. This code should be re-used through the GetStemmedWordIterator and not embedded in each search operation

Also, as discussed, we only will need to use this StemmedWordIterator during the initialization of the search operations once the open PR is merged in

while (key_iter_.IsValid() && !key_iter_.ContainsFields(field_mask_)) {
key_iter_.NextKey();
}
if (key_iter_.IsValid()) {
return;
}

++stem_word_iter_;
if (stem_word_iter_ != stem_target_->end()) {
word_ = text_index_->prefix_.GetWordIterator(*stem_word_iter_);
if (!word_.Done()) {
target_posting_ = word_.GetTarget();
if (target_posting_) {
key_iter_ = target_posting_->GetKeyIterator();
}
}
}
}

if ((!stem_target_ || stem_word_iter_ == stem_target_->end()) &&
(word_.GetWord() != data_ || word_.Done())) {
word_ = text_index_->prefix_.GetWordIterator(data_);
if (!word_.Done() && word_.GetWord() == data_) {
target_posting_ = word_.GetTarget();
if (target_posting_) {
key_iter_ = target_posting_->GetKeyIterator();
while (key_iter_.IsValid() && !key_iter_.ContainsFields(field_mask_)) {
key_iter_.NextKey();
}
if (key_iter_.IsValid()) {
return;
}
}
}
}
}

while (!Done() && key_iter_.IsValid() && !key_iter_.ContainsFields(field_mask_)) {
key_iter_.NextKey();
}
}

const InternedStringPtr& TermIterator::operator*() const {
// Return the current key from the key iterator of the posting object.
return key_iter_.GetKey();
}

Expand Down
15 changes: 14 additions & 1 deletion src/indexes/text/term.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "src/indexes/index_base.h"
#include "src/indexes/text/posting.h"
#include "src/indexes/text/radix_tree.h"
#include "src/indexes/text/text_index.h"
#include "src/utils/string_interning.h"

namespace valkey_search::indexes::text {
Expand All @@ -31,7 +32,11 @@ class TermIterator : public indexes::EntriesFetcherIteratorBase {
public:
TermIterator(const WordIterator& word, const absl::string_view data,
const FieldMaskPredicate field_mask,
const InternedStringSet* untracked_keys = nullptr);
const InternedStringSet* untracked_keys = nullptr,
bool stemming_enabled = false,
std::shared_ptr<TextIndexSchema> text_index_schema = nullptr,
uint32_t min_stem_size = 0,
std::shared_ptr<TextIndex> text_index = nullptr);

bool Done() const override;
void Next() override;
Expand All @@ -49,6 +54,14 @@ class TermIterator : public indexes::EntriesFetcherIteratorBase {
const InternedStringSet* untracked_keys_;
InternedStringPtr current_key_;
FieldMaskPredicate field_mask_;

// Stemming support
bool stemming_enabled_ = false;
std::shared_ptr<TextIndexSchema> text_index_schema_;
uint32_t min_stem_size_ = 0;
std::shared_ptr<TextIndex> text_index_;
std::shared_ptr<StemTarget> stem_target_;
StemTarget::iterator stem_word_iter_;
};

} // namespace valkey_search::indexes::text
Expand Down
Loading
Loading