Skip to content

Commit 02580c6

Browse files
sayanshaw24Sayan Shaw
andauthored
Add Support For Tokenizer Options (#1785)
### Updates This PR introduces support for ORT Extensions introduced in [this PR](microsoft/onnxruntime-extensions#998), which allows passing an **options map** with `OrtxCreateTokenizerWithOptions` when creating tokenizers, or using a new `OrtxUpdateTokenizerOptions` method which allows updating the options map on an existing tokenizer object (including those created using `OrtxCreateTokenizer`), enabling more flexible configurations. It additionally removes the previously added `OrtxTokenizeWithOptions` and `OrtxDetokenize1DWithOptions` functions, which are now redundant. With the new design, **options are set once on the tokenizer object itself**, so there’s no longer a need to pass ad-hoc option sets into individual tokenize/detokenize calls — reducing API clutter and simplifying the C interface. In additions to the C API updates, it also adds bindings for C++, C# and Python. ### Sample Usage C++ ``` auto tokenizer = OgaTokenizer::Create(*model); // Define tokenizer options as C-style arrays const char* keys[] = {"add_bos_token", "trim_offsets"}; const char* values[] = {"true", "false"}; // Update tokenizer options tokenizer->UpdateOptions(keys, values, 2); ``` C# ``` var tokenizer = new Tokenizer(model); // Update tokenizer options using a dictionary var options = new Dictionary<string, string> { { "add_bos_token", "true" }, { "trim_offsets", "false" } }; tokenizer.UpdateOptions(options); ``` Python ``` tokenizer = Tokenizer(model) options = { "add_bos_token": "true", "trim_offsets": "false" } tokenizer.update_options(**options) ``` --------- Co-authored-by: Sayan Shaw <[email protected]>
1 parent 246df2e commit 02580c6

File tree

10 files changed

+188
-7
lines changed

10 files changed

+188
-7
lines changed

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;bd8fb6d86e98c17e397c42fc001913cc2e035597
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;9790faf2838d72cb229475cd2b5edc6fc779b5aa
1818

1919
# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
2020
llguidance;https://github.com/microsoft/llguidance.git;2d2f1de3c87e3289528affc346f734f7471216d9

src/csharp/NativeMethods.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
217217
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
218218
public static extern void OgaDestroyTokenizer(IntPtr /* OgaTokenizer* */ tokenizer);
219219

220+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
221+
public static extern IntPtr /* OgaResult* */ OgaUpdateTokenizerOptions(
222+
IntPtr /* const OgaTokenizer* */ tokenizer,
223+
string[] /* const char*[] */ keys,
224+
string[] /* const char*[] */ values,
225+
UIntPtr /* size_t */ numOptions);
226+
220227
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
221228
public static extern IntPtr /* OgaResult* */ OgaTokenizerEncode(IntPtr /* const OgaTokenizer* */ tokenizer,
222229
byte[] /* const char* */ strings,

src/csharp/Tokenizer.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Generic;
56

67
namespace Microsoft.ML.OnnxRuntimeGenAI
78
{
@@ -45,6 +46,31 @@ public string[] DecodeBatch(Sequences sequences)
4546
return result;
4647
}
4748

49+
public void UpdateOptions(Dictionary<string, string> options)
50+
{
51+
if (options == null)
52+
throw new ArgumentNullException(nameof(options));
53+
54+
// Prepare native arrays
55+
string[] keys = new string[options.Count];
56+
string[] values = new string[options.Count];
57+
int i = 0;
58+
foreach (var kvp in options)
59+
{
60+
keys[i] = kvp.Key;
61+
values[i] = kvp.Value;
62+
i++;
63+
}
64+
65+
// Call native function
66+
Result.VerifySuccess(
67+
NativeMethods.OgaUpdateTokenizerOptions(
68+
_tokenizerHandle,
69+
keys,
70+
values,
71+
(UIntPtr)options.Count));
72+
}
73+
4874
public Sequences Encode(string str)
4975
{
5076
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));

src/models/model.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,25 @@ const std::string& TokenizerStream::Decode(int32_t token) {
259259
}
260260

261261
Tokenizer::Tokenizer(Config& config) : pad_token_id_{config.model.pad_token_id} {
262-
CheckResult(OrtxCreateTokenizer(tokenizer_.Address(), config.config_path.string().c_str()));
262+
// Default tokenizer options
263+
const char* keys[] = {"add_special_tokens", "skip_special_tokens"};
264+
const char* values[] = {"false", "true"};
265+
266+
CheckResult(OrtxCreateTokenizerWithOptions(tokenizer_.Address(), config.config_path.string().c_str(), keys, values, 2));
263267
}
264268

265269
std::unique_ptr<TokenizerStream> Tokenizer::CreateStream() const {
266270
return std::make_unique<TokenizerStream>(*this);
267271
}
268272

273+
void Tokenizer::UpdateOptions(const char* const* keys, const char* const* values, size_t num_options) {
274+
// Tap into ORT Extensions API
275+
CheckResult(OrtxUpdateTokenizerOptions(tokenizer_, const_cast<const char**>(keys), const_cast<const char**>(values), num_options));
276+
}
277+
269278
std::vector<int32_t> Tokenizer::Encode(const char* text) const {
270279
OrtxPtr<OrtxTokenId2DArray> ids;
271-
CheckResult(OrtxTokenizeWithOptions(tokenizer_, &text, 1, ids.Address(), false /* add_special_tokens */));
280+
CheckResult(OrtxTokenize(tokenizer_, &text, 1, ids.Address()));
272281

273282
const extTokenId_t* tokens;
274283
size_t count;
@@ -278,7 +287,7 @@ std::vector<int32_t> Tokenizer::Encode(const char* text) const {
278287

279288
std::string Tokenizer::Decode(std::span<const int32_t> tokens) const {
280289
OrtxPtr<OrtxStringArray> ortx_string_array;
281-
CheckResult(OrtxDetokenize1DWithOptions(tokenizer_, reinterpret_cast<const uint32_t*>(tokens.data()), tokens.size(), ortx_string_array.Address(), true /* skip_special_tokens */));
290+
CheckResult(OrtxDetokenize1D(tokenizer_, reinterpret_cast<const uint32_t*>(tokens.data()), tokens.size(), ortx_string_array.Address()));
282291

283292
const char* string;
284293
CheckResult(OrtxStringArrayGetItem(ortx_string_array, 0, &string));

src/models/model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct Tokenizer : std::enable_shared_from_this<Tokenizer>, LeakChecked<Tokenize
8585

8686
std::unique_ptr<TokenizerStream> CreateStream() const;
8787

88+
void UpdateOptions(const char* const* keys, const char* const* values, size_t num_options);
8889
std::vector<int32_t> Encode(const char* text) const;
8990
std::string Decode(std::span<const int32_t> tokens) const;
9091
std::string ApplyChatTemplate(const char* template_str, const char* messages, const char* tools, bool add_generation_prompt) const;

src/ort_genai.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ struct OgaTokenizer : OgaAbstract {
305305
return std::unique_ptr<OgaTokenizer>(p);
306306
}
307307

308+
void UpdateOptions(const char* const* keys, const char* const* values, size_t num_options) {
309+
OgaCheckResult(OgaUpdateTokenizerOptions(this, keys, values, num_options));
310+
}
311+
308312
void Encode(const char* str, OgaSequences& sequences) const {
309313
OgaCheckResult(OgaTokenizerEncode(this, str, &sequences));
310314
}

src/ort_genai_c.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,23 @@ OgaResult* OGA_API_CALL OgaCreateTokenizer(const OgaModel* model, OgaTokenizer**
581581
OGA_CATCH
582582
}
583583

584+
OgaResult* OGA_API_CALL OgaUpdateTokenizerOptions(
585+
OgaTokenizer* tokenizer,
586+
const char* const* keys,
587+
const char* const* values,
588+
size_t num_options) {
589+
OGA_TRY
590+
591+
if (!tokenizer)
592+
throw std::runtime_error("Tokenizer pointer is null");
593+
594+
tokenizer->UpdateOptions(keys, values, num_options);
595+
596+
return nullptr;
597+
598+
OGA_CATCH
599+
}
600+
584601
OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* tokenizer, const char* str, OgaSequences* sequences) {
585602
OGA_TRY
586603
sequences->emplace_back(tokenizer->Encode(str));

src/ort_genai_c.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,43 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateMultiModalProcessor(const OgaModel*
572572

573573
OGA_EXPORT void OGA_API_CALL OgaDestroyMultiModalProcessor(OgaMultiModalProcessor* processor);
574574

575+
/**
576+
* Updates tokenizer options for the given OgaTokenizer instance.
577+
* The provided keys and values must be null-terminated UTF-8 strings.
578+
*
579+
* This function allows updating tokenizer behavior at runtime by passing
580+
* key/value string pairs. Each key corresponds to a configurable tokenizer
581+
* option. Both keys and values must remain valid for the duration of this call.
582+
*
583+
* @param tokenizer Pointer to the OgaTokenizer whose options will be updated.
584+
* @param keys Array of option key strings.
585+
* @param values Array of corresponding option value strings (same length as keys).
586+
* @param num_options Number of key/value pairs provided.
587+
*
588+
* @return nullptr on success, or an OgaResult* describing the error.
589+
* The returned OgaResult* (if not null) must be freed with OgaDestroyResult.
590+
*
591+
* Supported options:
592+
*
593+
* - `add_special_tokens`
594+
* - Purpose: Controls whether to add special tokens (e.g., BOS/EOS) during tokenization.
595+
* - Values: `"true"` / `"false"` or `"1"` / `"0"`.
596+
* - Default: `"false"`. This is the default value set by ORT GenAI prior to any options updating.
597+
*
598+
* - `skip_special_tokens`
599+
* - Purpose: Controls whether to remove special tokens during detokenization.
600+
* - Values: `"true"` / `"false"` or `"1"` / `"0"`.
601+
* - Default: `"true"`. This is the default value set by ORT GenAI prior to any options updating.
602+
*
603+
* Future tokenizer options may be added without changing this API signature.
604+
* Passing unknown keys will result in an error.
605+
*/
606+
OGA_EXPORT OgaResult* OGA_API_CALL OgaUpdateTokenizerOptions(
607+
OgaTokenizer* tokenizer,
608+
const char* const* keys,
609+
const char* const* values,
610+
size_t num_options);
611+
575612
/**
576613
* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences
577614
* when it is no longer needed.

src/python/python.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,29 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
361361

362362
pybind11::class_<OgaTokenizer>(m, "Tokenizer")
363363
.def(pybind11::init([](const OgaModel& model) { return OgaTokenizer::Create(model); }))
364+
.def("update_options", [](OgaTokenizer& t, pybind11::kwargs kwargs) {
365+
std::vector<std::string> key_storage;
366+
std::vector<std::string> value_storage;
367+
key_storage.reserve(kwargs.size());
368+
value_storage.reserve(kwargs.size());
369+
370+
std::vector<const char*> keys;
371+
std::vector<const char*> values;
372+
keys.reserve(kwargs.size());
373+
values.reserve(kwargs.size());
374+
375+
for (auto& item : kwargs) {
376+
key_storage.emplace_back(pybind11::str(item.first));
377+
value_storage.emplace_back(pybind11::str(item.second));
378+
keys.push_back(key_storage.back().c_str());
379+
values.push_back(value_storage.back().c_str());
380+
}
381+
382+
t.UpdateOptions(keys.data(), values.data(), kwargs.size()); })
364383
.def("encode", [](const OgaTokenizer& t, std::string s) -> pybind11::array_t<int32_t> {
365384
auto sequences = OgaSequences::Create();
366385
t.Encode(s.c_str(), *sequences);
367-
return ToPython(sequences->Get(0));
368-
})
386+
return ToPython(sequences->Get(0)); })
369387
.def("to_token_id", &OgaTokenizer::ToTokenId)
370388
.def("decode", [](const OgaTokenizer& t, pybind11::array_t<int32_t> tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; })
371389
.def("apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true)

test/c_api_tests.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,68 @@ TEST(CAPITests, TokenizerCAPI) {
113113
#endif
114114
}
115115

116+
TEST(CAPITests, TokenizerUpdateOptions) {
117+
#if TEST_PHI2
118+
auto config = OgaConfig::Create(PHI2_PATH);
119+
auto model = OgaModel::Create(*config);
120+
auto tokenizer = OgaTokenizer::Create(*model);
121+
122+
// Update tokenizer options
123+
// Note: This simply tests the UpdateOptions API; these options are already set as default.
124+
{
125+
const char* keys[] = {"add_special_tokens", "skip_special_tokens"};
126+
const char* values[] = {"false", "true"};
127+
tokenizer->UpdateOptions(keys, values, 2);
128+
}
129+
130+
// Encode single decode single
131+
{
132+
const char* input_string = "She sells sea shells by the sea shore.";
133+
auto input_sequences = OgaSequences::Create();
134+
tokenizer->Encode(input_string, *input_sequences);
135+
136+
auto out_string = tokenizer->Decode(input_sequences->SequenceData(0), input_sequences->SequenceCount(0));
137+
ASSERT_STREQ(input_string, out_string);
138+
}
139+
140+
const char* input_strings[] = {
141+
"This is a test.",
142+
"Rats are awesome pets!",
143+
"The quick brown fox jumps over the lazy dog.",
144+
};
145+
146+
auto sequences = OgaSequences::Create();
147+
148+
// Encode all strings
149+
{
150+
for (auto& string : input_strings)
151+
tokenizer->Encode(string, *sequences);
152+
}
153+
154+
// Decode one at a time
155+
for (size_t i = 0; i < sequences->Count(); i++) {
156+
auto out_string = tokenizer->Decode(sequences->SequenceData(i), sequences->SequenceCount(i));
157+
std::cout << "Decoded string:" << out_string << std::endl;
158+
if (strcmp(input_strings[i], out_string) != 0)
159+
throw std::runtime_error("Token decoding mismatch");
160+
}
161+
162+
// Stream Decode one at a time
163+
for (size_t i = 0; i < sequences->Count(); i++) {
164+
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
165+
166+
auto* sequence = sequences->SequenceData(i);
167+
std::string stream_result;
168+
for (size_t j = 0; j < sequences->SequenceCount(i); j++) {
169+
stream_result += tokenizer_stream->Decode(sequence[j]);
170+
}
171+
std::cout << "Stream decoded string:" << stream_result << std::endl;
172+
if (strcmp(input_strings[i], stream_result.c_str()) != 0)
173+
throw std::runtime_error("Stream token decoding mismatch");
174+
}
175+
#endif
176+
}
177+
116178
TEST(CAPITests, ChatTemplate) {
117179
#if TEST_PHI2
118180
// We load the phi-2 model just to get a tokenizer (phi-2 does not have a chat template)
@@ -1281,4 +1343,4 @@ TEST(CAPITests, SetGuidance) {
12811343

12821344
#endif
12831345
}
1284-
#endif
1346+
#endif

0 commit comments

Comments
 (0)