Skip to content

Commit fcb8b89

Browse files
committed
Adding simple batch example
1 parent b14fc7c commit fcb8b89

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public class ExampleRunner
2929
{ "Semantic Kernel: Prompt", SemanticKernelPrompt.Run },
3030
{ "Semantic Kernel: Chat", SemanticKernelChat.Run },
3131
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
32+
{ "Batched Executor: Simple", BatchedExecutorSimple.Run },
3233
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
3334
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
3435
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
using System.Diagnostics.CodeAnalysis;
2+
using System.Text;
3+
using LLama.Batched;
4+
using LLama.Common;
5+
using LLama.Native;
6+
using LLama.Sampling;
7+
using Spectre.Console;
8+
9+
namespace LLama.Examples.Examples;
10+
11+
/// <summary>
12+
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
13+
/// </summary>
14+
public class BatchedExecutorSimple
15+
{
16+
/// <summary>
17+
/// Set total length of the sequence to generate
18+
/// </summary>
19+
private const int TokenCount = 72;
20+
21+
public static async Task Run()
22+
{
23+
// Load model weights
24+
var parameters = new ModelParams(UserSettings.GetModelPath());
25+
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
26+
27+
// Create an executor that can evaluate a batch of conversations together
28+
using var executor = new BatchedExecutor(model, parameters);
29+
30+
// we'll need this for evaluating if we are at the end of generation
31+
var modelTokens = executor.Context.NativeHandle.ModelHandle.Tokens;
32+
33+
// Print some info
34+
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
35+
Console.WriteLine($"Created executor with model: {name}");
36+
37+
var messages = new[]
38+
{
39+
"What's 2+2?",
40+
"Where is the coldest part of Texas?",
41+
"What's the capital of France?",
42+
"What's a one word name for a food item with ground beef patties on a bun?",
43+
"What are two toppings for a pizza?",
44+
"What american football play are you calling on a 3rd and 8 from our own 25?",
45+
"What liquor should I add to egg nog?",
46+
"I have two sons, Bert and Ernie. What should I name my daughter?",
47+
"What day comes after Friday?",
48+
"What color shoes should I wear with dark blue pants?",
49+
};
50+
51+
var conversations = new List<ConversationData>();
52+
foreach (var message in messages)
53+
{
54+
// apply the model's prompt template to our question and system prompt
55+
var template = new LLamaTemplate(model);
56+
template.Add("system", "I am a helpful bot that returns short and concise answers. I include a ten word description of my reasoning when I finish.");
57+
template.Add("user", message);
58+
template.AddAssistant = true;
59+
var templatedMessage = Encoding.UTF8.GetString(template.Apply());
60+
61+
// create a new conversation and prompt it. include special and bos because we are using the template
62+
var conversation = executor.Create();
63+
conversation.Prompt(executor.Context.Tokenize(templatedMessage, addBos: true, special: true));
64+
65+
conversations.Add(new ConversationData {
66+
Prompt = message,
67+
Conversation = conversation,
68+
Sampler = new GreedySamplingPipeline(),
69+
Decoder = new StreamingTokenDecoder(executor.Context)
70+
});
71+
}
72+
73+
var table = BuildTable(conversations);
74+
await AnsiConsole.Live(table).StartAsync(async ctx =>
75+
{
76+
for (var i = 0; i < TokenCount; i++)
77+
{
78+
// Run inference for all conversations in the batch which have pending tokens.
79+
var decodeResult = await executor.Infer();
80+
if (decodeResult == DecodeResult.NoKvSlot)
81+
throw new Exception("Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context.");
82+
if (decodeResult == DecodeResult.Error)
83+
throw new Exception("Unknown error occurred while inferring.");
84+
85+
foreach (var conversationData in conversations.Where(c => c.IsComplete == false))
86+
{
87+
if (conversationData.Conversation.RequiresSampling == false) continue;
88+
89+
// sample a single token for the executor, passing the sample index of the conversation
90+
var token = conversationData.Sampler.Sample(
91+
executor.Context.NativeHandle,
92+
conversationData.Conversation.GetSampleIndex());
93+
94+
if (modelTokens.IsEndOfGeneration(token))
95+
{
96+
conversationData.MarkComplete();
97+
}
98+
else
99+
{
100+
// it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
101+
conversationData.Decoder.Add(token);
102+
conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));
103+
104+
// add the token to the conversation
105+
conversationData.Conversation.Prompt(token);
106+
}
107+
}
108+
109+
// render the current state
110+
table = BuildTable(conversations);
111+
ctx.UpdateTarget(table);
112+
113+
if (conversations.All(c => c.IsComplete))
114+
{
115+
break;
116+
}
117+
}
118+
119+
// if we ran out of tokens before completing just mark them as complete for rendering purposes.
120+
foreach (var data in conversations.Where(i => i.IsComplete == false))
121+
{
122+
data.MarkComplete();
123+
}
124+
125+
table = BuildTable(conversations);
126+
ctx.UpdateTarget(table);
127+
});
128+
}
129+
130+
/// <summary>
131+
/// Helper to build a table to display the conversations.
132+
/// </summary>
133+
private static Table BuildTable(List<ConversationData> conversations)
134+
{
135+
var table = new Table()
136+
.RoundedBorder()
137+
.AddColumns("Prompt", "Response");
138+
139+
foreach (var data in conversations)
140+
{
141+
table.AddRow(data.Prompt.EscapeMarkup(), data.AnswerMarkdown);
142+
}
143+
144+
return table;
145+
}
146+
}
147+
148+
public class ConversationData
149+
{
150+
public required string Prompt { get; init; }
151+
public required Conversation Conversation { get; init; }
152+
public required BaseSamplingPipeline Sampler { get; init; }
153+
public required StreamingTokenDecoder Decoder { get; init; }
154+
155+
public string AnswerMarkdown => IsComplete
156+
? $"[green]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"
157+
: $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]";
158+
159+
public bool IsComplete { get; private set; }
160+
161+
// we are only keeping track of the answer in two parts to render them differently.
162+
private (string Message, string LatestToken) _inProgressAnswer = (string.Empty, string.Empty);
163+
164+
public void AppendAnswer(string newText) => _inProgressAnswer = (_inProgressAnswer.Message + _inProgressAnswer.LatestToken, newText);
165+
166+
public void MarkComplete()
167+
{
168+
IsComplete = true;
169+
if (Conversation.IsDisposed == false)
170+
{
171+
// clean up the conversation and sampler to release more memory for inference.
172+
// real life usage would protect against these two being referenced after being disposed.
173+
Conversation.Dispose();
174+
Sampler.Dispose();
175+
}
176+
}
177+
}

0 commit comments

Comments
 (0)