Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a8be14c
Changes for SGLang support
avnermay Mar 18, 2026
1b2af07
Small test script
avnermay Mar 18, 2026
b9aceb5
Changes
avnermay Mar 18, 2026
fb9546a
Runner helpers
avnermay Mar 18, 2026
e8f7292
Updates to small test, assert in loader.py
avnermay Mar 18, 2026
af8c8ac
Changes
avnermay Mar 18, 2026
ff11967
Refactor of runner_helpers for all send/receive commands to use same …
avnermay Mar 19, 2026
6795127
Switch some torch.empty calls back to torch.zeros for correctness
avnermay Mar 19, 2026
04439b1
Add PrefillRequest and SpeculationRequest objects in runner_helpers.py
avnermay Mar 19, 2026
a3d6cf0
NIT bug fix
avnermay Mar 20, 2026
0b8a6e5
Further refactor of PrefillRequest, SpeculationRequest, SpeculationRe…
avnermay Mar 20, 2026
6a36a14
Improvements to logging
avnermay Mar 21, 2026
4c127df
dist_utils needed for cross-node support
avnermay Mar 23, 2026
82ca79c
Fix bugs in how recovery_activations and eagle_activations are set an…
avnermay Mar 23, 2026
66b8b7b
FA4 support
avnermay Mar 28, 2026
65301a3
Add tests and tree_mask.py so that FA4 works
avnermay Mar 28, 2026
fc1130d
Remove debug loading of Eagle activations
avnermay Mar 28, 2026
aa50214
Merge branch 'avner/sglang' into avner/sglang-fa4
avnermay Mar 28, 2026
d1c9215
Update pyproject.toml to reflect flash-attn 4 dependency, and no more…
Mar 28, 2026
2463748
Fix FA4 import
avnermay Mar 28, 2026
d86d0fb
Add logging statement once draft process is waiting for target proces…
avnermay Mar 28, 2026
1425f32
Trust remote code fix
avnermay Mar 28, 2026
cb51158
Add logging for draft model warmup
avnermay Mar 28, 2026
bfcb931
Switch all attention calls to use FA4
avnermay Mar 29, 2026
cce45eb
Add tests for attention fa4
avnermay Mar 29, 2026
080c4a3
Upgrade transformers, pin FA4
avnermay Mar 29, 2026
eb5e612
DUMP_TENSORS=false fix
avnermay Mar 30, 2026
ff59fdf
Switch from ssh to https git dependency in pyproject.toml
avnermay Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions bench/small_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import os

from transformers import AutoTokenizer
from ssd import LLM, SamplingParams

if __name__ == '__main__':

llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6'
llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b'
eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd'
# eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B'
assert os.path.isdir(llama_1b_path)
assert os.path.isdir(llama_70b_path)
assert os.path.isdir(eagle_path)

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=llama_1b_path)
parser.add_argument("--draft", type=str, default=llama_1b_path)
parser.add_argument("--eagle", action="store_true")
parser.add_argument("--k", type=int, default=7)
parser.add_argument("--jit-speculate", action="store_true")
parser.add_argument("--num-gpus", type=int, default=2)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--chat-template", action="store_true")
parser.add_argument("--communicate-logits", action="store_true")
parser.add_argument("--communicate-cache-hits", action="store_true")
parser.add_argument("--mary", action="store_true")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()

if args.eagle:
args.draft = eagle_path
args.model = llama_70b_path
args.num_gpus = 5
args.jit_speculate = True
args.chat_template = True

llm = LLM(
model=args.model,
draft=args.draft,
use_eagle=args.eagle,
speculate_k=args.k,
speculate=True,
draft_async=True,
num_gpus=args.num_gpus,
jit_speculate=args.jit_speculate,
verbose=args.verbose,
communicate_logits=args.communicate_logits,
communicate_cache_hits=args.communicate_cache_hits,
)
sampling_params = [SamplingParams(temperature=0.0, max_new_tokens=64, ignore_eos=args.ignore_eos)]

if args.mary:
text = "Can you please tell me the lyrics to Mary had a little lamb, and can you repeat it 10 times?"
else:
text = "What is the capital city of France?"
if args.chat_template:
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokens = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
)
token_str = tokenizer.decode(tokens)
print(f"Generating response to prompt: '{token_str}'")
print(f"=============================================================")
outputs, _ = llm.generate([tokens], sampling_params)

else:
outputs, _ = llm.generate([text], sampling_params)

