Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ logs/
runs/
outputs/
results/
experiments/


# Backup files
Expand Down Expand Up @@ -161,7 +160,6 @@ dmypy.json
/runs/
/outputs/
/results/
/experiments/

# Temporary files
*.tmp
Expand Down
2 changes: 1 addition & 1 deletion configs/prune_llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Configurations for generating results in the SCAR LLM pruning paper.

Run all experiments:
```bash
bash drafts/LLM_prune/paper/slurm/run_all_paper.sh
bash slurm_jobs/prune_llm/run_all_paper.sh
```

Run single model:
Expand Down
23 changes: 22 additions & 1 deletion configs/prune_llm/llama2_7b_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ llm:

evaluate_perplexity: true
evaluation_num_samples: 100
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
use_nvidia_fewshot: true
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
perplexity_protocol: "oats"
wikitext_subset: "wikitext-2-raw-v1"
perplexity_seq_len: 2048

evaluation_metrics:
- "perplexity"
Expand Down Expand Up @@ -137,6 +144,20 @@ supernode:
core_fraction: 0.01
follower_fraction: 0.10
halo_fraction: 0.10
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
# that lands on the top-K hidden dimensions most written-to by supernodes.
connectivity_topk: 256
# Optional post-processing for Conn (defaults keep current behavior)
connectivity_rank_normalize: false
connectivity_power: 1.0
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
# (used for paper mechanism plots; does NOT affect pruning decisions).
non_halo_sample_size: 256
non_halo_sample_seed: 0
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
protection_normalization: "rank_power"
protection_rank_power: 8.0
protection_floor: 0.2
protect_core: true
protect_core_metrics:
- "scar_loss_proxy" # SCAR-LP
Expand Down Expand Up @@ -232,7 +253,7 @@ pruning:
dependency_aware: true

sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]

algorithms:
- "rayleigh_quotient"
Expand Down
2 changes: 1 addition & 1 deletion configs/prune_llm/llama2_7b_unified.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ cascade_analysis:
pruning:
enabled: true
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]
distribution: "uniform"
min_per_layer: 0.0
max_per_layer: 0.95
Expand Down
28 changes: 27 additions & 1 deletion configs/prune_llm/llama3_8b_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ llm:

evaluate_perplexity: true
evaluation_num_samples: 100
# Use NVIDIA Minitron official few-shot settings for downstream tasks
# (MMLU 5-shot, HellaSwag 10-shot, ARC 25-shot, WinoGrande 5-shot, etc.).
use_nvidia_fewshot: true
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
perplexity_protocol: "oats"
wikitext_subset: "wikitext-2-raw-v1"
perplexity_seq_len: 2048

evaluation_metrics:
# Language modeling
Expand All @@ -99,6 +107,7 @@ llm:
- "accuracy_hellaswag"
- "accuracy_arc_easy"
- "accuracy_arc_challenge"
- "accuracy_openbookqa"

# Common Sense
- "accuracy_winogrande"
Expand Down Expand Up @@ -174,6 +183,21 @@ supernode:
core_fraction: 0.01
follower_fraction: 0.10
halo_fraction: 0.10
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
# that lands on the top-K hidden dimensions most written-to by supernodes.
# (Avoids the ~1/hidden_dim collapse of L1-normalized dot-product overlap for dense matrices.)
connectivity_topk: 256
# Optional post-processing for Conn (defaults keep current behavior)
connectivity_rank_normalize: false
connectivity_power: 1.0
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
# (used for paper mechanism plots; does NOT affect pruning decisions).
non_halo_sample_size: 256
non_halo_sample_seed: 0
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
protection_normalization: "rank_power"
protection_rank_power: 8.0
protection_floor: 0.2
protect_core: true
# Apply hard supernode protection only for the listed pruning metrics.
# If omitted, legacy behavior is to protect for *all* pruning metrics.
Expand Down Expand Up @@ -286,7 +310,9 @@ pruning:
dependency_aware: true

sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
# We only report (and run) the standard pruning direction: prune *low*-scoring channels.
# The "high" mode (prune highest scores) is a pathological control and is excluded from paper runs.
selection_modes: ["low"]

