Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-dgraph-vector-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
dgraph-vector-tests:
if: github.event.pull_request.draft == false
runs-on: warp-ubuntu-latest-x64-4x
timeout-minutes: 30
timeout-minutes: 120
steps:
- uses: actions/checkout@v5
- name: Set up Go
Expand Down
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"--security",
"whitelist=0.0.0.0/0;"
],
"showLog": true
"showLog": false
},
{
"name": "Zero",
Expand All @@ -25,7 +25,7 @@
"program": "${workspaceRoot}/dgraph/",
"env": {},
"args": ["zero"],
"showLog": true
"showLog": false
},
{
"name": "AlphaACL",
Expand Down
284 changes: 284 additions & 0 deletions posting/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
"github.com/hypermodeinc/dgraph/v25/schema"
"github.com/hypermodeinc/dgraph/v25/tok"
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
"github.com/hypermodeinc/dgraph/v25/tok/kmeans"

"github.com/hypermodeinc/dgraph/v25/types"
"github.com/hypermodeinc/dgraph/v25/x"
)
Expand Down Expand Up @@ -1412,6 +1415,284 @@
return prefixes, nil
}

func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
pk := x.ParsedKey{Attr: rb.Attr}

indexer, err := factorySpecs[0].CreateIndex(pk.Attr)
if err != nil {
return err
}

dimension := indexer.Dimension()
// If dimension is -1, it means that the dimension is not set through options in case of partitioned hnsw.
if dimension == -1 {
numVectorsToCheck := 100
lenFreq := make(map[int]int, numVectorsToCheck)
maxFreq := 0
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{

Check failure on line 1432 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `MemLayerInstance.IterateDisk` is not checked
Prefix: pk.DataPrefix(),
ReadTs: rb.StartTs,
AllVersions: false,
Reverse: false,
CheckInclusion: func(uid uint64) error {
return nil
},
Function: func(l *List, pk x.ParsedKey) error {
val, err := l.Value(rb.StartTs)
if err != nil {
return err
}
inVec := types.BytesAsFloatArray(val.Value.([]byte))
lenFreq[len(inVec)] += 1
if lenFreq[len(inVec)] > maxFreq {
maxFreq = lenFreq[len(inVec)]
dimension = len(inVec)
}
numVectorsToCheck -= 1
if numVectorsToCheck <= 0 {
return ErrStopIteration
}
return nil
},
StartKey: x.DataKey(rb.Attr, 0),
})

indexer.SetDimension(rb.CurrentSchema, dimension)
}

fmt.Println("Selecting vector dimension to be:", dimension)

norm := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
norm.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
val, err := pl.Value(rb.StartTs)
if err != nil {
return nil, err
}
if val.Tid == types.VFloatID {
return nil, nil
}

// Convert to VFloatID and persist as binary bytes.
sv, err := types.Convert(val, types.VFloatID)
if err != nil {
return nil, err
}
b := types.ValueForType(types.BinaryID)
if err = types.Marshal(sv, &b); err != nil {
return nil, err
}

edge := &pb.DirectedEdge{
Attr: rb.Attr,
Entity: uid,
Value: b.Value.([]byte),
ValueType: types.VFloatID.Enum(),
}
inKey := x.DataKey(edge.Attr, uid)
p, err := txn.Get(inKey)
if err != nil {
return []*pb.DirectedEdge{}, err
}

if err := p.addMutation(ctx, txn, edge); err != nil {
return []*pb.DirectedEdge{}, err
}
return nil, nil
}

if err := norm.RunWithoutTemp(ctx); err != nil {
return err
}

count := 0

if indexer.NumSeedVectors() > 0 {
err := MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
Prefix: pk.DataPrefix(),
ReadTs: rb.StartTs,
AllVersions: false,
Reverse: false,
CheckInclusion: func(uid uint64) error {
return nil
},
Function: func(l *List, pk x.ParsedKey) error {
val, err := l.Value(rb.StartTs)
if err != nil {
return err
}

if val.Tid != types.VFloatID {
// Here, we convert the defaultID type vector into vfloat.
sv, err := types.Convert(val, types.VFloatID)
if err != nil {
return err
}
b := types.ValueForType(types.BinaryID)
if err = types.Marshal(sv, &b); err != nil {
return err
}

val.Value = b.Value
val.Tid = types.VFloatID
}

inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension {
return fmt.Errorf("vector dimension mismatch expected dimension %d but got %d", dimension, len(inVec))
}
count += 1
indexer.AddSeedVector(inVec)
if count == indexer.NumSeedVectors() {
return ErrStopIteration
}
return nil
},
StartKey: x.DataKey(rb.Attr, 0),
})
if err != nil {
return err
}
}

