Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ requires-python = ">=3.11,<3.13"
dependencies = [
"torch==2.9.1",
"triton",
"transformers==4.57.1",
"transformers>=5.3.0",
"xxhash",
"numpy",
"safetensors",
"tqdm",
"flashinfer-python==0.6.6",
"sgl-kernel==0.3.21",
"nvidia-cutlass-dsl>=4.3.4",
"wandb==0.22.0",
"hf_transfer",
"tiktoken",
# Install from source for now, for latest support on Hopper
"flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@5301a359f59ef8fa10f211618d9f7a69716a8898#subdirectory=flash_attn/cute",
]

[project.urls]
Expand Down
280 changes: 43 additions & 237 deletions ssd/engine/helpers/cudagraph_helpers.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ssd/engine/helpers/runner_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _dump_ts():
print(f"[{_ts()}] BANANA: Dumping tensors to {DUMP_TENSORS_DIR}")
os.makedirs(DUMP_TENSORS_DIR, exist_ok=True)
DUMP_TENSORS = True
else:
DUMP_TENSORS = False

def list_to_str(lst: list[float] | list[list[float]], num_decimals: int = 4) -> str:
assert len(lst) > 0
Expand Down
138 changes: 35 additions & 103 deletions ssd/engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from multiprocessing.shared_memory import SharedMemory
from transformers import AutoTokenizer, AutoConfig
import os
import flashinfer
from ssd.config import Config
from ssd.engine.sequence import Sequence
from ssd.models.qwen3 import Qwen3ForCausalLM
Expand All @@ -35,7 +34,6 @@
capture_fi_tree_decode_cudagraph,
capture_glue_decode_cudagraph,
)
from ssd.engine.helpers.mask_helpers import get_custom_mask

NCCL_LOG = os.environ.get("SSD_NCCL_LOG", "0") == "1"

Expand All @@ -61,7 +59,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra
self.hf_config = config.hf_config if not is_draft else config.draft_hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path if config.tokenizer_path else config.model, use_fast=True)
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path if config.tokenizer_path else config.model, use_fast=True, trust_remote_code=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static Code Analysis Risk: Together python huggingface trust remote code

trust_remote_code=True downloads and executes arbitrary Python code from the model repository without sandboxing (OWASP LLM03:2025 Supply Chain). A malicious or compromised model repo can achieve RCE on every host that loads the model (CWE-94). Pin to a verified commit hash and audit remote code before use, or use models that don't require trust_remote_code.

Severity: High 🚨
Status: Open 🔴

References:

  1. https://cwe.mitre.org/data/definitions/94
  2. https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
  3. https://genai.owasp.org/llmrisk/llm032025-supply-chain/
  4. https://hiddenlayer.com/research/weaponizing-machine-learning-models-with-ransomware/

Suggested reviewers 🧐: @avnermay

More details:

🌻 View in Arnica

If you see an issue, please contact Shasheen in the #security-engineering Slack channel.


Take action by replying with an [arnica] command 💬

Actions

Use [arnica] or [a] to interact with the Arnica bot to acknowledge or dismiss code risks.

To acknowledge the finding as a valid code risk: [arnica] ack <acknowledge additional details>

To dismiss the risk with a reason: [arnica] dismiss <fp|accept|capacity> <dismissal reason>

Examples

  • [arnica] ack This is a valid risk and I'm looking into it

  • [arnica] dismiss fp Dismissed - Risk Not Accurate: (i.e. False Positive)

  • [arnica] dismiss accept Dismiss - Risk Accepted: Allow the risk to exist in the system

  • [arnica] dismiss capacity Dismiss - No Capacity: This will need to wait for a future sprint

self.max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size

assert self.hf_config is not None, "ERROR in ModelRunner: hf_config is None" # this implies boundedness to the end
Expand Down Expand Up @@ -98,11 +96,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra
self.device = torch.device(f'cuda:{self.rank}')
self._cmd = torch.empty(1, dtype=torch.int64, device=self.device)


