-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path03_train.py
More file actions
104 lines (86 loc) · 3.56 KB
/
03_train.py
File metadata and controls
104 lines (86 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Step 3: Train YOLOv11n on the MTG dataset using Apple Silicon.
Trains a nano-sized YOLO model for fast local training.
Uses transfer learning from COCO pretrained weights.
Note: Training uses CPU because PyTorch 2.10 MPS has known tensor corruption
bugs on macOS 26 (clamp_() and TAL assigner indexing). CPU training on Apple
Silicon is still fast thanks to high-bandwidth unified memory and NEON SIMD.
MPS is still used for inference (predict/validate) which works correctly.
Usage:
uv run python scripts/03_train.py
"""
import shutil
from pathlib import Path
import torch
from ultralytics import YOLO
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_YAML = PROJECT_ROOT / "data" / "mtg-detection" / "data.yaml"
MODELS_DIR = PROJECT_ROOT / "models"
RUNS_DIR = PROJECT_ROOT / "runs"
def main():
if not DATA_YAML.exists():
print(f"Error: {DATA_YAML} not found. Run 01_setup_dataset.py first.")
raise SystemExit(1)
MODELS_DIR.mkdir(parents=True, exist_ok=True)
# Use CPU for training (MPS has known bugs with PyTorch 2.10 on macOS 26)
# MPS works fine for inference — only training loss computation is affected
device = "cpu"
if torch.backends.mps.is_available():
print("Apple Silicon detected (MPS available for inference)")
print(f"Training device: CPU (stable on all platforms)")
# Load pretrained YOLOv11 nano (COCO weights for transfer learning)
print("\nLoading yolo11n.pt (pretrained on COCO)...")
model = YOLO("yolo11n.pt")
print("\nStarting training...")
print(f" Dataset: {DATA_YAML}")
print(f" Device: {device}")
print(f" Epochs: 100 (early stopping patience=20)")
print(f" Batch size: 16")
print(f" Image size: 640")
print(f" Augmentation: geometric (degrees=15, perspective=0.001, shear=2)")
print(f" Multi-scale: 0.5 (trains at 320-960px)")
print(f" Mixup: 0.05 (background false positive reduction)")
print()
results = model.train(
data=str(DATA_YAML),
epochs=100,
batch=16,
imgsz=640,
device=device,
patience=20,
optimizer="AdamW",
lr0=0.001,
cos_lr=True,
augment=True,
plots=True,
save_period=25,
project=str(RUNS_DIR),
name="mtg-detect",
exist_ok=True,
workers=8, # Parallel data loading (feed CPU faster)
# --- Geometric augmentation (webcam robustness) ---
degrees=15.0, # Cards held at angles up to ~15°
perspective=0.001, # Simulates keystoning from angled holding
shear=2.0, # Mild perspective distortion
# --- Multi-scale training (resolution invariance) ---
multi_scale=0.5, # Train at 320-960px for any-resolution support
# --- Background false positive reduction ---
mixup=0.05, # 5% image blending — reduces overconfidence on backgrounds
)
# Copy best weights to models/
best_pt = RUNS_DIR / "mtg-detect" / "weights" / "best.pt"
dest = MODELS_DIR / "mtg-detect-best.pt"
if best_pt.exists():
shutil.copy2(best_pt, dest)
print(f"\nBest weights saved to: {dest}")
else:
print(f"\nWarning: {best_pt} not found")
# Print final metrics
metrics = results.results_dict
map50 = metrics.get("metrics/mAP50(B)", "N/A")
map50_95 = metrics.get("metrics/mAP50-95(B)", "N/A")
print(f"\nFinal metrics:")
print(f" mAP50: {map50}")
print(f" mAP50-95: {map50_95}")
print("\nDone! Run scripts/04_validate.py next.")
if __name__ == "__main__":
main()