Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions config/llm_config_qwen3_8b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"arch": "qwen",
"log_name": "qwen3-8b",
"name": "Qwen/Qwen3-8B-Instruct",
"load_in_8bit": true,
"device_map": "auto",
"max_tokens": 3500,
"max_ctx_length": 8192
}
9 changes: 7 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import datetime
from tqdm import tqdm
from benchmark.longbench import LongBench
from promptcache.model import Llama2, Falcon, Mpt
from promptcache.model import Llama2, Falcon, Mpt, Qwen
from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
GenerationEngine, GenerationParameters

Expand All @@ -30,7 +30,7 @@ def __init__(self, llm_config_path, dataset, enable_cache, use_cpu_for_inference
self.enable_cache = enable_cache
self.use_cpu_for_inference = use_cpu_for_inference

self.model_name = self.llm_config["name"]
self.model_name = self.llm_config["name"].lower()
if "llama" in self.model_name:
self.model_name = "llama"
self.lm_for_caching = Llama2(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
Expand All @@ -40,6 +40,9 @@ def __init__(self, llm_config_path, dataset, enable_cache, use_cpu_for_inference
elif "mpt" in self.model_name:
self.model_name = "mpt"
self.lm_for_caching = Mpt(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
elif "qwen" in self.model_name:
self.model_name = "qwen"
self.lm_for_caching = Qwen(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
else:
raise ValueError("Invalid model name")

Expand All @@ -50,6 +53,8 @@ def __init__(self, llm_config_path, dataset, enable_cache, use_cpu_for_inference
self.lm = Falcon(name=self.llm_config['name'], device_map=None)
elif "mpt" in self.model_name:
self.lm = Mpt(name=self.llm_config['name'], device_map=None)
elif "qwen" in self.model_name:
self.lm = Qwen(name=self.llm_config['name'], device_map=None)
else:
self.lm = self.lm_for_caching

Expand Down
4 changes: 3 additions & 1 deletion eval_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import datetime
from tqdm import tqdm
from benchmark.longbench import LongBench
from promptcache.model import Llama2, Falcon, Mpt
from promptcache.model import Llama2, Falcon, Mpt, Qwen
from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
GenerationEngine, GenerationParameters

Expand Down Expand Up @@ -81,6 +81,8 @@ def __init__(self, gpu_id, llm_config_path, dataset_list, enable_cache):
self.lm = Falcon(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
elif self.model_arch == "mpt":
self.lm = Mpt(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
elif self.model_arch == "qwen":
self.lm = Qwen(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
else:
raise ValueError("Invalid model name")

Expand Down
6 changes: 5 additions & 1 deletion eval_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datetime
from tqdm import tqdm
from benchmark.longbench import LongBench
from promptcache.model import Llama2, Falcon, Mpt
from promptcache.model import Llama2, Falcon, Mpt, Qwen
from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
GenerationEngine, GenerationParameters

Expand Down Expand Up @@ -39,6 +39,8 @@ def __init__(self, memo, llm_config_path, use_cpu_for_inference=False):
self.lm_for_caching = Falcon(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
elif self.model_arch == "mpt":
self.lm_for_caching = Mpt(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
elif self.model_arch == "qwen":
self.lm_for_caching = Qwen(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
else:
raise ValueError("Invalid model name")

Expand All @@ -49,6 +51,8 @@ def __init__(self, memo, llm_config_path, use_cpu_for_inference=False):
self.lm = Falcon(name=self.model_name, device_map=None)
elif self.model_arch == "mpt":
self.lm = Mpt(name=self.model_name, device_map=None)
elif self.model_arch == "qwen":
self.lm = Qwen(name=self.model_name, device_map=None)
else:
self.lm = self.lm_for_caching

Expand Down
68 changes: 42 additions & 26 deletions promptcache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools

import torch

import termcolor
from .model import LanguageModel
from .prompt import Prompt, ModuleRef
from .schema import Parameter, TokenSequence, UnionModule, Schema, Path, Module
Expand Down Expand Up @@ -92,18 +92,22 @@ class PromptCache:
max_ctx_length: int
num_head: int
head_dim: int
target_device: torch.device
device_cache: KVCache

# hidden_dim is usually num_head * head_dim
def __init__(self, max_ctx_length: int, num_layers: int, num_head: int, head_dim: int, target_device: torch.device):
def __init__(self, max_ctx_length: int, num_layers: int, num_head: int, head_dim: int, target_device: torch.device,
dtype: torch.dtype = torch.float16):

self.max_ctx_length = max_ctx_length
self.num_head = num_head
self.head_dim = head_dim
self.target_device = target_device
self.dtype = dtype

self.device_cache = [
(torch.empty(num_head, max_ctx_length, head_dim, device=target_device, dtype=torch.half), # key
torch.empty(num_head, max_ctx_length, head_dim, device=target_device, dtype=torch.half)) for _ in
(torch.empty(num_head, max_ctx_length, head_dim, device=target_device, dtype=self.dtype), # key
torch.empty(num_head, max_ctx_length, head_dim, device=target_device, dtype=self.dtype)) for _ in
range(num_layers)]

# print(num_head, max_ctx_length, head_dim)
Expand All @@ -117,8 +121,7 @@ def update(self, modules: List[TokenSequenceCache]):

# TODO: adopt in-place sorting to reduce redundant host-device memory copies

# cache rearrangement -> becomes new layout
modules_ordered = sorted(modules, key=lambda e: e.usage_counter, reverse=True)
modules_ordered = sorted(modules, key=lambda e: e.token_sequence.offset)

retained = []

Expand All @@ -131,28 +134,33 @@ def update(self, modules: List[TokenSequenceCache]):
offset = sum(map(len, retained))
updates = modules_ordered[len(retained):]

# update the cache
for m in updates:
if len(updates) > 0:
for m in updates:
m.upload(self.target_device)

update_len = sum(map(len, updates))
st = offset
ed = st + len(m)
ed = st + update_len
update_caches = [m.cache for m in updates]

for i in range(len(self.device_cache)):
k_cache_tgt, v_cache_tgt = self.device_cache[i]
k_cache_src, v_cache_src = m.cache[i]

# print('k_src', k_cache_src.shape)
# print('v_src', v_cache_src.shape)
# print('k_tgt', k_cache_tgt.shape)
# print('v_tgt', v_cache_tgt.shape)

k_cache_tgt[:, st:ed, :].copy_(k_cache_src, non_blocking=True)
v_cache_tgt[:, st:ed, :].copy_(v_cache_src, non_blocking=True)

offset += len(m)
k_chunks = [cache_i[i][0] for cache_i in update_caches]
v_chunks = [cache_i[i][1] for cache_i in update_caches]
if len(k_chunks) == 1:
k_merged = k_chunks[0]
v_merged = v_chunks[0]
else:
k_merged = torch.cat(k_chunks, dim=1)
v_merged = torch.cat(v_chunks, dim=1)
k_cache_tgt[:, st:ed, :].copy_(k_merged, non_blocking=True)
v_cache_tgt[:, st:ed, :].copy_(v_merged, non_blocking=True)

offset = ed

# re-organize the cache

self.staged = modules
self.staged = modules_ordered
self.length = offset

def __len__(self):
Expand Down Expand Up @@ -243,7 +251,7 @@ def _process(self, batch_size: int = 1):
d_output = self.lm(
input_ids=torch.tensor(batch_token_ids_padded, device=self.lm.device, dtype=torch.long),
position_ids=torch.tensor(batch_position_ids_padded, device=self.lm.device, dtype=torch.long),
attention_mask=torch.tensor(attn_mask, device=self.lm.device, dtype=torch.float16),
attention_mask=torch.tensor(attn_mask, device=self.lm.device, dtype=torch.bool),
use_cache=True
)

Expand Down Expand Up @@ -341,13 +349,17 @@ def __init__(self, max_ctx_length: int, lm: LanguageModel, target_device=None):
self.target_device = lm.device if target_device is None else target_device

num_layers, num_head, head_dim = lm.get_cache_shape()
cache_dtype = getattr(lm.hf_model, "dtype", torch.float16)
if not isinstance(cache_dtype, torch.dtype) or not cache_dtype.is_floating_point:
cache_dtype = torch.float16

self.prompt_cache = PromptCache(
max_ctx_length=max_ctx_length,
num_layers=num_layers,
num_head=num_head,
head_dim=head_dim,
target_device=self.target_device
target_device=self.target_device,
dtype=cache_dtype
)

def add_schema(self, schema: Union[str, Schema],
Expand Down Expand Up @@ -472,6 +484,11 @@ def process(self, prompt: Prompt, no_cache: bool = False, return_full_position_i

input_ids = list(itertools.chain(*argument_ids_list))
position_ids = list(itertools.chain(*argument_pos_ids_list))
if len(position_ids) > 0:
sorted_pairs = sorted(zip(position_ids, input_ids))
position_ids, input_ids = zip(*sorted_pairs)
position_ids = list(position_ids)
input_ids = list(input_ids)

if no_cache:
orig_input_ids = list(itertools.chain(*orig_ids_list))
Expand All @@ -486,8 +503,7 @@ def process(self, prompt: Prompt, no_cache: bool = False, return_full_position_i
torch.cuda.synchronize()
cache_time = start.elapsed_time(end)

# print(f'Cache overhead: {cache_time:.2f} ms')

print(termcolor.colored(f'Cache overhead: {cache_time:.2f} ms', 'yellow'))
vv = list(range(len(orig_position_ids)))

return orig_input_ids, vv, cache_time, None
Expand All @@ -512,7 +528,7 @@ def process(self, prompt: Prompt, no_cache: bool = False, return_full_position_i
for i in range(len(cache)):
cache[i] = (self.lm.read_k_hook(cache[i][0]), self.lm.read_v_hook(cache[i][1]))

# print(f'Cache overhead: {cache_time:.2f} ms')
print(termcolor.colored(f'Cache overhead: {cache_time:.2f} ms', 'yellow'))

if return_full_position_ids:
orig_position_ids = list(itertools.chain(*orig_pos_ids_list))
Expand Down
13 changes: 7 additions & 6 deletions promptcache/generation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def generate(self,

device = self.lm.device

position_offset = max(position_ids) + 1
cache_seq_len = cache[0][0].shape[1] if cache is not None and len(cache) > 0 else 0
position_offset = max(max(position_ids) + 1, cache_seq_len + len(token_ids))
past_key_values = None
new_token_id = 0

Expand All @@ -99,7 +100,7 @@ def generate(self,

# add redundant batch dim
if cache is not None:
cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache]
cache = tuple((k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand All @@ -115,8 +116,8 @@ def generate(self,
response_time = inference_time
# print(f'Response time: {inference_time:.2f} ms')
# pretty print using termcolor
print(termcolor.colored(f'Prefill latency: {inference_time:.2f} ms', 'yellow'))

print(termcolor.colored(f'Prefill latency: {response_time:.2f} ms', 'yellow'))
logits = out.logits
past_key_values = out.past_key_values

Expand Down Expand Up @@ -187,13 +188,13 @@ def generate(self,
partially_stopped = False

for each_stop in params.stop_str:
pos = new_output.rfind(each_stop, 0)
pos = new_output.find(each_stop)
if pos != -1:
new_output = new_output[:pos]
stopped = True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
partially_stopped = is_partial_stop(new_output, each_stop)
if partially_stopped:
break

Expand Down
71 changes: 70 additions & 1 deletion promptcache/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, \
PretrainedConfig, PreTrainedModel, CodeLlamaTokenizer
try:
from transformers.cache_utils import Cache, DynamicCache
except ImportError:
Cache = None
DynamicCache = None

from promptcache.model.falcon import FalconForCausalLM
from promptcache.model.llama2 import LlamaForCausalLM
Expand Down Expand Up @@ -272,7 +277,7 @@ def __init__(self, name="mosaicml/mpt-7b-chat", **kwargs):
assistant=("", "<|im_end|>\n"))

self.formatter = conv
self.use_full_position_ids = True
self.use_full_position_ids = False

stop_token_ids = [50278, 0]
stop_str = []
Expand All @@ -293,3 +298,67 @@ def get_cache_shape(self) -> Tuple[int, int, int]:
#
# def read_k_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
# return v_cache.transpose(1, 2)


class Qwen(LanguageModel):
def __init__(self, name="Qwen/Qwen2.5-7B-Instruct", **kwargs):
tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, **kwargs)

conv = FormatConversation(
system=("<|im_start|>system\n", "<|im_end|>\n", ""),
user=("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
assistant=("", "<|im_end|>\n")
)

self.formatter = conv
self.use_full_position_ids = True

stop_token_ids = [tokenizer.eos_token_id]
if hasattr(tokenizer, "convert_tokens_to_ids"):
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if isinstance(im_end_id, int) and im_end_id >= 0 and im_end_id not in stop_token_ids:
stop_token_ids.append(im_end_id)

stop_str = ["<|im_end|>"]

super().__init__(name, model, tokenizer, stop_token_ids, stop_str)

def get_formatter(self) -> Callable[[str], str]:
return self.formatter

def get_cache_shape(self) -> Tuple[int, int, int]:
num_head = getattr(self.hf_model.config, "num_key_value_heads", self.hf_model.config.num_attention_heads)
head_dim = self.hf_model.config.hidden_size // self.hf_model.config.num_attention_heads
return self.hf_model.config.num_hidden_layers, num_head, head_dim

@staticmethod
def _is_legacy_kv_cache(past_key_values) -> bool:
if not isinstance(past_key_values, (list, tuple)):
return False
if len(past_key_values) == 0:
return True
first = past_key_values[0]
return isinstance(first, (list, tuple)) and len(first) == 2

def __call__(self, **kwargs):
past_key_values = kwargs.get("past_key_values", None)
if past_key_values is not None and DynamicCache is not None:
is_cache_obj = Cache is not None and isinstance(past_key_values, Cache)
if not is_cache_obj and self._is_legacy_kv_cache(past_key_values):
if hasattr(DynamicCache, "from_legacy_cache"):
new_cache = DynamicCache.from_legacy_cache(tuple(past_key_values))
else:
new_cache = DynamicCache()
for i, (k, v) in enumerate(past_key_values):
new_cache.update(k, v, layer_idx=i)
kwargs["past_key_values"] = new_cache

if "position_ids" in kwargs and kwargs.get("input_ids") is not None:
input_len = kwargs["input_ids"].shape[1]
if kwargs["position_ids"].shape[1] != input_len:
kwargs["position_ids"] = kwargs["position_ids"][:, -input_len:]
if kwargs.get("past_key_values") is not None:
kwargs["cache_position"] = kwargs["position_ids"][0]

return super().__call__(**kwargs)