txns := make([]*Txn, indexer.NumThreads())
for i := range txns {
txns[i] = NewTxn(rb.StartTs)
}
caches := make([]tokIndex.CacheType, indexer.NumThreads())
for i := range caches {
caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
}

if count < indexer.NumSeedVectors() {
indexer.SetNumPasses(0)
}

for pass_idx := range indexer.NumBuildPasses() {
fmt.Println("Building pass", pass_idx)

indexer.StartBuild(caches)

builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
val, err := pl.Value(rb.StartTs)
if err != nil {
return []*pb.DirectedEdge{}, err
}

inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension {
return []*pb.DirectedEdge{}, nil
}
indexer.BuildInsert(ctx, uid, inVec)

Check failure on line 1586 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `indexer.BuildInsert` is not checked
return []*pb.DirectedEdge{}, nil
}

err := builder.RunWithoutTemp(ctx)
if err != nil {
return err
}

indexer.EndBuild()
}

centroids := indexer.GetCentroids()

if centroids != nil {
txn := NewTxn(rb.StartTs)

bCentroids, err := json.Marshal(centroids)
if err != nil {
return err
}

if err := addCentroidInDB(ctx, rb.Attr, bCentroids, txn); err != nil {
return err
}
txn.Update()
writer := NewTxnWriter(pstore)
if err := txn.CommitToDisk(writer, rb.StartTs); err != nil {
return err
}
}

numIndexPasses := indexer.NumIndexPasses()

if count < indexer.NumSeedVectors() {
numIndexPasses = 1
}

for pass_idx := range numIndexPasses {
fmt.Println("Indexing pass", pass_idx)

indexer.StartBuild(caches)

builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
val, err := pl.Value(rb.StartTs)
if err != nil {
return []*pb.DirectedEdge{}, err
}

inVec := types.BytesAsFloatArray(val.Value.([]byte))
if len(inVec) != dimension && centroids != nil {
if pass_idx == 0 {
glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec))
}
return []*pb.DirectedEdge{}, nil
}

indexer.BuildInsert(ctx, uid, inVec)

Check failure on line 1644 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `indexer.BuildInsert` is not checked

return []*pb.DirectedEdge{}, nil
}

err := builder.RunWithoutTemp(ctx)
if err != nil {
return err
}

for _, idx := range indexer.EndBuild() {
txns[idx].Update()
writer := NewTxnWriter(pstore)

x.ExponentialRetry(int(x.Config.MaxRetries),

Check failure on line 1658 in posting/index.go

View workflow job for this annotation

GitHub Actions / Trunk Check

golangci-lint2(errcheck)

[new] Error return value of `x.ExponentialRetry` is not checked
20*time.Millisecond, func() error {
err := txns[idx].CommitToDisk(writer, rb.StartTs)
if err == badger.ErrBannedKey {
glog.Errorf("Error while writing to banned namespace.")
return nil
}
return err
})

txns[idx].cache.plists = nil
txns[idx] = nil
}
}

return nil
}

func addCentroidInDB(ctx context.Context, attr string, vec []byte, txn *Txn) error {
indexCountAttr := hnsw.ConcatStrings(attr, kmeans.CentroidPrefix)
countKey := x.DataKey(indexCountAttr, 1)
pl, err := txn.Get(countKey)
if err != nil {
return err
}

edge := &pb.DirectedEdge{
Entity: 1,
Attr: indexCountAttr,
Value: vec,
ValueType: pb.Posting_ValType(12),
}
if err := pl.addMutation(ctx, txn, edge); err != nil {
return err
}
return nil
}

// rebuildTokIndex rebuilds index for a given attribute.
// We commit mutations with startTs and ignore the errors.
func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
Expand Down Expand Up @@ -1443,6 +1724,9 @@
}

runForVectors := (len(factorySpecs) != 0)
if runForVectors {
return rebuildVectorIndex(ctx, factorySpecs, rb)
}

pk := x.ParsedKey{Attr: rb.Attr}
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
Expand Down
Loading
Loading