Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions milvus/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
dependencies {
api(project(":core"))
api(project(":linq4j"))

implementation("io.milvus:milvus-sdk-java:2.5.13")

testImplementation(platform("org.junit:junit-bom:5.10.0"))
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.testcontainers:testcontainers")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.milvus.convention;

import org.apache.calcite.adapter.milvus.factory.MilvusTranslatableTable;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;

import java.util.List;
import java.util.Map;

/**
* Relational expression that uses Milvus calling convention.
*/
public interface MilvusRel extends RelNode {
void implement(Implementor implementor);

Convention CONVENTION = new Convention.Impl("MILVUS", MilvusRel.class);

/**
* Implementor for Milvus relational expressions.
*/
class Implementor {
public final RexBuilder rexBuilder;

// scan
public RelOptTable table;
public MilvusTranslatableTable milvusTable;
public RelDataType rowType;
public Map<String, String> milvusOptions;

// filter
public RexNode filterCondition;

// project
public RelDataType projectRowType;
public List<RexNode> projects;

// vector search
public RexNode vectorDistanceExpr;
public Integer vectorDistanceFieldIndex;
public RelFieldCollation.Direction sortOrder;
public RexNode limit;

public Implementor(RexBuilder rexBuilder) {
this.rexBuilder = rexBuilder;
}

public void visitChild(int ordinal, RelNode input) {
assert ordinal == 0;

RelNode node = findMilvusRel(input);

if (!(node instanceof MilvusRel)) {
throw new IllegalStateException(
"Expected MilvusRel input but got "
+ (node == null ? "null" : node.getClass().getName())
+ " (original=" + input.getClass().getName() + ")");
}
((MilvusRel) node).implement(this);
}
}

static RelNode findMilvusRel(RelNode input) {
RelNode node = input;
if (node instanceof RelSubset) {
final RelSubset subset = (RelSubset) node;
RelNode best = subset.getBest();
if (best != null) {
node = best;
} else {
// find first MilvusRel in the subset
for (RelNode r : subset.getRelList()) {
if (r instanceof MilvusRel) {
node = r;
break;
}
}
}
}
return node;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.milvus.convention;

import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.milvus.factory.MilvusTranslatableTable;
import org.apache.calcite.adapter.milvus.operation.MilvusProjectExpression;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.plan.ConventionTraitDef;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterImpl;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Pair;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;

/**
* MilvusToEnumerableConverter converts a relational expression
* from Milvus calling convention to Enumerable calling convention.
*/
public class MilvusToEnumerableConverter
extends ConverterImpl
implements EnumerableRel {
protected MilvusToEnumerableConverter(
RelOptCluster cluster,
RelTraitSet traits,
RelNode input) {
super(cluster, ConventionTraitDef.INSTANCE, traits, input);
}

@Override public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
return new MilvusToEnumerableConverter(
getCluster(), traitSet, sole(inputs));
}

@Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
RelMetadataQuery mq) {
return super.computeSelfCost(planner, mq).multiplyBy(.1);
}

@Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
final BlockBuilder list = new BlockBuilder();
final MilvusRel.Implementor milvusImplementor =
new MilvusRel.Implementor(getCluster().getRexBuilder());
milvusImplementor.visitChild(0, getInput());

final Expression root = implementor.getRootExpression();
final Expression schema =
Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method);

// scan
final List<String> qualifiedTableName = milvusImplementor.table.getQualifiedName();
final Expression table = getScanInfo(qualifiedTableName, schema);
final Expression tableExpr =
Expressions.convert_(table, MilvusTranslatableTable.class);
// project
final RelDataType rowType = milvusImplementor.projectRowType != null
? milvusImplementor.projectRowType
: getRowType();

final PhysType physType =
PhysTypeImpl.of(
implementor.getTypeFactory(), rowType,
pref.prefer(JavaRowFormat.ARRAY));

final List<RexNode> projects = milvusImplementor.projects;
List<Pair<Integer, MilvusProjectExpression>>
projectInfo = getProjectInfo(projects, rowType, milvusImplementor.rowType, physType);
final Expression projectInfoExpr =
list.append("projectRowTypeMapForEnumerator", expressionForProjectPairs(projectInfo));

Expression enumerable =
list.append(
"enumerable", Expressions.call(tableExpr,
"scan",
Expressions.constant(""),
projectInfoExpr));

list.add(Expressions.return_(null, enumerable));
return implementor.result(physType, list.toBlock());
}

