-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_tatli_new.py
More file actions
162 lines (143 loc) · 4.88 KB
/
run_tatli_new.py
File metadata and controls
162 lines (143 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import hydra
import numpy as np
from comet_ml import Experiment
from omegaconf import DictConfig
from scipy.stats import ortho_group
from kernel import Kernel
from utils import solve_for_K
def check_accuracy(varphi_triplets, y, K_hat):
y_hat = []
for i in range(len(varphi_triplets)):
varphi_a, varphi_b, varphi_c = varphi_triplets[i]
dist = (2 * varphi_a - varphi_b - varphi_c).T @ K_hat @ (varphi_c - varphi_b)
y_t = np.sign(dist)
y_hat.append(y_t)
y_hat = np.array(y_hat)
return (y_hat == y).mean()
def create_triplets_from_items(K, G, rng, T, x, deterministic, noise_param):
triplets = []
ys = []
dists = []
for i in range(T):
h, i, j = i * 3, i * 3 + 1, i * 3 + 2
dist_diff = (2 * K[h] - K[i] - K[j]).T @ G @ (K[j] - K[i])
if deterministic:
y_t = np.sign(dist_diff)
else:
p_t = 1 / (1 + np.exp(noise_param * dist_diff))
y_t = -1 if rng.uniform() < p_t else 1
dists.append(dist_diff)
triplets.append([x[h], x[i], x[j]])
ys.append(y_t)
dists = np.array(dists)
triplets = np.array(triplets)
ys = np.array(ys)
return triplets, ys, dists
def create_K_from_kernel_and_z(kernel, x, z):
K_fast = kernel.gram_matrix(x, z)
# K_fast = K_fast / np.linalg.norm(K_fast, axis=1, keepdims=True)
return K_fast
@hydra.main(version_base=None, config_path=".", config_name="config_tatli_new")
def main(cfg: DictConfig):
experiment = Experiment(
api_key="sIW0DaJFOJfKpgOYLBnusxnEw",
project_name=cfg.project_name,
workspace="kitkatdafu",
)
experiment.log_parameters(cfg)
eigen_cutoff = cfg.eigen_cutoff
eigen_threshold = cfg.eigen_threshold
seed = cfg.seed
# define a random number generator
rng = np.random.default_rng(seed=seed)
# initialize a kernel for generating labels
kernel_1 = Kernel(cfg.kernel_1, eigen_cutoff=None, **cfg.kernel_params)
kernel_2 = Kernel(
cfg.kernel_2,
eigen_cutoff=eigen_cutoff,
eigen_threshold=eigen_threshold,
**cfg.kernel_params,
)
# dimension of the points in Euclidean space
d = cfg.d
# rank of linear functional (metric)
r = cfg.r
# number of triplets
T = cfg.T
T_val = cfg.T_val
# number of items
n = 3 * T
n_val = 3 * T_val
# random set of points
# z = rng.multivariate_normal(np.zeros(d), 1 / d * np.eye(d), size=r)
z = np.eye(d)
# generate a set of items
x = rng.multivariate_normal(np.zeros(d), 1 / d * np.eye(d), size=n)
x_vals = [
rng.multivariate_normal(np.zeros(d), 1 / d * np.eye(d), size=n_val)
for _ in range(10)
]
# construct r-dimensional vector k
K = create_K_from_kernel_and_z(kernel_1, x, z)
K_vals = [create_K_from_kernel_and_z(kernel_1, x_val, z) for x_val in x_vals]
# generate a random psd matrix
U = ortho_group.rvs(dim=d, random_state=rng)[:, :r]
# A = rng.normal(size=(d, r))
# U, _ = np.linalg.qr(A)
G = d / np.sqrt(r) * U @ U.T
# calculate difference of distances of triplets
triplets, y_train, dists = create_triplets_from_items(
K, G, rng, T, x, deterministic=not cfg.noise, noise_param=cfg.noise_param
)
# calculate difference of distances of triplets in the validation set
triplets_vals, y_vals, dists_vals = list(
zip(
*[
create_triplets_from_items(
K_val,
G,
rng,
T_val,
x_val,
deterministic=True,
noise_param=cfg.noise_param,
)
for K_val, x_val in zip(K_vals, x_vals)
]
)
)
varphi_triplets, varphi_triplets_vals = kernel_2.generate_projection_triplets(
triplets, triplets_vals
)
problem, K_hat = solve_for_K(
p=varphi_triplets.shape[2],
triplets=varphi_triplets,
y_true=y_train,
gamma=max(np.abs(dists)),
lambda_=d * np.sqrt(r),
solver=cfg.solver,
loss_type=cfg.loss_type,
constraint_type=cfg.constraint_type,
verbose=cfg.verbose,
)
time = problem.solver_stats.solve_time
val_acc_s = (
np.array(
[
check_accuracy(varphi_triplets_val, y_val, K_hat)
for varphi_triplets_val, y_val in zip(varphi_triplets_vals, y_vals)
]
)
* 100
)
train_acc = check_accuracy(varphi_triplets, y_train, K_hat) * 100
experiment.log_metric("solve_time", time)
experiment.log_metric("train_accuracy", train_acc)
for val_acc in val_acc_s:
experiment.log_metric("validation_accuracy", val_acc)
print(cfg)
print(
f"Train accuracy: {train_acc:.2f}; Validation accuracy: {val_acc_s.mean():.2f} +/- {val_acc_s.std():.2f}"
)
if __name__ == "__main__":
main()