-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhparam_search.py
More file actions
136 lines (110 loc) · 3.8 KB
/
hparam_search.py
File metadata and controls
136 lines (110 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Literal
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from framework.data_utils import prepare_dataset
from framework.fitness import calculate_composite_fitness
from models.cnn import CNNModel
from models.decision_tree import DecisionTreeModel
from models.factory import get_model_by_name
from models.knn import KNNModel
from search import RandomSearch
RANDOM_SEED = 321
# CNN Specific
DEFAULT_EPOCHS = 5
DEFAULT_PATIENCE = 2
REPO_ROOT = Path(__file__).resolve().parent
LOG_ROOT = REPO_ROOT / ".cache" / "tensorboard" / "search"
def set_seeds(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def evaluate_model(
model_key: Literal["dt", "knn", "cnn"],
params: Dict[str, Any],
data: Dict[str, Any],
) -> Dict[str, float]:
model = get_model_by_name(model_key)
if model_key in {"dt", "knn"}:
assert isinstance(model, (DecisionTreeModel, KNNModel))
model.create_model(**params)
model.train(data["train_flat"], data["train_labels"])
metrics = model.evaluate(data["val_flat"], data["val_labels"])
elif model_key == "cnn":
assert isinstance(model, CNNModel)
model.create_model(**params)
model.train(
data["train_images"],
data["train_labels"],
data["val_images"],
data["val_labels"],
)
metrics = model.evaluate(data["val_images"], data["val_labels"])
else:
raise ValueError(f"Unsupported model key: {model_key}")
metrics["composite_fitness"] = calculate_composite_fitness(metrics)
return metrics
def run_search(model_key: Literal["dt", "knn", "cnn"], trials: int) -> None:
set_seeds(RANDOM_SEED)
print("Preparing dataset...")
data = prepare_dataset()
print("Preparing parameter space...")
param_space = get_model_by_name(model_key).get_param_space()
searcher = RandomSearch(
param_space=param_space,
evaluate_fn=lambda sampled: evaluate_model(model_key, sampled, data),
metric_key="composite_fitness",
seed=RANDOM_SEED,
)
log_dir = create_search_log_dir(model_key)
print(f"Running search... (logging to {log_dir})")
writer = SummaryWriter(log_dir=log_dir)
try:
result = searcher.run(trials, verbose=True, writer=writer)
finally:
writer.close()
print("-" * 80)
print(f"Model: {model_key}")
print(f"Trials: {trials}")
print(f"Best composite fitness: {result.best_metrics['composite_fitness']:.4f}")
print("Best metrics:")
for name, value in result.best_metrics.items():
if isinstance(value, float):
print(f" {name}: {value:.4f}")
print("Best hyperparameters:")
for name, value in result.best_params.items():
print(f" {name}: {value}")
def create_search_log_dir(model_key: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = LOG_ROOT / model_key / f"run_{timestamp}"
log_dir.mkdir(parents=True, exist_ok=True)
return str(log_dir)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Random hyperparameter search for CIFAR-10 models."
)
parser.add_argument(
"--model",
type=str,
default="dt",
choices=["dt", "knn", "cnn"],
help="Model to optimize.",
)
parser.add_argument(
"--trials",
type=int,
default=5,
help="Number of random search trials.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
run_search(args.model, args.trials)
if __name__ == "__main__":
main()