-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
44 lines (40 loc) · 1.59 KB
/
train.py
File metadata and controls
44 lines (40 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os, numpy as np, torch, mlflow, mlflow.pytorch
from model import VAE
import torch.nn as nn
os.makedirs('models', exist_ok=True)
def synthetic_dataset(n=1000):
X = []
for i in range(n):
base = np.random.normal(loc=0.0, scale=1.0, size=10)
if np.random.rand() < 0.03:
base += np.random.normal(loc=4.0, scale=1.5, size=10)
X.append(base.astype('float32'))
return np.stack(X)
def train():
X = synthetic_dataset(1000)
device = torch.device('cpu')
model = VAE(n_features=10, latent_dim=4)
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
mse = nn.MSELoss()
mlflow.set_experiment('mosul_vae_experiment')
with mlflow.start_run(run_name='demo_run') as run:
for epoch in range(5):
losses = []
for i in range(0, len(X), 64):
batch = torch.tensor(X[i:i+64], dtype=torch.float32).to(device)
recon, z, mu, logvar = model(batch)
loss = mse(recon, batch)
opt.zero_grad(); loss.backward(); opt.step()
losses.append(loss.item())
avg = float(sum(losses)/len(losses))
print(f'Epoch {epoch+1} avg_loss={avg:.6f}')
mlflow.log_metric('avg_loss', avg, step=epoch+1)
path = 'models/vae_demo.pt'
torch.save(model.state_dict(), path)
mlflow.log_artifact(path, artifact_path='models')
mlflow.pytorch.log_model(model, artifact_path='pytorch_model')
mlflow.log_param('latent_dim', 4)
print('Saved model to', path)
if __name__ == '__main__':
train()