Skip to content

Add Encode with Options for add_special_tokens=True use-case #1504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
byte[] /* const char* */ strings,
IntPtr /* OgaSequences* */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerEncodeWithOptions(IntPtr /* const OgaTokenizer* */ tokenizer,
byte[] /* const char* */ strings,
IntPtr /* OgaSequences* */ sequences,
bool /* bool */ add_special_tokens);

// This function is used to decode the given token into a string. The caller is responsible for freeing the
// returned string using the OgaDestroyString function when it is no longer needed.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
Expand Down
15 changes: 15 additions & 0 deletions src/csharp/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ public Sequences Encode(string str)
}
}

public Sequences EncodeWithOptions(string str, bool add_special_tokens)
{
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
Result.VerifySuccess(NativeMethods.OgaTokenizerEncodeWithOptions(_tokenizerHandle, StringUtils.ToUtf8(str), nativeSequences, add_special_tokens));
return new Sequences(nativeSequences);
}
catch
{
NativeMethods.OgaDestroySequences(nativeSequences);
throw;
}
}

public string Decode(ReadOnlySpan<int> sequence)
{
IntPtr outStr = IntPtr.Zero;
Expand Down
10 changes: 10 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ std::vector<int32_t> Tokenizer::Encode(const char* text) const {
return {tokens, tokens + count};
}

std::vector<int32_t> Tokenizer::EncodeWithOptions(const char* text, bool add_special_tokens) const {
OrtxPtr<OrtxTokenId2DArray> ids;
CheckResult(OrtxTokenizeWithOptions(tokenizer_, &text, 1, ids.Address(), add_special_tokens));

const extTokenId_t* tokens;
size_t count;
CheckResult(OrtxTokenId2DArrayGetItem(ids, 0, &tokens, &count));
return {tokens, tokens + count};
}

std::string Tokenizer::Decode(std::span<const int32_t> tokens) const {
OrtxPtr<OrtxStringArray> ortx_string_array;
CheckResult(OrtxDetokenize1D(tokenizer_, reinterpret_cast<const uint32_t*>(tokens.data()), tokens.size(), ortx_string_array.Address()));
Expand Down
1 change: 1 addition & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct Tokenizer : std::enable_shared_from_this<Tokenizer>, LeakChecked<Tokenize
std::unique_ptr<TokenizerStream> CreateStream() const;

std::vector<int32_t> Encode(const char* text) const;
std::vector<int32_t> EncodeWithOptions(const char* text, bool add_special_tokens) const;
std::string Decode(std::span<const int32_t> tokens) const;
std::string ApplyChatTemplate(const char* template_str, const char* messages, const char* tools, bool add_generation_prompt) const;

Expand Down
4 changes: 4 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ struct OgaTokenizer : OgaAbstract {
OgaCheckResult(OgaTokenizerEncode(this, str, &sequences));
}

void EncodeWithOptions(const char* str, OgaSequences& sequences, bool add_special_tokens) const {
OgaCheckResult(OgaTokenizerEncodeWithOptions(this, str, &sequences, add_special_tokens));
}

std::unique_ptr<OgaTensor> EncodeBatch(const char** strings, size_t count) const {
OgaTensor* out;
OgaCheckResult(OgaTokenizerEncodeBatch(this, strings, count, &out));
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* tokenizer, const
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerEncodeWithOptions(const OgaTokenizer* tokenizer, const char* str, OgaSequences* sequences, bool add_special_tokens) {
OGA_TRY
sequences->emplace_back(tokenizer->EncodeWithOptions(str, add_special_tokens));
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer* tokenizer, const char** strings, size_t count, OgaTensor** out) {
OGA_TRY
auto tensor = tokenizer->EncodeBatch(std::span<const char*>(strings, count));
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,13 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyMultiModalProcessor(OgaMultiModalProcesso
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

/**
* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences, with options such as whether to add or omit special tokens.
* The OgaSequences must be freed with OgaDestroySequences when it is no longer needed.
* Note: this method may be used to add more options in the future, to leverage OrtxTokenizeWithOptions from ORT Extensions.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeWithOptions(const OgaTokenizer*, const char* str, OgaSequences* sequences, bool add_special_tokens);

/**
* Batch encode an array of strings and return a single tensor output
*/
Expand Down
5 changes: 5 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
t.Encode(s.c_str(), *sequences);
return ToPython(sequences->Get(0));
})
.def("encode_with_options", [](const OgaTokenizer& t, std::string s, bool add_special_tokens) -> pybind11::array_t<int32_t> {
auto sequences = OgaSequences::Create();
t.EncodeWithOptions(s.c_str(), *sequences, add_special_tokens);
return ToPython(sequences->Get(0));
})
.def("to_token_id", &OgaTokenizer::ToTokenId)
.def("decode", [](const OgaTokenizer& t, pybind11::array_t<int32_t> tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; })
.def("apply_chat_template", [](const OgaTokenizer& t, const char* template_str, const char* messages, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("template_str") = nullptr, pybind11::arg("messages"), pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt"))
Expand Down
Loading