Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions configs/examples/gemma2b_pruning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# ============================================================================
# GEMMA-2B COMPREHENSIVE PRUNING COMPARISON
# ============================================================================
# Reference: https://arxiv.org/abs/2403.08295 (Gemma)
#
# PURPOSE: Test SCAR on smaller efficient model (faster experiments)
#
# Gemma-2B specs:
# - 18 layers, 2048 hidden dim, 16384 intermediate dim
# - Uses GeGLU activation (similar to SwiGLU)
# - Efficient model good for rapid experimentation
#
# NOTE: Gemma uses slightly different layer naming than Llama
# ============================================================================

experiment:
name: "gemma2b_pruning"
type: "llm_alignment"
seed: 42
device: "cuda"
output_dir: "./results/gemma2b_pruning"
num_networks: 1

model:
name: "hf_causal_lm"
model_id: "google/gemma-2b"
dtype: "bfloat16"
device_map: "auto"

# Gemma uses same MLP structure as Llama (GeGLU, similar to SwiGLU)
# Layer naming: model.layers.*.mlp.{up_proj, gate_proj, down_proj}
tracked_layers:
- "model.model.layers.*.mlp.up_proj"
- "model.model.layers.*.mlp.gate_proj"
- "model.model.layers.*.mlp.down_proj"

dataset:
name: "wikitext"
batch_size: 1
num_workers: 0

# ============================================================================
# IMPORTANCE METRICS
# ============================================================================
metrics:
enabled:
- "rayleigh_quotient"
- "gaussian_mi_analytic"
- "average_redundancy"
- "activation_l2_norm"

num_samples: 64

rayleigh_quotient:
relative: true
regularization: 1.0e-6

# ============================================================================
# LLM-SPECIFIC SETTINGS
# ============================================================================
llm:
scar_metrics: true
scar_num_samples: 64
scar_max_length: 512

evaluate_perplexity: true
evaluation_num_samples: 200

use_nvidia_fewshot: true
use_chain_of_thought: true

evaluation_metrics:
- "perplexity"
- "loss"
- "accuracy_winogrande"
- "accuracy_arc_challenge"
- "accuracy_mmlu"
- "accuracy_hellaswag"
- "accuracy_arc_easy"
- "accuracy_piqa"
- "accuracy_boolq"

# ============================================================================
# SUPERNODE CONFIGURATION
# ============================================================================
supernode:
enabled: true
core_fraction: 0.01
follower_fraction: 0.10
score_metric: "activation_l2_norm"
protect_core: true
cross_layer_analysis: true
compare_by_connection: true

compute_metrics:
- "activation"
- "rayleigh_quotient"
- "mutual_information"
- "redundancy"

# ============================================================================
# SUPERNODE ROBUSTNESS ANALYSIS
# ============================================================================
supernode_robustness:
enabled: true
supernode_fraction: 0.01
num_bootstrap_samples: 5
batch_size: 32
max_samples: 128

metrics:
- "scar_activation_power"
- "scar_loss_proxy"
- "rayleigh_quotient"
- "activation_l2_norm"

target_layers: null

# ============================================================================
# SUPERNODE SUMMARY ANALYSIS
# ============================================================================
# Generates summary plots:
# 1. Halo vs Non-Halo metrics by layer (mean activation, RQ, MI, redundancy)
# 2. Supernode outlier z-scores by layer (how much of an outlier)
supernode_summary:
enabled: true
outlier_analysis: true

# ============================================================================
# PRUNING CONFIGURATION
# ============================================================================
pruning:
enabled: true

sparsity_levels: [0.25, 0.5, 0.75]

selection_modes: ["low", "high"]

distribution: "uniform"
structured: true
dependency_aware: true

algorithms:
# Alignment-based
- "rayleigh_quotient"
- "gaussian_mi_analytic"
- "average_redundancy"
# SCAR
- "scar_loss_proxy"
# Supernode-aware
- "supernode_protection_score"
- "supernode_connectivity_score"
# Baselines
- "activation_l2_norm"
- "wanda"
- "sparsegpt"

single_strategy: null

fine_tune:
enabled: false

# ============================================================================
# ADVANCED ANALYSIS FLAGS
# ============================================================================
do_directed_redundancy: true
do_connectivity_pruning: true

# ============================================================================
# ANALYSIS & VISUALIZATION
# ============================================================================
analysis:
save_scores: true
generate_plots: true

