Skip to content

macto94/adaptive_cot_framework

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Adaptive Chain-of-Thought Framework

A research-oriented framework for Adaptive Chain-of-Thought (CoT) with self-consistency for parallel test-time scaling. The framework dynamically determines the number of reasoning branches based on prefill-stage analysis and uses true parallel generation for efficient inference.

🎯 Overview

This framework implements an innovative approach to Chain-of-Thought reasoning that:

  • Adaptively allocates computational resources based on problem difficulty
  • Uses prefill-stage analysis to extract difficulty signals (entropy, KL divergence, confidence)
  • Generates multiple reasoning paths in parallel using num_return_sequences
  • Applies self-consistency through majority voting to aggregate answers
  • Supports multiple backends (HuggingFace Transformers, vLLM)

🏗️ Architecture

Core Components

┌─────────────────┐    ┌──────────────────┐    ┌─────────────────┐
│   Prefill       │    │   Branch         │    │   Generation    │
│   Analyzer      │───▶│   Allocator      │───▶│   Engine        │
│                 │    │                  │    │                 │
│ • Entropy       │    │ • Difficulty     │    │ • Parallel      │
│ • KL Divergence │    │ • Branch Count   │    │ • num_return_   │
│ • Confidence    │    │ • Strategy       │    │   sequences     │
└─────────────────┘    └──────────────────┘    └─────────────────┘
         │                                               │
         ▼                                               ▼
┌─────────────────┐    ┌──────────────────┐    ┌─────────────────┐
│   Research      │    │   Self-          │    │   Evaluation    │
│   Logger        │    │   Consistency    │    │   Metrics       │
│                 │    │                  │    │                 │
│ • Data Logging  │    │ • Majority Vote  │    │ • Accuracy     │
│ • Signal Track  │    │ • Confidence     │    │ • Efficiency   │
│ • Performance   │    │ • Consensus      │    │ • Consistency  │
└─────────────────┘    └──────────────────┘    └─────────────────┘

Key Innovation: Two-Prefill Process

Instead of using a fixed number of branches, the framework uses a sophisticated two-prefill approach:

  1. First Prefill: Analyze problem difficulty → extract signals
  2. Branch Allocation: Determine optimal number of branches based on signals
  3. Second Prefill: Generate multiple reasoning paths with num_return_sequences
  4. Self-Consistency: Apply majority voting to get final answer

📁 Project Structure

adaptive_cot_framework/
├── src/
│   ├── adaptive/                    # Core adaptive CoT implementation
│   │   ├── adaptive_cot.py             # Main AdaptiveCoT class
│   │   ├── prefill_analyzer.py         # Prefill signal analysis
│   │   └── branch_allocator.py         # Branch allocation logic
│   ├── models/                      # Model implementations
│   │   ├── base_model.py               # Abstract base class
│   │   ├── deepseek_model.py           # DeepSeek model wrapper
│   │   ├── vllm_model.py               # vLLM model wrapper
│   │   ├── generic_model.py            # Generic HuggingFace model
│   │   └── model_factory.py            # Model factory
│   ├── benchmarks/                  # Benchmark datasets
│   │   ├── math_benchmarks.py          # Math datasets (GSM8K, AIME, etc.)
│   │   └── benchmark_factory.py        # Benchmark factory
│   ├── evaluation/                  # Evaluation framework
│   │   ├── evaluator.py                # Main evaluator
│   │   ├── metrics.py                  # Evaluation metrics
│   │   └── lighteval_integration.py    # LightEval integration
│   ├── experiments/                 # Experiment runners
│   │   └── experiment_runner.py        # Systematic experiments
│   └── utils/                       # Utilities
│       ├── research_logger.py          # Research data logging
│       ├── memory_monitor.py           # Memory monitoring
│       └── visualization.py            # Visualization tools
├── configs/
│   └── model_config.yaml            # Configuration file
├── run_experiment.py                # Main research experiment runner
├── run_quick_test.py                # Quick test script
├── test_adaptive_cot.py             # Individual test script
├── requirements.txt                 # Dependencies
└── README.md                        # This file

🚀 Quick Start

Installation (Linux, CUDA)

# Clone the repository
git clone <repository-url>
cd adaptive_cot_framework

# Create & activate a virtual environment (recommended)
python3 -m venv .venv
source .venv/bin/activate

# Upgrade pip
pip install --upgrade pip

# Install dependencies
pip install -r requirements.txt

# Optional: install vLLM for high-throughput inference
pip install vllm

# Install the package (editable)
pip install -e .

# Verify CUDA visibility
nvidia-smi | head -n 10 | cat

Basic Usage

from src.models.model_factory import ModelFactory
from src.adaptive.adaptive_cot import AdaptiveCoT

