Skip to content

Comments

feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task#13

Open
neerajaryaai wants to merge 5 commits intodlb_v2from
forward_only_generation
Open

feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task#13
neerajaryaai wants to merge 5 commits intodlb_v2from
forward_only_generation

Conversation

@neerajaryaai
Copy link
Collaborator

@neerajaryaai neerajaryaai commented Feb 4, 2026

Summary

This PR introduces a new forward-only-generation task type to run_task() that enables fast token generation without computing relevance scores. This is useful when users only need generated tokens and want to minimize memory usage and latency.

It also adds a two-step API (forward_pass() + relevance_pass()) that decouples multi-token generation from relevance computation, giving users full control over when and which tokens to explain.

Changes

New Feature: forward-only-generation Task

  • Added new task type "forward-only-generation" to run_task() method
  • Implements autoregressive token generation loop with:
    • Greedy decoding (default)
    • Temperature scaling for controlling output distribution
    • Top-k sampling for limiting token candidates
    • Top-p (nucleus) sampling for dynamic vocabulary truncation
  • Supports early stopping via eos_token_id parameter
  • Clears node_io between generation steps to reduce memory footprint

Improved Visualization API

  • Added show and inline_format parameters to visualize_dlbacktrace() method
  • Made visualization output path consistent between small and large graphs

New Feature: Two-Step forward_pass() + relevance_pass() API

  • forward_pass() — Runs autoregressive generation for N tokens, storing per-step node_io snapshots efficiently:
    • Clones only tensor data (input_values, output_values, layer_hyperparams); shallow-copies immutable graph metadata
    • Moves snapshot tensors to CPU by default to free GPU VRAM
    • Clears GPU memory + runs gc.collect() between steps
    • Supports greedy, temperature, top-k, and top-p sampling
    • Returns generated_token_ids, complete_sequence, and num_steps
  • relevance_pass() — Computes relevance for selected tokens from a prior forward_pass():
    • Reads from self.node_io_trace (populated by forward_pass())
    • Accepts token_indices to explain specific steps (e.g., [0, 4, 9]) or all steps (None)
    • Returns per-step {'step_index', 'token_id', 'relevance'} dicts
  • clear_traces() — Frees all stored snapshots, relevance data, and GPU cache

Usage Examples

# Fast forward-only generation (no relevance, no tokenizer needed)
results = dlb.run_task(
    task="forward-only-generation",
    inputs={'input_ids': input_ids, 'attention_mask': attention_mask},
    max_new_tokens=50,
    temperature=0.8,
    top_k=50,
    top_p=0.9,
    eos_token_id=tokenizer.eos_token_id,
)
generated_tokens = results['generated_token_ids']  # List[int]
# Two-step API: generate first, explain later
# Step 1 — Forward pass (generate 10 tokens, store node I/O)
result = dlb.forward_pass(
    inputs={'input_ids': input_ids, 'attention_mask': attention_mask},
    max_new_tokens=10,
    temperature=0.8,
    top_k=50,
    debug=True,
)
print(result['generated_token_ids'])  # [tok1, tok2, ..., tok10]

# Step 2 — Relevance pass (explain selected tokens)
relevance_results = dlb.relevance_pass(
    token_indices=[0, 4, 9],  # explain 1st, 5th, and 10th tokens
    multiplier=100.0,
    debug=True,
)
for r in relevance_results:
    print(f"Token {r['token_id']} (step {r['step_index']})")

# Cleanup
dlb.clear_traces()

@neerajaryaai neerajaryaai changed the title feat: Add 'forward-only-generation' task for fast token generation without relevance tracing feat: Add 'forward-only-generation' task for fast token generation without relevance tracing and two-step forward_pass() + relevance_pass() API for decoupled multi-token generation and explanation Feb 9, 2026
@neerajaryaai neerajaryaai changed the title feat: Add 'forward-only-generation' task for fast token generation without relevance tracing and two-step forward_pass() + relevance_pass() API for decoupled multi-token generation and explanation feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task Feb 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant