Skip to content

Commit eb94d1a

Browse files
authored
Add TraceMessage for observability (#133)
1 parent 49f3f5b commit eb94d1a

File tree

13 files changed

+785
-8
lines changed

13 files changed

+785
-8
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88

99
### Added
10+
- Added a few new open-weights models hosted by Fireworks.ai to the registry (DBRX Instruct, Mixtral 8x22b Instruct, Qwen 72b). If you're curious about how well they work, try them!
11+
- Added basic support for observability downstream. Created custom callback infrastructure with `initialize_tracer` and `finalize_tracer` and dedicated types are `TracerMessage` and `TracerMessageLike`. See `?TracerMessage` for more information and the corresponding `aigenerate` docstring.
12+
13+
### Updated
14+
- Changed default model for `RAGTools.CohereReranker` to "cohere-rerank-english-v3.0".
1015

1116
### Fixed
1217

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.19.0"
66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
9+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
910
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1011
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1112
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -31,6 +32,7 @@ RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra"]
3132
AbstractTrees = "0.4"
3233
Aqua = "0.7"
3334
Base64 = "<0.0.1, 1"
35+
Dates = "<0.0.1, 1"
3436
GoogleGenAI = "0.3"
3537
HTTP = "1"
3638
JSON3 = "1"

src/Experimental/RAGTools/retrieval.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ end
390390
verbose::Bool = false,
391391
api_key::AbstractString = PT.COHERE_API_KEY,
392392
top_n::Integer = length(candidates.scores),
393-
model::AbstractString = "rerank-english-v2.0",
393+
model::AbstractString = "rerank-english-v3.0",
394394
return_documents::Bool = false,
395395
cost_tracker = Threads.Atomic{Float64}(0.0),
396396
kwargs...)
@@ -404,10 +404,10 @@ Re-ranks a list of candidate chunks using the Cohere Rerank API. See https://coh
404404
- `question`: The query to be used for the search.
405405
- `candidates`: The candidate chunks to be re-ranked.
406406
- `top_n`: The number of most relevant documents to return. Default is `length(documents)`.
407-
- `model`: The model to use for reranking. Default is `rerank-english-v2.0`.
407+
- `model`: The model to use for reranking. Default is `rerank-english-v3.0`.
408408
- `return_documents`: A boolean flag indicating whether to return the reranked documents in the response. Default is `false`.
409409
- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`.
410-
- `cost_tracker`: An atomic counter to track the cost of the retrieval. Default is `Threads.Atomic{Float64}(0.0)`. Not currently tracked (cost unclear).
410+
- `cost_tracker`: An atomic counter to track the cost of the retrieval. Not implemented /tracked (cost unclear). Provided for consistency.
411411
412412
"""
413413
function rerank(
@@ -416,7 +416,7 @@ function rerank(
416416
verbose::Bool = false,
417417
api_key::AbstractString = PT.COHERE_API_KEY,
418418
top_n::Integer = length(candidates.scores),
419-
model::AbstractString = "rerank-english-v2.0",
419+
model::AbstractString = "rerank-english-v3.0",
420420
return_documents::Bool = false,
421421
cost_tracker = Threads.Atomic{Float64}(0.0),
422422
kwargs...)

src/PromptingTools.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module PromptingTools
22

33
import AbstractTrees
44
using Base64: base64encode
5+
using Dates: now, DateTime
56
using Logging
67
using OpenAI
78
using JSON3
@@ -72,6 +73,7 @@ include("llm_ollama.jl")
7273
include("llm_google.jl")
7374
include("llm_anthropic.jl")
7475
include("llm_sharegpt.jl")
76+
include("llm_tracer.jl")
7577

7678
## Convenience utils
7779
export @ai_str, @aai_str, @ai!_str, @aai!_str

src/llm_interface.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,28 @@ Frequently used schema for finetuning LLMs. Conversations are recorded as a vect
285285
"""
286286
struct ShareGPTSchema <: AbstractShareGPTSchema end
287287

288+
abstract type AbstractTracerSchema <: AbstractPromptSchema end
289+
290+
"""
291+
TracerSchema <: AbstractTracerSchema
292+
293+
A schema designed to wrap another schema, enabling pre- and post-execution callbacks for tracing and additional functionalities. This type is specifically utilized within the `TracerMessage` type to trace the execution flow, facilitating observability and debugging in complex conversational AI systems.
294+
295+
The `TracerSchema` acts as a middleware, allowing developers to insert custom logic before and after the execution of the primary schema's functionality. This can include logging, performance measurement, or any other form of tracing required to understand or improve the execution flow.
296+
297+
# Usage
298+
```julia
299+
wrap_schema = TracerSchema(OpenAISchema())
300+
msg = aigenerate(wrap_schema, "Say hi!"; model="gpt-4")
301+
# output type should be TracerMessage
302+
msg isa TracerMessage
303+
```
304+
You can define your own tracer schema and the corresponding methods: `initialize_tracer`, `finalize_tracer`. See `src/llm_tracer.jl`
305+
"""
306+
struct TracerSchema <: AbstractTracerSchema
307+
schema::AbstractPromptSchema
308+
end
309+
288310
## Dispatch into a default schema (can be set by Preferences.jl)
289311
# Since we load it as strings, we need to convert it to a symbol and instantiate it
290312
global PROMPT_SCHEMA::AbstractPromptSchema = @load_preference("PROMPT_SCHEMA",

src/llm_tracer.jl

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Tracing infrastructure for logging and other callbacks
2+
# - Define your own schema that is subtype of AbstractTracerSchema and wraps the underlying LLM provider schema
3+
# - Customize initialize_tracer and finalize_tracer with your custom callback
4+
# - Call your ai* function with the tracer schema as usual
5+
6+
# Simple passthrough, do nothing
7+
"""
8+
render(tracer_schema::AbstractTracerSchema,
9+
conv::AbstractVector{<:AbstractMessage}; kwargs...)
10+
11+
Passthrough. No changes.
12+
"""
13+
function render(tracer_schema::AbstractTracerSchema,
14+
conv::AbstractVector{<:AbstractMessage}; kwargs...)
15+
return conv
16+
end
17+
18+
"""
19+
initialize_tracer(
20+
tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...)
21+
22+
Initializes `tracer`/callback (if necessary). Can provide any keyword arguments in `tracer_kwargs` (eg, `parent_id`, `thread_id`, `run_id`).
23+
Is executed prior to the `ai*` calls.
24+
25+
In the default implementation, we just collect the necessary data to build the tracer object in `finalize_tracer`.
26+
"""
27+
function initialize_tracer(
28+
tracer_schema::AbstractTracerSchema; model = "", tracer_kwargs = NamedTuple(), kwargs...)
29+
return (; time_sent = now(), model, tracer_kwargs...)
30+
end
31+
32+
"""
33+
finalize_tracer(
34+
tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...)
35+
36+
Finalizes the calltracer of whatever is nedeed after the `ai*` calls. Use `tracer_kwargs` to provide any information necessary (eg, `parent_id`, `thread_id`, `run_id`).
37+
38+
In the default implementation, we convert all non-tracer messages into `TracerMessage`.
39+
"""
40+
function finalize_tracer(
41+
tracer_schema::AbstractTracerSchema, tracer, msg_or_conv; tracer_kwargs = NamedTuple(), model = "", kwargs...)
42+
# We already captured all kwargs, they are already in `tracer`, we can ignore them in this implementation
43+
time_received = now()
44+
# work with arrays for unified processing
45+
is_vector = msg_or_conv isa AbstractVector
46+
conv = msg_or_conv isa AbstractVector{<:AbstractMessage} ?
47+
convert(Vector{AbstractMessage}, msg_or_conv) :
48+
AbstractMessage[msg_or_conv]
49+
# all msg non-traced, set times
50+
for i in eachindex(conv)
51+
msg = conv[i]
52+
# change into TracerMessage if not already, use the current kwargs
53+
if !istracermessage(msg)
54+
# we saved our data for `tracer`
55+
conv[i] = TracerMessage(; object = msg, tracer..., time_received)
56+
end
57+
end
58+
return is_vector ? conv : first(conv)
59+
end
60+
61+
"""
62+
aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
63+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
64+
65+
Wraps the normal `aigenerate` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
66+
67+
Logic:
68+
- calls `initialize_tracer`
69+
- calls `aigenerate` (with the `tracer_schema.schema`)
70+
- calls `finalize_tracer`
71+
72+
# Example
73+
```julia
74+
wrap_schema = PT.TracerSchema(PT.OpenAISchema())
75+
msg = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t")
76+
msg isa TracerMessage # true
77+
msg.content # access content like if it was the message
78+
PT.pprint(msg) # pretty-print the message
79+
```
80+
81+
It works on a vector of messages and converts only the non-tracer ones, eg,
82+
```julia
83+
wrap_schema = PT.TracerSchema(PT.OpenAISchema())
84+
conv = aigenerate(wrap_schema, "Say hi!"; model = "gpt4t", return_all = true)
85+
all(PT.istracermessage, conv) #true
86+
```
87+
"""
88+
function aigenerate(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
89+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
90+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs, kwargs...)
91+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
92+
msg_or_conv = aigenerate(tracer_schema.schema, prompt; merged_kwargs...)
93+
return finalize_tracer(
94+
tracer_schema, tracer, msg_or_conv; model, tracer_kwargs, kwargs...)
95+
end
96+
97+
"""
98+
aiembed(tracer_schema::AbstractTracerSchema,
99+
doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity;
100+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
101+
102+
Wraps the normal `aiembed` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
103+
104+
Logic:
105+
- calls `initialize_tracer`
106+
- calls `aiembed` (with the `tracer_schema.schema`)
107+
- calls `finalize_tracer`
108+
"""
109+
function aiembed(tracer_schema::AbstractTracerSchema,
110+
doc_or_docs::Union{AbstractString, AbstractVector{<:AbstractString}}, postprocess::Function = identity;
111+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
112+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
113+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
114+
embed_or_conv = aiembed(
115+
tracer_schema.schema, doc_or_docs, postprocess; merged_kwargs...)
116+
return finalize_tracer(
117+
tracer_schema, tracer, embed_or_conv; model, tracer_kwargs..., kwargs...)
118+
end
119+
120+
"""
121+
aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
122+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
123+
124+
Wraps the normal `aiclassify` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
125+
126+
Logic:
127+
- calls `initialize_tracer`
128+
- calls `aiclassify` (with the `tracer_schema.schema`)
129+
- calls `finalize_tracer`
130+
"""
131+
function aiclassify(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
132+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
133+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
134+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
135+
classify_or_conv = aiclassify(tracer_schema.schema, prompt; merged_kwargs...)
136+
return finalize_tracer(
137+
tracer_schema, tracer, classify_or_conv; model, tracer_kwargs..., kwargs...)
138+
end
139+
140+
"""
141+
aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
142+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
143+
144+
Wraps the normal `aiextract` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
145+
146+
Logic:
147+
- calls `initialize_tracer`
148+
- calls `aiextract` (with the `tracer_schema.schema`)
149+
- calls `finalize_tracer`
150+
"""
151+
function aiextract(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
152+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
153+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
154+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
155+
extract_or_conv = aiextract(tracer_schema.schema, prompt; merged_kwargs...)
156+
return finalize_tracer(
157+
tracer_schema, tracer, extract_or_conv; model, tracer_kwargs..., kwargs...)
158+
end
159+
160+
"""
161+
aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
162+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
163+
164+
Wraps the normal `aiscan` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
165+
166+
Logic:
167+
- calls `initialize_tracer`
168+
- calls `aiscan` (with the `tracer_schema.schema`)
169+
- calls `finalize_tracer`
170+
"""
171+
function aiscan(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
172+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
173+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
174+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
175+
scan_or_conv = aiscan(tracer_schema.schema, prompt; merged_kwargs...)
176+
return finalize_tracer(
177+
tracer_schema, tracer, scan_or_conv; model, tracer_kwargs..., kwargs...)
178+
end
179+
180+
"""
181+
aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
182+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
183+
184+
Wraps the normal `aiimage` call in a tracing/callback system. Use `tracer_kwargs` to provide any information necessary to the tracer/callback system only (eg, `parent_id`, `thread_id`, `run_id`).
185+
186+
Logic:
187+
- calls `initialize_tracer`
188+
- calls `aiimage` (with the `tracer_schema.schema`)
189+
- calls `finalize_tracer`
190+
"""
191+
function aiimage(tracer_schema::AbstractTracerSchema, prompt::ALLOWED_PROMPT_TYPE;
192+
tracer_kwargs = NamedTuple(), model = "", kwargs...)
193+
tracer = initialize_tracer(tracer_schema; model, tracer_kwargs..., kwargs...)
194+
merged_kwargs = isempty(model) ? kwargs : (; model, kwargs...) # to not override default model for each schema if not provided
195+
image_or_conv = aiimage(tracer_schema.schema, prompt; merged_kwargs...)
196+
return finalize_tracer(
197+
tracer_schema, tracer, image_or_conv; model, tracer_kwargs..., kwargs...)
198+
end

0 commit comments

Comments
 (0)