# cudagraph logic for FlashInfer kernels, need diff wrapper for each batch size we make a graph for
if is_draft and config.draft_async:
self._init_flashinfer_wrappers()


if self.verbose: print(f'INSIDE MODEL RUNNER INIT, DRAFT={is_draft}', flush=True)
self.tp_pg = None

Expand Down Expand Up @@ -167,56 +161,6 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra

if self.verbose: print(f'-----{model_type}MODEL RUNNER INITIALIZED----', flush=True)

def _init_flashinfer_wrappers(self):
"""Initialize FlashInfer wrappers for draft async mode."""
self.workspace_buffer = torch.zeros(
768 * 1024 * 1024, dtype=torch.uint8, device=f"cuda:{self.rank}")

if self.config.enforce_eager:
self.only_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
else:
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size

# FlashInfer kernel tensors
# pages_for_max_len = (self.config.max_model_len + self.block_size - 1) // self.block_size
last_page_len_max_len = self.config.max_model_len % self.block_size
last_page_len_max_len = self.block_size if last_page_len_max_len == 0 else last_page_len_max_len
MQ_LEN = self.config.async_fan_out * (self.config.speculate_k + 1)

cu_seqlens_q = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device)
kv_indptr = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device)
kv_indices = torch.empty(max_bs * max_num_blocks, dtype=torch.int32, device=self.device)
kv_last_page_len = torch.empty(max_bs, dtype=torch.int32, device=self.device)
custom_mask_buf = torch.empty(max_bs * MQ_LEN * self.config.max_model_len, dtype=torch.uint8, device=self.device)
mask_indptr_buf = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device)

# Create graph_bs_list to match what will be used in cudagraph_helpers.py
graph_bs_list = [1]
for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)):
if bs <= max_bs:
graph_bs_list.append(bs)
if max_bs not in graph_bs_list:
graph_bs_list.append(max_bs)
graph_bs_list.sort()

# Create a dict of wrappers, one for each bs we will touch in cudagraph_helpers.py
self.prefill_wrappers = {}
print(f'[model_runner about to wrapper.init()] graph_bs_list={graph_bs_list}', flush=True)
for bs in graph_bs_list:
self.prefill_wrappers[bs] = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD",
use_cuda_graph=True,
qo_indptr_buf=cu_seqlens_q[:bs + 1],
paged_kv_indptr_buf=kv_indptr[:bs + 1],
paged_kv_indices_buf=kv_indices[:bs * max_num_blocks],
paged_kv_last_page_len_buf=kv_last_page_len[:bs],
custom_mask_buf=custom_mask_buf[:bs * MQ_LEN * self.config.max_model_len],
mask_indptr_buf=mask_indptr_buf[:bs + 1],
)
print(f'wrapper backend is {self.prefill_wrappers[bs]._backend}', flush=True)


def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoConfig, init_q=None, is_draft=False):
# cudagraphs
self.graph_vars = {}
Expand Down Expand Up @@ -268,6 +212,12 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC

