Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ env:
WEAVIATE_132: 1.32.16
WEAVIATE_133: 1.33.4
WEAVIATE_134: 1.34.0
WEAVIATE_135: 1.35.0-dev-8d38bb2.amd64

jobs:
lint-and-format:
Expand Down Expand Up @@ -304,7 +305,8 @@ jobs:
$WEAVIATE_131,
$WEAVIATE_132,
$WEAVIATE_133,
$WEAVIATE_134
$WEAVIATE_134,
$WEAVIATE_135
]
steps:
- name: Checkout
Expand Down
4 changes: 4 additions & 0 deletions profiling/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _factory(
headers=headers,
additional_config=AdditionalConfig(timeout=(60, 120)), # for image tests
)
# client_fixture = weaviate.connect_to_weaviate_cloud(
# cluster_url="flnyoj61teuw1mxfwf1fsa.c0.europe-west3.gcp.weaviate.cloud",
# auth_credentials=weaviate.auth.Auth.api_key("QnVtdnlnM2RYeUh3NVlFNF82V3pqVEtoYnloMlo0MHV2R2hYMU9BUFFsR3cvUUlkUG9CTFRiQXNjam1nPV92MjAw"),
# )
client_fixture.collections.delete(name_fixture)
if integration_config is not None:
client_fixture.integrations.configure(integration_config)
Expand Down
4 changes: 2 additions & 2 deletions profiling/test_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
start = time.time()

import_objects = 1000000
with collection.batch.dynamic() as batch:
with collection.batch.experimental() as batch:
with open(sphere_file) as jsonl_file:
for i, jsonl in enumerate(jsonl_file):
if i == import_objects or batch.number_errors > 10:
Expand All @@ -46,7 +46,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
vector=json_parsed["vector"],
)
if i % 1000 == 0:
print(f"Imported {len(collection)} objects")
print(f"Imported {len(collection)} objects after processing {i} lines")
assert len(collection.batch.failed_objects) == 0
assert len(collection) == import_objects
print(f"Imported {import_objects} objects in {time.time() - start}")
69 changes: 40 additions & 29 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def prepend(self, item: List[TBatchInput]) -> None:
self._lock.release()


Ref = TypeVar("Ref", bound=Union[_BatchReference, batch_pb2.BatchReference])
Ref = TypeVar("Ref", bound=BatchReference)


class ReferencesBatchRequest(BatchRequest[Ref, BatchReferenceReturn]):
Expand All @@ -111,8 +111,9 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
i = 0
self._lock.acquire()
while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items):
if self._items[i].from_uuid not in uuid_lookup and (
self._items[i].to_uuid is None or self._items[i].to_uuid not in uuid_lookup
if self._items[i].from_object_uuid not in uuid_lookup and (
self._items[i].to_object_uuid is None
or self._items[i].to_object_uuid not in uuid_lookup
):
ret.append(self._items.pop(i))
else:
Expand All @@ -132,7 +133,7 @@ def head(self) -> Optional[Ref]:
return item


Obj = TypeVar("Obj", bound=Union[_BatchObject, batch_pb2.BatchObject])
Obj = TypeVar("Obj", bound=BatchObject)


class ObjectsBatchRequest(Generic[Obj], BatchRequest[Obj, BatchObjectReturn]):
Expand Down Expand Up @@ -843,11 +844,11 @@ def __init__(
batch_mode: _BatchMode,
executor: ThreadPoolExecutor,
vectorizer_batching: bool,
objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None,
references: Optional[ReferencesBatchRequest] = None,
objects: Optional[ObjectsBatchRequest[BatchObject]] = None,
references: Optional[ReferencesBatchRequest[BatchReference]] = None,
) -> None:
self.__batch_objects = objects or ObjectsBatchRequest[batch_pb2.BatchObject]()
self.__batch_references = references or ReferencesBatchRequest[batch_pb2.BatchReference]()
self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]()
self.__batch_references = references or ReferencesBatchRequest[BatchReference]()

self.__connection = connection
self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM
Expand Down Expand Up @@ -879,6 +880,10 @@ def __init__(
self.__objs_cache: dict[str, BatchObject] = {}
self.__refs_cache: dict[str, BatchReference] = {}

self.__acks_lock = threading.Lock()
self.__inflight_objs: set[str] = set()
self.__inflight_refs: set[str] = set()

# maxsize=1 so that __batch_send does not run faster than generator for __batch_recv
# thereby using too much buffer in case of server-side shutdown
self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1)
Expand Down Expand Up @@ -1005,10 +1010,13 @@ def __batch_send(self) -> None:
return
time.sleep(refresh_time)

