Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ private QueryResults executeQuery(String expr, long offset, long limit, long ts,
QueryReq queryReq = QueryReq.builder()
.databaseName(queryIteratorReq.getDatabaseName())
.collectionName(queryIteratorReq.getCollectionName())
.clusterId(queryIteratorReq.getClusterId())
.partitionNames(queryIteratorReq.getPartitionNames())
.consistencyLevel(queryIteratorReq.getConsistencyLevel())
.outputFields(outputFields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public class SearchIterator {
private Float filteredDistance = null;
private Map<String, Object> params;
private final RpcUtils rpcUtils;
private String clusterId = "";
private long sessionTs = 0;

public SearchIterator(SearchIteratorParam searchIteratorParam,
Expand Down Expand Up @@ -113,6 +114,7 @@ public SearchIterator(SearchIteratorReq searchIteratorReq,
this.expr = this.searchIteratorParam.getExpr();
this.topK = this.searchIteratorParam.getTopK();
this.rpcUtils = new RpcUtils();
this.clusterId = searchIteratorReq.getClusterId();

initParams();
checkForSpecialIndexParam();
Expand Down Expand Up @@ -292,6 +294,13 @@ private SearchResults executeSearch(Map<String, Object> params, String nextExpr,
.setKey(Constant.ITERATOR_FIELD)
.setValue(String.valueOf(Boolean.TRUE))
.build());
if (StringUtils.isNotEmpty(clusterId)) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.CLUSTER_ID)
.setValue(clusterId)
.build());
}
Comment thread
yhmo marked this conversation as resolved.

// pass the session ts to search interface
builder.setGuaranteeTimestamp(ts).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ private SearchResults executeSearch(int limit) {
.collectionName(searchIteratorReq.getCollectionName())
.partitionNames(searchIteratorReq.getPartitionNames())
.databaseName(searchIteratorReq.getDatabaseName())
.clusterId(searchIteratorReq.getClusterId())
.annsField(searchIteratorReq.getVectorFieldName())
.data(searchIteratorReq.getVectors())
.limit(limit)
Expand Down
1 change: 1 addition & 0 deletions sdk-core/src/main/java/io/milvus/param/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class Constant {
public static final String TIMEZONE = "timezone";
public static final String REDUCE_STOP_FOR_BEST = "reduce_stop_for_best";
public static final String ITERATOR_FIELD = "iterator";
public static final String CLUSTER_ID = "cluster_id";
public static final String GROUP_BY_FIELD = "group_by_field";
public static final String GROUP_SIZE = "group_size";
public static final String STRICT_GROUP_SIZE = "strict_group_size";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ public String currentUsedDatabase() {
return dbName;
}

public MilvusClientV2Session session(String clusterId) {
if (StringUtils.isEmpty(clusterId)) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "clusterId cannot be null or empty");
}
return new MilvusClientV2Session(this, clusterId);
}


/////////////////////////////////////////////////////////////////////////////////////////////
// Database Operations
Expand Down
245 changes: 245 additions & 0 deletions sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2Session.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.milvus.v2.client;

import io.milvus.orm.iterator.QueryIterator;
import io.milvus.orm.iterator.SearchIterator;
import io.milvus.orm.iterator.SearchIteratorV2;
import io.milvus.v2.exception.ErrorCode;
import io.milvus.v2.exception.MilvusClientException;
import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.QueryResp;
import io.milvus.v2.service.vector.response.SearchResp;
import org.apache.commons.lang3.StringUtils;

public class MilvusClientV2Session {
private final MilvusClientV2 parent;
private final String clusterId;
private boolean closed = false;

MilvusClientV2Session(MilvusClientV2 parent, String clusterId) {
this.parent = parent;
this.clusterId = clusterId;
}

public SearchResp search(SearchReq request) {
ensureOpen();
return parent.search(copy(request));
}

public SearchResp hybridSearch(HybridSearchReq request) {
ensureOpen();
return parent.hybridSearch(copy(request));
}

public QueryResp query(QueryReq request) {
ensureOpen();
return parent.query(copy(request));
}

public QueryIterator queryIterator(QueryIteratorReq request) {
ensureOpen();
return parent.queryIterator(copy(request));
}

public SearchIterator searchIterator(SearchIteratorReq request) {
ensureOpen();
return parent.searchIterator(copy(request));
}

public SearchIteratorV2 searchIteratorV2(SearchIteratorReqV2 request) {
ensureOpen();
return parent.searchIteratorV2(copy(request));
}

public GetResp get(GetReq request) {
ensureOpen();
return parent.get(copy(request));
}

public void close() {
closed = true;
}

private void ensureOpen() {
if (closed) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "MilvusClient session is closed");
}
}

private void checkClusterId(String requestClusterId) {
if (StringUtils.isNotEmpty(requestClusterId) && !clusterId.equals(requestClusterId)) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "clusterId conflicts with session clusterId");
}
}

