Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed .github/workflows/benchmark.yml
Empty file.
Empty file removed .github/workflows/gpu-test.yml
Empty file.
7 changes: 3 additions & 4 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ configs/
## Usage

```bash
python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml
python scripts/run_experiment.py --config configs/examples/llama3_pruning.yaml
```
python scripts/run_experiment.py --config configs/examples/llama3_comprehensive_pruning.yaml.yaml
python scripts/run_experiment.py --config configs/examples/vision_pruning_test.yaml

## Configuration Blocks

| Block | Purpose |
|-------|---------|
| `experiment` | Name, type (`alignment_analysis` or `llm_alignment`), seed, device |
| `model` | Architecture, pretrained, tracked_layers. For LLMs: model_id, torch_dtype |
| `model` | Architecture, pretrained, tracked_layers. For LLMs: model_id, dtype |
| `dataset` | Dataset name, batch_size, data_path |
| `metrics` | `enabled`: list of metrics. `num_samples`: calibration samples. `composite_weights`: for composite scoring |
| `training` | `enabled`, epochs, learning_rate, optimizer |
Expand Down
84 changes: 84 additions & 0 deletions configs/examples/alexnet_pruning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# AlexNet Pruning Experiment on ImageNet
# AlexNet requires 224x224 input, so we use ImageNet

experiment:
name: "alexnet_pruning_test"
type: "alignment_analysis"
seed: 42
device: "auto" # Auto-detect GPU/CPU
output_dir: "./results/alexnet_pruning"
num_networks: 1 # Single network

model:
name: "alexnet"
pretrained: true # Uses pretrained weights (auto-disables training)
tracked_layers: null # Auto-detect conv/linear layers

# ImageNet dataset path on Kempner cluster
dataset:
name: "imagenet"
data_path: "/n/holylfs06/LABS/kempner_shared/Everyone/testbed/vision/imagenet_1k"
batch_size: 128
num_workers: 4

# Metrics to compute - these are also used for pruning
# Note: mutual_information_gaussian requires outputs, excluded for now
metrics:
enabled:
- "rayleigh_quotient" # Alignment to input covariance
- "activation_l2_norm" # Activation magnitude
num_samples: 256

# CNN activation preprocessing
# "unfold" (recommended): Best for RQ/covariance metrics
# "patchwise": Good for patch-level analysis
# "channel_variance": Faster but less accurate
cnn:
mode: "unfold"

# Pruning experiments
pruning:
enabled: true

