Skip to content

Made Vocabulary's properties be initialized only ONCE on creation #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions LLama/Native/LLamaToken.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Diagnostics;
using System.Linq;

namespace LLama.Native;

Expand Down Expand Up @@ -98,10 +99,7 @@ public bool IsControl(SafeLlamaModelHandle model)
/// <returns></returns>
public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab)
{
unsafe
{
return LLamaVocabNative.llama_vocab_is_control(vocab.VocabNative, this);
}
return vocab.ControlTokens.Contains(this);
}

/// <summary>
Expand All @@ -121,10 +119,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model)
/// <returns></returns>
public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab)
{
unsafe
{
return LLamaVocabNative.llama_vocab_is_eog(vocab.VocabNative, this);
}
return vocab.EOGTokens.Contains(this);
}

/// <inheritdoc />
Expand Down
252 changes: 70 additions & 182 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using LLama.Exceptions;

Expand Down Expand Up @@ -631,34 +632,65 @@ public sealed class Vocabulary

internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model);

/// <summary>
/// Map of each token in this vocabulary to its string representation
/// </summary>
internal readonly IReadOnlyDictionary<LLamaToken, string> TokenToString;

/// <summary>
/// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc)
/// </summary>
internal readonly IReadOnlyList<LLamaToken> EOGTokens;

/// <summary>
/// Contains unique tokens that exist for inference control rather than text output
/// </summary>
internal readonly IReadOnlyList<LLamaToken> ControlTokens;

internal Vocabulary(SafeLlamaModelHandle model)
{
_model = model;
}
TokenToString = GetVocabCache();

private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
{
if (!token.HasValue)
return null;

// Try to convert using a fixed size buffer
const int buffSize = 32;
Span<byte> buff = stackalloc byte[buffSize];
var tokenLength = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);

// Negative indicates that there was no result
if (tokenLength <= 0)
return null;

// if the original buffer wasn't large enough, try again with one that's the right size
if (tokenLength > buffSize)
// Cache the various properties that llama.cpp API exposes about the vocab
unsafe
{
buff = stackalloc byte[(int)tokenLength];
_ = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);
var vocabNative = llama_model_get_vocab(_model);
Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative);
Type = LLamaVocabNative.llama_vocab_type(vocabNative);
BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative));
EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative));
Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative));
Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative));
SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative));
InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative));
InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative));
InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative));
InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative));
InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative));
InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative));
EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative));
DecoderStartToken = Normalize(llama_model_decoder_start_token(_model));
ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative);
ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative);

EOGTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).ToList();
ControlTokens = TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).ToList();
}
}

var slice = buff.Slice(0, (int)tokenLength);
return Encoding.UTF8.GetStringFromSpan(slice);
private Dictionary<LLamaToken, string> GetVocabCache()
{
var decoder = Encoding.UTF8.GetDecoder();
var (bytesArr, charsArr) = (new byte[1024], new char[1024]);
return Enumerable.Range(0, Count).ToDictionary(
keySelector: i => (LLamaToken) i,
elementSelector: i =>
{
decoder.Convert(bytesArr, 0, (int) _model.TokenToSpan(i, bytesArr), charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _);
return string.Join("", charsArr.Take(charsUsed));
}
);
}

private static LLamaToken? Normalize(LLamaToken token)
Expand All @@ -669,232 +701,88 @@ internal Vocabulary(SafeLlamaModelHandle model)
/// <summary>
/// Total number of tokens in this vocabulary
/// </summary>
public int Count
{
get
{
unsafe
{
return LLamaVocabNative.llama_vocab_n_tokens(VocabNative);
}
}
}
public int Count { get; }

/// <summary>
/// Get the the type of this vocabulary
/// </summary>
public LLamaVocabType Type
{
get
{
unsafe
{
return LLamaVocabNative.llama_vocab_type(VocabNative);
}
}
}
public LLamaVocabType Type { get; }

/// <summary>
/// Get the Beginning of Sentence token for this model
/// </summary>
public LLamaToken? BOS
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_bos(VocabNative));
}
}
}
public LLamaToken? BOS { get; }

/// <summary>
/// Get the End of Sentence token for this model
/// </summary>
public LLamaToken? EOS
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_eos(VocabNative));
}
}
}
public LLamaToken? EOS { get; }

/// <summary>
/// Get the newline token for this model
/// </summary>
public LLamaToken? Newline
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_nl(VocabNative));
}
}
}
public LLamaToken? Newline { get; }

/// <summary>
/// Get the padding token for this model
/// </summary>
public LLamaToken? Pad
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_pad(VocabNative));
}
}
}
public LLamaToken? Pad { get; }

/// <summary>
/// Get the sentence separator token for this model
/// </summary>
public LLamaToken? SEP
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_sep(VocabNative));
}
}
}
public LLamaToken? SEP { get; }

/// <summary>
/// Codellama beginning of infill prefix
/// </summary>
public LLamaToken? InfillPrefix
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_pre(VocabNative));
}
}
}
public LLamaToken? InfillPrefix { get; }

/// <summary>
/// Codellama beginning of infill middle
/// </summary>
public LLamaToken? InfillMiddle
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_mid(VocabNative));
}
}
}
public LLamaToken? InfillMiddle { get; }

/// <summary>
/// Codellama beginning of infill suffix
/// </summary>
public LLamaToken? InfillSuffix
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_suf(VocabNative));
}
}
}
public LLamaToken? InfillSuffix { get; }

/// <summary>
/// Codellama pad
/// </summary>
public LLamaToken? InfillPad
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_pad(VocabNative));
}
}
}
public LLamaToken? InfillPad { get; }

/// <summary>
/// Codellama rep
/// </summary>
public LLamaToken? InfillRep
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_rep(VocabNative));
}
}
}
public LLamaToken? InfillRep { get; }

/// <summary>
/// Codellama rep
/// </summary>
public LLamaToken? InfillSep
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_fim_sep(VocabNative));
}
}
}
public LLamaToken? InfillSep { get; }

/// <summary>
/// end-of-turn token
/// </summary>
public LLamaToken? EOT
{
get
{
unsafe
{
return Normalize(LLamaVocabNative.llama_vocab_eot(VocabNative));
}
}
}
public LLamaToken? EOT { get; }

/// <summary>
/// For encoder-decoder models, this function returns id of the token that must be provided
/// to the decoder to start generating output sequence.
/// </summary>
public LLamaToken? DecoderStartToken => Normalize(llama_model_decoder_start_token(_model));
public LLamaToken? DecoderStartToken { get; }

/// <summary>
/// Check if the current model requires a BOS token added
/// </summary>
public bool ShouldAddBOS
{
get
{
unsafe
{
return LLamaVocabNative.llama_vocab_get_add_bos(llama_model_get_vocab(_model));
}
}
}
public bool ShouldAddBOS { get; }

/// <summary>
/// Check if the current model requires a EOS token added
/// </summary>
public bool ShouldAddEOS
{
get
{
unsafe
{
return LLamaVocabNative.llama_vocab_get_add_eos(llama_model_get_vocab(_model));
}
}
}
public bool ShouldAddEOS { get; }
}
}
}
Loading