@@ -12,13 +12,15 @@ namespace LLama.Examples.Examples;
1212/// </summary>
1313public class BatchedExecutorGuidance
1414{
15- private const int n_len = 32 ;
15+ /// <summary>
16+ /// Set how many tokens should be generated
17+ /// </summary>
18+ private const int TokenCount = 32 ;
1619
1720 public static async Task Run ( )
1821 {
19- string modelPath = UserSettings . GetModelPath ( ) ;
20-
21- var parameters = new ModelParams ( modelPath ) ;
22+ // Load model weights
23+ var parameters = new ModelParams ( UserSettings . GetModelPath ( ) ) ;
2224 using var model = await LLamaWeights . LoadFromFileAsync ( parameters ) ;
2325
2426 var positivePrompt = AnsiConsole . Ask ( "Positive Prompt (or ENTER for default):" , "My favourite colour is" ) . Trim ( ) ;
@@ -29,7 +31,7 @@ public static async Task Run()
2931 using var executor = new BatchedExecutor ( model , parameters ) ;
3032
3133 // Print some info
32- var name = executor . Model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
34+ var name = model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
3335 Console . WriteLine ( $ "Created executor with model: { name } ") ;
3436
3537 // Load the two prompts into two conversations
@@ -48,30 +50,30 @@ await AnsiConsole
4850 using var unguided = guided . Fork ( ) ;
4951
5052 // Run inference loop
51- var unguidedSampler = new GuidedSampler ( null , weight ) ;
53+ var unguidedSampler = new DefaultSamplingPipeline ( ) ;
5254 var unguidedDecoder = new StreamingTokenDecoder ( executor . Context ) ;
5355 var guidedSampler = new GuidedSampler ( guidance , weight ) ;
5456 var guidedDecoder = new StreamingTokenDecoder ( executor . Context ) ;
5557 await AnsiConsole
5658 . Progress ( )
5759 . StartAsync ( async progress =>
5860 {
59- var reporter = progress . AddTask ( "Running Inference" , maxValue : n_len ) ;
61+ var reporter = progress . AddTask ( "Running Inference" , maxValue : TokenCount ) ;
6062
61- for ( var i = 0 ; i < n_len ; i ++ )
63+ for ( var i = 0 ; i < TokenCount ; i ++ )
6264 {
6365 if ( i != 0 )
6466 await executor . Infer ( ) ;
6567
6668 // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
6769 // guidance. This serves as a comparison to show the effect of guidance.
68- var u = unguidedSampler . Sample ( executor . Context . NativeHandle , unguided . Sample ( ) , Array . Empty < LLamaToken > ( ) ) ;
70+ var u = unguidedSampler . Sample ( executor . Context . NativeHandle , unguided . Sample ( ) , [ ] ) ;
6971 unguidedDecoder . Add ( u ) ;
7072 unguided . Prompt ( u ) ;
7173
7274 // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
7375 // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
74- var g = guidedSampler . Sample ( executor . Context . NativeHandle , guided . Sample ( ) , Array . Empty < LLamaToken > ( ) ) ;
76+ var g = guidedSampler . Sample ( executor . Context . NativeHandle , guided . Sample ( ) , [ ] ) ;
7577 guidedDecoder . Add ( g ) ;
7678
7779 // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
@@ -91,37 +93,34 @@ await AnsiConsole
9193 AnsiConsole . MarkupLine ( $ "[green]Guided:[/][white]{ guidedDecoder . Read ( ) . ReplaceLineEndings ( " " ) } [/]") ;
9294 }
9395
94- private class GuidedSampler ( Conversation ? guidance , float weight )
96+ private class GuidedSampler ( Conversation guidance , float weight )
9597 : BaseSamplingPipeline
9698 {
99+ protected override LLamaToken ProcessTokenDataArray ( SafeLLamaContextHandle ctx , LLamaTokenDataArray candidates , ReadOnlySpan < LLamaToken > lastTokens )
100+ {
101+ // Get the logits generated by the guidance sequences
102+ var guidanceLogits = guidance . Sample ( ) ;
103+
104+ // Modify these logits based on the guidance logits
105+ candidates . Guidance ( ctx , guidanceLogits , weight ) ;
106+
107+ // Basic sampling
108+ candidates . Temperature ( ctx , 0.8f ) ;
109+ candidates . TopK ( ctx , 25 ) ;
110+ return candidates . SampleToken ( ctx ) ;
111+ }
112+
97113 public override void Accept ( SafeLLamaContextHandle ctx , LLamaToken token )
98114 {
99115 }
100-
116+
101117 public override ISamplingPipeline Clone ( )
102118 {
103119 throw new NotSupportedException ( ) ;
104120 }
105-
121+
106122 protected override void ProcessLogits ( SafeLLamaContextHandle ctx , Span < float > logits , ReadOnlySpan < LLamaToken > lastTokens )
107123 {
108124 }
109-
110- protected override LLamaToken ProcessTokenDataArray ( SafeLLamaContextHandle ctx , LLamaTokenDataArray candidates , ReadOnlySpan < LLamaToken > lastTokens )
111- {
112- if ( guidance != null )
113- {
114- // Get the logits generated by the guidance sequences
115- var guidanceLogits = guidance . Sample ( ) ;
116-
117- // Modify these logits based on the guidance logits
118- candidates . Guidance ( ctx , guidanceLogits , weight ) ;
119- }
120-
121- candidates . Temperature ( ctx , 0.8f ) ;
122- candidates . TopK ( ctx , 25 ) ;
123-
124- return candidates . SampleToken ( ctx ) ;
125- }
126125 }
127126}
0 commit comments