# Create model
model = ModelFactory.create_model("deepseek", "/path/to/model", config)
model.load_model()

# Create adaptive CoT
cot = AdaptiveCoT(model, config)

# Solve problem
result = cot.solve_problem("Sarah has 12 apples. She gives 3 to her friend...")
print(f"Answer: {result['answer']}")
print(f"Branches used: {result['num_branches']}")

Command Line Usage

Run math benchmarks with static/adaptiveNEW

# vLLM static sweep (e.g., 16/32/64 branches)
bash run_static_sweep_vllm.sh 0 Qwen/Qwen3-14B -1 "aime_2024 aime_2025"

# Prefill-only dump (extract signals without generation)
bash run_prefill_dump.sh 0 Qwen/Qwen3-14B "aime_2024 aime_2025" -1 50 0 2
#                              ^ GPUs   ^ model         ^ datasets         ^ all samples ^top-k ^seed ^content-steps

# T2B calibration (join latest prefill with static-bK)
bash run_t2b_calibrate.sh aime_2024 Qwen__Qwen3-14B 32
bash run_t2b_calibrate.sh aime_2025 Qwen__Qwen3-14B 32

Direct Python Testing

# Test with 1 branch, 100 samples
python test_gsm8k_full.py --branches 1 --samples 100

# Test with 8 branches, 1000 samples
python test_gsm8k_full.py --branches 8 --samples 1000 --output results/my_test.json

# Test all samples with custom output
python test_gsm8k_full.py --branches 5 --samples 1319 --output results/full_gsm8k_5branch.json

Quick Testing

# Quick test with 5 samples
python test_5_samples.py

# Test with 10 samples
python test_10_samples.py

# Test single sample for debugging
python test_single_sample.py

Individual Testing

# Test adaptive branching
python test_adaptive_cot.py --problem "Sarah has 12 apples..." --adaptive

# Test static branching
python test_adaptive_cot.py --problem "What is 2+2?" --static --branches 3

# Test with custom model
python test_adaptive_cot.py --problem "Find the area..." --model-path "/path/to/model" --adaptive

Verification and Validation

Identical Behavior Verification

We have thoroughly verified that our framework produces identical results to direct generation when using the same parameters:

Zero-Shot Verification

  • Identical reasoning paths: 100% text similarity between our framework and direct generation
  • Identical answers: All extracted answers match exactly
  • Identical accuracy: Same performance on test samples
  • Deterministic generation: Using temperature=0.0 and do_sample=False

Test Results

# Verification test (5 samples)
python test_5_samples.py

# Results show:
# ✅ Reasoning Identical: 5/5 (100.0%)
# ✅ Accuracy: 0.600 (3/5) - identical for both methods
# ✅ All answers match exactly

Key Fixes Applied

  1. Whitespace handling: Both methods now use identical text processing
  2. Random seed management: Deterministic generation with proper seed setting
  3. Stop sequence processing: Consistent application across both methods
  4. Answer extraction: Synchronized extraction logic

This verification ensures that any performance differences observed in research are due to the adaptive branching strategy itself, not implementation differences.

🔬 Research Features

Prefill Analysis Signals

The framework extracts signals from prefill and next-token distributions:

  • Sequence-level (averaged over prefill positions): entropy_seq, kl_div_seq (to uniform), confidence_seq (avg top‑1 prob)
  • Next-token (final prefill position): entropy_next, kl_to_uniform_next, tvd_uniform_next
  • Decode-matched TVD: tvd_decode_next and decode_set_size computed on the actual candidate set (top‑k then nucleus top‑p)
  • Distribution shape: top1_prob, top2_prob, margin_top2, entropy_norm
  • First content token features (optional): *_content counterparts computed at the first nontrivial token after prefill

Adaptive Branch Allocation

Branch count can be determined by either a heuristic mapping or a learned allocator:

# Example learned approach (high level):
# 1) Predict p_hat from features via a booster; 2) map to N via Hoeffding with shrinkage.

Self-Consistency Metrics

  • Consensus Confidence: Fraction of branches agreeing on the answer
  • Answer Distribution: Count of different answers across branches
  • Branch Diversity: Measures of reasoning path diversity

🏭 Backend Support

HuggingFace Transformers

# Uses num_return_sequences for efficient batch generation
generated_texts = model.generate(
    prompt,
    num_return_sequences=num_branches,
    temperature=temperature,
    do_sample=True,  # Always True for self-consistency
)

vLLM

# Uses vLLM's built-in batch generation capabilities
generated_texts = model.generate(
    prompt,
    num_return_sequences=num_branches,
    temperature=temperature,
    do_sample=True,
    # Prefix caching on, proper stop sequences, and exact tokenizer-based token counting
)

