Skip to content

Commit 91f5d34

Browse files
committed
Added test for vector[] type with SQLAlchemy and asyncpg [skip ci]
1 parent 257eb3b commit 91f5d34

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,26 @@ async def test_asyncpg_sparsevec(self):
606606
assert item.sparse_embedding.to_list() == embedding
607607

608608
await engine.dispose()
609+
610+
@pytest.mark.asyncio
611+
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
612+
async def test_asyncpg_vector_array(self):
613+
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
614+
async_session = async_sessionmaker(engine, expire_on_commit=False)
615+
616+
# TODO do not throw error when types are registered
617+
# @event.listens_for(engine.sync_engine, "connect")
618+
# def connect(dbapi_connection, connection_record):
619+
# from pgvector.asyncpg import register_vector
620+
# dbapi_connection.run_async(register_vector)
621+
622+
async with async_session() as session:
623+
async with session.begin():
624+
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
625+
626+
# this fails if the driver does not cast arrays
627+
item = await session.get(Item, 1)
628+
assert item.embeddings[0].tolist() == [1, 2, 3]
629+
assert item.embeddings[1].tolist() == [4, 5, 6]
630+
631+
await engine.dispose()

0 commit comments

Comments
 (0)