1
+ using System . Diagnostics . CodeAnalysis ;
2
+ using System . Text ;
3
+ using LLama . Batched ;
4
+ using LLama . Common ;
5
+ using LLama . Native ;
6
+ using LLama . Sampling ;
7
+ using Spectre . Console ;
8
+
9
+ namespace LLama . Examples . Examples ;
10
+
11
+ /// <summary>
12
+ /// This demonstrates generating multiple replies to the same prompt, with a shared cache
13
+ /// </summary>
14
+ public class BatchedExecutorSimple
15
+ {
16
+ /// <summary>
17
+ /// Set total length of the sequence to generate
18
+ /// </summary>
19
+ private const int TokenCount = 72 ;
20
+
21
+ public static async Task Run ( )
22
+ {
23
+ // Load model weights
24
+ var parameters = new ModelParams ( UserSettings . GetModelPath ( ) ) ;
25
+ using var model = await LLamaWeights . LoadFromFileAsync ( parameters ) ;
26
+
27
+ // Create an executor that can evaluate a batch of conversations together
28
+ using var executor = new BatchedExecutor ( model , parameters ) ;
29
+
30
+ // we'll need this for evaluating if we are at the end of generation
31
+ var modelTokens = executor . Context . NativeHandle . ModelHandle . Tokens ;
32
+
33
+ // Print some info
34
+ var name = model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
35
+ Console . WriteLine ( $ "Created executor with model: { name } ") ;
36
+
37
+ var messages = new [ ]
38
+ {
39
+ "What's 2+2?" ,
40
+ "Where is the coldest part of Texas?" ,
41
+ "What's the capital of France?" ,
42
+ "What's a one word name for a food item with ground beef patties on a bun?" ,
43
+ "What are two toppings for a pizza?" ,
44
+ "What american football play are you calling on a 3rd and 8 from our own 25?" ,
45
+ "What liquor should I add to egg nog?" ,
46
+ "I have two sons, Bert and Ernie. What should I name my daughter?" ,
47
+ "What day comes after Friday?" ,
48
+ "What color shoes should I wear with dark blue pants?" ,
49
+ } ;
50
+
51
+ var conversations = new List < ConversationData > ( ) ;
52
+ foreach ( var message in messages )
53
+ {
54
+ // apply the model's prompt template to our question and system prompt
55
+ var template = new LLamaTemplate ( model ) ;
56
+ template . Add ( "system" , "I am a helpful bot that returns short and concise answers. I include a ten word description of my reasoning when I finish." ) ;
57
+ template . Add ( "user" , message ) ;
58
+ template . AddAssistant = true ;
59
+ var templatedMessage = Encoding . UTF8 . GetString ( template . Apply ( ) ) ;
60
+
61
+ // create a new conversation and prompt it. include special and bos because we are using the template
62
+ var conversation = executor . Create ( ) ;
63
+ conversation . Prompt ( executor . Context . Tokenize ( templatedMessage , addBos : true , special : true ) ) ;
64
+
65
+ conversations . Add ( new ConversationData {
66
+ Prompt = message ,
67
+ Conversation = conversation ,
68
+ Sampler = new GreedySamplingPipeline ( ) ,
69
+ Decoder = new StreamingTokenDecoder ( executor . Context )
70
+ } ) ;
71
+ }
72
+
73
+ var table = BuildTable ( conversations ) ;
74
+ await AnsiConsole . Live ( table ) . StartAsync ( async ctx =>
75
+ {
76
+ for ( var i = 0 ; i < TokenCount ; i ++ )
77
+ {
78
+ // Run inference for all conversations in the batch which have pending tokens.
79
+ var decodeResult = await executor . Infer ( ) ;
80
+ if ( decodeResult == DecodeResult . NoKvSlot )
81
+ throw new Exception ( "Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context." ) ;
82
+ if ( decodeResult == DecodeResult . Error )
83
+ throw new Exception ( "Unknown error occurred while inferring." ) ;
84
+
85
+ foreach ( var conversationData in conversations . Where ( c => c . IsComplete == false ) )
86
+ {
87
+ if ( conversationData . Conversation . RequiresSampling == false ) continue ;
88
+
89
+ // sample a single token for the executor, passing the sample index of the conversation
90
+ var token = conversationData . Sampler . Sample (
91
+ executor . Context . NativeHandle ,
92
+ conversationData . Conversation . GetSampleIndex ( ) ) ;
93
+
94
+ if ( modelTokens . IsEndOfGeneration ( token ) )
95
+ {
96
+ conversationData . MarkComplete ( ) ;
97
+ }
98
+ else
99
+ {
100
+ // it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
101
+ conversationData . Decoder . Add ( token ) ;
102
+ conversationData . AppendAnswer ( conversationData . Decoder . Read ( ) . ReplaceLineEndings ( " " ) ) ;
103
+
104
+ // add the token to the conversation
105
+ conversationData . Conversation . Prompt ( token ) ;
106
+ }
107
+ }
108
+
109
+ // render the current state
110
+ table = BuildTable ( conversations ) ;
111
+ ctx . UpdateTarget ( table ) ;
112
+
113
+ if ( conversations . All ( c => c . IsComplete ) )
114
+ {
115
+ break ;
116
+ }
117
+ }
118
+
119
+ // if we ran out of tokens before completing just mark them as complete for rendering purposes.
120
+ foreach ( var data in conversations . Where ( i => i . IsComplete == false ) )
121
+ {
122
+ data . MarkComplete ( ) ;
123
+ }
124
+
125
+ table = BuildTable ( conversations ) ;
126
+ ctx . UpdateTarget ( table ) ;
127
+ } ) ;
128
+ }
129
+
130
+ /// <summary>
131
+ /// Helper to build a table to display the conversations.
132
+ /// </summary>
133
+ private static Table BuildTable ( List < ConversationData > conversations )
134
+ {
135
+ var table = new Table ( )
136
+ . RoundedBorder ( )
137
+ . AddColumns ( "Prompt" , "Response" ) ;
138
+
139
+ foreach ( var data in conversations )
140
+ {
141
+ table . AddRow ( data . Prompt . EscapeMarkup ( ) , data . AnswerMarkdown ) ;
142
+ }
143
+
144
+ return table ;
145
+ }
146
+ }
147
+
148
+ public class ConversationData
149
+ {
150
+ public required string Prompt { get ; init ; }
151
+ public required Conversation Conversation { get ; init ; }
152
+ public required BaseSamplingPipeline Sampler { get ; init ; }
153
+ public required StreamingTokenDecoder Decoder { get ; init ; }
154
+
155
+ public string AnswerMarkdown => IsComplete
156
+ ? $ "[green]{ _inProgressAnswer . Message . EscapeMarkup ( ) } { _inProgressAnswer . LatestToken . EscapeMarkup ( ) } [/]"
157
+ : $ "[grey]{ _inProgressAnswer . Message . EscapeMarkup ( ) } [/][white]{ _inProgressAnswer . LatestToken . EscapeMarkup ( ) } [/]";
158
+
159
+ public bool IsComplete { get ; private set ; }
160
+
161
+ // we are only keeping track of the answer in two parts to render them differently.
162
+ private ( string Message , string LatestToken ) _inProgressAnswer = ( string . Empty , string . Empty ) ;
163
+
164
+ public void AppendAnswer ( string newText ) => _inProgressAnswer = ( _inProgressAnswer . Message + _inProgressAnswer . LatestToken , newText ) ;
165
+
166
+ public void MarkComplete ( )
167
+ {
168
+ IsComplete = true ;
169
+ if ( Conversation . IsDisposed == false )
170
+ {
171
+ // clean up the conversation and sampler to release more memory for inference.
172
+ // real life usage would protect against these two being referenced after being disposed.
173
+ Conversation . Dispose ( ) ;
174
+ Sampler . Dispose ( ) ;
175
+ }
176
+ }
177
+ }
0 commit comments