22// Licensed under the MIT License.
33
44using System ;
5- using System . Buffers ;
65using System . Collections . Generic ;
76using System . Runtime . CompilerServices ;
7+ using System . Text ;
88using System . Threading ;
99using System . Threading . Tasks ;
1010using Microsoft . Extensions . AI ;
@@ -16,8 +16,8 @@ namespace Microsoft.ML.OnnxRuntimeGenAI;
1616/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
1717public sealed class OnnxRuntimeGenAIChatClient : IChatClient
1818{
19- /// <summary>The options used to configure the instance.</summary>
20- private readonly OnnxRuntimeGenAIChatClientOptions _options ;
19+ /// <summary>Options used to configure the instance's behavior .</summary>
20+ private readonly OnnxRuntimeGenAIChatClientOptions ? _options ;
2121 /// <summary>The wrapped <see cref="Model"/>.</summary>
2222 private readonly Model _model ;
2323 /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
@@ -32,58 +32,44 @@ public sealed class OnnxRuntimeGenAIChatClient : IChatClient
3232 private CachedGenerator ? _cachedGenerator ;
3333
3434 /// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
35- /// <param name="options">Options used to configure the client instance.</param>
3635 /// <param name="modelPath">The file path to the model to load.</param>
37- /// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/> .</exception >
36+ /// <param name="options">Options used to configure the client instance .</param >
3837 /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is <see langword="null"/>.</exception>
39- public OnnxRuntimeGenAIChatClient ( OnnxRuntimeGenAIChatClientOptions options , string modelPath )
38+ public OnnxRuntimeGenAIChatClient ( string modelPath , OnnxRuntimeGenAIChatClientOptions ? options = null )
4039 {
41- if ( options is null )
42- {
43- throw new ArgumentNullException ( nameof ( options ) ) ;
44- }
45-
4640 if ( modelPath is null )
4741 {
4842 throw new ArgumentNullException ( nameof ( modelPath ) ) ;
4943 }
5044
51- _options = options ;
52-
5345 _ownsModel = true ;
5446 _model = new Model ( modelPath ) ;
5547 _tokenizer = new Tokenizer ( _model ) ;
48+ _options = options ;
5649
5750 _metadata = new ( "onnx" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
5851 }
5952
6053 /// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
61- /// <param name="options">Options used to configure the client instance.</param>
6254 /// <param name="model">The model to employ.</param>
6355 /// <param name="ownsModel">
6456 /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
6557 /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
6658 /// The default is <see langword="true"/>.
6759 /// </param>
68- /// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/> .</exception >
60+ /// <param name="options">Options used to configure the client instance .</param >
6961 /// <exception cref="ArgumentNullException"><paramref name="model"/> is <see langword="null"/>.</exception>
70- public OnnxRuntimeGenAIChatClient ( OnnxRuntimeGenAIChatClientOptions options , Model model , bool ownsModel = true )
62+ public OnnxRuntimeGenAIChatClient ( Model model , bool ownsModel = true , OnnxRuntimeGenAIChatClientOptions ? options = null )
7163 {
72- if ( options is null )
73- {
74- throw new ArgumentNullException ( nameof ( options ) ) ;
75- }
76-
7764 if ( model is null )
7865 {
7966 throw new ArgumentNullException ( nameof ( model ) ) ;
8067 }
8168
82- _options = options ;
83-
8469 _ownsModel = ownsModel ;
8570 _model = model ;
8671 _tokenizer = new Tokenizer ( _model ) ;
72+ _options = options ;
8773
8874 _metadata = new ( "onnx" ) ;
8975 }
@@ -106,16 +92,16 @@ public void Dispose()
10692
10793 /// <inheritdoc/>
10894 public Task < ChatResponse > GetResponseAsync (
109- IList < ChatMessage > chatMessages , ChatOptions ? options = null , CancellationToken cancellationToken = default ) =>
110- GetStreamingResponseAsync ( chatMessages , options , cancellationToken ) . ToChatResponseAsync ( cancellationToken : cancellationToken ) ;
95+ IEnumerable < ChatMessage > messages , ChatOptions ? options = null , CancellationToken cancellationToken = default ) =>
96+ GetStreamingResponseAsync ( messages , options , cancellationToken ) . ToChatResponseAsync ( cancellationToken : cancellationToken ) ;
11197
11298 /// <inheritdoc/>
11399 public async IAsyncEnumerable < ChatResponseUpdate > GetStreamingResponseAsync (
114- IList < ChatMessage > chatMessages , ChatOptions ? options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
100+ IEnumerable < ChatMessage > messages , ChatOptions ? options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
115101 {
116- if ( chatMessages is null )
102+ if ( messages is null )
117103 {
118- throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
104+ throw new ArgumentNullException ( nameof ( messages ) ) ;
119105 }
120106
121107 // Check to see whether there's a cached generator. If there is, and if its id matches what we got from the client,
@@ -133,12 +119,18 @@ generator.ChatThreadId is null ||
133119 }
134120
135121 // If caching is enabled, generate a new ID to represent the state of the generator when we finish this response.
136- generator . ChatThreadId = _options . EnableCaching ? Guid . NewGuid ( ) . ToString ( "N" ) : null ;
122+ generator . ChatThreadId = _options ? . EnableCaching is true ? Guid . NewGuid ( ) . ToString ( "N" ) : null ;
137123
138124 // Format and tokenize the message.
139- using Sequences tokens = _tokenizer . Encode ( _options . PromptFormatter ( chatMessages , options ) ) ;
125+ string formattedPrompt = _options ? . PromptFormatter is { } formatter ?
126+ formatter ( messages , options ) :
127+ FormatPromptDefault ( messages , options ) ;
128+
129+ using Sequences tokens = _tokenizer . Encode ( formattedPrompt ) ;
140130 try
141131 {
132+ string responseId = Guid . NewGuid ( ) . ToString ( "N" ) ;
133+
142134 generator . Generator . AppendTokenSequences ( tokens ) ;
143135 int inputTokens = tokens [ 0 ] . Length , outputTokens = 0 ;
144136
@@ -172,11 +164,10 @@ generator.ChatThreadId is null ||
172164
173165 // Yield the next token in the stream.
174166 outputTokens ++ ;
175- yield return new ( )
167+ yield return new ( ChatRole . Assistant , next )
176168 {
177169 CreatedAt = DateTimeOffset . UtcNow ,
178- Role = ChatRole . Assistant ,
179- Text = next ,
170+ ResponseId = responseId ,
180171 } ;
181172 }
182173
@@ -193,7 +184,7 @@ generator.ChatThreadId is null ||
193184 CreatedAt = DateTimeOffset . UtcNow ,
194185 FinishReason = options is not null && options . MaxOutputTokens <= outputTokens ? ChatFinishReason . Length : ChatFinishReason . Stop ,
195186 ModelId = _metadata . ModelId ,
196- ResponseId = Guid . NewGuid ( ) . ToString ( ) ,
187+ ResponseId = responseId ,
197188 Role = ChatRole . Assistant ,
198189 } ;
199190 }
@@ -228,7 +219,19 @@ generator.ChatThreadId is null ||
228219 /// <summary>Gets whether the specified token is a stop sequence.</summary>
229220 private bool IsStop ( string token , ChatOptions ? options ) =>
230221 options ? . StopSequences ? . Contains ( token ) is true ||
231- _options . StopSequences . Contains ( token ) ;
222+ _options ? . StopSequences ? . Contains ( token ) is true ;
223+
224+ /// <summary>Formats messages into a prompt using a default format.</summary>
225+ private static string FormatPromptDefault ( IEnumerable < ChatMessage > messages , ChatOptions ? options )
226+ {
227+ StringBuilder sb = new ( ) ;
228+ foreach ( var message in messages )
229+ {
230+ sb . Append ( message ) . AppendLine ( ) ;
231+ }
232+
233+ return sb . ToString ( ) ;
234+ }
232235
233236 /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
234237 private static void UpdateGeneratorParamsFromOptions ( GeneratorParams generatorParams , ChatOptions ? options )
0 commit comments