Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions cmd/migrate_from_qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"os/signal"
"runtime"
"sort"
"strings"
"sync"
"syscall"
"time"

Expand Down Expand Up @@ -151,24 +149,38 @@ func (r *MigrateFromQdrantCmd) prepareTargetCollection(ctx context.Context, sour
fmt.Print("\n")
pterm.Info.Printfln("Target collection '%s' already exists. Skipping creation.", targetCollection)
} else {
params := sourceCollectionInfo.Config.GetParams()
config := sourceCollectionInfo.GetConfig()
params := config.GetParams()
if err := targetClient.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: targetCollection,
HnswConfig: sourceCollectionInfo.Config.GetHnswConfig(),
WalConfig: sourceCollectionInfo.Config.GetWalConfig(),
OptimizersConfig: sourceCollectionInfo.Config.GetOptimizerConfig(),
ShardNumber: &params.ShardNumber,
OnDiskPayload: &params.OnDiskPayload,
VectorsConfig: params.VectorsConfig,
ReplicationFactor: params.ReplicationFactor,
WriteConsistencyFactor: params.WriteConsistencyFactor,
QuantizationConfig: sourceCollectionInfo.Config.GetQuantizationConfig(),
ShardingMethod: params.ShardingMethod,
SparseVectorsConfig: params.SparseVectorsConfig,
StrictModeConfig: sourceCollectionInfo.Config.GetStrictModeConfig(),
HnswConfig: config.GetHnswConfig(),
WalConfig: config.GetWalConfig(),
OptimizersConfig: config.GetOptimizerConfig(),
ShardNumber: qdrant.PtrOf(params.GetShardNumber()),
OnDiskPayload: qdrant.PtrOf(params.GetOnDiskPayload()),
VectorsConfig: params.GetVectorsConfig(),
ReplicationFactor: qdrant.PtrOf(params.GetReplicationFactor()),
WriteConsistencyFactor: qdrant.PtrOf(params.GetWriteConsistencyFactor()),
QuantizationConfig: config.GetQuantizationConfig(),
ShardingMethod: params.GetShardingMethod().Enum(),
SparseVectorsConfig: params.GetSparseVectorsConfig(),
StrictModeConfig: config.GetStrictModeConfig(),
Metadata: config.GetMetadata(),
}); err != nil {
return fmt.Errorf("failed to create target collection: %w", err)
}

if params.GetShardingMethod() == qdrant.ShardingMethod_Custom {
shardKeys, err := sourceClient.ListShardKeys(ctx, sourceCollection)
if err != nil {
return fmt.Errorf("failed to list shard keys: %w", err)
}
for _, shardKey := range shardKeys {
if err := targetClient.CreateShardKey(ctx, targetCollection, &qdrant.CreateShardKey{ShardKey: shardKey.GetKey()}); err != nil {
return fmt.Errorf("failed to create shard key: %w", err)
}
}
}
}
}

Expand Down Expand Up @@ -322,7 +334,7 @@ func (r *MigrateFromQdrantCmd) samplePointIDs(ctx context.Context, client *qdran

// processBatch handles the upserting of a batch of points to the target collection.
// It deals with sharding by creating shard keys if they don't exist and retries on transient errors.
func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdrant.RetrievedPoint, targetClient *qdrant.Client, targetCollection string, shardKeys *sync.Map, wait bool) error {
func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdrant.RetrievedPoint, targetClient *qdrant.Client, targetCollection string, wait bool) error {
// Group points by their shard key.
byShardKey := make(map[string][]*qdrant.PointStruct)
shardKeyObjs := make(map[string]*qdrant.ShardKey)
Expand Down Expand Up @@ -353,14 +365,6 @@ func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdran
Wait: qdrant.PtrOf(wait),
}
if key != "" {
// If the shard key is new, create it on the target collection.
if _, ok := shardKeys.Load(key); !ok {
err := targetClient.CreateShardKey(ctx, targetCollection, &qdrant.CreateShardKey{ShardKey: shardKeyObjs[key]})
if err != nil && !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("failed to create shard key %s: %w", key, err)
}
shardKeys.Store(key, true)
}
// Specify the shard key for the upsert request.
req.ShardKeySelector = &qdrant.ShardKeySelector{ShardKeys: []*qdrant.ShardKey{shardKeyObjs[key]}}
}
Expand Down Expand Up @@ -395,7 +399,6 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source

bar, _ := pterm.DefaultProgressbar.WithTotal(int(sourcePointCount)).Start()
displayMigrationProgress(bar, count)
shardKeys := &sync.Map{}

for {
// Scroll through points from the source collection in batches.
Expand All @@ -411,7 +414,7 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source
}

points := resp.GetResult()
if err := r.processBatch(ctx, points, targetClient, targetCollection, shardKeys, true); err != nil {
if err := r.processBatch(ctx, points, targetClient, targetCollection, true); err != nil {
return err
}

Expand Down Expand Up @@ -478,15 +481,14 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl
displayMigrationProgress(bar, totalProcessed)

// Use a semaphore to limit the number of concurrent workers.
shardKeys := &sync.Map{}
errs := make(chan error, len(ranges))
sem := make(chan struct{}, r.NumWorkers)

// Start a goroutine for each range.
for _, rg := range ranges {
sem <- struct{}{}
go func(rg rangeSpec) {
errs <- r.migrateRange(ctx, sourceCollection, targetCollection, sourceClient, targetClient, rg, shardKeys, bar)
errs <- r.migrateRange(ctx, sourceCollection, targetCollection, sourceClient, targetClient, rg, bar)
<-sem
}(rg)
}
Expand All @@ -507,7 +509,7 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl

// migrateRange is the function executed by each worker in parallel migration.
// It scrolls through a specific range of points and upserts them to the target.
func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollection, targetCollection string, sourceClient, targetClient *qdrant.Client, rg rangeSpec, shardKeys *sync.Map, bar *pterm.ProgressbarPrinter) error {
func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollection, targetCollection string, sourceClient, targetClient *qdrant.Client, rg rangeSpec, bar *pterm.ProgressbarPrinter) error {
offsetKey := fmt.Sprintf("%s-workers-%d-range-%d", sourceCollection, r.NumWorkers, rg.id)
offset := rg.start
var count uint64
Expand Down Expand Up @@ -542,7 +544,7 @@ func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollectio
}
}

if err := r.processBatch(ctx, points, targetClient, targetCollection, shardKeys, false); err != nil {
if err := r.processBatch(ctx, points, targetClient, targetCollection, false); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
//nolint:unparam
func qdrantContainer(ctx context.Context, t *testing.T, apiKey string) testcontainers.Container {
req := testcontainers.ContainerRequest{
Image: "qdrant/qdrant:v1.16.1",
Image: "qdrant/qdrant:v1.17.0",
ExposedPorts: []string{"6333/tcp", "6334/tcp"},
Env: map[string]string{
"QDRANT__CLUSTER__ENABLED": "true",
Expand Down