Skip to content

Commit 09344a1

Browse files
authored
Merge pull request #1190 from stephentoub/updatemeai
Update to stable Microsoft.Extensions.AI.Abstractions
2 parents 3ded20c + e89a936 commit 09344a1

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

LLama/Extensions/LLamaExecutorExtensions.cs

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ public async Task<ChatResponse> GetResponseAsync(
7171
text.Append(token);
7272
}
7373

74-
return new(new ChatMessage(ChatRole.Assistant, text.ToString()))
74+
var message = new ChatMessage(ChatRole.Assistant, text.ToString())
75+
{
76+
MessageId = Guid.NewGuid().ToString("N"),
77+
};
78+
79+
return new(message)
7580
{
7681
CreatedAt = DateTime.UtcNow,
7782
};
@@ -83,11 +88,13 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
8388
{
8489
var result = _executor.InferAsync(CreatePrompt(messages), CreateInferenceParams(options), cancellationToken);
8590

91+
string messageId = Guid.NewGuid().ToString("N");
8692
await foreach (var token in _outputTransform.TransformAsync(result))
8793
{
8894
yield return new(ChatRole.Assistant, token)
8995
{
9096
CreatedAt = DateTime.UtcNow,
97+
MessageId = messageId,
9198
};
9299
}
93100
}
@@ -124,37 +131,33 @@ private string CreatePrompt(IEnumerable<ChatMessage> messages)
124131
}
125132

126133
/// <summary>Convert the chat options to inference parameters.</summary>
127-
private static InferenceParams? CreateInferenceParams(ChatOptions? options)
134+
private InferenceParams CreateInferenceParams(ChatOptions? options)
128135
{
129-
List<string> antiPrompts = new(s_antiPrompts);
130-
if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList<string>? anti) is true)
131-
{
132-
antiPrompts.AddRange(anti);
133-
}
136+
InferenceParams ip = options?.RawRepresentationFactory?.Invoke(this) as InferenceParams ?? new();
134137

135-
return new()
138+
ip.AntiPrompts = [.. s_antiPrompts, .. ip.AntiPrompts];
139+
ip.MaxTokens = options?.MaxOutputTokens ?? 256; // arbitrary upper limit
140+
ip.SamplingPipeline = new DefaultSamplingPipeline()
136141
{
137-
AntiPrompts = antiPrompts,
138-
TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep,
139-
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
140-
SamplingPipeline = new DefaultSamplingPipeline()
141-
{
142-
FrequencyPenalty = options?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty,
143-
PresencePenalty = options?.PresencePenalty ?? s_defaultPipeline.PresencePenalty,
144-
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
145-
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
146-
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
147-
PenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.PenaltyCount,
148-
Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar,
149-
MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep,
150-
MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP,
151-
Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(),
152-
Temperature = options?.Temperature ?? s_defaultPipeline.Temperature,
153-
TopP = options?.TopP ?? s_defaultPipeline.TopP,
154-
TopK = options?.TopK ?? s_defaultPipeline.TopK,
155-
TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP,
156-
},
142+
FrequencyPenalty = options?.FrequencyPenalty ?? (ip.SamplingPipeline as DefaultSamplingPipeline)?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty,
143+
PresencePenalty = options?.PresencePenalty ?? (ip.SamplingPipeline as DefaultSamplingPipeline)?.PresencePenalty ?? s_defaultPipeline.PresencePenalty,
144+
PreventEOS = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PreventEOS ?? s_defaultPipeline.PreventEOS,
145+
PenalizeNewline = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PenalizeNewline ?? s_defaultPipeline.PenalizeNewline,
146+
RepeatPenalty = (ip.SamplingPipeline as DefaultSamplingPipeline)?.RepeatPenalty ?? s_defaultPipeline.RepeatPenalty,
147+
PenaltyCount = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PenaltyCount ?? s_defaultPipeline.PenaltyCount,
148+
Grammar = (ip.SamplingPipeline as DefaultSamplingPipeline)?.Grammar ?? s_defaultPipeline.Grammar,
149+
GrammarOptimization = (ip.SamplingPipeline as DefaultSamplingPipeline)?.GrammarOptimization ?? s_defaultPipeline.GrammarOptimization,
150+
LogitBias = (ip.SamplingPipeline as DefaultSamplingPipeline)?.LogitBias ?? s_defaultPipeline.LogitBias,
151+
MinKeep = (ip.SamplingPipeline as DefaultSamplingPipeline)?.MinKeep ?? s_defaultPipeline.MinKeep,
152+
MinP = (ip.SamplingPipeline as DefaultSamplingPipeline)?.MinP ?? s_defaultPipeline.MinP,
153+
Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(),
154+
Temperature = options?.Temperature ?? s_defaultPipeline.Temperature,
155+
TopP = options?.TopP ?? s_defaultPipeline.TopP,
156+
TopK = options?.TopK ?? s_defaultPipeline.TopK,
157+
TypicalP = (ip.SamplingPipeline as DefaultSamplingPipeline)?.TypicalP ?? s_defaultPipeline.TypicalP,
157158
};
159+
160+
return ip;
158161
}
159162

160163
/// <summary>A default transform that appends "Assistant: " to the end.</summary>

LLama/LLamaSharp.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
<ItemGroup>
5252
<PackageReference Include="CommunityToolkit.HighPerformance" Version="8.4.0" />
5353
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="9.0.3" />
54-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.5.0-preview.1.25262.9" />
54+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.5.0" />
5555
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.3" />
5656
<PackageReference Include="System.Numerics.Tensors" Version="9.0.3" />
5757
</ItemGroup>

0 commit comments

Comments
 (0)