Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<phase>deploy</phase>
<goals>
<goal>sign</goal>
</goals>
Expand Down
26 changes: 25 additions & 1 deletion src/it/java/io/weaviate/integration/CollectionsITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.weaviate.client6.v1.api.collections.DataType;
import io.weaviate.client6.v1.api.collections.InvertedIndex;
import io.weaviate.client6.v1.api.collections.Property;
import io.weaviate.client6.v1.api.collections.Quantization;
import io.weaviate.client6.v1.api.collections.ReferenceProperty;
import io.weaviate.client6.v1.api.collections.Replication;
import io.weaviate.client6.v1.api.collections.VectorConfig;
Expand Down Expand Up @@ -194,7 +195,7 @@ public void testInvalidCollectionName() throws IOException {
}

@Test
public void testNestedProperties() throws IOException, Exception {
public void testNestedProperties() throws IOException {
var nsBuildings = ns("Buildings");

client.collections.create(
Expand Down Expand Up @@ -227,4 +228,27 @@ public void testNestedProperties() throws IOException, Exception {
.extracting(Property::dataTypes).extracting(types -> types.get(0))
.containsExactly(DataType.INT, DataType.NUMBER);
}

@Test
public void test_updateQuantization() throws IOException {
// Arrange
var nsThings = ns("Things");

var things = client.collections.create(nsThings,
c -> c.vectorConfig(VectorConfig.selfProvided(
self -> self.quantization(Quantization.uncompressed()))));

// Act
things.config.update(
c -> c.vectorConfig(VectorConfig.selfProvided(
self -> self.quantization(Quantization.bq()))));

// Assert
var config = things.config.get();
Assertions.assertThat(config).get()
.extracting(CollectionConfig::vectors)
.extracting("default", InstanceOfAssertFactories.type(VectorConfig.class))
.extracting(VectorConfig::quantization)
.returns(Quantization.Kind.BQ, Quantization::_kind);
}
}
2 changes: 1 addition & 1 deletion src/it/java/io/weaviate/integration/DataITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import io.weaviate.client6.v1.api.collections.data.BatchReference;
import io.weaviate.client6.v1.api.collections.data.DeleteManyResponse;
import io.weaviate.client6.v1.api.collections.data.Reference;
import io.weaviate.client6.v1.api.collections.query.Filter;
import io.weaviate.client6.v1.api.collections.query.Metadata;
import io.weaviate.client6.v1.api.collections.query.Metadata.MetadataField;
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
import io.weaviate.client6.v1.api.collections.query.QueryReference;
import io.weaviate.client6.v1.api.collections.query.Filter;
import io.weaviate.containers.Container;

public class DataITest extends ConcurrentTest {
Expand Down
34 changes: 32 additions & 2 deletions src/it/java/io/weaviate/integration/SearchITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import io.weaviate.ConcurrentTest;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.api.WeaviateClient;
import io.weaviate.client6.v1.api.collections.Generative;
import io.weaviate.client6.v1.api.collections.ObjectMetadata;
import io.weaviate.client6.v1.api.collections.Property;
import io.weaviate.client6.v1.api.collections.ReferenceProperty;
import io.weaviate.client6.v1.api.collections.Reranker;
import io.weaviate.client6.v1.api.collections.VectorConfig;
import io.weaviate.client6.v1.api.collections.Vectors;
import io.weaviate.client6.v1.api.collections.WeaviateMetadata;
Expand All @@ -37,8 +39,10 @@
import io.weaviate.client6.v1.api.collections.query.Metadata;
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
import io.weaviate.client6.v1.api.collections.query.Rerank;
import io.weaviate.client6.v1.api.collections.query.SortBy;
import io.weaviate.client6.v1.api.collections.query.Target;
import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker;
import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw;
import io.weaviate.client6.v1.api.collections.vectorindex.MultiVector;
import io.weaviate.containers.Container;
Expand All @@ -52,7 +56,7 @@ public class SearchITest extends ConcurrentTest {
Weaviate.custom()
.withModel2VecUrl(Model2Vec.URL)
.withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE)
.addModules("generative-dummy")
.addModules(Generative.Kind.DUMMY.jsonValue(), Reranker.Kind.DUMMY.jsonValue())
.build(),
Container.IMG2VEC_NEURAL,
Container.MODEL2VEC);
Expand Down Expand Up @@ -741,7 +745,7 @@ public void teset_filterPropertyLength() throws IOException {
// Assertions
Assertions.assertThat(got.objects()).hasSize(2);
}

/**
* Ensure the client respects server's configuration for max gRPC size:
* we create a server with 1-byte message size and try to send a large payload
Expand All @@ -768,4 +772,30 @@ public void test_maxGrpcMessageSize() throws Exception {
}).isInstanceOf(io.grpc.StatusRuntimeException.class);
}
}

@Test
public void test_rerankQueries() throws IOException {
// Arrange
var nsThigns = ns("Things");

var things = client.collections.create(nsThigns,
c -> c
.properties(Property.text("title"), Property.integer("price"))
.vectorConfig(VectorConfig.text2vecModel2Vec(
t2v -> t2v.sourceProperties("title", "price")))
.rerankerModules(new DummyReranker()));

things.data.insertMany(
Map.of("title", "Ergonomic chair", "price", 269),
Map.of("title", "Height-adjustable desk", "price", 349));

// Act
var got = things.query.nearText(
"office supplies",
nt -> nt.rerank(Rerank.by("price",
rank -> rank.query("cheaper first"))));

// Assert: ranking not important really, just that the request was valid.
Assertions.assertThat(got.objects()).hasSize(2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -276,19 +276,23 @@ public void write(JsonWriter out, CollectionConfig value) throws IOException {
// Reranker and Generative module configs belong to the "moduleConfig".
var rerankerModules = jsonObject.remove("rerankerModules").getAsJsonArray();
var generativeModule = jsonObject.remove("generativeModule");
if (!rerankerModules.isEmpty() || !generativeModule.isJsonNull()) {
var modules = new JsonObject();

var modules = new JsonObject();
if (!rerankerModules.isEmpty()) {
// Copy configuration for each reranker module.
rerankerModules.forEach(reranker -> {
reranker.getAsJsonObject().entrySet()
.stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue()));
});
}

if (!generativeModule.isJsonNull()) {
// Copy configuration for each generative module.
generativeModule.getAsJsonObject().entrySet()
.stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue()));
}

if (!modules.isEmpty()) {
jsonObject.add("moduleConfig", modules);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import io.weaviate.client6.v1.api.collections.quantizers.SQ;
import io.weaviate.client6.v1.api.collections.quantizers.Uncompressed;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.TaggedUnion;
import io.weaviate.client6.v1.internal.json.JsonEnum;

public interface Quantization {
public interface Quantization extends TaggedUnion<Quantization.Kind, Object> {

public enum Kind implements JsonEnum<Kind> {
UNCOMPRESSED("skipDefaultQuantization"),
Expand Down Expand Up @@ -112,6 +113,46 @@ public static Quantization rq(Function<RQ.Builder, ObjectBuilder<RQ>> fn) {
return RQ.of(fn);
}

default BQ asBQ() {
return _as(Quantization.Kind.BQ);
}

default RQ asRQ() {
return _as(Quantization.Kind.RQ);
}

default PQ asPQ() {
return _as(Quantization.Kind.PQ);
}

default SQ asSQ() {
return _as(Quantization.Kind.SQ);
}

default Uncompressed asUncompressed() {
return _as(Quantization.Kind.UNCOMPRESSED);
}

default boolean isBQ() {
return _is(Quantization.Kind.BQ);
}

default boolean isRQ() {
return _is(Quantization.Kind.RQ);
}

default boolean isPQ() {
return _is(Quantization.Kind.PQ);
}

default boolean isSQ() {
return _is(Quantization.Kind.SQ);
}

default boolean isUncompressed() {
return _is(Quantization.Kind.UNCOMPRESSED);
}

public static enum CustomTypeAdapterFactory implements TypeAdapterFactory {
INSTANCE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.google.gson.stream.JsonWriter;

import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker;
import io.weaviate.client6.v1.api.collections.rerankers.DummyReranker;
import io.weaviate.client6.v1.api.collections.rerankers.JinaAiReranker;
import io.weaviate.client6.v1.api.collections.rerankers.NvidiaReranker;
import io.weaviate.client6.v1.api.collections.rerankers.TransformersReranker;
Expand All @@ -24,6 +25,7 @@

public interface Reranker extends TaggedUnion<Reranker.Kind, Object> {
public enum Kind implements JsonEnum<Kind> {
DUMMY("reranker-dummy"),
JINAAI("reranker-jinaai"),
VOYAGEAI("reranker-voyageai"),
NVIDIA("reranker-nvidia"),
Expand Down Expand Up @@ -120,6 +122,11 @@ private final void addAdapter(Gson gson, Reranker.Kind kind, Class<? extends Rer

private final void init(Gson gson) {
addAdapter(gson, Reranker.Kind.COHERE, CohereReranker.class);
addAdapter(gson, Reranker.Kind.JINAAI, JinaAiReranker.class);
addAdapter(gson, Reranker.Kind.NVIDIA, NvidiaReranker.class);
addAdapter(gson, Reranker.Kind.TRANSFORMERS, TransformersReranker.class);
addAdapter(gson, Reranker.Kind.VOYAGEAI, VoyageAiReranker.class);
addAdapter(gson, Reranker.Kind.DUMMY, DummyReranker.class);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1606,15 +1606,18 @@ public VectorConfig read(JsonReader in) throws IOException {
var vectorIndexConfig = jsonObject.get("vectorIndexConfig").getAsJsonObject();

String quantizationKind = null;
if (vectorIndexConfig.has(Quantization.Kind.BQ.jsonValue())) {
quantizationKind = Quantization.Kind.BQ.jsonValue();
} else if (vectorIndexConfig.has(Quantization.Kind.PQ.jsonValue())) {
quantizationKind = Quantization.Kind.PQ.jsonValue();
} else if (vectorIndexConfig.has(Quantization.Kind.SQ.jsonValue())) {
quantizationKind = Quantization.Kind.SQ.jsonValue();
} else if (vectorIndexConfig.has(Quantization.Kind.RQ.jsonValue())) {
quantizationKind = Quantization.Kind.RQ.jsonValue();
} else {
for (var kind : new String[] {
Quantization.Kind.BQ.jsonValue(),
Quantization.Kind.PQ.jsonValue(),
Quantization.Kind.SQ.jsonValue(),
Quantization.Kind.RQ.jsonValue() }) {
if (vectorIndexConfig.has(kind)
&& vectorIndexConfig.get(kind).getAsJsonObject().get("enabled").getAsBoolean()) {
quantizationKind = kind;
}
}
if (quantizationKind == null && vectorIndexConfig.has(Quantization.Kind.UNCOMPRESSED.jsonValue())
&& vectorIndexConfig.get(Quantization.Kind.UNCOMPRESSED.jsonValue()).getAsBoolean()) {
quantizationKind = Quantization.Kind.UNCOMPRESSED.jsonValue();
}

Expand Down Expand Up @@ -1649,7 +1652,7 @@ public VectorConfig read(JsonReader in) throws IOException {
// Each individual vectorizer has a `Quantization quantization` field.
// We need to specify the kind in order for
// Quantization.CustomTypeAdapterFactory to be able to find the right adapter.
if (vectorIndexConfig.has(quantizationKind)) {
if (quantizationKind != null && vectorIndexConfig.has(quantizationKind)) {
JsonObject quantization = new JsonObject();
quantization.add(quantizationKind, vectorIndexConfig.get(quantizationKind));
concreteVectorizer.add("quantization", quantization);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import java.util.List;
import java.util.function.Function;

import io.weaviate.client6.v1.api.collections.query.Filter;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase;

public record Aggregation(
AggregateObjectFilter filter,
Filter whereFilter,
Integer objectLimit,
boolean includeTotalCount,
List<PropertyAggregation> returnMetrics) {
Expand All @@ -29,6 +32,7 @@ public static Aggregation of(AggregateObjectFilter objectFilter, Function<Builde
public Aggregation(Builder builder) {
this(
builder.objectFilter,
builder.whereFilter,
builder.objectLimit,
builder.includeTotalCount,
builder.metrics);
Expand All @@ -41,6 +45,7 @@ public Builder(AggregateObjectFilter objectFilter) {
this.objectFilter = objectFilter;
}

private Filter whereFilter;
private List<PropertyAggregation> metrics = new ArrayList<>();
private Integer objectLimit;
private boolean includeTotalCount = false;
Expand All @@ -55,6 +60,24 @@ public final Builder includeTotalCount(boolean include) {
return this;
}

/**
* Filter result set using traditional filtering operators: {@code eq},
* {@code gte}, {@code like}, etc.
* Subsequent calls to {@link #filter} aggregate with an AND operator.
*/
public final Builder filters(Filter filter) {
this.whereFilter = this.whereFilter == null
? filter
: Filter.and(this.whereFilter, filter);
return this;
}

/** Combine several conditions using with an AND operator. */
public final Builder filters(Filter... filters) {
Arrays.stream(filters).map(this::filters);
return this;
}

@SafeVarargs
public final Builder metrics(PropertyAggregation... metrics) {
this.metrics = Arrays.asList(metrics);
Expand All @@ -80,6 +103,12 @@ public void appendTo(WeaviateProtoAggregate.AggregateRequest.Builder req) {
req.setObjectLimit(objectLimit);
}

if (whereFilter != null) {
var protoFilters = WeaviateProtoBase.Filters.newBuilder();
whereFilter.appendTo(protoFilters);
req.setFilters(protoFilters);
}

for (final var metric : returnMetrics) {
var aggregation = WeaviateProtoAggregate.AggregateRequest.Aggregation.newBuilder();
metric.appendTo(aggregation);
Expand Down
Loading