Skip to content

Commit 651df08

Browse files
committed
Improved SQLModel tests [skip ci]
1 parent b350d6a commit 651df08

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

tests/test_sqlmodel.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ class Item(SQLModel, table=True):
3434

3535

3636
def create_items():
37-
session = Session(engine)
38-
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
39-
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
40-
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
41-
session.commit()
37+
with Session(engine) as session:
38+
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
39+
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
40+
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
41+
session.commit()
4242

4343

4444
class TestSqlmodel:
@@ -52,11 +52,11 @@ def test_orm(self):
5252
item2 = Item(embedding=[4, 5, 6])
5353
item3 = Item()
5454

55-
session = Session(engine)
56-
session.add(item)
57-
session.add(item2)
58-
session.add(item3)
59-
session.commit()
55+
with Session(engine) as session:
56+
session.add(item)
57+
session.add(item2)
58+
session.add(item3)
59+
session.commit()
6060

6161
stmt = select(Item)
6262
with Session(engine) as session:
@@ -71,11 +71,11 @@ def test_orm(self):
7171
assert items[2].embedding is None
7272

7373
def test_vector(self):
74-
session = Session(engine)
75-
session.add(Item(id=1, embedding=[1, 2, 3]))
76-
session.commit()
77-
item = session.get(Item, 1)
78-
assert item.embedding.tolist() == [1, 2, 3]
74+
with Session(engine) as session:
75+
session.add(Item(id=1, embedding=[1, 2, 3]))
76+
session.commit()
77+
item = session.get(Item, 1)
78+
assert item.embedding.tolist() == [1, 2, 3]
7979

8080
def test_vector_l2_distance(self):
8181
create_items()
@@ -102,11 +102,11 @@ def test_vector_l1_distance(self):
102102
assert [v.id for v in items] == [1, 3, 2]
103103

104104
def test_halfvec(self):
105-
session = Session(engine)
106-
session.add(Item(id=1, half_embedding=[1, 2, 3]))
107-
session.commit()
108-
item = session.get(Item, 1)
109-
assert item.half_embedding.to_list() == [1, 2, 3]
105+
with Session(engine) as session:
106+
session.add(Item(id=1, half_embedding=[1, 2, 3]))
107+
session.commit()
108+
item = session.get(Item, 1)
109+
assert item.half_embedding.to_list() == [1, 2, 3]
110110

111111
def test_halfvec_l2_distance(self):
112112
create_items()
@@ -133,11 +133,11 @@ def test_halfvec_l1_distance(self):
133133
assert [v.id for v in items] == [1, 3, 2]
134134

135135
def test_bit(self):
136-
session = Session(engine)
137-
session.add(Item(id=1, binary_embedding='101'))
138-
session.commit()
139-
item = session.get(Item, 1)
140-
assert item.binary_embedding == '101'
136+
with Session(engine) as session:
137+
session.add(Item(id=1, binary_embedding='101'))
138+
session.commit()
139+
item = session.get(Item, 1)
140+
assert item.binary_embedding == '101'
141141

142142
def test_bit_hamming_distance(self):
143143
create_items()
@@ -152,11 +152,11 @@ def test_bit_jaccard_distance(self):
152152
assert [v.id for v in items] == [2, 3, 1]
153153

154154
def test_sparsevec(self):
155-
session = Session(engine)
156-
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
157-
session.commit()
158-
item = session.get(Item, 1)
159-
assert item.sparse_embedding.to_list() == [1, 2, 3]
155+
with Session(engine) as session:
156+
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
157+
session.commit()
158+
item = session.get(Item, 1)
159+
assert item.sparse_embedding.to_list() == [1, 2, 3]
160160

161161
def test_sparsevec_l2_distance(self):
162162
create_items()
@@ -232,7 +232,7 @@ def test_halfvec_sum(self):
232232

233233
def test_bad_dimensions(self):
234234
item = Item(embedding=[1, 2])
235-
session = Session(engine)
236-
session.add(item)
237-
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
238-
session.commit()
235+
with Session(engine) as session:
236+
session.add(item)
237+
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
238+
session.commit()

0 commit comments

Comments
 (0)