Skip to content

Commit 9470cb5

Browse files
authored
Update to M.E.AI 9.3.0-preview.1.25161.3 (#1317)
1 parent a6a279b commit 9470cb5

File tree

5 files changed

+52
-68
lines changed

5 files changed

+52
-68
lines changed

nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414
<tags>ONNX;ONNX Runtime;ONNX Runtime Gen AI;Machine Learning</tags>
1515
<dependencies>
1616
<group targetFramework="net8.0">
17-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25114.11" />
17+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25161.3" />
1818
</group>
1919
<group targetFramework="netstandard2.0">
20-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25114.11" />
20+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25161.3" />
2121
</group>
2222
<group targetFramework="net8.0-android31.0">
23-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25114.11" />
23+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25161.3" />
2424
</group>
2525
<group targetFramework="net8.0-ios15.4">
26-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25114.11" />
26+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25161.3" />
2727
</group>
2828
<group targetFramework="net8.0-maccatalyst14.0">
29-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25114.11" />
29+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.3.0-preview.1.25161.3" />
3030
</group>
3131
</dependencies>
3232
</metadata>

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
</ItemGroup>
123123

124124
<ItemGroup>
125-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25114.11" />
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25161.3" />
126126
</ItemGroup>
127127

128128
</Project>

src/csharp/OnnxRuntimeGenAIChatClient.cs

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
// Licensed under the MIT License.
33

44
using System;
5-
using System.Buffers;
65
using System.Collections.Generic;
76
using System.Runtime.CompilerServices;
7+
using System.Text;
88
using System.Threading;
99
using System.Threading.Tasks;
1010
using Microsoft.Extensions.AI;
@@ -16,8 +16,8 @@ namespace Microsoft.ML.OnnxRuntimeGenAI;
1616
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
1717
public sealed class OnnxRuntimeGenAIChatClient : IChatClient
1818
{
19-
/// <summary>The options used to configure the instance.</summary>
20-
private readonly OnnxRuntimeGenAIChatClientOptions _options;
19+
/// <summary>Options used to configure the instance's behavior.</summary>
20+
private readonly OnnxRuntimeGenAIChatClientOptions? _options;
2121
/// <summary>The wrapped <see cref="Model"/>.</summary>
2222
private readonly Model _model;
2323
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
@@ -32,58 +32,44 @@ public sealed class OnnxRuntimeGenAIChatClient : IChatClient
3232
private CachedGenerator? _cachedGenerator;
3333

3434
/// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
35-
/// <param name="options">Options used to configure the client instance.</param>
3635
/// <param name="modelPath">The file path to the model to load.</param>
37-
/// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
36+
/// <param name="options">Options used to configure the client instance.</param>
3837
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is <see langword="null"/>.</exception>
39-
public OnnxRuntimeGenAIChatClient(OnnxRuntimeGenAIChatClientOptions options, string modelPath)
38+
public OnnxRuntimeGenAIChatClient(string modelPath, OnnxRuntimeGenAIChatClientOptions? options = null)
4039
{
41-
if (options is null)
42-
{
43-
throw new ArgumentNullException(nameof(options));
44-
}
45-
4640
if (modelPath is null)
4741
{
4842
throw new ArgumentNullException(nameof(modelPath));
4943
}
5044

51-
_options = options;
52-
5345
_ownsModel = true;
5446
_model = new Model(modelPath);
5547
_tokenizer = new Tokenizer(_model);
48+
_options = options;
5649

5750
_metadata = new("onnx", new Uri($"file://{modelPath}"), modelPath);
5851
}
5952

6053
/// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
61-
/// <param name="options">Options used to configure the client instance.</param>
6254
/// <param name="model">The model to employ.</param>
6355
/// <param name="ownsModel">
6456
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
6557
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
6658
/// The default is <see langword="true"/>.
6759
/// </param>
68-
/// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
60+
/// <param name="options">Options used to configure the client instance.</param>
6961
/// <exception cref="ArgumentNullException"><paramref name="model"/> is <see langword="null"/>.</exception>
70-
public OnnxRuntimeGenAIChatClient(OnnxRuntimeGenAIChatClientOptions options, Model model, bool ownsModel = true)
62+
public OnnxRuntimeGenAIChatClient(Model model, bool ownsModel = true, OnnxRuntimeGenAIChatClientOptions? options = null)
7163
{
72-
if (options is null)
73-
{
74-
throw new ArgumentNullException(nameof(options));
75-
}
76-
7764
if (model is null)
7865
{
7966
throw new ArgumentNullException(nameof(model));
8067
}
8168

82-
_options = options;
83-
8469
_ownsModel = ownsModel;
8570
_model = model;
8671
_tokenizer = new Tokenizer(_model);
72+
_options = options;
8773

8874
_metadata = new("onnx");
8975
}
@@ -106,16 +92,16 @@ public void Dispose()
10692

10793
/// <inheritdoc/>
10894
public Task<ChatResponse> GetResponseAsync(
109-
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
110-
GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(cancellationToken: cancellationToken);
95+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
96+
GetStreamingResponseAsync(messages, options, cancellationToken).ToChatResponseAsync(cancellationToken: cancellationToken);
11197

11298
/// <inheritdoc/>
11399
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
114-
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
100+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
115101
{
116-
if (chatMessages is null)
102+
if (messages is null)
117103
{
118-
throw new ArgumentNullException(nameof(chatMessages));
104+
throw new ArgumentNullException(nameof(messages));
119105
}
120106

121107
// Check to see whether there's a cached generator. If there is, and if its id matches what we got from the client,
@@ -133,12 +119,18 @@ generator.ChatThreadId is null ||
133119
}
134120

135121
// If caching is enabled, generate a new ID to represent the state of the generator when we finish this response.
136-
generator.ChatThreadId = _options.EnableCaching ? Guid.NewGuid().ToString("N") : null;
122+
generator.ChatThreadId = _options?.EnableCaching is true ? Guid.NewGuid().ToString("N") : null;
137123

138124
// Format and tokenize the message.
139-
using Sequences tokens = _tokenizer.Encode(_options.PromptFormatter(chatMessages, options));
125+
string formattedPrompt = _options?.PromptFormatter is { } formatter ?
126+
formatter(messages, options) :
127+
FormatPromptDefault(messages, options);
128+
129+
using Sequences tokens = _tokenizer.Encode(formattedPrompt);
140130
try
141131
{
132+
string responseId = Guid.NewGuid().ToString("N");
133+
142134
generator.Generator.AppendTokenSequences(tokens);
143135
int inputTokens = tokens[0].Length, outputTokens = 0;
144136

@@ -172,11 +164,10 @@ generator.ChatThreadId is null ||
172164

173165
// Yield the next token in the stream.
174166
outputTokens++;
175-
yield return new()
167+
yield return new(ChatRole.Assistant, next)
176168
{
177169
CreatedAt = DateTimeOffset.UtcNow,
178-
Role = ChatRole.Assistant,
179-
Text = next,
170+
ResponseId = responseId,
180171
};
181172
}
182173

@@ -193,7 +184,7 @@ generator.ChatThreadId is null ||
193184
CreatedAt = DateTimeOffset.UtcNow,
194185
FinishReason = options is not null && options.MaxOutputTokens <= outputTokens ? ChatFinishReason.Length : ChatFinishReason.Stop,
195186
ModelId = _metadata.ModelId,
196-
ResponseId = Guid.NewGuid().ToString(),
187+
ResponseId = responseId,
197188
Role = ChatRole.Assistant,
198189
};
199190
}
@@ -228,7 +219,19 @@ generator.ChatThreadId is null ||
228219
/// <summary>Gets whether the specified token is a stop sequence.</summary>
229220
private bool IsStop(string token, ChatOptions? options) =>
230221
options?.StopSequences?.Contains(token) is true ||
231-
_options.StopSequences.Contains(token);
222+
_options?.StopSequences?.Contains(token) is true;
223+
224+
/// <summary>Formats messages into a prompt using a default format.</summary>
225+
private static string FormatPromptDefault(IEnumerable<ChatMessage> messages, ChatOptions? options)
226+
{
227+
StringBuilder sb = new();
228+
foreach (var message in messages)
229+
{
230+
sb.Append(message).AppendLine();
231+
}
232+
233+
return sb.ToString();
234+
}
232235

233236
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
234237
private static void UpdateGeneratorParamsFromOptions(GeneratorParams generatorParams, ChatOptions? options)

src/csharp/OnnxRuntimeGenAIChatClientOptions.cs

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
using System;
55
using System.Collections.Generic;
6-
using System.Text;
76
using Microsoft.Extensions.AI;
87

98
#nullable enable
@@ -36,17 +35,6 @@ public sealed class OnnxRuntimeGenAIChatClientOptions
3635
{
3736
private IList<string> _stopSequences = [];
3837

39-
private Func<IEnumerable<ChatMessage>, ChatOptions?, string> _promptFormatter = static (messages, _) =>
40-
{
41-
StringBuilder sb = new();
42-
foreach (var message in messages)
43-
{
44-
sb.Append(message).AppendLine();
45-
}
46-
47-
return sb.ToString();
48-
};
49-
5038
/// <summary>Initializes a new instance of the <see cref="OnnxRuntimeGenAIChatClientOptions"/> class.</summary>
5139
/// <param name="stopSequences">The stop sequences used by the model.</param>
5240
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
@@ -62,25 +50,18 @@ public OnnxRuntimeGenAIChatClientOptions()
6250
/// <remarks>
6351
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>
6452
/// provided to the <see cref="IChatClient.GetResponseAsync"/> and <see cref="IChatClient.GetStreamingResponseAsync"/>
65-
/// methods.
53+
/// methods. If <see langword="null"/>, this will not contribute any additional stop sequences.
6654
/// </remarks>
67-
public IList<string> StopSequences
68-
{
69-
get => _stopSequences;
70-
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
71-
}
55+
public IList<string>? StopSequences { get; set; }
7256

7357
/// <summary>Gets or sets a delegate that formats a prompt string from a list of chat messages.</summary>
7458
/// <remarks>
7559
/// Each time <see cref="IChatClient.GetResponseAsync"/> or <see cref="IChatClient.GetStreamingResponseAsync"/>
7660
/// is invoked, this delegate will be invoked with the supplied list of messages to produce a string that
77-
/// will be tokenized and provided to the underlying <see cref="Generator"/>.
61+
/// will be tokenized and provided to the underlying <see cref="Generator"/>. If <see langword="null"/>,
62+
/// the <see cref="OnnxRuntimeGenAIChatClient"/> will choose a default prompt formatter to employ.
7863
/// </remarks>
79-
public Func<IEnumerable<ChatMessage>, ChatOptions?, string> PromptFormatter
80-
{
81-
get => _promptFormatter;
82-
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
83-
}
64+
public Func<IEnumerable<ChatMessage>, ChatOptions?, string>? PromptFormatter { get; set; }
8465

8566
/// <summary>Gets or sets whether to cache the most recent conversation.</summary>
8667
/// <remarks>

test/csharp/TestOnnxRuntimeGenAIAPI.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ public void TestTopKTopPSearch()
360360
[IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
361361
public async Task TestChatClient()
362362
{
363-
OnnxRuntimeGenAIChatClientOptions config = new()
363+
OnnxRuntimeGenAIChatClientOptions options = new()
364364
{
365365
StopSequences = ["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
366366
PromptFormatter = static (messages, options) =>
@@ -375,7 +375,7 @@ public async Task TestChatClient()
375375
},
376376
};
377377

378-
using var client = new OnnxRuntimeGenAIChatClient(config, _phi2Path);
378+
using var client = new OnnxRuntimeGenAIChatClient(_phi2Path, options);
379379

380380
var completion = await client.GetResponseAsync("The quick brown fox jumps over the lazy dog.", new()
381381
{
@@ -384,7 +384,7 @@ public async Task TestChatClient()
384384
StopSequences = ["."],
385385
});
386386

387-
Assert.NotEmpty(completion.Message.Text);
387+
Assert.NotEmpty(completion.Text);
388388
}
389389

390390
[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]

0 commit comments

Comments
 (0)