Skip to content

Commit e1f067a

Browse files
committed
Integrate feedback; change structs and general workflow
1 parent f6c01c8 commit e1f067a

File tree

7 files changed

+288
-113
lines changed

7 files changed

+288
-113
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ version = "0.45.0"
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
10-
DotEnv = "4dc1fcf4-5e3b-5448-94ab-0c38ec0385c1"
1110
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1211
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1312
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -18,7 +17,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1817
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1918
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
2019
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21-
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
2220
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2321

2422
[weakdeps]
@@ -34,7 +32,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
3432
FlashRankPromptingToolsExt = ["FlashRank"]
3533
GoogleGenAIPromptingToolsExt = ["GoogleGenAI"]
3634
MarkdownPromptingToolsExt = ["Markdown"]
37-
RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode"]
35+
RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode", "Pinecone"]
3836
SnowballPromptingToolsExt = ["Snowball"]
3937

4038
[compat]

src/Experimental/RAGTools/RAGTools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ include("api_services.jl")
3232

3333
include("rag_interface.jl")
3434

35-
export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, PTPineconeIndex, CandidateChunks, RAGResult
35+
export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, PineconeIndex, CandidateChunks, CandidateWithChunks, RAGResult
3636
export MultiIndex, SubChunkIndex, MultiCandidateChunks
3737
include("types.jl")
3838

3939
export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIndexer,
40-
KeywordsIndexer, PTPineconeIndexer
40+
KeywordsIndexer, PineconeIndexer
4141
include("preparation.jl")
4242

4343
include("rank_gpt.jl")

