Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
<gson.version>2.13.2</gson.version>
<httpclient.version>5.5.1</httpclient.version>
<lang3.version>3.20.0</lang3.version>
<junit.version>5.13.4</junit.version>
<testcontainers.version>1.21.3</testcontainers.version>
<junit.version>4.13.2</junit.version>
<testcontainers.version>2.0.2</testcontainers.version>
<assertj-core.version>3.27.6</assertj-core.version>
<jparams.version>1.0.4</jparams.version>
<mockito.version>5.20.0</mockito.version>
Expand Down Expand Up @@ -134,13 +134,13 @@
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>weaviate</artifactId>
<artifactId>testcontainers-weaviate</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>minio</artifactId>
<artifactId>testcontainers-minio</artifactId>
<version>${testcontainers.version}</version>
<scope>test</scope>
</dependency>
Expand All @@ -150,6 +150,12 @@
<version>${assertj-core.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.jparams</groupId>
<artifactId>jparams-junit4</artifactId>
Expand Down
81 changes: 81 additions & 0 deletions src/it/java/io/weaviate/integration/VectorizersITest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.weaviate.integration;

import java.io.IOException;
import java.util.Map;

import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TestRule;

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.v1.api.WeaviateClient;
import io.weaviate.client6.v1.api.collections.Property;
import io.weaviate.client6.v1.api.collections.VectorConfig;
import io.weaviate.client6.v1.api.collections.WeaviateObject;
import io.weaviate.client6.v1.api.collections.query.FetchObjectById;
import io.weaviate.containers.Container;
import io.weaviate.containers.Model2Vec;
import io.weaviate.containers.Weaviate;

import static org.assertj.core.api.Assertions.assertThat;

public class VectorizersITest extends ConcurrentTest {
private static final Container.ContainerGroup compose = Container.compose(
Weaviate.custom()
.withModel2VecUrl(Model2Vec.URL)
.build(),
Container.MODEL2VEC);
@ClassRule // Bind containers to the lifetime of the test
public static final TestRule _rule = compose.asTestRule();
private static final WeaviateClient client = compose.getClient();

@Test
public void testVectorizerModel2VecPropeties() throws IOException {
var collectionName = ns("Model2Vec2NamedVectors");
client.collections.create(collectionName,
col -> col
.properties(Property.text("name"), Property.text("author"))
.vectorConfig(
VectorConfig.text2vecModel2Vec("name", v -> v.sourceProperties("name")),
VectorConfig.text2vecModel2Vec("author", v -> v.sourceProperties("author"))
)
);

var model2vec = client.collections.use(collectionName);
assertThat(model2vec).isNotNull();

String uuid1 = "00000000-0000-0000-0000-000000000001";
WeaviateObject<Map<String, Object>> obj1 = WeaviateObject.of(o ->
o.properties(Map.of("name", "Dune", "author", "Frank Herbert")).uuid(uuid1)
);
String uuid2 = "00000000-0000-0000-0000-000000000002";
WeaviateObject<Map<String, Object>> obj2 = WeaviateObject.of(o ->
o.properties(Map.of("name", "same content", "author", "same content")).uuid(uuid2)
);

var resp = model2vec.data.insertMany(obj1, obj2);
assertThat(resp).isNotNull().satisfies(s -> {
assertThat(s.errors()).isEmpty();
});

var o1 = model2vec.query.fetchObjectById(uuid1, FetchObjectById.Builder::includeVector);
// Assert that for object1 we have generated 2 different vectors
assertThat(o1).get()
.extracting(WeaviateObject::vectors)
.satisfies(v -> {
assertThat(v.getSingle("name")).isNotEmpty();
assertThat(v.getSingle("author")).isNotEmpty();
assertThat(v.getSingle("name")).isNotEqualTo(v.getSingle("author"));
});

var o2 = model2vec.query.fetchObjectById(uuid2, FetchObjectById.Builder::includeVector);
// Assert that for object2 we have generated same vectors
assertThat(o2).get()
.extracting(WeaviateObject::vectors)
.satisfies(v -> {
assertThat(v.getSingle("name")).isNotEmpty();
assertThat(v.getSingle("author")).isNotEmpty();
assertThat(v.getSingle("name")).isEqualTo(v.getSingle("author"));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ public static Map.Entry<String, VectorConfig> multi2vecCohere(String vectorName,
*
* @param location Geographic region the Google Cloud model runs in.
*/
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String location) {
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, location);
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String projectId, String location) {
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, projectId, location);
}

/**
Expand All @@ -364,9 +364,10 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(String location) {
* @param fn Lambda expression for optional parameters.
*/
public static Map.Entry<String, VectorConfig> multi2vecGoogle(
String projectId,
String location,
Function<Multi2VecGoogleVectorizer.Builder, ObjectBuilder<Multi2VecGoogleVectorizer>> fn) {
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, location, fn);
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, projectId, location, fn);
}

/**
Expand All @@ -375,8 +376,8 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(
* @param vectorName Vector name.
* @param location Geographic region the Google Cloud model runs in.
*/
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName, String location) {
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(location));
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName, String projectId, String location) {
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(projectId, location));
}

