Skip to content

Commit 5a9e13c

Browse files
authored
Merge pull request #223 from martindevans/batch_decoding
New Binaries, Improved Sampling API, Batch Decoding Prototype
2 parents f8b2c5d + db8f398 commit 5a9e13c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+814
-222
lines changed

LLama.Examples/LLama.Examples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
<Platforms>AnyCPU;x64</Platforms>
99
<!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults -->
1010
<IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes>
11+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
1112
</PropertyGroup>
1213

1314
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
using System.Diagnostics;
2+
using System.Security.Cryptography;
3+
using System.Text;
4+
using LLama.Common;
5+
using LLama.Native;
6+
7+
namespace LLama.Examples.NewVersion;
8+
9+
/// <summary>
10+
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
11+
/// </summary>
12+
/// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
13+
public class BatchedDecoding
14+
{
15+
private const int n_parallel = 8;
16+
private const int n_len = 32;
17+
18+
private const int top_k = 80;
19+
private const float top_p = 0.8f;
20+
private const float temp = 0.5f;
21+
22+
public static async Task Run()
23+
{
24+
Console.Write("Please input your model path: ");
25+
var modelPath = Console.ReadLine();
26+
27+
Console.WriteLine("Prompt (leave blank to select automatically):");
28+
var prompt = Console.ReadLine();
29+
if (string.IsNullOrWhiteSpace(prompt))
30+
prompt = "Not many people know that";
31+
32+
// Load model
33+
var parameters = new ModelParams(modelPath);
34+
using var model = LLamaWeights.LoadFromFile(parameters);
35+
36+
// Tokenize prompt
37+
var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
38+
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
39+
40+
// Create a context
41+
parameters.ContextSize = (uint)model.ContextSize;
42+
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
43+
using var context = model.CreateContext(parameters);
44+
45+
var n_ctx = context.ContextSize;
46+
47+
// make sure the KV cache is big enough to hold all the prompt and generated tokens
48+
if (n_kv_req > n_ctx)
49+
{
50+
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
51+
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
52+
return;
53+
}
54+
55+
using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
56+
57+
// evaluate the initial prompt
58+
for (var i = 0; i < prompt_tokens.Length; i++)
59+
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
60+
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);
61+
62+
// llama_decode will output logits only for the last token of the prompt
63+
unsafe
64+
{
65+
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
66+
}
67+
68+
if (context.NativeHandle.Decode(batch) != 0)
69+
{
70+
await Console.Error.WriteLineAsync("llama_decode failed");
71+
return;
72+
}
73+
74+
// assign the system KV cache to all parallel sequences
75+
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
76+
for (var i = 1; i < n_parallel; ++i)
77+
{
78+
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
79+
}
80+
81+
if (n_parallel > 1)
82+
{
83+
Console.WriteLine();
84+
Console.WriteLine($"generating {n_parallel} sequences...");
85+
}
86+
87+
// remember the batch index of the last token for each parallel sequence
88+
// we need this to determine which logits to sample from
89+
List<int> i_batch = new();
90+
for (var i = 0; i < n_parallel; i++)
91+
i_batch.Add(batch.NativeBatch.n_tokens - 1);
92+
93+
var n_cur = batch.NativeBatch.n_tokens;
94+
var n_decode = 0;
95+
96+
var streams = new List<int>[n_parallel];
97+
for (var i = 0; i < n_parallel; i++)
98+
streams[i] = new();
99+
100+
var eos = model.EndOfSentenceToken;
101+
var nl = model.NewlineToken;
102+
103+
var timer = new Stopwatch();
104+
timer.Start();
105+
while (n_cur <= n_len)
106+
{
107+
batch.LLamaBatchClear();
108+
109+
for (var i = 0; i < n_parallel; i++)
110+
{
111+
// Skip completed streams
112+
if (i_batch[i] < 0)
113+
continue;
114+
115+
var n_vocab = model.VocabCount;
116+
LLamaTokenDataArray candidates;
117+
unsafe
118+
{
119+
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
120+
}
121+
122+
candidates.TopK(context.NativeHandle, top_k);
123+
candidates.TopP(context.NativeHandle, top_p);
124+
candidates.Temperature(context.NativeHandle, temp);
125+
var new_token_id = candidates.SampleToken(context.NativeHandle);
126+
127+
if (new_token_id == eos || new_token_id == nl)
128+
{
129+
i_batch[i] = -1;
130+
Console.WriteLine($"Completed Stream {i} early");
131+
continue;
132+
}
133+
134+
streams[i].Add(new_token_id);
135+
136+
i_batch[i] = batch.NativeBatch.n_tokens;
137+
138+
// push this new token for next evaluation
139+
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
140+
141+
n_decode++;
142+
}
143+
144+
// all streams are finished
145+
if (batch.NativeBatch.n_tokens == 0)
146+
{
147+
break;
148+
}
149+
150+
n_cur++;
151+
152+
// evaluate the current batch with the transformer model
153+
if (context.NativeHandle.Decode(batch) != 0)
154+
{
155+
await Console.Error.WriteLineAsync("failed to eval");
156+
return;
157+
}
158+
}
159+
160+
timer.Stop();
161+
Console.ForegroundColor = ConsoleColor.Yellow;
162+
Console.WriteLine();
163+
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
164+
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
165+
166+
var index = 0;
167+
foreach (var stream in streams)
168+
{
169+
var text = context.DeTokenize(stream);
170+
171+
Console.ForegroundColor = ConsoleColor.Green;
172+
Console.Write($"{index++}. {prompt}");
173+
Console.ForegroundColor = ConsoleColor.Red;
174+
Console.WriteLine(text);
175+
}
176+
}
177+
}

