Skip to content

Commit b027d05

Browse files
committed
Add documentation
1 parent 611716e commit b027d05

File tree

4 files changed

+229
-48
lines changed

4 files changed

+229
-48
lines changed

src/Experimental/RAGTools/generation.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ function build_context(contexter::ContextEnumerator,
6464
return context
6565
end
6666

67+
"""
68+
build_context(contexter::ContextEnumerator,
69+
index::AbstractManagedIndex, candidates::AbstractCandidateWithChunks;
70+
verbose::Bool = true,
71+
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
72+
73+
build_context!(contexter::ContextEnumerator,
74+
index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...)
75+
76+
Dispatch for `AbstractManagedIndex` with `AbstractCandidateWithChunks`.
77+
"""
6778
function build_context(contexter::ContextEnumerator,
6879
index::AbstractManagedIndex,
6980
candidates::AbstractCandidateWithChunks;
@@ -124,7 +135,6 @@ function answer!(
124135
throw(ArgumentError("Answerer $(typeof(answerer)) not implemented"))
125136
end
126137

127-
# TODO: update docs signature
128138
"""
129139
answer!(
130140
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
@@ -173,6 +183,17 @@ function answer!(
173183

174184
return result
175185
end
186+
187+
"""
188+
answer!(
189+
answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
190+
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
191+
template::Symbol = :RAGAnswerFromContext,
192+
cost_tracker = Threads.Atomic{Float64}(0.0),
193+
kwargs...)
194+
195+
Dispatch for `AbstractManagedIndex`.
196+
"""
176197
function answer!(
177198
answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult;
178199
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
@@ -228,7 +249,6 @@ function refine!(
228249
end
229250

230251

231-
# TODO: update docs signature
232252
"""
233253
refine!(
234254
refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
@@ -247,10 +267,9 @@ function refine!(
247267
end
248268

249269

250-
# TODO: update docs signature
251270
"""
252271
refine!(
253-
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
272+
refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
254273
verbose::Bool = true,
255274
model::AbstractString = PT.MODEL_CHAT,
256275
template::Symbol = :RAGAnswerRefiner,
@@ -303,10 +322,9 @@ function refine!(
303322
end
304323

305324

306-
# TODO: update docs signature
307325
"""
308326
refine!(
309-
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
327+
refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
310328
verbose::Bool = true,
311329
model::AbstractString = PT.MODEL_CHAT,
312330
include_answer::Bool = true,
@@ -458,10 +476,9 @@ It uses `ContextEnumerator`, `SimpleAnswerer`, `SimpleRefiner`, and `NoPostproce
458476
postprocessor::AbstractPostprocessor = NoPostprocessor()
459477
end
460478

461-
# TODO: update docs signature
462479
"""
463480
generate!(
464-
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
481+
generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult;
465482
verbose::Integer = 1,
466483
api_kwargs::NamedTuple = NamedTuple(),
467484
contexter::AbstractContextBuilder = generator.contexter,
@@ -591,8 +608,9 @@ function Base.show(io::IO, cfg::AbstractRAGConfig)
591608
dump(io, cfg; maxdepth = 2)
592609
end
593610

611+
# TODO: add example for Pinecone
594612
"""
595-
airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
613+
airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex};
596614
question::AbstractString,
597615
verbose::Integer = 1, return_all::Bool = false,
598616
api_kwargs::NamedTuple = NamedTuple(),

src/Experimental/RAGTools/preparation.jl

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,12 @@ end
145145
PineconeIndexer <: AbstractIndexBuilder
146146
147147
Pinecone index to be returned by `build_index`.
148+
149+
It uses `FileChunker`, `SimpleEmbedder` and `NoTagger` as default chunker, embedder and tagger.
148150
"""
149151
@kwdef mutable struct PineconeIndexer <: AbstractIndexBuilder
150152
chunker::AbstractChunker = FileChunker()
153+
# TODO: BatchEmbedder?
151154
embedder::AbstractEmbedder = SimpleEmbedder()
152155
tagger::AbstractTagger = NoTagger()
153156
end
@@ -726,26 +729,102 @@ function build_index(
726729
return index
727730
end
728731

732+
# TODO: where to put these?
729733
using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3, init_v3, Index, PineconeVector, upsert
730734
using UUIDs: UUIDs, uuid4
731-
# TODO: change docs
732735
"""
733736
build_index(
734-
indexer::PineconeIndexer;
735-
namespace::AbstractString,
737+
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
738+
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
739+
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
740+
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
741+
pinecone_namespace::AbstractString = "",
742+
upsert::Bool = true,
736743
verbose::Integer = 1,
737-
index_id = gensym("PTPineconeIndex"),
744+
index_id = gensym(pinecone_namespace),
745+
chunker::AbstractChunker = indexer.chunker,
746+
chunker_kwargs::NamedTuple = NamedTuple(),
747+
embedder::AbstractEmbedder = indexer.embedder,
748+
embedder_kwargs::NamedTuple = NamedTuple(),
749+
tagger::AbstractTagger = indexer.tagger,
750+
tagger_kwargs::NamedTuple = NamedTuple(),
751+
api_kwargs::NamedTuple = NamedTuple(),
738752
cost_tracker = Threads.Atomic{Float64}(0.0))
739753
740754
Builds a `PineconeIndex` containing a Pinecone context (API key, index and namespace).
755+
The index stores the document chunks and their embeddings (and potentially other information).
756+
757+
The function processes each file or document (depending on `chunker`), splits its content into chunks, embeds these chunks
758+
and then combines this information into a retrievable index. The chunks and embeddings are upsert to Pinecone using
759+
the provided Pinecone context (unless the `upsert` flag is set to `false`).
760+
761+
# Arguments
762+
- `indexer::PineconeIndexer`: The indexing logic for Pinecone operations.
763+
- `files_or_docs`: A vector of valid file paths to be indexed (chunked and embedded).
764+
- `metadata::Vector{Dict{String, Any}}`: A vector of metadata attributed to each docs file, given as dictionaries with `String` keys. Default is empty vector.
765+
- `pinecone_context::Pinecone.PineconeContextv3`: The Pinecone API key generated using Pinecone.jl. Must be specified.
766+
- `pinecone_index::Pinecone.PineconeIndexv3`: The Pinecone index generated using Pinecone.jl. Must be specified.
767+
- `pinecone_namespace::AbstractString`: The Pinecone namespace associated to `pinecone_index`.
768+
- `upsert::Bool = true`: A flag specifying whether to upsert the chunks and embeddings to Pinecone. Defaults to `true`.
769+
- `verbose`: An Integer specifying the verbosity of the logs. Default is `1` (high-level logging). `0` is disabled.
770+
- `index_id`: A unique identifier for the index. Default is a generated symbol.
771+
- `chunker`: The chunker logic to use for splitting the documents. Default is `TextChunker()`.
772+
- `chunker_kwargs`: Parameters to be provided to the `get_chunks` function. Useful to change the `separators` or `max_length`.
773+
- `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs`.
774+
- `embedder`: The embedder logic to use for embedding the chunks. Default is `BatchEmbedder()`.
775+
- `embedder_kwargs`: Parameters to be provided to the `get_embeddings` function. Useful to change the `target_batch_size_length` or reduce asyncmap tasks `ntasks`.
776+
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
777+
- `tagger`: The tagger logic to use for extracting tags from the chunks. Default is `NoTagger()`, ie, skip tag extraction. There are also `PassthroughTagger` and `OpenTagger`.
778+
- `tagger_kwargs`: Parameters to be provided to the `get_tags` function.
779+
- `model`: The model to use for tags extraction. Default is `PT.MODEL_CHAT`.
780+
- `template`: A template to be used for tags extraction. Default is `:RAGExtractMetadataShort`.
781+
- `tags`: A vector of vectors of strings directly providing the tags for each chunk. Applicable for `tagger::PasstroughTagger`.
782+
- `api_kwargs`: Parameters to be provided to the API endpoint. Shared across all API calls if provided.
783+
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
784+
785+
# Returns
786+
- `PineconeIndex`: An object containing the compiled index of chunks, embeddings, tags, vocabulary, sources and metadata, together with the Pinecone connection data.
787+
788+
See also: `PineconeIndex`, `get_chunks`, `get_embeddings`, `get_tags`, `CandidateWithChunks`, `find_closest`, `find_tags`, `rerank`, `retrieve`, `generate!`, `airag`
789+
790+
# Examples
791+
```julia
792+
using Pinecone
793+
794+
# Prepare the Pinecone connection data
795+
pinecone_context = Pinecone.init_v3(ENV["PINECONE_API_KEY"])
796+
pindex = ENV["PINECONE_INDEX"]
797+
pinecone_index = !isempty(pindex) ? Pinecone.Index(pinecone_context, pindex) : nothing
798+
namespace = "my-namespace"
799+
800+
# Add metadata about the sources in Pinecone
801+
metadata = [Dict{String, Any}("source" => doc_file) for doc_file in docs_files]
802+
803+
# Build the index. By default, the chunks and embeddings get upserted to Pinecone.
804+
const RT = PromptingTools.Experimental.RAGTools
805+
index_pinecone = RT.build_index(
806+
RT.PineconeIndexer(),
807+
docs_files;
808+
pinecone_context = pinecone_context,
809+
pinecone_index = pinecone_index,
810+
pinecone_namespace = namespace,
811+
metadata = metadata
812+
)
813+
814+
# Notes
815+
- If you get errors about exceeding embedding input sizes, first check the `max_length` in your chunks.
816+
If that does NOT resolve the issue, try changing the `embedding_kwargs`.
817+
In particular, reducing the `target_batch_size_length` parameter (eg, 10_000) and number of tasks `ntasks=1`.
818+
Some providers cannot handle large batch sizes (eg, Databricks).
819+
741820
"""
742821
function build_index(
743822
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
744823
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
745824
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
746825
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
747826
pinecone_namespace::AbstractString = "",
748-
upsert::Bool = false,
827+
upsert::Bool = true,
749828
verbose::Integer = 1,
750829
index_id = gensym(pinecone_namespace),
751830
chunker::AbstractChunker = indexer.chunker,

src/Experimental/RAGTools/retrieval.jl

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,37 @@ function find_closest(
241241
return CandidateChunks(indexid(index), positions, Float32.(scores))
242242
end
243243

244+
# Dispatch to find scores for multiple embeddings
245+
function find_closest(
246+
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
247+
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
248+
top_k::Int = 100, kwargs...)
249+
if isnothing(chunkdata(parent(index)))
250+
return CandidateChunks(; index_id = indexid(index))
251+
end
252+
## reduce top_k since we have more than one query
253+
top_k_ = top_k ÷ size(query_emb, 2)
254+
## simply vcat together (gets sorted from the highest similarity to the lowest)
255+
if isempty(query_tokens)
256+
mapreduce(
257+
c -> find_closest(finder, index, c; top_k = top_k_, kwargs...), vcat, eachcol(query_emb))
258+
else
259+
@assert length(query_tokens)==size(query_emb, 2) "Length of `query_tokens` must be equal to the number of columns in `query_emb`."
260+
mapreduce(
261+
(emb, tok) -> find_closest(finder, index, emb, tok; top_k = top_k_, kwargs...), vcat, eachcol(query_emb), query_tokens)
262+
end
263+
end
264+
265+
"""
266+
find_closest(
267+
finder::AbstractSimilarityFinder, index::PineconeIndex,
268+
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
269+
top_k::Int = 10, kwargs...)
270+
271+
Finds the indices of chunks that are closest to query embedding (`query_emb`) by querying Pinecone.
272+
273+
Returns only `top_k` closest indices.
274+
"""
244275
function find_closest(
245276
finder::AbstractSimilarityFinder, index::PineconeIndex,
246277
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
@@ -261,6 +292,7 @@ function find_closest(
261292
scores = [m.score for m in matches]
262293
chunks = [m.metadata.content for m in matches]
263294
metadata = [JSON3.read(JSON3.write(m.metadata), Dict{String, Any}) for m in matches]
295+
# TODO: metadata might not have `source`, change this
264296
sources = [m.metadata.source for m in matches]
265297

266298
return CandidateWithChunks(
@@ -272,6 +304,7 @@ function find_closest(
272304
sources = Vector{String}(sources))
273305
end
274306

307+
# Dispatch to find scores for multiple embeddings
275308
function find_closest(
276309
finder::AbstractSimilarityFinder, index::PineconeIndex,
277310
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
@@ -290,27 +323,6 @@ function find_closest(
290323
end
291324
end
292325

293-
# Dispatch to find scores for multiple embeddings
294-
function find_closest(
295-
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
296-
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
297-
top_k::Int = 100, kwargs...)
298-
if isnothing(chunkdata(parent(index)))
299-
return CandidateChunks(; index_id = indexid(index))
300-
end
301-
## reduce top_k since we have more than one query
302-
top_k_ = top_k ÷ size(query_emb, 2)
303-
## simply vcat together (gets sorted from the highest similarity to the lowest)
304-
if isempty(query_tokens)
305-
mapreduce(
306-
c -> find_closest(finder, index, c; top_k = top_k_, kwargs...), vcat, eachcol(query_emb))
307-
else
308-
@assert length(query_tokens)==size(query_emb, 2) "Length of `query_tokens` must be equal to the number of columns in `query_emb`."
309-
mapreduce(
310-
(emb, tok) -> find_closest(finder, index, emb, tok; top_k = top_k_, kwargs...), vcat, eachcol(query_emb), query_tokens)
311-
end
312-
end
313-
314326
### For MultiIndex
315327
function find_closest(
316328
finder::MultiFinder, index::AbstractMultiIndex,
@@ -612,20 +624,14 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
612624
end
613625

614626
"""
615-
find_tags(method::NoTagFilter, index::AbstractChunkIndex,
627+
find_tags(method::NoTagFilter, index::Union{AbstractChunkIndex, AbstractManagedIndex},
616628
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
617629
Union{
618630
AbstractString, Regex, Nothing}}
619631
tags; kwargs...)
620632
621633
Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch).
622634
"""
623-
# function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
624-
# tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
625-
# Union{
626-
# AbstractString, Regex, Nothing}}
627-
# return nothing
628-
# end
629635
function find_tags(
630636
method::NoTagFilter, index::Union{AbstractChunkIndex,
631637
AbstractManagedIndex},
@@ -748,8 +754,6 @@ function rerank(reranker::NoReranker,
748754
candidates::AbstractCandidateWithChunks;
749755
top_n::Integer = length(candidates),
750756
kwargs...)
751-
# Since this is almost a passthrough strategy, it returns the candidate_chunks unchanged
752-
# but it truncates to `top_n` if necessary
753757
return first(candidates, top_n)
754758
end
755759

@@ -1017,11 +1021,22 @@ end
10171021
PineconeRetriever <: AbstractRetriever
10181022
10191023
Dispatch for `retrieve` for Pinecone.
1024+
1025+
# Fields
1026+
- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` - uses `NoRephraser`
1027+
- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` (see Preparation Stage for more details) - uses `SimpleEmbedder`
1028+
- `processor::AbstractProcessor`: the processor method, dispatching `get_keywords` (see Preparation Stage for more details) - uses `NoProcessor`
1029+
- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` - uses `CosineSimilarity`
1030+
- `tagger::AbstractTagger`: the tag generating method, dispatching `get_tags` (see Preparation Stage for more details) - uses `NoTagger`
1031+
- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` - uses `NoTagFilter`
1032+
- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` - uses `NoReranker`
10201033
"""
10211034
@kwdef mutable struct PineconeRetriever <: AbstractRetriever
10221035
rephraser::AbstractRephraser = NoRephraser()
1036+
# TODO: BatchEmbedder?
10231037
embedder::AbstractEmbedder = SimpleEmbedder()
10241038
processor::AbstractProcessor = NoProcessor()
1039+
# TODO: actually do something with this; Pinecone allows choosing finder
10251040
finder::AbstractSimilarityFinder = CosineSimilarity()
10261041
tagger::AbstractTagger = NoTagger()
10271042
filter::AbstractTagFilter = NoTagFilter()
@@ -1242,6 +1257,33 @@ function retrieve(retriever::AbstractRetriever,
12421257
return result
12431258
end
12441259

1260+
"""
1261+
retrieve(retriever::PineconeRetriever,
1262+
index::PineconeIndex,
1263+
question::AbstractString;
1264+
verbose::Integer = 1,
1265+
top_k::Integer = 100,
1266+
top_n::Integer = 10,
1267+
api_kwargs::NamedTuple = NamedTuple(),
1268+
rephraser::AbstractRephraser = retriever.rephraser,
1269+
rephraser_kwargs::NamedTuple = NamedTuple(),
1270+
embedder::AbstractEmbedder = retriever.embedder,
1271+
embedder_kwargs::NamedTuple = NamedTuple(),
1272+
processor::AbstractProcessor = retriever.processor,
1273+
processor_kwargs::NamedTuple = NamedTuple(),
1274+
finder::AbstractSimilarityFinder = retriever.finder,
1275+
finder_kwargs::NamedTuple = NamedTuple(),
1276+
tagger::AbstractTagger = retriever.tagger,
1277+
tagger_kwargs::NamedTuple = NamedTuple(),
1278+
filter::AbstractTagFilter = retriever.filter,
1279+
filter_kwargs::NamedTuple = NamedTuple(),
1280+
reranker::AbstractReranker = retriever.reranker,
1281+
reranker_kwargs::NamedTuple = NamedTuple(),
1282+
cost_tracker = Threads.Atomic{Float64}(0.0),
1283+
kwargs...)
1284+
1285+
Dispatch method for `PineconeIndex`.
1286+
"""
12451287
function retrieve(retriever::PineconeRetriever,
12461288
index::PineconeIndex,
12471289
question::AbstractString;

0 commit comments

Comments
 (0)