📊 Performance Characteristics

Efficiency Gains

  • 2-3x faster generation through num_return_sequences and prefix caching
  • Memory efficient through shared computation
  • True parallel processing with GPU batching

Adaptive Benefits

  • Resource optimization: More branches for difficult problems
  • Quality improvement: Better consensus through adaptive allocation
  • Research insights: Understanding problem difficulty patterns

🧪 Testing and Evaluation

Eval helpers and outputs

Outputs are organized under iclr_results/ per model/dataset/method. Prefill-only runs are stored in timestamped subdirectories under iclr_results/prefill_analysis/ with a *_latest symlink and LAST_OUTPUT_DIR.txt pointer.

Python Script Usage

# Test with 1 branch, 100 samples
python test_gsm8k_full.py --branches 1 --samples 100

# Test with 8 branches, 1000 samples
python test_gsm8k_full.py --branches 8 --samples 1000 --output results/my_test.json

# Test all samples with custom output
python test_gsm8k_full.py --branches 5 --samples 1319 --output results/full_gsm8k_5branch.json

Sample Size Specification

You can specify the number of samples to evaluate:

  • Small tests: 5-50 samples for quick validation
  • Medium tests: 100-500 samples for development
  • Large tests: 1000+ samples for research
  • Full dataset: 1319 samples (complete GSM8K)

Examples:

# Quick validation (5 samples)
./run_gsm8k_evaluation.sh 0 1 5

# Development testing (100 samples)
./run_gsm8k_evaluation.sh 0 8 100

# Research evaluation (1000 samples)
./run_gsm8k_evaluation.sh 0 8 1000

# Full dataset (all 1319 samples)
./run_gsm8k_evaluation.sh 0 1 1319

Output Format

Results are saved in JSON format with comprehensive metrics:

{
  "config": {
    "adaptive_branching": false,
    "min_branches": 1,
    "max_branches": 1,
    "default_branches": 1,
    "num_fewshot": 0,
    "temperature": 0.0,
    "top_p": 1.0,
    "max_tokens": 512
  },
  "dataset_info": {
    "name": "gsm8k",
    "total_samples": 1319,
    "evaluated_samples": 100
  },
  "results": [
    {
      "problem_id": 1,
      "question": "Janet's ducks lay 16 eggs per day...",
      "ground_truth": "18",
      "our_answer": "18",
      "our_reasoning": "## Step 1: Calculate...",
      "correct": true,
      "confidence": 1.0,
      "num_branches": 1,
      "duration": 2.5
    }
  ],
  "metrics": {
    "accuracy": 0.85,
    "correct": 85,
    "total": 100,
    "duration": 250.5,
    "avg_duration_per_problem": 2.5,
    "branch_count": 1
  },
  "timestamp": "2024-12-15 14:30:25"
}

Quick Testing Scripts

# 5-sample test (zero-shot)
python test_5_samples.py

# 10-sample test (zero-shot)
python test_10_samples.py

# Single sample debugging
python test_single_sample.py

# Individual problem testing
python test_adaptive_cot.py --problem "Your math problem here" --adaptive
python test_adaptive_cot.py --problem "Your math problem here" --static --branches 5

Benchmark Support

  • GSM8K: Grade school math problems (1319 samples)
  • AIME: American Invitational Mathematics Examination
  • MATH: Mathematical reasoning dataset
  • Olympiad: Math competition problems

Evaluation Metrics

  • Accuracy: Correctness of final answers
  • Consensus Confidence: Agreement across branches
  • Efficiency: Time and memory usage
  • Adaptive Effectiveness: Correlation between difficulty and branch count
  • Per-Problem Duration: Individual problem solving time
  • Branch Utilization: How many branches were used

⚙️ Configuration

Model Configuration

models:
  deepseek_r1_distill_qwen:
    model_name: "/path/to/model"
    model_type: "reasoning"
    generation_params:
      max_new_tokens: 2048
      temperature: 0.6
      top_p: 0.95

Adaptive Branching

adaptive_branching:
  enabled: true
  min_branches: 1
  max_branches: 10
  default_branches: 3
  prefill_analysis:
    entropy_threshold: 0.8
    kl_divergence_threshold: 0.5
    confidence_threshold: 0.7

🔧 Technical Implementation

Two-Prefill Process

def solve_problem(self, problem: str) -> Dict[str, Any]:
    # Step 1: First prefill - analyze problem difficulty
    prefill_signals = self._analyze_problem_difficulty(problem)
    
    # Step 2: Determine branch count based on signals
    num_branches = self._determine_branch_count(prefill_signals)
    
    # Step 3: Second prefill - generate multiple reasoning paths
    reasoning_paths = self._generate_reasoning_paths(problem, num_branches, prefill_signals)
    
    # Step 4: Apply self-consistency
    final_answer, consensus_info = self._apply_self_consistency(answers)
    
    return result

