Skip to content

Commit 611716e

Browse files
committed
Add SubManagedIndex and view
1 parent 1c4e0f5 commit 611716e

File tree

3 files changed

+117
-14
lines changed

3 files changed

+117
-14
lines changed

src/Experimental/RAGTools/preparation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -742,12 +742,12 @@ Builds a `PineconeIndex` containing a Pinecone context (API key, index and names
742742
function build_index(
743743
indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString};
744744
metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(),
745-
context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
746-
index::Pinecone.PineconeIndexv3 = nothing,
747-
namespace::AbstractString = "",
745+
pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""),
746+
pinecone_index::Pinecone.PineconeIndexv3 = nothing,
747+
pinecone_namespace::AbstractString = "",
748748
upsert::Bool = false,
749749
verbose::Integer = 1,
750-
index_id = gensym(namespace),
750+
index_id = gensym(pinecone_namespace),
751751
chunker::AbstractChunker = indexer.chunker,
752752
chunker_kwargs::NamedTuple = NamedTuple(),
753753
embedder::AbstractEmbedder = indexer.embedder,
@@ -756,7 +756,7 @@ function build_index(
756756
tagger_kwargs::NamedTuple = NamedTuple(),
757757
api_kwargs::NamedTuple = NamedTuple(),
758758
cost_tracker = Threads.Atomic{Float64}(0.0))
759-
@assert !isempty(context.apikey) && !isnothing(index) "Pinecone context and index not set"
759+
@assert !isempty(pinecone_context.apikey) && !isnothing(pinecone_index) "Pinecone context and index not set"
760760