plots:
histograms: true
scatter_plots: true
pruning_curves: true
redundancy_heatmaps: true

scatter_pairs:
- ["activation_l2_norm", "rayleigh_quotient"]
- ["activation_l2_norm", "gaussian_mi_analytic"]
- ["scar_activation_power", "scar_loss_proxy"]
- ["rayleigh_quotient", "scar_loss_proxy"]
- ["average_redundancy", "rayleigh_quotient"]

visualization:
format: "png"
dpi: 300
145 changes: 145 additions & 0 deletions configs/examples/gpt2_fast_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# ============================================================================
# GPT-2 FAST PRUNING TEST
# ============================================================================
# PURPOSE: Quick validation of pruning pipeline on small model
#
# GPT-2 (124M) specs:
# - 12 layers, 768 hidden dim, 3072 intermediate dim
# - Uses standard MLP (fc1 -> GELU -> fc2), NOT SwiGLU
# - Very fast for testing (~10-30 min on H100)
#
# NOTE: GPT-2 uses different MLP naming than Llama/Mistral:
# - h.*.mlp.c_fc (up projection)
# - h.*.mlp.c_proj (down projection)
# - No gate_proj (standard MLP, not gated)
#
# Use this config to quickly test pipeline before running larger models
# ============================================================================

experiment:
name: "gpt2_fast_test"
type: "llm_alignment"
seed: 42
device: "cuda"
output_dir: "./results/gpt2_fast_test"
num_networks: 1

model:
name: "hf_causal_lm"
model_id: "gpt2" # 124M params - very fast
dtype: "float16" # GPT-2 works better with fp16 than bf16
device_map: "auto"

# GPT-2 uses standard MLP (not SwiGLU)
# c_fc is the up projection, c_proj is the down projection
tracked_layers:
- "model.transformer.h.*.mlp.c_fc"
- "model.transformer.h.*.mlp.c_proj"

dataset:
name: "wikitext"
batch_size: 1
num_workers: 0

# ============================================================================
# IMPORTANCE METRICS (reduced for fast test)
# ============================================================================
metrics:
enabled:
- "rayleigh_quotient"
- "activation_l2_norm"

num_samples: 16 # Fewer samples for speed

rayleigh_quotient:
relative: true
regularization: 1.0e-6

# ============================================================================
# LLM-SPECIFIC SETTINGS (reduced for speed)
# ============================================================================
llm:
scar_metrics: true
scar_num_samples: 16 # Fewer samples
scar_max_length: 256 # Shorter sequences

evaluate_perplexity: true
evaluation_num_samples: 50 # Fewer eval samples

use_nvidia_fewshot: false # Skip few-shot for speed

# Minimal benchmarks for quick test
evaluation_metrics:
- "perplexity"
- "loss"
- "accuracy_hellaswag" # Just one benchmark

# ============================================================================
# SUPERNODE CONFIGURATION
# ============================================================================
supernode:
enabled: true
core_fraction: 0.01
follower_fraction: 0.10
score_metric: "activation_l2_norm"
protect_core: true
cross_layer_analysis: false # Skip for speed
compare_by_connection: false

compute_metrics:
- "activation"

# ============================================================================
# SUPERNODE ROBUSTNESS (disabled for fast test)
# ============================================================================
supernode_robustness:
enabled: false

# ============================================================================
# PRUNING CONFIGURATION (reduced for speed)
# ============================================================================
pruning:
enabled: true

# Just two sparsity levels for quick test
sparsity_levels: [0.25, 0.5]

selection_modes: ["low"] # Just one mode for speed

distribution: "uniform"
structured: true
dependency_aware: true

# Reduced algorithm set
algorithms:
- "rayleigh_quotient"
- "activation_l2_norm"
- "wanda"

single_strategy: null

fine_tune:
enabled: false

# ============================================================================
# ADVANCED ANALYSIS (disabled for speed)
# ============================================================================
do_directed_redundancy: false
do_connectivity_pruning: false

# ============================================================================
# ANALYSIS & VISUALIZATION
# ============================================================================
analysis:
save_scores: true
generate_plots: true

plots:
histograms: true
scatter_plots: false # Skip for speed
pruning_curves: true
redundancy_heatmaps: false

visualization:
format: "png"
dpi: 150 # Lower DPI for speed
Loading
Loading