src/Experimental/RAGTools/generation.jl

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ context = build_context(ContextEnumerator(), index, candidates; chunks_window_ma
3737
```
3838
"""
3939
function build_context(contexter::ContextEnumerator,
40-
index::AbstractDocumentIndex, candidates::AbstractCandidateChunks;
40+
index::AbstractManagedIndex,
41+
candidates::AbstractCandidateWithChunks;
4142
verbose::Bool = true,
4243
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
4344
## Checks
@@ -63,27 +64,31 @@ function build_context(contexter::ContextEnumerator,
6364
return context
6465
end
6566

66-
using Pinecone: Pinecone, query
67-
using JSON3: JSON3, read
68-
"""
69-
build_context(contexter::ContextEnumerator,
70-
index::AbstractPTPineconeIndex;
71-
verbose::Bool = true,
72-
top_k::Int = 10,
73-
kwargs...)
74-
75-
Build context strings by querying Pinecone.
76-
```
77-
"""
7867
function build_context(contexter::ContextEnumerator,
79-
index::AbstractPTPineconeIndex;
68+
index::AbstractManagedIndex,
69+
candidates::AbstractCandidateWithChunks;
8070
verbose::Bool = true,
81-
top_k::Int = 10,
82-
kwargs...)
83-
pinecone_results = Pinecone.query(index.pinecone_context, index.pinecone_index, index.embedding, top_k, index.namespace, false, true)
84-
results_json = JSON3.read(pinecone_results)
85-
context = results_json.matches[1].metadata.content
71+
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
72+
## Checks
73+
@assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative"
74+
75+
context = String[]
76+
for (i, position) in enumerate(positions(candidates))
77+
## select the right index
78+
id = candidates isa MultiCandidateChunks ? candidates.index_ids[i] :
79+
candidates.index_id
80+
index_ = index isa AbstractChunkIndex ? index : index[id]
81+
isnothing(index_) && continue
8682

83+
chunks_ = chunks(candidates)[
84+
max(1, position - chunks_window_margin[1]):min(end,
85+
position + chunks_window_margin[2])]
86+
## Check if surrounding chunks are from the same source
87+
is_same_source = sources(candidates)[
88+
max(1, position - chunks_window_margin[1]):min(end,
89+
position + chunks_window_margin[2])] .== sources(candidates)[position]
90+
push!(context, "$(i). $(join(chunks_[is_same_source], "\n"))")
91+
end
8792
return context
8893
end
8994

@@ -94,7 +99,7 @@ end
9499

95100
# Mutating version that dispatches on the result to the underlying implementation
96101
function build_context!(contexter::ContextEnumerator,
97-
index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
102+
index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...)
98103
result.context = build_context(contexter, index, result.reranked_candidates; kwargs...)
99104
return result
100105
end
@@ -114,6 +119,7 @@ function answer!(
114119
throw(ArgumentError("Answerer $(typeof(answerer)) not implemented"))
115120
end
116121

122+
# TODO: update docs signature
117123
"""
118124
answer!(
119125
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
@@ -138,7 +144,7 @@ Generates an answer using the `aigenerate` function with the provided `result.co
138144
139145
"""
140146
function answer!(
141-
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
147+
answerer::SimpleAnswerer, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
142148
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
143149
template::Symbol = :RAGAnswerFromContext,
144150
cost_tracker = Threads.Atomic{Float64}(0.0),
@@ -186,11 +192,13 @@ Refines the answer by executing a web search using the Tavily API. This method a
186192
struct TavilySearchRefiner <: AbstractRefiner end
187193

188194
function refine!(
189-
refiner::AbstractRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
195+
refiner::AbstractRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
190196
kwargs...)
191197
throw(ArgumentError("Refiner $(typeof(refiner)) not implemented"))
192198
end
193199

200+
201+
# TODO: update docs signature
194202
"""
195203
refine!(
196204
refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
@@ -199,7 +207,7 @@ end
199207
Simple no-op function for `refine!`. It simply copies the `result.answer` and `result.conversations[:answer]` without any changes.
200208
"""
201209
function refine!(
202-
refiner::NoRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
210+
refiner::NoRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
203211
kwargs...)
204212
result.final_answer = result.answer
205213
if haskey(result.conversations, :answer)
@@ -208,6 +216,8 @@ function refine!(
208216
return result
209217
end
210218

219+
220+
# TODO: update docs signature
211221
"""
212222
refine!(
213223
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
@@ -234,7 +244,7 @@ This method uses the same context as the original answer, however, it can be mod
234244
- `cost_tracker`: An atomic counter to track the cost of the operation.
235245
"""
236246
function refine!(
237-
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
247+
refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
238248
verbose::Bool = true,
239249
model::AbstractString = PT.MODEL_CHAT,
240250
template::Symbol = :RAGAnswerRefiner,
@@ -262,6 +272,8 @@ function refine!(
262272
return result
263273
end
264274

275+
276+
# TODO: update docs signature
265277
"""
266278
refine!(
267279
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
@@ -312,7 +324,7 @@ pprint(result)
312324
```
313325
"""
314326
function refine!(
315-
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
327+
refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
316328
verbose::Bool = true,
317329
model::AbstractString = PT.MODEL_CHAT,
318330
include_answer::Bool = true,
@@ -377,13 +389,13 @@ Overload this method to add custom postprocessing steps, eg, logging, saving con
377389
"""
378390
struct NoPostprocessor <: AbstractPostprocessor end
379391

380-
function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractDocumentIndex,
392+
function postprocess!(postprocessor::AbstractPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex},
381393
result::AbstractRAGResult; kwargs...)
382394
throw(ArgumentError("Postprocessor $(typeof(postprocessor)) not implemented"))
383395
end
384396

385397
function postprocess!(
386-
::NoPostprocessor, index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
398+
::NoPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...)
387399
return result
388400
end
389401

@@ -416,6 +428,7 @@ It uses `ContextEnumerator`, `SimpleAnswerer`, `SimpleRefiner`, and `NoPostproce
416428
postprocessor::AbstractPostprocessor = NoPostprocessor()
417429
end
418430

431+
# TODO: update docs signature
419432
"""
420433
generate!(
421434
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
@@ -483,7 +496,7 @@ result = generate!(index, result)
483496
```
484497
"""
485498
function generate!(
486-
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
499+
generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
487500
verbose::Integer = 1,
488501
api_kwargs::NamedTuple = NamedTuple(),
489502
contexter::AbstractContextBuilder = generator.contexter,
@@ -672,7 +685,7 @@ PT.pprint(result)
672685
673686
For easier manipulation of nested kwargs, see utilities `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
674687
"""
675-
function airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
688+
function airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex};
676689
question::AbstractString,
677690
verbose::Integer = 1, return_all::Bool = false,
678691
api_kwargs::NamedTuple = NamedTuple(),

src/Experimental/RAGTools/preparation.jl

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ Chunker when you provide text to `get_chunks` functions. Inputs are directly chu
2020
"""
2121
struct TextChunker <: AbstractChunker end
2222

23-
"""
24-
NoChunker <: AbstractChunker
25-
26-
27-
"""
28-
struct NoChunker <: AbstractChunker end
29-
3023
### Embedding Types
3124
"""
3225
NoEmbedder <: AbstractEmbedder
@@ -35,6 +28,13 @@ No-op embedder for `get_embeddings` functions. It returns `nothing`.
3528
"""
3629
struct NoEmbedder <: AbstractEmbedder end
3730

31+
"""
32+
SimpleEmbedder <: AbstractEmbedder
33+
34+
Simply passes the input to `aiembed`.
35+
"""
36+
struct SimpleEmbedder <: AbstractEmbedder end
37+
3838
"""
3939
BatchEmbedder <: AbstractEmbedder
4040
@@ -142,15 +142,13 @@ It uses `TextChunker`, `KeywordsProcessor`, and `NoTagger` as default chunker, p
142142
end
143143

