1+ using System . Diagnostics ;
2+ using System . Security . Cryptography ;
3+ using System . Text ;
4+ using LLama . Common ;
5+ using LLama . Native ;
6+
7+ namespace LLama . Examples . NewVersion ;
8+
9+ /// <summary>
10+ /// This demonstrates generating multiple replies to the same prompt, with a shared cache
11+ /// </summary>
12+ /// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
13+ public class BatchedDecoding
14+ {
15+ private const int n_parallel = 8 ;
16+ private const int n_len = 32 ;
17+
18+ private const int top_k = 80 ;
19+ private const float top_p = 0.8f ;
20+ private const float temp = 0.5f ;
21+
22+ public static async Task Run ( )
23+ {
24+ Console . Write ( "Please input your model path: " ) ;
25+ var modelPath = Console . ReadLine ( ) ;
26+
27+ Console . WriteLine ( "Prompt (leave blank to select automatically):" ) ;
28+ var prompt = Console . ReadLine ( ) ;
29+ if ( string . IsNullOrWhiteSpace ( prompt ) )
30+ prompt = "Not many people know that" ;
31+
32+ // Load model
33+ var parameters = new ModelParams ( modelPath ) ;
34+ using var model = LLamaWeights . LoadFromFile ( parameters ) ;
35+
36+ // Tokenize prompt
37+ var prompt_tokens = model . NativeHandle . Tokenize ( prompt , true , false , Encoding . UTF8 ) ;
38+ var n_kv_req = prompt_tokens . Length + ( n_len - prompt_tokens . Length ) * n_parallel ;
39+
40+ // Create a context
41+ parameters . ContextSize = ( uint ) model . ContextSize ;
42+ parameters . BatchSize = ( uint ) Math . Max ( n_len , n_parallel ) ;
43+ using var context = model . CreateContext ( parameters ) ;
44+
45+ var n_ctx = context . ContextSize ;
46+
47+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
48+ if ( n_kv_req > n_ctx )
49+ {
50+ await Console . Error . WriteLineAsync ( $ "error: n_kv_req ({ n_kv_req } ) > n_ctx, the required KV cache size is not big enough\n ") ;
51+ await Console . Error . WriteLineAsync ( " either reduce n_parallel or increase n_ctx\n " ) ;
52+ return ;
53+ }
54+
55+ using var batch = LLamaBatchSafeHandle . Create ( Math . Max ( prompt_tokens . Length , n_parallel ) , 0 , 1 ) ;
56+
57+ // evaluate the initial prompt
58+ for ( var i = 0 ; i < prompt_tokens . Length ; i ++ )
59+ batch . LLamaBatchAdd ( prompt_tokens [ i ] , i , new [ ] { ( LLamaSeqId ) 0 } , false ) ;
60+ Debug . Assert ( batch . NativeBatch . n_tokens == prompt_tokens . Length ) ;
61+
62+ // llama_decode will output logits only for the last token of the prompt
63+ unsafe
64+ {
65+ batch . NativeBatch . logits [ batch . NativeBatch . n_tokens - 1 ] = 1 ;
66+ }
67+
68+ if ( context . NativeHandle . Decode ( batch ) != 0 )
69+ {
70+ await Console . Error . WriteLineAsync ( "llama_decode failed" ) ;
71+ return ;
72+ }
73+
74+ // assign the system KV cache to all parallel sequences
75+ // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
76+ for ( var i = 1 ; i < n_parallel ; ++ i )
77+ {
78+ NativeApi . llama_kv_cache_seq_cp ( context . NativeHandle , ( LLamaSeqId ) 0 , ( LLamaSeqId ) i , 0 , batch . NativeBatch . n_tokens ) ;
79+ }
80+
81+ if ( n_parallel > 1 )
82+ {
83+ Console . WriteLine ( ) ;
84+ Console . WriteLine ( $ "generating { n_parallel } sequences...") ;
85+ }
86+
87+ // remember the batch index of the last token for each parallel sequence
88+ // we need this to determine which logits to sample from
89+ List < int > i_batch = new ( ) ;
90+ for ( var i = 0 ; i < n_parallel ; i ++ )
91+ i_batch . Add ( batch . NativeBatch . n_tokens - 1 ) ;
92+
93+ var n_cur = batch . NativeBatch . n_tokens ;
94+ var n_decode = 0 ;
95+
96+ var streams = new List < int > [ n_parallel ] ;
97+ for ( var i = 0 ; i < n_parallel ; i ++ )
98+ streams [ i ] = new ( ) ;
99+
100+ var eos = model . EndOfSentenceToken ;
101+ var nl = model . NewlineToken ;
102+
103+ var timer = new Stopwatch ( ) ;
104+ timer . Start ( ) ;
105+ while ( n_cur <= n_len )
106+ {
107+ batch . LLamaBatchClear ( ) ;
108+
109+ for ( var i = 0 ; i < n_parallel ; i ++ )
110+ {
111+ // Skip completed streams
112+ if ( i_batch [ i ] < 0 )
113+ continue ;
114+
115+ var n_vocab = model . VocabCount ;
116+ LLamaTokenDataArray candidates ;
117+ unsafe
118+ {
119+ candidates = LLamaTokenDataArray . Create ( new Span < float > ( NativeApi . llama_get_logits_ith ( context . NativeHandle , i_batch [ i ] ) , n_vocab ) ) ;
120+ }
121+
122+ candidates . TopK ( context . NativeHandle , top_k ) ;
123+ candidates . TopP ( context . NativeHandle , top_p ) ;
124+ candidates . Temperature ( context . NativeHandle , temp ) ;
125+ var new_token_id = candidates . SampleToken ( context . NativeHandle ) ;
126+
127+ if ( new_token_id == eos || new_token_id == nl )
128+ {
129+ i_batch [ i ] = - 1 ;
130+ Console . WriteLine ( $ "Completed Stream { i } early") ;
131+ continue ;
132+ }
133+
134+ streams [ i ] . Add ( new_token_id ) ;
135+
136+ i_batch [ i ] = batch . NativeBatch . n_tokens ;
137+
138+ // push this new token for next evaluation
139+ batch . LLamaBatchAdd ( new_token_id , n_cur , new [ ] { ( LLamaSeqId ) i } , true ) ;
140+
141+ n_decode ++ ;
142+ }
143+
144+ // all streams are finished
145+ if ( batch . NativeBatch . n_tokens == 0 )
146+ {
147+ break ;
148+ }
149+
150+ n_cur ++ ;
151+
152+ // evaluate the current batch with the transformer model
153+ if ( context . NativeHandle . Decode ( batch ) != 0 )
154+ {
155+ await Console . Error . WriteLineAsync ( "failed to eval" ) ;
156+ return ;
157+ }
158+ }
159+
160+ timer . Stop ( ) ;
161+ Console . ForegroundColor = ConsoleColor . Yellow ;
162+ Console . WriteLine ( ) ;
163+ Console . WriteLine ( $ "Decoded { n_decode } tokens in { timer . ElapsedMilliseconds } ms") ;
164+ Console . WriteLine ( $ "Rate: { n_decode / timer . Elapsed . TotalSeconds : ##.000} tokens/second") ;
165+
166+ var index = 0 ;
167+ foreach ( var stream in streams )
168+ {
169+ var text = context . DeTokenize ( stream ) ;
170+
171+ Console . ForegroundColor = ConsoleColor . Green ;
172+ Console . Write ( $ "{ index ++ } . { prompt } ") ;
173+ Console . ForegroundColor = ConsoleColor . Red ;
174+ Console . WriteLine ( text ) ;
175+ }
176+ }
177+ }
0 commit comments