Skip to content

Commit 4657e98

Browse files
authored
Merge pull request #735 from martindevans/less_sampler_allocations
Less Sampler Allocations
2 parents 3582d82 + b80f043 commit 4657e98

File tree

5 files changed

+118
-75
lines changed

5 files changed

+118
-75
lines changed

LLama.Examples/Examples/BatchedExecutorGuidance.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LLama.Batched;
1+
using LLama.Batched;
22
using LLama.Common;
33
using LLama.Native;
44
using LLama.Sampling;
@@ -105,18 +105,19 @@ public override ISamplingPipeline Clone()
105105

106106
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
107107
{
108-
if (guidance == null)
109-
return;
110-
111-
// Get the logits generated by the guidance sequences
112-
var guidanceLogits = guidance.Sample();
113-
114-
// Use those logits to guide this sequence
115-
NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
116108
}
117109

118110
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
119111
{
112+
if (guidance != null)
113+
{
114+
// Get the logits generated by the guidance sequences
115+
var guidanceLogits = guidance.Sample();
116+
117+
// Modify these logits based on the guidance logits
118+
candidates.Guidance(ctx, guidanceLogits, weight);
119+
}
120+
120121
candidates.Temperature(ctx, 0.8f);
121122
candidates.TopK(ctx, 25);
122123

LLama/LLamaContext.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ public uint BatchThreads
8989
/// Get the maximum batch size for this context
9090
/// </summary>
9191
public uint BatchSize => NativeHandle.BatchSize;
92+
93+
private LLamaTokenData[]? _samplingBuffer;
9294

9395
/// <summary>
9496
/// Create a new LLamaContext for the given LLamaWeights
@@ -496,7 +498,9 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> la
496498
var nl_logit = logits[(int?)nl_token ?? 0];
497499

498500
// Convert logits into token candidates
499-
var candidates_p = LLamaTokenDataArray.Create(logits);
501+
if (_samplingBuffer == null || _samplingBuffer.Length < logits.Length)
502+
_samplingBuffer = new LLamaTokenData[logits.Length];
503+
var candidates_p = LLamaTokenDataArray.Create(logits, _samplingBuffer);
500504

501505
// Extract most recently returned tokens
502506
var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount);
@@ -508,14 +512,14 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> la
508512
// Restore newline token logit value if necessary
509513
if (!penalizeNL && nl_token.HasValue)
510514
{
511-
var candidatesSpan = candidates_p.data.Span;
512-
for (var i = 0; i < candidates_p.data.Length; i++)
515+
var candidatesSpan = candidates_p.Data.Span;
516+
for (var i = 0; i < candidates_p.Data.Length; i++)
513517
{
514518
ref var item = ref candidatesSpan[i];
515519
if (item.id == nl_token)
516520
item.logit = nl_logit;
517521
}
518-
candidates_p.sorted = false;
522+
candidates_p.Sorted = false;
519523
}
520524

521525
return candidates_p;

0 commit comments

Comments
 (0)