private SearchReq copy(SearchReq request) {
checkClusterId(request.getClusterId());
return SearchReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.annsField(request.getAnnsField())
.metricType(request.getMetricType())
.filter(request.getFilter())
.outputFields(request.getOutputFields())
.data(request.getData())
.ids(request.getIds())
.offset(request.getOffset())
.limit(request.getLimit())
.roundDecimal(request.getRoundDecimal())
.searchParams(request.getSearchParams())
.guaranteeTimestamp(request.getGuaranteeTimestamp())
.gracefulTime(request.getGracefulTime())
.consistencyLevel(request.getConsistencyLevel())
.ignoreGrowing(request.isIgnoreGrowing())
.timezone(request.getTimezone())
.groupByFieldName(request.getGroupByFieldName())
.groupSize(request.getGroupSize())
.strictGroupSize(request.getStrictGroupSize())
.ranker(request.getRanker())
.functionScore(request.getFunctionScore())
.filterTemplateValues(request.getFilterTemplateValues())
.highlighter(request.getHighlighter())
.build();
}

private HybridSearchReq copy(HybridSearchReq request) {
checkClusterId(request.getClusterId());
return HybridSearchReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.searchRequests(request.getSearchRequests())
.ranker(request.getRanker())
.functionScore(request.getFunctionScore())
.limit(request.getLimit())
.outFields(request.getOutFields())
.offset(request.getOffset())
.roundDecimal(request.getRoundDecimal())
.consistencyLevel(request.getConsistencyLevel())
.groupByFieldName(request.getGroupByFieldName())
.groupSize(request.getGroupSize())
.strictGroupSize(request.getStrictGroupSize())
.build();
}

private QueryReq copy(QueryReq request) {
checkClusterId(request.getClusterId());
return QueryReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.outputFields(request.getOutputFields())
.ids(request.getIds())
.filter(request.getFilter())
.consistencyLevel(request.getConsistencyLevel())
.offset(request.getOffset())
.limit(request.getLimit())
.ignoreGrowing(request.isIgnoreGrowing())
.timezone(request.getTimezone())
.queryParams(request.getQueryParams())
.filterTemplateValues(request.getFilterTemplateValues())
.build();
}

private QueryIteratorReq copy(QueryIteratorReq request) {
checkClusterId(request.getClusterId());
return QueryIteratorReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.outputFields(request.getOutputFields())
.expr(request.getExpr())
.consistencyLevel(request.getConsistencyLevel())
.offset(request.getOffset())
.limit(request.getLimit())
.ignoreGrowing(request.isIgnoreGrowing())
.timezone(request.getTimezone())
.batchSize(request.getBatchSize())
.reduceStopForBest(request.isReduceStopForBest())
.filterTemplateValues(request.getFilterTemplateValues())
.build();
}

private SearchIteratorReq copy(SearchIteratorReq request) {
checkClusterId(request.getClusterId());
return SearchIteratorReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.metricType(request.getMetricType())
.vectorFieldName(request.getVectorFieldName())
.limit(request.getLimit())
.expr(request.getExpr())
.outputFields(request.getOutputFields())
.vectors(request.getVectors())
.roundDecimal(request.getRoundDecimal())
.params(request.getParams())
.consistencyLevel(request.getConsistencyLevel())
.ignoreGrowing(request.isIgnoreGrowing())
.groupByFieldName(request.getGroupByFieldName())
.batchSize(request.getBatchSize())
.build();
}