761761
## Split into chunks
762762
chunks, sources = get_chunks(chunker, files_or_docs;
@@ -788,12 +788,12 @@ function build_index(
788788
embeddings_arr = [embeddings[:,i] for i in axes(embeddings,2)]
789789
for (idx, emb) in enumerate(embeddings_arr)
790790
pinevector = Pinecone.PineconeVector(string(UUIDs.uuid4()), emb, metadata[idx])
791-
Pinecone.upsert(context, index, [pinevector], namespace)
791+
Pinecone.upsert(pinecone_context, pinecone_index, [pinevector], pinecone_namespace)
792792
@info "Upsert #$idx complete"
793793
end
794794
end
795795

796-
index = PineconeIndex(; id = index_id, context, index, namespace, chunks, embeddings, tags, tags_vocab, metadata, sources)
796+
index = PineconeIndex(; id = index_id, pinecone_context, pinecone_index, pinecone_namespace, chunks, embeddings, tags, tags_vocab, metadata, sources)
797797

798798
(verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))"
799799

src/Experimental/RAGTools/retrieval.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ function find_closest(
246246
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
247247
top_n::Int = 10, kwargs...)
248248
# get Pinecone info
249-
pinecone_context = index.context
250-
pinecone_index = index.index
251-
pinecone_namespace = index.namespace
249+
pinecone_context = index.pinecone_context
250+
pinecone_index = index.pinecone_index
251+
pinecone_namespace = index.pinecone_namespace
252252

253253
# query candidates
254254
pinecone_results = Pinecone.query(pinecone_context, pinecone_index,

src/Experimental/RAGTools/types.jl

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ chunkdata(index::ChunkEmbeddingsIndex) = embeddings(index)
137137
const ChunkIndex = ChunkEmbeddingsIndex
138138

139139
indexid(index::AbstractManagedIndex) = index.id
140+
chunks(index::AbstractManagedIndex) = index.chunks
141+
sources(index::AbstractManagedIndex) = index.sources
140142

141143
using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3
142144
@kwdef struct PineconeIndex{
@@ -145,9 +147,9 @@ using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3
145147
T3 <: Union{Nothing, AbstractMatrix{<:Bool}}
146148
} <: AbstractManagedIndex
147149
id::Symbol # namespace
148-
context::Pinecone.PineconeContextv3
149-
index::Pinecone.PineconeIndexv3
150-
namespace::String
150+
pinecone_context::Pinecone.PineconeContextv3
151+
pinecone_index::Pinecone.PineconeIndexv3
152+
pinecone_namespace::String
151153
# underlying document chunks / snippets
152154
chunks::Vector{T1} = nothing
153155
# for semantic search
@@ -546,6 +548,84 @@ Base.@propagate_inbounds function translate_positions_to_parent(
546548
return sub_positions[pos]
547549
end
548550

551+
552+
@kwdef struct SubManagedIndex{T <: AbstractManagedIndex} <: AbstractManagedIndex
553+
parent::T
554+
positions::Vector{Int}
555+
end
556+
557+
indexid(index::SubManagedIndex) = parent(index) |> indexid
558+
positions(index::SubManagedIndex) = index.positions
559+
Base.parent(index::SubManagedIndex) = index.parent
560+
HasEmbeddings(index::SubManagedIndex) = HasEmbeddings(parent(index))
561+
HasKeywords(index::SubManagedIndex) = HasKeywords(parent(index))
562+
563+
Base.@propagate_inbounds function chunks(index::SubManagedIndex)
564+
view(chunks(parent(index)), positions(index))
565+
end
566+
Base.@propagate_inbounds function sources(index::SubManagedIndex)
567+
view(sources(parent(index)), positions(index))
568+
end
569+
Base.@propagate_inbounds function chunkdata(index::SubManagedIndex)
570+
chunkdata(parent(index), positions(index))
571+
end
572+
"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index"
573+
Base.@propagate_inbounds function chunkdata(
574+
index::SubManagedIndex, chunk_idx::AbstractVector{<:Integer})
575+
## We need this accessor because different chunk indices can have chunks in different dimensions!!
576+
index_chunk_idx = translate_positions_to_parent(index, chunk_idx)
577+
pos = intersect(positions(index), index_chunk_idx)
578+
chkdata = chunkdata(parent(index), pos)
579+
end
580+
function embeddings(index::SubManagedIndex)
581+
if HasEmbeddings(index)
582+
view(embeddings(parent(index)), :, positions(index))
583+
else
584+
throw(ArgumentError("`embeddings` not implemented for $(typeof(index))"))
585+
end
586+
end
587+
function tags(index::SubManagedIndex)
588+
tagsdata = tags(parent(index))
589+
isnothing(tagsdata) && return nothing
590+
view(tagsdata, positions(index), :)
591+
end
592+
function tags_vocab(index::SubManagedIndex)
593+
tags_vocab(parent(index))
594+
end
595+
function extras(index::SubManagedIndex)
596+
extrasdata = extras(parent(index))
597+
isnothing(extrasdata) && return nothing
598+
view(extrasdata, positions(index))
599+
end
600+
function Base.vcat(i1::SubManagedIndex, i2::SubManagedIndex)
601+
throw(ArgumentError("vcat not implemented for type $(typeof(i1)) and $(typeof(i2))"))
602+
end
603+
function Base.vcat(i1::T, i2::T) where {T <: SubManagedIndex}
604+
## Check if can be merged
605+
if indexid(parent(i1)) != indexid(parent(i2))
606+
throw(ArgumentError("Parent indices must be the same (provided: $(indexid(parent(i1))) and $(indexid(parent(i2))))"))
607+
end
608+
return SubChunkIndex(parent(i1), vcat(positions(i1), positions(i2)))
609+
end
610+
function Base.unique(index::SubManagedIndex)
611+
return SubChunkIndex(parent(index), unique(positions(index)))
612+
end
613+
function Base.length(index::SubManagedIndex)
614+
return length(positions(index))
615+
end
616+
function Base.isempty(index::SubManagedIndex)
617+
return isempty(positions(index))
618+
end
619+
function Base.show(io::IO, index::SubManagedIndex)
620+
print(io,
621+
"A view of $(typeof(parent(index))|>nameof) (id: $(indexid(parent(index)))) with $(length(index)) chunks")
622+
end
623+
Base.@propagate_inbounds function translate_positions_to_parent(
624+
index::SubManagedIndex, pos::AbstractVector{<:Integer})
625+
sub_positions = positions(index)
626+
return sub_positions[pos]
627+
end
628+
549629
# # CandidateChunks for Retrieval
550630

551631
"""
@@ -864,7 +944,18 @@ Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::MultiCandi
864944
end
865945
# TODO: proper `view` -- `SubManagedIndex`?
866946
Base.@propagate_inbounds function Base.view(index::AbstractManagedIndex, cc::CandidateWithChunks)
867-
return cc
947+
@boundscheck let chk_vector = chunks(parent(index))
948+
if !checkbounds(Bool, axes(chk_vector, 1), positions(cc))
949+
## Avoid printing huge position arrays, show the extremas of the attempted range
950+
max_pos = extrema(positions(cc))
951+
throw(BoundsError(chk_vector, max_pos))
952+
end
953+
end
954+
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
955+
return SubManagedIndex(parent(index), pos)
956+
end
957+
Base.@propagate_inbounds function Base.view(index::SubManagedIndex, cc::CandidateWithChunks)
958+
SubManagedIndex(index, cc)
868959
end
869960
Base.@propagate_inbounds function SubChunkIndex(index::SubChunkIndex, cc::CandidateChunks)
870961
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
@@ -892,6 +983,18 @@ Base.@propagate_inbounds function SubChunkIndex(
892983
end
893984
return SubChunkIndex(parent(index), intersect_pos)
894985
end
986+
Base.@propagate_inbounds function SubManagedIndex(index::SubManagedIndex, cc::CandidateWithChunks)
987+
pos = indexid(index) == indexid(cc) ? positions(cc) : Int[]
988+
intersect_pos = intersect(pos, positions(index))
989+
@boundscheck let chk_vector = chunks(parent(index))
990+
if !checkbounds(Bool, axes(chk_vector, 1), intersect_pos)
991+
## Avoid printing huge position arrays, show the extremas of the attempted range
992+
max_pos = extrema(intersect_pos)
993+
throw(BoundsError(chk_vector, max_pos))
994+
end
995+
end
996+
return SubManagedIndex(parent(index), intersect_pos)
997+
end
895998

896999
## Getindex
8971000

0 commit comments

Comments
 (0)