Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions conf/task/gtm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# @package _global_
name: gtm

algebra:
p: 3
q: 0
r: 1
device: cuda

model:
channels: 32
num_steps: 12
max_steps: 24
num_hypotheses: 8
coord_scale: 1.0
head_hidden: 128
num_rule_slots: 8
num_memory_channels: 4
weight_share_steps: false

log_manifold:
gate_init: -5.0

search_plane:
conviction_threshold: 0.9
evolve_hidden: 64

info_geometry:
halt_eps: 0.01
use_supervised_fim: true

action_engine:
gate_init: 0.0

attention:
num_heads: 4
head_dim: 8

dataset:
data_dir: data/arc
include_toy: true
toy_n_examples: 20000
toy_max_grid_size: 15
num_demos: 3
epoch_samples: 4000

training:
epochs: 150
lr: 0.0005
batch_size: 16
optimizer_type: riemannian_adam
max_bivector_norm: 10.0
grad_clip: 1.0

# CUDA acceleration
num_workers: 4
pin_memory: true
amp: true
compile: false
cudnn_benchmark: true

# Three-phase schedule
warmup_epochs: 8
trim_epochs: 72
act_epochs: 70

# Temperature schedule
tau_start: 1.0
tau_mid: 0.5
tau_end: 0.2

# Loss weights
ortho_weight: 0.005
gate_entropy_weight: 0.01
info_gain_weight: 0.01
eval_every: 5
8 changes: 4 additions & 4 deletions core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def induced_norm(algebra: CliffordAlgebra, A: torch.Tensor) -> torch.Tensor:
sq_norm = inner_product(algebra, A, A_rev)

# In mixed signatures, sq_norm can be negative.
# We return sqrt(|sq_norm|)
return torch.sqrt(torch.abs(sq_norm))
return torch.sqrt(torch.abs(sq_norm).clamp(min=1e-12))

def geometric_distance(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Computes geometric distance.
Expand Down Expand Up @@ -238,7 +237,8 @@ def hermitian_norm(algebra: CliffordAlgebra, A: torch.Tensor) -> torch.Tensor:
Norm [..., 1]. Always >= 0.
"""
sq = hermitian_inner_product(algebra, A, A)
return torch.sqrt(torch.abs(sq))
# Clamp before sqrt to avoid inf gradient when sq ≈ 0 (e.g. null multivectors in PGA).
return torch.sqrt(torch.abs(sq).clamp(min=1e-12))


def hermitian_distance(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -277,7 +277,7 @@ def hermitian_angle(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor)
sq_b = (signs * B * B).sum(dim=-1, keepdim=True)
# Use sqrt(sq_a * sq_b) instead of sqrt(sq_a)*sqrt(sq_b) to avoid
# float32 precision loss from two separate sqrt operations.
denom = torch.sqrt(torch.abs(sq_a) * torch.abs(sq_b)).clamp(min=1e-6)
denom = torch.sqrt((torch.abs(sq_a) * torch.abs(sq_b)).clamp(min=1e-12)).clamp(min=1e-6)
cos_theta = ip / denom
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
return torch.acos(cos_theta)
Expand Down
4 changes: 4 additions & 0 deletions datalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .md17 import get_md17_loaders
from .deap import DEAPDataset, get_deap_loaders, get_group_sizes
from .lqa import CLUTRRDataset, HANSDataset, BoolQNegDataset, get_lqa_loaders
from .arc import ToyARCDataset, ARCDataset, get_arc_loaders

__all__ = [
"SRDataset",
Expand All @@ -27,4 +28,7 @@
"HANSDataset",
"BoolQNegDataset",
"get_lqa_loaders",
"ToyARCDataset",
"ARCDataset",
"get_arc_loaders",
]
Loading
Loading