144144
"""
145-
PTPineconeIndexer <: AbstractIndexBuilder
145+
PineconeIndexer <: AbstractIndexBuilder
146146
147147
Pinecone index to be returned by `build_index`.
148-
149-
It uses `NoChunker`, `NoEmbedder`, and `NoTagger` as default chunker, embedder, and tagger.
150148
"""
151-
@kwdef mutable struct PTPineconeIndexer <: AbstractIndexBuilder
152-
chunker::AbstractChunker = NoChunker()
153-
embedder::AbstractEmbedder = NoEmbedder()
149+
@kwdef mutable struct PineconeIndexer <: AbstractIndexBuilder
150+
chunker::AbstractChunker = TextChunker()
151+
embedder::AbstractEmbedder = SimpleEmbedder()
154152
tagger::AbstractTagger = NoTagger()
155153
end
156154

@@ -186,10 +184,6 @@ function load_text(chunker::TextChunker, input::AbstractString;
186184
@assert length(source)<=512 "Each `source` should be less than 512 characters long. Detected: $(length(source)) characters. You must provide sources for each text when using `TextChunker`"
187185
return input, source
188186
end
189-
function load_text(chunker::NoChunker, input::AbstractString = "";
190-
source::AbstractString = input, kwargs...)
191-
return input, source
192-
end
193187

194188
"""
195189
get_chunks(chunker::AbstractChunker,
@@ -251,6 +245,13 @@ function get_embeddings(
251245
return nothing
252246
end
253247

248+
function get_embeddings(
249+
embedder::SimpleEmbedder, docs::AbstractVector{<:AbstractString};
250+
model::AbstractString = PT.MODEL_EMBEDDING,
251+
kwargs...)
252+
return hcat([Vector{Float32}(aiembed(doc; model).content) for doc in docs]...)
253+
end
254+
254255
"""
255256
get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:AbstractString};
256257
verbose::Bool = true,
@@ -719,31 +720,31 @@ function build_index(
719720
return index
720721
end
721722

722-
using Pinecone: Pinecone, init_v3, Index
723+
using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3, init_v3, Index
724+
# TODO: change docs
723725
"""
724726
build_index(
725-
indexer::PTPineconeIndexer;
727+
indexer::PineconeIndexer;
726728
namespace::AbstractString,
727-
schema::AbstractPromptSchema = OpenAISchema();
728729
verbose::Integer = 1,
729730
index_id = gensym("PTPineconeIndex"),
730731
cost_tracker = Threads.Atomic{Float64}(0.0))
731732
732-
Builds a `PTPineconeIndex` containing a Pinecone context (API key, index and namespace).
733+
Builds a `PineconeIndex` containing a Pinecone context (API key, index and namespace).
733734
"""
734735
function build_index(
735-
indexer::PTPineconeIndexer,
736+
indexer::PineconeIndexer,
737+
context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
738+
index::Pinecone.PineconeIndexv3 = "",
736739
namespace::AbstractString,
737-
schema::PromptingTools.AbstractPromptSchema = PromptingTools.OpenAISchema();
738740
verbose::Integer = 1,
739-
index_id = gensym("PTPineconeIndex"),
741+
index_id = gensym(namespace),
740742
cost_tracker = Threads.Atomic{Float64}(0.0))
743+
@assert !isempty(context.api_key) && !isempty(index) "Pinecone context and index not set"
741744

742-
pinecone_context = Pinecone.init_v3(ENV["PINECONE_API_KEY"])
743-
pindex = ENV["PINECONE_INDEX"]
744-
pinecone_index = pinecone_index = !isempty(pindex) ? Pinecone.Index(pinecone_context, pindex) : nothing
745+
# TODO: add chunking, embedding, tags?
745746

746-
index = PTPineconeIndex(; id = index_id, pinecone_context, pinecone_index, namespace, schema)
747+
index = PineconeIndex(; id = index_id, context, index, namespace)
747748

748749
(verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))"
749750

src/Experimental/RAGTools/rag_interface.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,6 @@ Main abstract type for storing document chunks and their embeddings. It also sto
161161
"""
162162
abstract type AbstractChunkIndex <: AbstractDocumentIndex end
163163

164-
"""
165-
AbstractPTPineconeIndex <: AbstractDocumentIndex
166-
167-
Abstract type for working with Pinecone. For now, just an empty index.
168-
"""
169-
abstract type AbstractPTPineconeIndex <: AbstractDocumentIndex end
170-
171164
# ## Retrieval stage
172165

173166
"""
@@ -184,6 +177,8 @@ Return type from `find_closest` and `find_tags` functions.
184177
"""
185178
abstract type AbstractCandidateChunks end
186179

180+
abstract type AbstractCandidateWithChunks end
181+
187182
# Main supertype for retrieval customizations
188183
abstract type AbstractRetrievalMethod end
189184

0 commit comments

Comments
 (0)