private static Expression getScanInfo(List<String> qualifiedName,
Expression schema) {
final String schemaName = qualifiedName.size() > 1 ? qualifiedName.get(0) : null;
final String tableName = qualifiedName.get(qualifiedName.size() - 1);

Expression current = schema;

if (schemaName != null) {
current =
Expressions.call(current, BuiltInMethod.SCHEMA_GET_SUB_SCHEMA.method,
Expressions.constant(schemaName));
current = Expressions.convert_(current, Schema.class);
}

return Expressions.call(current,
BuiltInMethod.SCHEMA_GET_TABLE.method,
Expressions.constant(tableName));
}

private static List<Pair<Integer, MilvusProjectExpression>> getProjectInfo(List<RexNode> projects,
RelDataType rowType, RelDataType inputRowType, PhysType physType) {
List<Pair<Integer, MilvusProjectExpression>> projectInfo = new ArrayList<>();
if (projects != null) {
for (int i = 0; i < projects.size(); i++) {
RexNode project = projects.get(i);
Class<?> fieldClass = physType.fieldClass(i);
MilvusProjectExpression expr;

if (project instanceof RexInputRef) {
int inputIndex = ((RexInputRef) project).getIndex();
String originalFieldName = inputRowType.getFieldNames().get(inputIndex);
expr = new MilvusProjectExpression.InputField(originalFieldName, fieldClass);
} else if (project instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) project;
expr = new MilvusProjectExpression.Constant(fieldClass, literal.getValue3());
} else {
throw new UnsupportedOperationException("Unsupported project type");
}
projectInfo.add(Pair.of(i, expr));
}
} else {
List<String> inputFields = rowType.getFieldNames();
for (int i = 0; i < inputFields.size(); i++) {
String fieldName = inputFields.get(i);
Class<?> fieldClass = physType.fieldClass(i);
projectInfo.add(
Pair.of(i,
new MilvusProjectExpression.InputField(fieldName, fieldClass)));
}
}
return projectInfo;
}

private Expression expressionForProjectExpression(MilvusProjectExpression expr) {
if (expr instanceof MilvusProjectExpression.InputField) {
String fieldName = ((MilvusProjectExpression.InputField) expr).getFieldName();
return Expressions.new_(MilvusProjectExpression.InputField.class,
Expressions.constant(fieldName),
Expressions.constant(expr.getClazz(), Class.class));
} else if (expr instanceof MilvusProjectExpression.Constant) {
Object value = ((MilvusProjectExpression.Constant) expr).getValue();
return Expressions.new_(MilvusProjectExpression.Constant.class,
Expressions.constant(expr.getClazz(), Class.class),
Expressions.constant(value));
} else if (expr instanceof MilvusProjectExpression.VectorScore) {
return Expressions.new_(MilvusProjectExpression.VectorScore.class,
Expressions.constant(expr.getClazz(), Class.class));
} else {
throw new AssertionError("Unknown expression type: " + expr);
}
}

private Expression expressionForProjectPairs(List<Pair<Integer, MilvusProjectExpression>> pairs) {
List<Expression> pairExpressions = new ArrayList<>();

for (Pair<Integer, MilvusProjectExpression> pair : pairs) {
Expression first = Expressions.constant(pair.left, Integer.class);
Expression second = expressionForProjectExpression(pair.right);
Type pairType = Types.of(Pair.class, Integer.class, MilvusProjectExpression.class);
Expression pairExpr = Expressions.new_(pairType, first, second);
pairExpressions.add(pairExpr);
}
return Expressions.call(BuiltInMethod.ARRAYS_AS_LIST.method,
Expressions.newArrayInit(Pair.class, pairExpressions));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.milvus.convention;

import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterRule;

/**
* Rule to convert a relational expression from
* {@link MilvusRel#CONVENTION Milvus calling convention} to
* {@link org.apache.calcite.adapter.enumerable.EnumerableConvention Enumerable calling convention}.
*/
public class MilvusToEnumerableConverterRule extends ConverterRule {
public static final MilvusToEnumerableConverterRule INSTANCE =
Config.INSTANCE
.withConversion(MilvusRel.class, MilvusRel.CONVENTION,
EnumerableConvention.INSTANCE, "MilvusToEnumerableConverterRule")
.withRuleFactory(MilvusToEnumerableConverterRule::new)
.toRule(MilvusToEnumerableConverterRule.class);

protected MilvusToEnumerableConverterRule(Config config) {
super(config);
}

@Override public RelNode convert(RelNode relNode) {
final RelTraitSet traitSet = relNode.getTraitSet()
.replace(getOutTrait());
return new MilvusToEnumerableConverter(relNode.getCluster(), traitSet,
relNode);
}

@Override public void onMatch(RelOptRuleCall call) {
final RelNode rel = call.rel(0);
if (rel.getConvention() == getOutTrait()) {
return;
}
final RelNode converted = convert(rel);
if (converted != null) {
call.transformTo(converted);
}
}
}
Loading
Loading