/**
Expand All @@ -387,9 +388,9 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName,
* @param fn Lambda expression for optional parameters.
*/
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName,
String location,
String projectId, String location,
Function<Multi2VecGoogleVectorizer.Builder, ObjectBuilder<Multi2VecGoogleVectorizer>> fn) {
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(location, fn));
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(projectId, location, fn));
}

/** Create a vector index with an {@code multi2vec-jinaai} vectorizer. */
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.weaviate.client6.v1.api.collections.vectorizers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
Expand Down Expand Up @@ -46,7 +45,7 @@ public static class Builder implements ObjectBuilder<Img2VecNeuralVectorizer> {
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
private Quantization quantization;

private List<String> imageFields = new ArrayList<>();
private List<String> imageFields;

/** Add BLOB properties to include in the embedding. */
public Builder imageFields(List<String> fields) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.weaviate.client6.v1.api.collections.vectorizers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
Expand Down Expand Up @@ -52,12 +51,12 @@ public static class Builder implements ObjectBuilder<Multi2MultiVecJinaAiVectori
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
private Quantization quantization;

private final List<String> imageFields = new ArrayList<>();
private final List<String> textFields = new ArrayList<>();
private List<String> imageFields;
private List<String> textFields;

/** Add BLOB properties to include in the embedding. */
public Builder imageFields(List<String> fields) {
imageFields.addAll(fields);
imageFields = fields;
return this;
}

Expand All @@ -68,7 +67,7 @@ public Builder imageFields(String... fields) {

/** Add TEXT properties to include in the embedding. */
public Builder textFields(List<String> fields) {
textFields.addAll(fields);
textFields = fields;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package io.weaviate.client6.v1.api.collections.vectorizers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import com.google.gson.annotations.SerializedName;
Expand Down Expand Up @@ -64,11 +63,9 @@ public Multi2VecAwsVectorizer(Builder builder) {
builder.model,
builder.dimensions,
builder.region,
builder.imageFields.keySet().stream().toList(),
builder.textFields.keySet().stream().toList(),
new Weights(
builder.imageFields.values().stream().toList(),
builder.textFields.values().stream().toList()),
builder.imageFields,
builder.textFields,
builder.getWeights(),
builder.vectorIndex,
builder.quantization);
}
Expand All @@ -77,8 +74,10 @@ public static class Builder implements ObjectBuilder<Multi2VecAwsVectorizer> {
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
private Quantization quantization;

private Map<String, Float> imageFields = new LinkedHashMap<>();
private Map<String, Float> textFields = new LinkedHashMap<>();
private List<String> imageFields;
private List<Float> imageWeights;
private List<String> textFields;
private List<Float> textWeights;

private String model;
private Integer dimensions;
Expand All @@ -101,7 +100,7 @@ public Builder region(String region) {

/** Add BLOB properties to include in the embedding. */
public Builder imageFields(List<String> fields) {
fields.forEach(field -> imageFields.put(field, null));
this.imageFields = fields;
return this;
}

Expand All @@ -117,13 +116,20 @@ public Builder imageFields(String... fields) {
* @param weight Custom weight between 0.0 and 1.0.
*/
public Builder imageField(String field, float weight) {
imageFields.put(field, weight);
if (this.imageFields == null) {
this.imageFields = new ArrayList<>();
}
if (this.imageWeights == null) {
this.imageWeights = new ArrayList<>();
}
this.imageFields.add(field);
this.imageWeights.add(weight);
return this;
}

/** Add TEXT properties to include in the embedding. */
public Builder textFields(List<String> fields) {
fields.forEach(field -> textFields.put(field, null));
this.textFields = fields;
return this;
}

Expand All @@ -139,10 +145,24 @@ public Builder textFields(String... fields) {
* @param weight Custom weight between 0.0 and 1.0.
*/
public Builder textField(String field, float weight) {
textFields.put(field, weight);
if (this.textFields == null) {
this.textFields = new ArrayList<>();
}
if (this.textWeights == null) {
this.textWeights = new ArrayList<>();
}
this.textFields.add(field);
this.textWeights.add(weight);
return this;
}

protected Weights getWeights() {
if (this.textWeights != null || this.imageWeights != null) {
return new Weights(this.imageWeights, this.textWeights);
}
return null;
}

/**
* Override default vector index configuration.
*
Expand Down
Loading
Loading