LLama.Examples/NewVersion/SemanticKernelChat.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ public static async Task Run()
1414
var modelPath = Console.ReadLine();
1515

1616
// Load weights into memory
17-
var parameters = new ModelParams(modelPath)
18-
{
19-
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
20-
};
17+
var parameters = new ModelParams(modelPath);
2118
using var model = LLamaWeights.LoadFromFile(parameters);
2219
using var context = model.CreateContext(parameters);
2320
var ex = new InteractiveExecutor(context);

LLama.Examples/NewVersion/SemanticKernelPrompt.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ public static async Task Run()
1616
var modelPath = Console.ReadLine();
1717

1818
// Load weights into memory
19-
var parameters = new ModelParams(modelPath)
20-
{
21-
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
22-
};
19+
var parameters = new ModelParams(modelPath);
2320
using var model = LLamaWeights.LoadFromFile(parameters);
2421
var ex = new StatelessExecutor(model, parameters);
2522

LLama.Examples/NewVersion/TalkToYourself.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ public static async Task Run()
1313
var modelPath = Console.ReadLine();
1414

1515
// Load weights into memory
16-
var @params = new ModelParams(modelPath)
17-
{
18-
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
19-
};
16+
var @params = new ModelParams(modelPath);
2017
using var weights = LLamaWeights.LoadFromFile(@params);
2118

2219
// Create 2 contexts sharing the same weights

LLama.Examples/NewVersion/TestRunner.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public static async Task Run()
2222
Console.WriteLine("12: Semantic Kernel Chat.");
2323
Console.WriteLine("13: Semantic Kernel Memory.");
2424
Console.WriteLine("14: Coding Assistant.");
25+
Console.WriteLine("15: Batch Decoding.");
2526

2627
while (true)
2728
{
@@ -88,6 +89,10 @@ public static async Task Run()
8889
{
8990
await CodingAssistant.Run();
9091
}
92+
else if (choice == 15)
93+
{
94+
await BatchedDecoding.Run();
95+
}
9196
else
9297
{
9398
Console.WriteLine("Cannot parse your choice. Please select again.");

LLama.Unittest/StatelessExecutorTest.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Diagnostics;
12
using LLama.Common;
23
using Xunit.Abstractions;
34

@@ -34,10 +35,17 @@ public async Task Stateless()
3435
const string question = "Question. what is a cat?\nAnswer: ";
3536
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
3637

38+
var timer = new Stopwatch();
39+
timer.Start();
40+
3741
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
3842
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
3943

44+
timer.Stop();
45+
_testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
46+
4047
_testOutputHelper.WriteLine(result1);
48+
_testOutputHelper.WriteLine(result2);
4149

4250
// Check that it produced the exact same result both times
4351
Assert.Equal(result1, result2);

LLama.Web/Common/InferenceOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public class InferenceOptions : IInferenceParams
2323
/// <summary>
2424
/// Sequences where the model will stop generating further tokens.
2525
/// </summary>
26-
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
26+
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
2727
/// <summary>
2828
/// path to file for saving/loading model eval state
2929
/// </summary>

LLama.Web/Common/ModelOptions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ public class ModelOptions
111111
/// <summary>
112112
/// RoPE base frequency
113113
/// </summary>
114-
public float RopeFrequencyBase { get; set; } = 10000.0f;
114+
public float? RopeFrequencyBase { get; set; }
115115

116116
/// <summary>
117117
/// RoPE frequency scaling factor
118118
/// </summary>
119-
public float RopeFrequencyScale { get; set; } = 1.0f;
119+
public float? RopeFrequencyScale { get; set; }
120120

121121
/// <summary>
122122
/// Use experimental mul_mat_q kernels

LLama/Abstractions/IContextParams.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ public interface IContextParams
3939
bool EmbeddingMode { get; set; }
4040

4141
/// <summary>
42-
/// RoPE base frequency
42+
/// RoPE base frequency (null to fetch from the model)
4343
/// </summary>
44-
float RopeFrequencyBase { get; set; }
44+
float? RopeFrequencyBase { get; set; }
4545

4646
/// <summary>
47-
/// RoPE frequency scaling factor
47+
/// RoPE frequency scaling factor (null to fetch from the model)
4848
/// </summary>
49-
float RopeFrequencyScale { get; set; }
49+
float? RopeFrequencyScale { get; set; }
5050

5151
/// <summary>
5252
/// Use experimental mul_mat_q kernels

0 commit comments

Comments
 (0)