if config.draft_async: # move this here so we don't get a timeout waiting for draft rank while load_model happens?
if config.async_nccl_port is not None:
print(
f'[model_runner] Waiting for target server at '
f'{config.async_nccl_host}:{config.async_nccl_port} '
f'to form NCCL process group...',
flush=True,
)
from torch.distributed import TCPStore
from ssd.utils.dist_utils import init_custom_process_group
store = TCPStore(config.async_nccl_host, port=config.async_nccl_port,
Expand All @@ -293,15 +243,14 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC
assert sum(config.fan_out_list) == sum(config.fan_out_list_miss) == config.async_fan_out * (config.speculate_k + 1), "ERROR in ModelRunner: fancy sampling only supported for constant fan out for now."

self.sampler = Sampler(sampler_x=config.sampler_x, async_fan_out=config.async_fan_out)
if self.verbose:
print(f'-----WARMING UP {model_type}MODEL----', flush=True)
print(f'[model_runner] Warming up {model_type}model...', flush=True)
self.warmup_model()
if self.verbose:
print(f'-----ALLOCATING {model_type}KV CACHE----', flush=True)
print(f'[model_runner] Allocating {model_type}KV cache...', flush=True)
self.allocate_kv_cache()

if not self.enforce_eager:
# if not self.is_draft or (self.is_draft and self.config.draft_async and self.config.speculate):
print(f'[model_runner] Capturing CUDA graphs for {model_type}model...', flush=True)
# if not self.is_draft or (self.is_draft and self.config.draft_async and self.config.speculate):
decode_graph_vars, decode_graph_pool, decode_graphs, decode_graph_bs_list = capture_cudagraph(self) # decode cudagraph, draft needs in spec and target in normal
self.graph_vars["decode"] = decode_graph_vars
self.graph_pools["decode"] = decode_graph_pool
Expand All @@ -326,6 +275,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC
self.graphs["glue_decode"] = glue_graphs
self.graph_bs_list["glue_decode"] = glue_bs_list

print(f'[model_runner] {model_type}model initialization complete.', flush=True)
if init_q is not None:
# Signal the scheduler that we're fully initialized (model loaded,
# KV cache allocated, CUDA graphs captured). Must happen after
Expand Down Expand Up @@ -543,15 +493,21 @@ def allocate_kv_cache(self):
)

print(f"allocate_kv_cache(): kv_cache shape = {self.kv_cache.shape}", flush=True)

# Create tree_score_mod once (shared across all attention layers)
tree_score_mod = None
if self.is_draft and self.draft_async:
from ssd.layers.tree_mask import create_tree_score_mod
tree_score_mod = create_tree_score_mod(config.max_model_len)

layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
if self.is_draft and self.draft_async and not self.enforce_eager:
module.prefill_wrappers = self.prefill_wrappers
elif self.is_draft and self.draft_async and self.enforce_eager:
module.only_prefill_wrapper = self.only_prefill_wrapper # this will make it not None so it can be used on fwd
if self.is_draft and self.draft_async:
module.max_seqlen_k = config.max_model_len
module.tree_score_mod = tree_score_mod
layer_id += 1


Expand Down Expand Up @@ -602,45 +558,21 @@ def prepare_sample(self, seqs: list[Sequence]):
return temperatures

def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits):
"""Plan FlashInfer for tree decode in eager mode"""
"""Set up context metadata for FA4 tree decode in eager mode."""
assert self.is_draft and self.config.draft_async, "ERROR in eager_tree_decode_plan: not a draft async model"
from ssd.layers.tree_mask import build_tree_mask_bias
context = get_context()

K, F = self.config.speculate_k, self.config.async_fan_out
# MQ_LEN = F * (K+1)
K = self.config.speculate_k
MQ_LEN = self.config.MQ_LEN
flat_batch_size = input_ids.size(0)
B = flat_batch_size // MQ_LEN # [N] tokens = B * sum(fan_out_list)

# Convert block_tables to FlashInfer format
block_tables = context.block_tables # [B, M]
context_lens = context.context_lens # [B]

counts = (context_lens + self.block_size - 1) // self.block_size # [B]
kv_indptr = torch.cat([torch.tensor([0], device=block_tables.device),
counts.cumsum(dim=0)]).to(torch.int32)
mask = torch.arange(block_tables.size(1), device=block_tables.device)[None, :] < counts[:, None]
kv_indices = block_tables[mask] # flattened page ids

# Last-page actual token count per request
kv_last_page_len = (context_lens % self.block_size)
kv_last_page_len[kv_last_page_len == 0] = self.block_size
kv_last_page_len = kv_last_page_len.to(torch.int32)
cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN # assumes same MQ_LEN across batch dimension
custom_mask = get_custom_mask(self.config, context_lens, step, K, F, B, device=self.device, cache_hits=cache_hits)