def __beacon(self, ref: batch_pb2.BatchReference) -> str:
return f"weaviate://localhost/{ref.from_collection}{f'#{ref.tenant}' if ref.tenant != '' else ''}/{ref.from_uuid}#{ref.name}->/{ref.to_collection}/{ref.to_uuid}"

def __generate_stream_requests(
self,
objs: List[batch_pb2.BatchObject],
refs: List[batch_pb2.BatchReference],
objects: List[BatchObject],
references: List[BatchReference],
) -> Generator[batch_pb2.BatchStreamRequest, None, None]:
per_object_overhead = 4 # extra overhead bytes per object in the request

Expand All @@ -1018,7 +1026,8 @@ def request_maker():
request = request_maker()
total_size = request.ByteSize()

for obj in objs:
for object_ in objects:
obj = self.__batch_grpc.grpc_object(object_._to_internal())
obj_size = obj.ByteSize() + per_object_overhead

if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size:
Expand All @@ -1028,8 +1037,12 @@ def request_maker():

request.data.objects.values.append(obj)
total_size += obj_size
if self.__connection._weaviate_version.is_at_least(1, 35, 0):
with self.__acks_lock:
self.__inflight_objs.add(obj.uuid)

for ref in refs:
for reference in references:
ref = self.__batch_grpc.grpc_reference(reference._to_internal())
ref_size = ref.ByteSize() + per_object_overhead

if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size:
Expand All @@ -1039,6 +1052,9 @@ def request_maker():

request.data.references.values.append(ref)
total_size += ref_size
if self.__connection._weaviate_version.is_at_least(1, 35, 0):
with self.__acks_lock:
self.__inflight_refs.add(reference._to_beacon())

if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0:
yield request
Expand Down Expand Up @@ -1091,6 +1107,10 @@ def __batch_recv(self) -> None:
logger.warning(
f"Updated batch size to {self.__batch_size} as per server request"
)
if message.HasField("acks"):
with self.__acks_lock:
self.__inflight_objs.difference_update(message.acks.uuids)
self.__inflight_refs.difference_update(message.acks.beacons)
if message.HasField("results"):
result_objs = BatchObjectReturn()
result_refs = BatchReferenceReturn()
Expand Down Expand Up @@ -1241,19 +1261,9 @@ def batch_recv_wrapper() -> None:
logger.warning(
f"Re-adding {len(self.__objs_cache)} cached objects to the batch"
)
self.__batch_objects.prepend(
[
self.__batch_grpc.grpc_object(o._to_internal())
for o in self.__objs_cache.values()
]
)
self.__batch_objects.prepend(list(self.__objs_cache.values()))
with self.__refs_cache_lock:
self.__batch_references.prepend(
[
self.__batch_grpc.grpc_reference(o._to_internal())
for o in self.__refs_cache.values()
]
)
self.__batch_references.prepend(list(self.__refs_cache.values()))
# start a new stream with a newly reconnected channel
return batch_recv_wrapper()

Expand Down Expand Up @@ -1307,14 +1317,14 @@ def _add_object(
uuid = str(batch_object.uuid)
with self.__uuid_lookup_lock:
self.__uuid_lookup.add(uuid)
self.__batch_objects.add(self.__batch_grpc.grpc_object(batch_object._to_internal()))
self.__batch_objects.add(batch_object)
with self.__objs_cache_lock:
self.__objs_cache[uuid] = batch_object
self.__objs_count += 1

# block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
# not need a long queue
while len(self.__batch_objects) >= self.__batch_size * 2:
while len(self.__inflight_objs) >= self.__batch_size * 2:
self.__check_bg_threads_alive()
time.sleep(0.01)

Expand Down Expand Up @@ -1352,12 +1362,13 @@ def _add_reference(
)
except ValidationError as e:
raise WeaviateBatchValidationError(repr(e))
self.__batch_references.add(
self.__batch_grpc.grpc_reference(batch_reference._to_internal())
)
self.__batch_references.add(batch_reference)
with self.__refs_cache_lock:
self.__refs_cache[batch_reference._to_beacon()] = batch_reference
self.__refs_count += 1
while len(self.__inflight_refs) >= self.__batch_size * 2:
self.__check_bg_threads_alive()
time.sleep(0.01)

def __check_bg_threads_alive(self) -> None:
if self.__any_threads_alive():
Expand Down
Loading
Loading