Prefill Analysis

def _analyze_problem_difficulty(self, problem: str) -> Dict[str, float]:
    """First prefill: Analyze problem difficulty to get signals."""
    analysis_prompt = f"Problem: {problem}\\nSolution:"
    prefill_signals = self.model.get_prefill_analysis(analysis_prompt)
    return prefill_signals

Parallel Generation

def _generate_reasoning_paths(self, problem: str, num_branches: int, prefill_signals: Dict[str, float]) -> List[str]:
    """Second prefill: Generate multiple reasoning paths using num_return_sequences."""
    cot_prompt = f"Please solve the following problem step by step...\\nProblem: {problem}\\nSolution:"
    
    generated_texts = self.model.generate(
        cot_prompt,
        num_return_sequences=num_branches,
        temperature=temperature,
        do_sample=True,  # Always True for self-consistency
    )
    
    return generated_texts

Self-Consistency

def _apply_self_consistency(self, answers: List[str]) -> Tuple[str, Dict[str, Any]]:
    """Apply self-consistency to get final answer."""
    # Clean answers for comparison
    cleaned_answers = [self._clean_answer(answer) for answer in answers]
    
    # Count answers
    answer_counts = Counter(cleaned_answers)
    
    # Get most common answer
    most_common_answer, most_common_count = answer_counts.most_common(1)[0]
    confidence = most_common_count / len(cleaned_answers)
    
    return most_common_answer, consensus_info

🎯 Research Applications

Parallel Test-Time Scaling

  • Self-Consistency: Multiple reasoning paths with majority voting
  • Adaptive Branching: Dynamic resource allocation based on difficulty
  • Efficient Generation: True parallel processing with num_return_sequences

Prefill Analysis Research

  • Difficulty Estimation: Using entropy, KL divergence, confidence
  • Resource Allocation: Optimal branch count for different problem types
  • Efficiency Optimization: Parallel generation and memory management

Benchmark Evaluation

  • Math Reasoning: GSM8K, AIME, MATH, Olympiad datasets
  • General Q&A: Extensible to other reasoning tasks
  • Performance Analysis: Speed, accuracy, and efficiency metrics

🎯 Current Status

✅ Recent updates

  • Tokenizer-accurate token counting; correct vLLM num_return_sequences and max_parallel_branches plumbing
  • Prefill-only dump mode; decode-matched TVD (tvd_decode_next) and decode_set_size
  • First content-token feature extraction (*_content)
  • T2B calibration/join and visualization scripts
  • Organized prefill outputs into per-dataset/model timestamped folders with *_latest symlink
  • Regex-based decode-time stopping for GSM8K final-answer patterns (HF backend)

🚀 Ready for Research

  • Static Branching: Single-branch and multi-branch evaluation ready
  • Deterministic Generation: Proper random seed management for reproducible results
  • Comprehensive Metrics: Accuracy, confidence, duration, and efficiency tracking
  • Flexible Testing: Support for any number of samples (5 to full dataset)

📊 Next Steps

  1. Train a boosting-based allocator (predict p̂ or N) on train splits; deploy at inference
  2. End-to-end evaluation with learned N on held-out AIME_2024/2025; report accuracy/savings
  3. Extend calibrations (e.g., b64) and ablations (single-feature, isotonic, booster)

🔮 Future Work

Immediate Improvements

  1. Advanced Prefill Analysis: More sophisticated difficulty signals
  2. Dynamic Branching: Real-time branch count adjustment
  3. Multi-Model Support: Different models for different problem types

Research Directions

  1. Advanced Consensus: Weighted voting based on confidence
  2. Memory Optimization: Better KV cache management
  3. Scalability: Support for larger models and datasets

Performance Optimization

  1. Speed Optimization: Faster prefill analysis
  2. Memory Efficiency: Better memory management
  3. Backend Integration: Enhanced vLLM support

📚 References

  • Self-Consistency: Wang et al., "Self-Consistency Improves Chain of Thought Reasoning in Language Models"
  • Adaptive Branching: Research into dynamic resource allocation
  • Prefill Analysis: Using early model signals for difficulty estimation
  • Parallel Generation: Efficient batch processing with num_return_sequences

🤝 Contributing

This is a research framework. Key areas for contribution:

  1. Advanced Prefill Analysis: More sophisticated difficulty signals
  2. Backend Support: Additional model backends
  3. Benchmark Integration: More evaluation datasets
  4. Performance Optimization: Speed and memory improvements

📄 License

This project is for research purposes. Please cite appropriately if used in research.


Last Updated: September 2025
Version: 1.1.0
Status: Active Development

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published