# ALL algorithms including SOTA baselines
algorithms:
Expand Down
87 changes: 87 additions & 0 deletions configs/prune_llm/llama3_8b_random_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ============================================================================
# LLAMA-3.1-8B RANDOM (CHANNEL) BASELINE
# ============================================================================
#
# Purpose:
# - Fill the missing "Random (channel)" baseline row in paper tables.
# - Run ONLY one pruning strategy (random) at 50% sparsity.
# - Keep evaluation protocol consistent with the main paper run (few-shot settings, ppl protocol).
#
# This is intentionally lightweight: we skip SCAR analyses/plots and only do:
# - Baseline eval
# - Random structured channel pruning @ 50%
# - Post-prune eval
# ============================================================================

experiment:
name: "llama3_8b_paper_results_random"
type: "llm_alignment"
output_dir: "./results/paper/llama3_8b_random"
seed: 42
device: "cuda"
save_activations: false
num_networks: 1

model:
name: "hf_causal_lm"
model_id: "meta-llama/Llama-3.1-8B"
dtype: "bfloat16"
device_map: "auto"
trust_remote_code: true

dataset:
name: "wikitext"
batch_size: 1
num_workers: 0

llm:
evaluate_perplexity: true
evaluation_num_samples: 100
use_nvidia_fewshot: true
perplexity_protocol: "oats"
wikitext_subset: "wikitext-2-raw-v1"
perplexity_seq_len: 2048

evaluation_metrics:
- "perplexity"
- "accuracy_openbookqa"
- "accuracy_mmlu"
- "accuracy_hellaswag"
- "accuracy_piqa"
- "accuracy_boolq"
- "accuracy_winogrande"
- "accuracy_arc_easy"
- "accuracy_arc_challenge"

# Disable heavy analyses for this baseline-only run
analysis:
generate_plots: false
save_scores: false

do_scar_metrics: false
do_directed_redundancy: false
do_connectivity_pruning: false
do_halo_analysis: false
do_generalized_importance: false

pruning:
enabled: true
target: "ffn"
structured: true
dependency_aware: true
distribution: "uniform"
min_per_layer: 0.0
max_per_layer: 0.95

# Single point needed for table_full_benchmarks_50
sparsity_levels: [0.5]

# Random structured pruning: selection done by mode="random"
selection_modes: ["random"]

# Only one strategy for this run; scores are generated in-code (deterministic).
algorithms:
- "random"

single_strategy: "random"

4 changes: 2 additions & 2 deletions configs/prune_llm/llama3_8b_unified.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# - All experiment-specific settings in `extra:` section
# - Same pruning/evaluation/visualization structure
#
# Usage: python scripts/run_experiment.py --config configs/unified/llama3_8b_unified.yaml
# Usage: python scripts/run_experiment.py --config configs/prune_llm/llama3_8b_unified.yaml
# Estimated runtime: ~6-8 hours on 1x A100
# =============================================================================

Expand Down Expand Up @@ -156,7 +156,7 @@ cascade_analysis:
pruning:
enabled: true
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]
distribution: "uniform"
min_per_layer: 0.0
max_per_layer: 0.95
Expand Down
23 changes: 22 additions & 1 deletion configs/prune_llm/mistral_7b_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ llm:

evaluate_perplexity: true
evaluation_num_samples: 100
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
use_nvidia_fewshot: true
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
perplexity_protocol: "oats"
wikitext_subset: "wikitext-2-raw-v1"
perplexity_seq_len: 2048

evaluation_metrics:
- "perplexity"
Expand Down Expand Up @@ -136,6 +143,20 @@ supernode:
core_fraction: 0.01
follower_fraction: 0.10
halo_fraction: 0.10
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
# that lands on the top-K hidden dimensions most written-to by supernodes.
connectivity_topk: 256
# Optional post-processing for Conn (defaults keep current behavior)
connectivity_rank_normalize: false
connectivity_power: 1.0
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
# (used for paper mechanism plots; does NOT affect pruning decisions).
non_halo_sample_size: 256
non_halo_sample_seed: 0
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
protection_normalization: "rank_power"
protection_rank_power: 8.0
protection_floor: 0.2
protect_core: true
protect_core_metrics:
- "scar_loss_proxy" # SCAR-LP
Expand Down Expand Up @@ -231,7 +252,7 @@ pruning:
dependency_aware: true

sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]

algorithms:
- "rayleigh_quotient"
Expand Down
2 changes: 1 addition & 1 deletion configs/prune_llm/mistral_7b_unified.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ cascade_analysis:
pruning:
enabled: true
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]
distribution: "uniform"
min_per_layer: 0.0
max_per_layer: 0.95
Expand Down
23 changes: 22 additions & 1 deletion configs/prune_llm/qwen2_7b_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ llm:

evaluate_perplexity: true
evaluation_num_samples: 100
# Use NVIDIA Minitron official few-shot settings for downstream tasks.
use_nvidia_fewshot: true
# Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity:
# concatenate full test set and evaluate in contiguous 2048-token blocks (no padding).
perplexity_protocol: "oats"
wikitext_subset: "wikitext-2-raw-v1"
perplexity_seq_len: 2048

evaluation_metrics:
- "perplexity"
Expand Down Expand Up @@ -137,6 +144,20 @@ supernode:
core_fraction: 0.01
follower_fraction: 0.10
halo_fraction: 0.10
# Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass
# that lands on the top-K hidden dimensions most written-to by supernodes.
connectivity_topk: 256
# Optional post-processing for Conn (defaults keep current behavior)
connectivity_rank_normalize: false
connectivity_power: 1.0
# Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels
# (used for paper mechanism plots; does NOT affect pruning decisions).
non_halo_sample_size: 256
non_halo_sample_seed: 0
# Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma)
protection_normalization: "rank_power"
protection_rank_power: 8.0
protection_floor: 0.2
protect_core: true
protect_core_metrics:
- "scar_loss_proxy" # SCAR-LP
Expand Down Expand Up @@ -232,7 +253,7 @@ pruning:
dependency_aware: true

sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]

algorithms:
- "rayleigh_quotient"
Expand Down
2 changes: 1 addition & 1 deletion configs/prune_llm/qwen2_7b_unified.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ cascade_analysis:
pruning:
enabled: true
ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
selection_modes: ["low", "high"]
selection_modes: ["low"]
distribution: "uniform"
min_per_layer: 0.0
max_per_layer: 0.95
Expand Down
10 changes: 7 additions & 3 deletions configs/vision_prune/mobilenetv2_cifar10_unified.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ metrics:

rayleigh_quotient:
enabled: true
relative: true
relative: false # Standard Rayleigh quotient (no trace-normalization)
shrinkage: true

redundancy:
Expand Down Expand Up @@ -143,7 +143,7 @@ cascade_analysis:
# More sensitive to pruning - interesting to see which metrics matter
pruning:
enabled: true
distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity
distribution: "global_threshold" # uniform, global_threshold, adaptive_sensitivity
dependency_aware: true # MobileNet has inverted residuals
min_per_layer: 0.0
max_per_layer: 0.95
Expand All @@ -156,6 +156,7 @@ pruning:
# =========================================================================
- "random" # Random baseline
- "magnitude" # Standard magnitude pruning (prune low)
- "activation_mean" # Mean |activation| baseline
- "taylor" # Gradient-based importance
- "network_slimming" # Network Slimming (BN gamma) baseline
- "geometric_median" # FPGM-style geometric median baseline
Expand All @@ -167,6 +168,8 @@ pruning:
- "rq_low" # Prune low Rayleigh Quotient
- "redundancy_low" # Prune low redundancy (MI)
- "synergy_low" # Prune low synergy
- "redundancy_high" # Control: prune high redundancy
- "synergy_high" # Control: prune high synergy

# =========================================================================
# COMPOSITE COMBINATIONS
Expand All @@ -183,6 +186,7 @@ pruning:
# CLUSTER-AWARE
# =========================================================================
- "cluster_aware" # Original: protect critical, target redundant
- "cluster_aware_annealed" # Annealed mixing / constraints schedule
- "cluster_aware_protect_redundant" # Inverted: protect redundant

scoring_methods:
Expand All @@ -203,7 +207,7 @@ pruning:
epochs: 5
learning_rate: 0.0001
weight_decay: 0.00001
max_batches: 100
max_batches: 200

# -----------------------------------------------------------------------------
# EVALUATION (Enhanced for Vision)
Expand Down
Loading
Loading