print(outputs[0]["text"])
27 changes: 11 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,20 @@ readme = "README.md"
description = "Async tree-based speculative decoding research engine"
requires-python = ">=3.11,<3.13"
dependencies = [
"torch==2.8.0",
"triton==3.4.0",
"transformers==4.57.1",
"xxhash==3.5.0",
"numpy==2.3.3",
"safetensors==0.6.2",
"tqdm==4.67.1",
"flashinfer-python==0.5.2",
"sgl-kernel==0.3.17.post1",
"nvidia-cutlass-dsl==4.2.1",
"torch==2.9.1",
"triton",
"transformers>=5.3.0",
"xxhash",
"numpy",
"safetensors",
"tqdm",
"sgl-kernel==0.3.21",
"nvidia-cutlass-dsl>=4.3.4",
"wandb==0.22.0",
"hf_transfer",
"tiktoken",
]

[project.optional-dependencies]
scripts = [
"datasets",
"huggingface_hub",
# Install from source for now, for latest support on Hopper
"flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@5301a359f59ef8fa10f211618d9f7a69716a8898#subdirectory=flash_attn/cute",
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion ssd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@
prepare_decode_tensors_from_seqs,
prepare_block_tables_from_seqs,
prepare_prefill_tensors_from_seqs,
prepare_prefill_payload,
PrefillRequest,
SpeculationRequest,
SpeculationResponse,
)
50 changes: 36 additions & 14 deletions ssd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import torch
from ssd.paths import DEFAULT_TARGET, DEFAULT_DRAFT


@dataclass
class Config:
model: str = DEFAULT_TARGET
max_num_batched_tokens: int = 16384
max_num_seqs: int = 1
max_model_len: int = 4096
max_num_seqs: int = 1
max_model_len: int = 4096
gpu_memory_utilization: float = 0.7
num_gpus: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 256
kvcache_block_size: int = 1
num_kvcache_blocks: int = -1
device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand All @@ -25,13 +26,17 @@ class Config:
draft: str = DEFAULT_DRAFT
speculate_k: int = 1
draft_async: bool = False

# async spec only
async_fan_out: int = 3
fan_out_list: list[int] | None = None
fan_out_list_miss: list[int] | None = None
sampler_x: float | None = None
jit_speculate: bool = False
jit_speculate: bool = False
async_nccl_port: int | None = None
async_nccl_host: str = "127.0.0.1"
communicate_logits: bool = False
communicate_cache_hits: bool = False

# eagle3
use_eagle: bool = False
Expand All @@ -49,26 +54,35 @@ def max_blocks(self):
return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size

def __post_init__(self):
model = self.model
model = self.model
assert os.path.isdir(model)

assert 1 <= self.num_gpus <= 8 # this codebase only works on one node
self.hf_config = AutoConfig.from_pretrained(model)
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings)
if self.speculate:

if not self.speculate:
if self.max_model_len:
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings)
else:
self.max_model_len = self.hf_config.max_position_embeddings
else:
draft = self.draft
self.draft_hf_config = AutoConfig.from_pretrained(draft)
self.max_model_len = min(
self.max_model_len, self.draft_hf_config.max_position_embeddings)
if self.max_model_len:
self.max_model_len = min(
self.max_model_len, self.draft_hf_config.max_position_embeddings)
else:
self.max_model_len = self.draft_hf_config.max_position_embeddings

if self.draft_async:
if self.fan_out_list is None:
self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1)
self.MQ_LEN = sum(self.fan_out_list)
if self.fan_out_list_miss is None:
self.fan_out_list_miss = self.fan_out_list
assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list"

if self.use_eagle:
if self.eagle_layers is None:
L = self.hf_config.num_hidden_layers
Expand All @@ -90,5 +104,13 @@ def __post_init__(self):
if target_max_pos != draft_max_pos:
print(f'[Config] Overriding eagle draft max_position_embeddings: {draft_max_pos} -> {target_max_pos}', flush=True)
self.draft_hf_config.max_position_embeddings = target_max_pos

assert self.max_num_batched_tokens >= self.max_model_len

if self.sampler_x is not None and not self.communicate_cache_hits:
self.communicate_cache_hits = True
print(f'[Config] Setting communicate_cache_hits to True because sampler_x is not None', flush=True)

# assert self.max_num_batched_tokens >= self.max_model_len
if self.max_num_batched_tokens < self.max_model_len:
print(f'[Config] Warning: max_num_batched_tokens ({self.max_num_batched_tokens}) is less than max_model_len ({self.max_model_len})', flush=True)
print(f'[Config] Setting max_num_batched_tokens to max_model_len', flush=True)
self.max_num_batched_tokens = self.max_model_len
5 changes: 5 additions & 0 deletions ssd/engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def _deallocate_n_blocks(self, block_ids: list[int]): # we need to separate wher

def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0

if self.blocks[block_id].hash != -1: # if block was finalized, remove from hash_to_block_id checkme
if self.hash_to_block_id.get(self.blocks[block_id].hash) == block_id:
del self.hash_to_block_id[self.blocks[block_id].hash]

self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)

Expand Down
Loading