Skip to content

Commit 1f97ad8

Browse files
authored
Merge pull request #333 from AsakusaRinne/master
feat: allow customized search path for native library loading.
2 parents 6dfda5e + ffc347a commit 1f97ad8

File tree

3 files changed

+100
-17
lines changed

3 files changed

+100
-17
lines changed

LLama/ChatSession.cs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,24 @@ public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams?
152152
foreach (var inputTransform in InputTransformPipeline)
153153
prompt = inputTransform.Transform(prompt);
154154

155-
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
156-
157-
if (_executor is InteractiveExecutor executor)
155+
// TODO: need to be refactored.
156+
if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun)
158157
{
159-
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
160-
prompt = state.IsPromptRun
161-
? HistoryTransform.HistoryToText(History)
162-
: prompt;
158+
History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt));
159+
var converted_prompt = HistoryTransform.HistoryToText(History);
160+
// Avoid missing anti-prompt.
161+
if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n"))
162+
{
163+
prompt = converted_prompt.Trim();
164+
}
165+
else
166+
{
167+
prompt = converted_prompt;
168+
}
169+
}
170+
else
171+
{
172+
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
163173
}
164174

165175
StringBuilder sb = new();

LLama/Native/NativeApi.Load.cs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Generic;
55
using System.Diagnostics;
66
using System.IO;
7+
using System.Linq;
78
using System.Runtime.InteropServices;
89
using System.Text.Json;
910

@@ -258,6 +259,7 @@ private static IntPtr TryLoadLibrary()
258259
enableLogging = configuration.Logging;
259260
// We move the flag to avoid loading library when the variable is called else where.
260261
NativeLibraryConfig.LibraryHasLoaded = true;
262+
Log(configuration.ToString(), LogLevel.Information);
261263

262264
if (!string.IsNullOrEmpty(configuration.Path))
263265
{
@@ -273,26 +275,30 @@ private static IntPtr TryLoadLibrary()
273275

274276
var libraryTryLoadOrder = GetLibraryTryOrder(configuration);
275277

278+
string[] preferredPaths = configuration.SearchDirectories;
276279
string[] possiblePathPrefix = new string[] {
277280
System.AppDomain.CurrentDomain.BaseDirectory,
278281
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
279282
};
280283

281284
var tryFindPath = (string filename) =>
282285
{
283-
int i = 0;
284-
while (!File.Exists(filename))
286+
foreach(var path in preferredPaths)
285287
{
286-
if (i < possiblePathPrefix.Length)
288+
if (File.Exists(Path.Combine(path, filename)))
287289
{
288-
filename = Path.Combine(possiblePathPrefix[i], filename);
289-
i++;
290+
return Path.Combine(path, filename);
290291
}
291-
else
292+
}
293+
294+
foreach(var path in possiblePathPrefix)
295+
{
296+
if (File.Exists(Path.Combine(path, filename)))
292297
{
293-
break;
298+
return Path.Combine(path, filename);
294299
}
295300
}
301+
296302
return filename;
297303
};
298304

LLama/Native/NativeLibraryConfig.cs

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
24

35
namespace LLama.Native
46
{
@@ -27,6 +29,10 @@ public sealed class NativeLibraryConfig
2729
private bool _allowFallback = true;
2830
private bool _skipCheck = false;
2931
private bool _logging = false;
32+
/// <summary>
33+
/// search directory -> priority level, 0 is the lowest.
34+
/// </summary>
35+
private List<string> _searchDirectories = new List<string>();
3036

3137
private static void ThrowIfLoaded()
3238
{
@@ -120,13 +126,50 @@ public NativeLibraryConfig WithLogs(bool enable = true)
120126
return this;
121127
}
122128

129+
/// <summary>
130+
/// Add self-defined search directories. Note that the file stucture of the added
131+
/// directories must be the same as the default directory. Besides, the directory
132+
/// won't be used recursively.
133+
/// </summary>
134+
/// <param name="directories"></param>
135+
/// <returns></returns>
136+
public NativeLibraryConfig WithSearchDirectories(IEnumerable<string> directories)
137+
{
138+
ThrowIfLoaded();
139+
140+
_searchDirectories.AddRange(directories);
141+
return this;
142+
}
143+
144+
/// <summary>
145+
/// Add self-defined search directories. Note that the file stucture of the added
146+
/// directories must be the same as the default directory. Besides, the directory
147+
/// won't be used recursively.
148+
/// </summary>
149+
/// <param name="directory"></param>
150+
/// <returns></returns>
151+
public NativeLibraryConfig WithSearchDirectory(string directory)
152+
{
153+
ThrowIfLoaded();
154+
155+
_searchDirectories.Add(directory);
156+
return this;
157+
}
158+
123159
internal static Description CheckAndGatherDescription()
124160
{
125161
if (Instance._allowFallback && Instance._skipCheck)
126162
{
127163
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
128164
}
129-
return new Description(Instance._libraryPath, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, Instance._skipCheck, Instance._logging);
165+
return new Description(
166+
Instance._libraryPath,
167+
Instance._useCuda,
168+
Instance._avxLevel,
169+
Instance._allowFallback,
170+
Instance._skipCheck,
171+
Instance._logging,
172+
Instance._searchDirectories.Concat(new string[] { "./" }).ToArray());
130173
}
131174

132175
internal static string AvxLevelToString(AvxLevel level)
@@ -183,7 +226,31 @@ public enum AvxLevel
183226
Avx512,
184227
}
185228

186-
internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging);
229+
internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, string[] SearchDirectories)
230+
{
231+
public override string ToString()
232+
{
233+
string avxLevelString = AvxLevel switch
234+
{
235+
AvxLevel.None => "NoAVX",
236+
AvxLevel.Avx => "AVX",
237+
AvxLevel.Avx2 => "AVX2",
238+
AvxLevel.Avx512 => "AVX512",
239+
_ => "Unknown"
240+
};
241+
242+
string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }";
243+
244+
return $"NativeLibraryConfig Description:\n" +
245+
$"- Path: {Path}\n" +
246+
$"- PreferCuda: {UseCuda}\n" +
247+
$"- PreferredAvxLevel: {avxLevelString}\n" +
248+
$"- AllowFallback: {AllowFallback}\n" +
249+
$"- SkipCheck: {SkipCheck}\n" +
250+
$"- Logging: {Logging}\n" +
251+
$"- SearchDirectories and Priorities: {searchDirectoriesString}";
252+
}
253+
}
187254
}
188255
#endif
189-
}
256+
}

0 commit comments

Comments
 (0)