private SearchIteratorReqV2 copy(SearchIteratorReqV2 request) {
checkClusterId(request.getClusterId());
return SearchIteratorReqV2.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionNames(request.getPartitionNames())
.metricType(request.getMetricType())
.vectorFieldName(request.getVectorFieldName())
.limit(request.getLimit())
.filter(request.getFilter())
.outputFields(request.getOutputFields())
.vectors(request.getVectors())
.roundDecimal(request.getRoundDecimal())
.searchParams(request.getSearchParams())
.consistencyLevel(request.getConsistencyLevel())
.ignoreGrowing(request.isIgnoreGrowing())
.timezone(request.getTimezone())
.groupByFieldName(request.getGroupByFieldName())
.batchSize(request.getBatchSize())
.externalFilterFunc(request.getExternalFilterFunc())
.filterTemplateValues(request.getFilterTemplateValues())
.build();
}

private GetReq copy(GetReq request) {
checkClusterId(request.getClusterId());
return GetReq.builder()
.databaseName(request.getDatabaseName())
.collectionName(request.getCollectionName())
.clusterId(clusterId)
.partitionName(request.getPartitionName())
.ids(request.getIds())
.outputFields(request.getOutputFields())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -350,11 +351,15 @@ public GetResp get(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, Get
String collectionName = request.getCollectionName();
String title = String.format("Get entities of collection: '%s' in database: '%s'", collectionName, dbName);
logger.debug(title);
QueryReq queryReq = QueryReq.builder()
QueryReq.QueryReqBuilder queryReqBuilder = QueryReq.builder()
.databaseName(dbName)
.collectionName(collectionName)
.ids(request.getIds())
.build();
.clusterId(request.getClusterId())
.ids(request.getIds());
if (StringUtils.isNotEmpty(request.getPartitionName())) {
queryReqBuilder.partitionNames(Collections.singletonList(request.getPartitionName()));
}
QueryReq queryReq = queryReqBuilder.build();
if (request.getOutputFields() != null) {
queryReq.setOutputFields(request.getOutputFields());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
public class GetReq {
private String databaseName;
private String collectionName;
private String clusterId;
private String partitionName = "";
private List<Object> ids;
private List<String> outputFields;

private GetReq(GetReqBuilder builder) {
this.databaseName = builder.databaseName;
this.collectionName = builder.collectionName;
this.clusterId = builder.clusterId;
this.partitionName = builder.partitionName;
this.ids = builder.ids;
this.outputFields = builder.outputFields;
Expand All @@ -56,6 +58,14 @@ public void setCollectionName(String collectionName) {
this.collectionName = collectionName;
}

public String getClusterId() {
return clusterId;
}

public void setClusterId(String clusterId) {
this.clusterId = clusterId;
}

public String getPartitionName() {
return partitionName;
}
Expand Down Expand Up @@ -85,6 +95,7 @@ public String toString() {
return "GetReq{" +
"databaseName='" + databaseName + '\'' +
", collectionName='" + collectionName + '\'' +
", clusterId='" + clusterId + '\'' +
", partitionName='" + partitionName + '\'' +
", ids=" + ids +
", outputFields=" + outputFields +
Expand All @@ -94,6 +105,7 @@ public String toString() {
public static class GetReqBuilder {
private String databaseName;
private String collectionName;
private String clusterId;
private String partitionName = "";
private List<Object> ids;
private List<String> outputFields;
Expand All @@ -108,6 +120,11 @@ public GetReqBuilder collectionName(String collectionName) {
return this;
}

public GetReqBuilder clusterId(String clusterId) {
this.clusterId = clusterId;
return this;
}

public GetReqBuilder partitionName(String partitionName) {
this.partitionName = partitionName;
return this;
Expand Down
Loading
Loading