Skip to content

Commit aea97e1

Browse files
brilingliam-o-marsh
authored andcommitted
Refactor test
1 parent f3fec85 commit aea97e1

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

tests/test_kernels.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,18 @@ def test_batched_local_kernels():
4949
Y = np.array([[0.09992856, 0.50806631, 0.20024754, 0.74415417], [0.192892 , 0.70084475, 0.29322811, 0.77447945]])
5050
K_L_good = np.array([[0.48938983, 0.58251676], [0.32374891, 0.31778924]])
5151

52-
X_huge = np.concatenate([X]*1_000, axis=1)
53-
X_huge = np.concatenate([X_huge]*1000, axis=0)
54-
Y_huge = np.concatenate([Y]*1_000, axis=1)
55-
Y_huge = np.concatenate([Y_huge]*50, axis=0)
56-
K_L_good_huge = np.concatenate([K_L_good]*1000, axis=0)
57-
K_L_good_huge = np.concatenate([K_L_good_huge]*50, axis=1)
52+
X_huge = np.tile(X, (1000,1000))
53+
Y_huge = np.tile(Y, (50,1000))
54+
K_L_good_huge = np.tile(K_L_good, (1000,50))
5855

5956
local_kernels.RAM_BATCHING_SIZE = 1024**2 * 50 # 50MiB
60-
for akernel in ['L_custom_c', 'L_custom_py', 'L', 'L_sklearn']:
61-
K = kernel.kernel(X_huge, Y_huge, akernel=akernel, sigma=2.0*1000)
62-
assert np.allclose(K, K_L_good_huge)
57+
58+
K = kernel.kernel(X_huge, Y_huge, akernel='L_custom_py', sigma=2.0*1000)
59+
assert np.allclose(K, K_L_good_huge)
60+
61+
K = kernel.kernel(X_huge.reshape((-1, 50, 80)), Y_huge.reshape((-1, 50, 80)), akernel='L_custom_py', sigma=2.0*1000)
62+
assert np.allclose(K, K_L_good_huge)
63+
6364

6465
if __name__ == '__main__':
6566
test_local_kernels()

0 commit comments

Comments
 (0)