# Sparsity levels to test (fraction of neurons to remove)
sparsity_levels: [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# Selection modes: WHICH neurons to prune based on their scores
# "low" - prune neurons with LOWEST scores (remove least important)
# "high" - prune neurons with HIGHEST scores (ablation study)
# "random" - prune randomly (baseline comparison)
selection_modes: ["low", "high", "random"]

# Distribution: HOW to allocate sparsity across layers
# "uniform" - same % per layer (simple, default)
# "global_threshold" - global score threshold
# "adaptive_sensitivity" - prune robust layers more
# "cascading" - sequential layer-by-layer
distribution: "uniform"

# Structured: prune entire neurons/channels (vs individual weights)
structured: true

# Dependency-aware: handle layer dependencies
dependency_aware: true

fine_tune:
enabled: true
epochs: 10
learning_rate: 0.0001

# Performance (all optimizations enabled by default)
performance:
eval_batches: null # null = all batches

analysis:
save_scores: true
generate_plots: true
plots:
histograms: true
scatter_plots: true
pruning_curves: true

visualization:
format: "png"
dpi: 300
105 changes: 105 additions & 0 deletions configs/examples/cnn2p2_pruning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# CNN2P2 Pruning Experiment on CIFAR-10
# A simpler test case before using AlexNet - 2 conv layers + 2 pool layers
# This model is faster to train and easier to debug pruning behavior

experiment:
name: "cnn2p2_pruning_test"
type: "alignment_analysis"
seed: 42
device: "auto" # Auto-detect GPU/CPU
output_dir: "./results/cnn2p2_pruning"
num_networks: 1 # Single network (can increase for statistical analysis)

model:
name: "cnn2p2"
pretrained: false # Train from scratch
tracked_layers: null # Auto-detect conv/linear layers

# CNN2P2 specific parameters (for CIFAR-10)
cnn2p2_params:
in_channels: 3 # RGB images
output_dim: 10 # CIFAR-10 classes
conv_channels: [32, 64] # Channels per conv layer
kernel_sizes: [5, 5]
strides: [1, 1]
paddings: [2, 2]
pool_kernel_size: 2
pool_stride: 2
hidden_fc_dim: 128
dropout_rate: 0.5
example_input_hw: [32, 32] # CIFAR-10 image size

# CIFAR-10 dataset
dataset:
name: "cifar10"
data_path: null # Uses torchvision default download
batch_size: 128
num_workers: 4

# Training configuration (since we're not using pretrained)
training:
enabled: true
epochs: 50 # Enough to converge on CIFAR-10
learning_rate: 0.01
optimizer: "sgd"
momentum: 0.9
weight_decay: 0.0001
scheduler: "cosine"
scheduler_config:
T_max: 50
eta_min: 0.0001

# Metrics to compute - these are also used for pruning
metrics:
enabled:
- "rayleigh_quotient" # Alignment to input covariance
- "activation_l2_norm" # Activation magnitude
num_samples: 256

# CNN activation preprocessing
cnn:
mode: "unfold" # Best for RQ/covariance metrics

# Pruning experiments
pruning:
enabled: true

# Sparsity levels to test (fraction of neurons to remove)
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# Selection modes: WHICH neurons to prune based on their scores
# "low" - prune neurons with LOWEST scores (remove least important)
# "high" - prune neurons with HIGHEST scores (ablation study)
# "random" - prune randomly (baseline comparison)
selection_modes: ["low", "high", "random"]

# Distribution: HOW to allocate sparsity across layers
distribution: "uniform"

# Structured: prune entire neurons/channels (vs individual weights)
structured: true

# Dependency-aware: handle layer dependencies
dependency_aware: true

fine_tune:
enabled: true
epochs: 10
learning_rate: 0.001

# Performance
performance:
eval_batches: null # null = all batches

analysis:
save_scores: true
generate_plots: true
plots:
histograms: true
scatter_plots: true
pruning_curves: true

visualization:
format: "png"
dpi: 300

104 changes: 104 additions & 0 deletions configs/examples/cnn2p2_pruning_cifar10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# CNN2P2 Pruning Experiment on CIFAR-10
# More challenging task to better differentiate pruning strategies
# Tests composite metrics for optimal pruning

experiment:
name: "cnn2p2_pruning_cifar10"
type: "alignment_analysis"
seed: 42
device: "auto"
output_dir: "./results/cnn2p2_pruning_cifar10"
num_networks: 1

model:
name: "cnn2p2"
pretrained: false
tracked_layers: null

# CNN2P2 for CIFAR-10 (3 channels, moderate size)
cnn2p2_params:
in_channels: 3 # RGB images
output_dim: 10 # CIFAR-10 classes
conv_channels: [16, 32] # Moderate size
kernel_sizes: [3, 3]
strides: [1, 1]
paddings: [1, 1]
pool_kernel_size: 2
pool_stride: 2
hidden_fc_dim: 64 # Moderate FC layer
dropout_rate: 0.3
example_input_hw: [32, 32] # CIFAR-10 image size

# CIFAR-10 dataset (more challenging than MNIST)
dataset:
name: "cifar10"
data_path: null
batch_size: 128
num_workers: 4

# Training configuration
training:
enabled: true
epochs: 30 # CIFAR-10 needs more epochs
learning_rate: 0.01
optimizer: "sgd"
momentum: 0.9
weight_decay: 0.0001
scheduler: "cosine"
scheduler_config:
T_max: 30
eta_min: 0.0001

# Metrics - comprehensive set with analytic MI and composite metrics
metrics:
enabled:
# Core alignment metrics
- "rayleigh_quotient"
- "conditional_rayleigh_quotient"
# Analytic MI (similar to RQ, computed analytically)
- "gaussian_mi_analytic"
- "mutual_information_gaussian"
# Activation-based
- "activation_l2_norm"
# Redundancy metrics
- "average_redundancy"
- "pairwise_redundancy_gaussian"
# Class-based metrics
- "mi_about_class"
# Composite metrics for optimal pruning
- "alignment_minus_redundancy" # RQ - Redundancy
num_samples: 256

cnn:
mode: "unfold"

# Pruning experiments
pruning:
enabled: true
sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
selection_modes: ["low", "high", "random"]
distribution: "uniform"
structured: true
dependency_aware: true

fine_tune:
enabled: true
epochs: 10
learning_rate: 0.001

performance:
eval_batches: null

analysis:
save_scores: true
generate_plots: true
plots:
histograms: true
scatter_plots: true
pruning_curves: true
redundancy_heatmaps: true

visualization:
format: "png"
dpi: 300

Loading
Loading