Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Archive and draft directories
_arxiv/
# Ignore drafts by default, but keep the SCAR paper folder tracked (it has its own .gitignore).
drafts/
!drafts/LLM_prune/
!drafts/LLM_prune/**
checkpoints/
results/
logs/
Expand Down Expand Up @@ -174,6 +177,10 @@ Thumbs.db
*.swo
*~

# SLURM default output files (created when submitting scripts without explicit --output)
slurm-*.out
slurm-*.err

# OS
.DS_Store
.DS_Store?
Expand Down
157 changes: 97 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Alignment Framework

Neural network alignment analysis and pruning framework.
Neural network analysis and structured pruning using alignment metrics and information theory.

## Overview

Tools for analyzing and pruning neural networks using alignment metrics, information theory, and structured pruning strategies.
This framework provides tools for analyzing and pruning neural networks through:

**Supported architectures**: MLPs, CNNs (ResNet, VGG), Transformers, LLMs (LLaMA, Mistral)
- **Alignment metrics**: Rayleigh quotient, activation-based importance
- **Information-theoretic analysis**: Mutual information, redundancy, synergy
- **Cluster-based analysis**: Functional type identification, cross-layer halo tracking
- **Structured pruning**: Channel/neuron removal with multiple scoring strategies

**Supported architectures**: MLPs, CNNs (ResNet, VGG, MobileNet), Transformers, LLMs (LLaMA, Mistral, Qwen)

## Installation

Expand All @@ -20,57 +25,27 @@ pip install -e .

## Quick Start

### Run Experiments

```bash
# Vision model analysis
python scripts/run_experiment.py --config configs/examples/mnist_basic.yaml

# CNN pruning
python scripts/run_experiment.py --config configs/examples/resnet_pruning.yaml

# LLM importance scoring
python scripts/run_experiment.py --config configs/examples/llm_alignment.yaml
```

### Programmatic Usage

```python
from alignment import ModelWrapper, get_metric

wrapper = ModelWrapper(model)
rq = get_metric('rayleigh_quotient')
# LLM analysis
python scripts/run_experiment.py --config configs/paper/llama3_8b_full.yaml

outputs, activations = wrapper.forward_with_activations(inputs)
weights = wrapper.get_layer_weights()
scores = rq.compute(activations['layer_input'], weights['layer'])
# Cluster-based analysis
python scripts/run_experiment.py --config configs/cluster_analysis/resnet18_cifar10_full.yaml
```

## Configuration

Experiments use YAML configuration files:

```yaml
model:
name: "resnet18"
pretrained: true

dataset:
name: "cifar10"
batch_size: 128
## Experiment Types

alignment_methods:
- "rayleigh_quotient"
- "pairwise_redundancy_gaussian"

pruning:
enabled: true
algorithms: ["alignment"]
sparsity_levels: [0.3, 0.5, 0.7]
structured: true
```

See `configs/template.yaml` for all parameters.
| Type | Description | Config Example |
|------|-------------|----------------|
| `alignment_analysis` | General alignment metrics | `mnist_basic.yaml` |
| `llm_alignment` | LLM supernode/SCAR analysis | `llama3_8b_full.yaml` |
| `cluster_analysis` | Metric-space clustering with halos | `resnet18_cifar10_full.yaml` |

## Metrics

Expand All @@ -80,42 +55,104 @@ See `configs/template.yaml` for all parameters.
| Alignment | `rayleigh_quotient`, `delta_alignment` |
| Information | `mutual_information_gaussian`, `pairwise_redundancy_gaussian`, `gaussian_pid_synergy_mmi` |
| SCAR (LLM) | `scar_activation_power`, `scar_taylor`, `scar_curvature`, `scar_loss_proxy` |
| Synergy | `synergy_continuous_target` (with logit margin) |

## Cluster-Based Analysis

The cluster analysis framework groups channels/neurons into functional types:

| Type | Characteristics | Pruning Implication |
|------|-----------------|---------------------|
| Critical | High RQ, Low Redundancy, High Synergy | Protect |
| Redundant | Moderate RQ, High Redundancy | Target for pruning |
| Synergistic | Moderate RQ, High Synergy | Preserve pairs |
| Background | Low on all metrics | Safe to remove |

Cross-layer halo analysis tracks downstream dependencies to predict cascade effects.

## Pruning Strategies

| Strategy | Description |
|----------|-------------|
| `magnitude` | Prune by weight magnitude |
| `alignment` | Prune by alignment score |
| `hybrid` | Combine magnitude and alignment |
| `composite` | Combine multiple metrics |
| `cluster_aware` | Use cluster membership and halo analysis |
| `random` | Random baseline |
| `global` | Cross-layer pruning |

## Project Structure

```
alignment/
├── configs/ # YAML configuration files
│ ├── examples/ # Example experiments
│ └── template.yaml # Parameter reference
├── scripts/ # Entry points
│ ├── run_experiment.py
│ └── run_analysis.py
├── src/alignment/ # Main package
│ ├── analysis/ # Visualization
│ ├── experiments/ # Experiment classes
│ ├── metrics/ # Alignment metrics
│ ├── models/ # Model wrappers
│ └── pruning/ # Pruning strategies
├── tests/ # Unit tests
└── docs/ # Documentation
├── configs/
│ ├── cluster_analysis/ # Cluster-based analysis configs
│ ├── paper/ # Paper experiment configs
│ └── examples/ # Example configs
├── scripts/
│ ├── run_experiment.py # Main entry point
│ └── run_analysis.py # Post-hoc analysis
├── src/alignment/
│ ├── analysis/ # Visualization, clustering, cascade analysis
│ ├── experiments/ # Experiment classes
│ ├── metrics/ # Importance metrics
│ ├── models/ # Model wrappers
│ └── pruning/ # Pruning strategies
├── tests/ # Unit tests
└── docs/ # Documentation
```

## Key Modules

### Analysis
- `MetricSpaceClustering`: K-means clustering in (RQ, Redundancy, Synergy) space
- `CrossLayerHaloAnalysis`: Track downstream channel dependencies
- `CascadeAnalysis`: Validate importance via ablation
- `UnifiedVisualizer`: Generate analysis plots

### Experiments
- `GeneralAlignmentExperiment`: Vision model analysis
- `LLMAlignmentExperiment`: LLM supernode and SCAR analysis
- `ClusterAnalysisExperiment`: Cluster-based analysis for any architecture

### Metrics
- `RayleighQuotient`: Input-weight alignment
- `PairwiseRedundancyGaussian`: Gaussian MI-based redundancy
- `SynergyContinuousTarget`: PID synergy with continuous target
- SCAR metrics for LLMs

## Documentation

- [Usage Guide](docs/usage.md) - Running experiments and configuration
- [API Reference](docs/api_reference.md) - Core classes and functions
- [LLM Guide](docs/llm_guide.md) - LLM-specific analysis and pruning
- [LLM Guide](docs/llm_guide.md) - LLM-specific analysis
- [Metric Consistency](docs/METRIC_CONSISTENCY.md) - Theory-code verification

## Configuration

```yaml
experiment_type: cluster_analysis # or llm_alignment, alignment_analysis

model:
name: resnet18
pretrained: true

dataset:
name: cifar10
batch_size: 128

clustering:
n_clusters: 4
compute_stability: true

halo_analysis:
percentile: 90.0

pruning:
ratios: [0.3, 0.5, 0.7]
methods: [magnitude, taylor, cluster_aware]
```

See `configs/template.yaml` for complete parameter reference.

## Testing

Expand Down
Loading
Loading