self.only_prefill_wrapper.plan(
cu_seqlens_q,
kv_indptr,
kv_indices,
kv_last_page_len,
self.hf_config.num_attention_heads,
self.hf_config.num_key_value_heads,
self.hf_config.head_dim,
self.block_size,
custom_mask=custom_mask,
q_data_type=self.hf_config.torch_dtype,
kv_data_type=self.hf_config.torch_dtype,
B = input_ids.size(0) // MQ_LEN
context.tree_cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN
context.tree_mask_bias = build_tree_mask_bias(
context.context_lens, step=step, K=K, MQ_LEN=MQ_LEN,
fan_out_list=self.config.fan_out_list,
fan_out_list_miss=self.config.fan_out_list_miss,
cache_hits=cache_hits,
max_kv_stride=self.config.max_model_len,
device=self.device,
)

@torch.inference_mode()
Expand Down
61 changes: 39 additions & 22 deletions ssd/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import triton
import triton.language as tl

from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from flash_attn.cute.interface import flash_attn_varlen_func as fa4_varlen_func
from ssd.layers.tree_mask import create_tree_score_mod
from ssd.utils.context import get_context


Expand Down Expand Up @@ -65,10 +66,10 @@ def __init__(
self.speculate = speculate
self.draft_async = draft_async
self.use_eagle = use_eagle
self.prefill_wrappers = {}
self.F = F # async_fan_out
self.K = K # speculate_k
self.only_prefill_wrapper = None
self.max_seqlen_k = 0 # set during KV cache allocation to config.max_model_len
self.tree_score_mod = None # set during KV cache allocation

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
o: torch.Tensor
Expand All @@ -87,7 +88,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
k, v = k_cache, v_cache

k, v = k.view(-1, self.num_kv_heads, self.head_dim), v.view(-1, self.num_kv_heads, self.head_dim)
o = flash_attn_varlen_func(q, k, v,
o, _ = fa4_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True)
Expand All @@ -104,29 +105,45 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

if verify_or_glue:
assert context.context_lens is not None
o = flash_attn_with_kvcache(q, k_cache, v_cache,
cache_seqlens=context.context_lens, page_table=context.block_tables,
o, _ = fa4_varlen_func(q, k_cache, v_cache,
cu_seqlens_q=context.cu_seqlens_q,
cu_seqlens_k=None,
max_seqlen_q=context.max_seqlen_q,
max_seqlen_k=self.max_seqlen_k,
seqused_k=context.context_lens,
page_table=context.block_tables,
softmax_scale=self.scale, causal=True,
cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q,
)

elif tree_decode:
if self.only_prefill_wrapper is not None:
prefill_wrapper = self.only_prefill_wrapper
else:
mq_len = self.F * (self.K+1)
bs = q.shape[0] // mq_len
wrapper_bs = None
for available_bs in sorted(self.prefill_wrappers.keys()):
if available_bs >= bs:
wrapper_bs = available_bs
break
prefill_wrapper = self.prefill_wrappers[wrapper_bs]
o = prefill_wrapper.run(q, (self.k_cache, self.v_cache))
score_mod_kwargs = {}
if self.tree_score_mod is not None and context.tree_mask_bias is not None:
score_mod_kwargs["score_mod"] = self.tree_score_mod
score_mod_kwargs["aux_tensors"] = [context.tree_mask_bias]
o, _ = fa4_varlen_func(
q,
self.k_cache,
self.v_cache,
cu_seqlens_q=context.tree_cu_seqlens_q,
cu_seqlens_k=None,
max_seqlen_q=self.F * (self.K + 1),
max_seqlen_k=self.max_seqlen_k,
seqused_k=context.context_lens,
page_table=context.block_tables,
softmax_scale=self.scale,
causal=False,
**score_mod_kwargs,
)
else: # single query decode
q = q.unsqueeze(1)
o = flash_attn_with_kvcache(q, k_cache, v_cache,
cache_seqlens=context.context_lens, page_table=context.block_tables,
batch_size = context.context_lens.shape[0]
cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=q.device)
o, _ = fa4_varlen_func(q, k_cache, v_cache,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=None,
max_seqlen_q=1,
max_seqlen_k=self.max_seqlen_k,
seqused_k=context.context_lens,
page_table=context.block_tables,
softmax_scale=self.scale, causal=True,
)

Expand Down
Loading