@@ -34,11 +34,11 @@ class Item(SQLModel, table=True):
3434
3535
3636def create_items ():
37- session = Session (engine )
38- session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
39- session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
40- session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
41- session .commit ()
37+ with Session (engine ) as session :
38+ session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
39+ session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
40+ session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
41+ session .commit ()
4242
4343
4444class TestSqlmodel :
@@ -52,11 +52,11 @@ def test_orm(self):
5252 item2 = Item (embedding = [4 , 5 , 6 ])
5353 item3 = Item ()
5454
55- session = Session (engine )
56- session .add (item )
57- session .add (item2 )
58- session .add (item3 )
59- session .commit ()
55+ with Session (engine ) as session :
56+ session .add (item )
57+ session .add (item2 )
58+ session .add (item3 )
59+ session .commit ()
6060
6161 stmt = select (Item )
6262 with Session (engine ) as session :
@@ -71,11 +71,11 @@ def test_orm(self):
7171 assert items [2 ].embedding is None
7272
7373 def test_vector (self ):
74- session = Session (engine )
75- session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
76- session .commit ()
77- item = session .get (Item , 1 )
78- assert item .embedding .tolist () == [1 , 2 , 3 ]
74+ with Session (engine ) as session :
75+ session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
76+ session .commit ()
77+ item = session .get (Item , 1 )
78+ assert item .embedding .tolist () == [1 , 2 , 3 ]
7979
8080 def test_vector_l2_distance (self ):
8181 create_items ()
@@ -102,11 +102,11 @@ def test_vector_l1_distance(self):
102102 assert [v .id for v in items ] == [1 , 3 , 2 ]
103103
104104 def test_halfvec (self ):
105- session = Session (engine )
106- session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
107- session .commit ()
108- item = session .get (Item , 1 )
109- assert item .half_embedding .to_list () == [1 , 2 , 3 ]
105+ with Session (engine ) as session :
106+ session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
107+ session .commit ()
108+ item = session .get (Item , 1 )
109+ assert item .half_embedding .to_list () == [1 , 2 , 3 ]
110110
111111 def test_halfvec_l2_distance (self ):
112112 create_items ()
@@ -133,11 +133,11 @@ def test_halfvec_l1_distance(self):
133133 assert [v .id for v in items ] == [1 , 3 , 2 ]
134134
135135 def test_bit (self ):
136- session = Session (engine )
137- session .add (Item (id = 1 , binary_embedding = '101' ))
138- session .commit ()
139- item = session .get (Item , 1 )
140- assert item .binary_embedding == '101'
136+ with Session (engine ) as session :
137+ session .add (Item (id = 1 , binary_embedding = '101' ))
138+ session .commit ()
139+ item = session .get (Item , 1 )
140+ assert item .binary_embedding == '101'
141141
142142 def test_bit_hamming_distance (self ):
143143 create_items ()
@@ -152,11 +152,11 @@ def test_bit_jaccard_distance(self):
152152 assert [v .id for v in items ] == [2 , 3 , 1 ]
153153
154154 def test_sparsevec (self ):
155- session = Session (engine )
156- session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
157- session .commit ()
158- item = session .get (Item , 1 )
159- assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
155+ with Session (engine ) as session :
156+ session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
157+ session .commit ()
158+ item = session .get (Item , 1 )
159+ assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
160160
161161 def test_sparsevec_l2_distance (self ):
162162 create_items ()
@@ -232,7 +232,7 @@ def test_halfvec_sum(self):
232232
233233 def test_bad_dimensions (self ):
234234 item = Item (embedding = [1 , 2 ])
235- session = Session (engine )
236- session .add (item )
237- with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
238- session .commit ()
235+ with Session (engine ) as session :
236+ session .add (item )
237+ with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
238+ session .commit ()
0 commit comments