Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions integration/test_batch_v4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import concurrent.futures
import uuid
from dataclasses import dataclass
Expand Down Expand Up @@ -819,38 +818,50 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None:
client.collections.delete(["target", "source"])


def test_ingest_one_hundred_thousand_data_objects(
client_factory: ClientFactory,
) -> None:
client, name = client_factory()
if client._connection._weaviate_version.is_lower_than(1, 34, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.34.0")
nr_objects = 100000
import time

start = time.time()
results = client.collections.use(name).data.ingest(
{"name": "test" + str(i)} for i in range(nr_objects)
)
end = time.time()
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
assert len(results.errors) == 0
assert len(results.all_responses) == nr_objects
assert len(results.uuids) == nr_objects
assert len(client.collections.use(name)) == nr_objects
assert results.has_errors is False
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]
client.collections.delete(name)


@pytest.mark.asyncio
async def test_add_ten_thousand_data_objects_async(
async def test_ingest_one_hundred_thousand_data_objects_async(
async_client_factory: AsyncClientFactory,
) -> None:
"""Test adding ten thousand data objects."""
client, name = await async_client_factory()
if client._connection._weaviate_version.is_lower_than(1, 34, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.34.0")
nr_objects = 100000
import time

start = time.time()
async with client.batch.experimental(concurrency=1) as batch:
async for i in arange(nr_objects):
await batch.add_object(
collection=name,
properties={"name": "test" + str(i)},
)
results = await client.collections.use(name).data.ingest(
{"name": "test" + str(i)} for i in range(nr_objects)
)
end = time.time()
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
assert len(client.batch.results.objs.errors) == 0
assert len(client.batch.results.objs.all_responses) == nr_objects
assert len(client.batch.results.objs.uuids) == nr_objects
assert len(results.errors) == 0
assert len(results.all_responses) == nr_objects
assert len(results.uuids) == nr_objects
assert await client.collections.use(name).length() == nr_objects
assert client.batch.results.objs.has_errors is False
assert len(client.batch.failed_objects) == 0, [
obj.message for obj in client.batch.failed_objects
]
assert results.has_errors is False
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]
await client.collections.delete(name)


async def arange(count):
for i in range(count):
yield i
await asyncio.sleep(0.0)
3 changes: 1 addition & 2 deletions weaviate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .auth import AuthCredentials
from .backup import _Backup, _BackupAsync
from .cluster import _Cluster, _ClusterAsync
from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync
from .collections.batch.client import _BatchClientWrapper
from .collections.collections import _Collections, _CollectionsAsync
from .config import AdditionalConfig
from .connect import executor
Expand Down Expand Up @@ -76,7 +76,6 @@ def __init__(
)
self.alias = _AliasAsync(self._connection)
self.backup = _BackupAsync(self._connection)
self.batch = _BatchClientWrapperAsync(self._connection)
self.cluster = _ClusterAsync(self._connection)
self.collections = _CollectionsAsync(self._connection)
self.debug = _DebugAsync(self._connection)
Expand Down
3 changes: 1 addition & 2 deletions weaviate/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from weaviate.users.sync import _Users

from .backup import _Backup, _BackupAsync
from .cluster import _Cluster, _ClusterAsync
from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync
from .collections.batch.client import _BatchClientWrapper
from .debug import _Debug, _DebugAsync
from .rbac import _Roles, _RolesAsync
from .types import NUMBER
Expand All @@ -29,7 +29,6 @@ class WeaviateAsyncClient(_WeaviateClientExecutor[ConnectionAsync]):
_connection: ConnectionAsync
alias: _AliasAsync
backup: _BackupAsync
batch: _BatchClientWrapperAsync
collections: _CollectionsAsync
cluster: _ClusterAsync
debug: _DebugAsync
Expand Down
122 changes: 45 additions & 77 deletions weaviate/collections/batch/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import uuid as uuid_package
from typing import (
AsyncGenerator,
Awaitable,
Callable,
List,
Optional,
Set,
Expand All @@ -15,8 +17,6 @@
ObjectsBatchRequest,
ReferencesBatchRequest,
_BatchDataWrapper,
_BatchMode,
_ServerSideBatching,
)
from weaviate.collections.batch.grpc_batch import _BatchGRPC
from weaviate.collections.classes.batch import (
Expand All @@ -36,6 +36,7 @@
from weaviate.collections.classes.types import WeaviateProperties
from weaviate.connect.v4 import ConnectionAsync
from weaviate.exceptions import (
WeaviateBatchStreamError,
WeaviateBatchValidationError,
WeaviateGRPCUnavailableError,
WeaviateStartUpError,
Expand All @@ -57,7 +58,6 @@ def __init__(
connection: ConnectionAsync,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
batch_mode: _BatchMode,
objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None,
references: Optional[ReferencesBatchRequest] = None,
) -> None:
Expand Down Expand Up @@ -95,21 +95,57 @@ def __init__(

self.__stop = False

self.__batch_mode = batch_mode

@property
def number_errors(self) -> int:
"""Return the number of errors in the batch."""
return len(self.__results_for_wrapper.failed_objects) + len(
self.__results_for_wrapper.failed_references
)

async def __wrap(self, fn: Callable[[], Awaitable[None]]):
try:
await fn()
except Exception as e:
socket_hung_up = False
if isinstance(e, WeaviateBatchStreamError) and (
"Socket closed" in e.message or "context canceled" in e.message
):
socket_hung_up = True
else:
logger.error(e)
logger.error(type(e))
self.__bg_thread_exception = e
if socket_hung_up:
# this happens during ungraceful shutdown of the coordinator
# lets restart the stream and add the cached objects again
logger.warning("Stream closed unexpectedly, restarting...")
await self.__reconnect()
# server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now
self.__is_shutting_down.clear()
with self.__objs_cache_lock:
logger.warning(
f"Re-adding {len(self.__objs_cache)} cached objects to the batch"
)
await self.__batch_objects.aprepend(
[
self.__batch_grpc.grpc_object(o._to_internal())
for o in self.__objs_cache.values()
]
)
with self.__refs_cache_lock:
await self.__batch_references.aprepend(
[
self.__batch_grpc.grpc_reference(o._to_internal())
for o in self.__refs_cache.values()
]
)
# start a new fn with a newly reconnected channel
return await fn()

async def _start(self):
assert isinstance(self.__batch_mode, _ServerSideBatching), (
"Only server-side batching is supported in this mode"
)
return _BgTasks(
send=asyncio.create_task(self.__send()), recv=asyncio.create_task(self.__recv())
send=asyncio.create_task(self.__wrap(self.__send)),
recv=asyncio.create_task(self.__wrap(self.__recv)),
)

async def _shutdown(self) -> None:
Expand Down Expand Up @@ -332,74 +368,6 @@ async def __reconnect(self, retry: int = 0) -> None:
logger.error("Failed to reconnect after 5 attempts")
self.__bg_thread_exception = e

# def __start_bg_threads(self) -> _BgThreads:
# """Create a background thread that periodically checks how congested the batch queue is."""
# self.__shut_background_thread_down = threading.Event()

# def batch_send_wrapper() -> None:
# try:
# self.__batch_send()
# logger.warning("exited batch send thread")
# except Exception as e:
# logger.error(e)
# self.__bg_thread_exception = e

# def batch_recv_wrapper() -> None:
# socket_hung_up = False
# try:
# self.__batch_recv()
# logger.warning("exited batch receive thread")
# except Exception as e:
# if isinstance(e, WeaviateBatchStreamError) and (
# "Socket closed" in e.message or "context canceled" in e.message
# ):
# socket_hung_up = True
# else:
# logger.error(e)
# logger.error(type(e))
# self.__bg_thread_exception = e
# if socket_hung_up:
# # this happens during ungraceful shutdown of the coordinator
# # lets restart the stream and add the cached objects again
# logger.warning("Stream closed unexpectedly, restarting...")
# self.__reconnect()
# # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now
# self.__is_shutting_down.clear()
# with self.__objs_cache_lock:
# logger.warning(
# f"Re-adding {len(self.__objs_cache)} cached objects to the batch"
# )
# self.__batch_objects.prepend(
# [
# self.__batch_grpc.grpc_object(o._to_internal())
# for o in self.__objs_cache.values()
# ]
# )
# with self.__refs_cache_lock:
# self.__batch_references.prepend(
# [
# self.__batch_grpc.grpc_reference(o._to_internal())
# for o in self.__refs_cache.values()
# ]
# )
# # start a new stream with a newly reconnected channel
# return batch_recv_wrapper()

# threads = _BgThreads(
# send=threading.Thread(
# target=batch_send_wrapper,
# daemon=True,
# name="BgBatchSend",
# ),
# recv=threading.Thread(
# target=batch_recv_wrapper,
# daemon=True,
# name="BgBatchRecv",
# ),
# )
# threads.start_recv()
# return threads

async def flush(self) -> None:
"""Flush the batch queue and wait for all requests to be finished."""
# bg thread is sending objs+refs automatically, so simply wait for everything to be done
Expand Down
Loading