Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import gguf
import re
import os
from pathlib import Path

from .ops import GGMLTensor
from .dequant import is_quantized, dequantize_tensor

IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}
VIS_TYPE_LIST = {"clip-vision", "mmproj"}

def get_orig_shape(reader, tensor_name):
Expand Down Expand Up @@ -177,6 +178,22 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
"output.weight": "lm_head.weight",
}

GEMMA3_SD_MAP = {
"blk.": "model.layers.",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_k": "self_attn.k_proj",
"attn_v": "self_attn.v_proj",
"attn_output": "self_attn.o_proj",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"token_embd": "model.embed_tokens",
"output_norm": "model.norm",
"output.weight": "lm_head.weight",
}

CLIP_VISION_SD_MAP = {
"mm.": "visual.merger.mlp.",
"v.post_ln.": "visual.merger.ln_q.",
Expand Down Expand Up @@ -217,18 +234,30 @@ def strip_quant_suffix(name):
name = name[:match.start()]
return name

def gguf_mmproj_loader(path):
def gguf_mmproj_loader(path, arch: str = ""):
# Reverse version of Qwen2VLVisionModel.modify_tensors
logging.info("Attenpting to find mmproj file for text encoder...")
logging.info("Attempting to find mmproj file for text encoder...")

# get name to match w/o quant suffix
tenc_fname = os.path.basename(path)
tenc = os.path.splitext(tenc_fname)[0].lower()
tenc = strip_quant_suffix(tenc)

# try and find matching mmproj
target = []
root = os.path.dirname(path)

# Look for expected gemma3 mmproj file
if arch == "gemma3":
mmproj_path = next(
(str(f) for f in Path(root).glob("*.gguf")
if "gemma" in f.name.lower() and "12b" in f.name.lower() and "mmproj" in f.name.lower()),
None
)
if mmproj_path:
target.append(mmproj_path)

# Or look for one sharing same name as root gguf
for fname in os.listdir(root):
name, ext = os.path.splitext(fname)
if ext.lower() != ".gguf":
Expand All @@ -239,14 +268,14 @@ def gguf_mmproj_loader(path):
target.append(fname)

if len(target) == 0:
logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Vision features will be broken!")
return {}
if len(target) > 1:
logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")

logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
target = os.path.join(root, target[0])
vsd = gguf_sd_loader(target, is_text_model=True)
vsd, _ = gguf_sd_loader(target, is_text_model=True)

# concat 4D to 5D
if "v.patch_embd.weight.1" in vsd:
Expand Down Expand Up @@ -374,8 +403,44 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
del reader
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))

def load_spiece_from_safetensor(gguf_path):
"""Try to load spiece_model from a safetensor file in the same directory."""
try:
from safetensors import safe_open
except ImportError:
logging.warning("safetensors not available, cannot load external spiece_model")
return None

from pathlib import Path

directory = os.path.dirname(gguf_path)
if not directory:
directory = "."

basename = os.path.splitext(os.path.basename(gguf_path))[0]
basename = strip_quant_suffix(basename).lower()

# Find all .safetensors files and filter for tokenizer/spiece patterns
path = Path(directory)
for safetensor_file in path.glob("*.safetensors"):
name_lower = safetensor_file.name.lower()
# Check if it matches our patterns
if not (name_lower.startswith(basename) and ('tokenizer' in name_lower or 'spiece' in name_lower)) \
and not ('tokenizer' in name_lower or 'spiece' in name_lower):
continue

try:
with safe_open(str(safetensor_file), framework="pt", device="cpu") as f:
if "spiece_model" in f.keys():
logging.info(f"Loading spiece_model from {safetensor_file.name}")
return f.get_tensor("spiece_model")
except Exception as e:
logging.warning(f"Failed to load spiece_model from {safetensor_file.name}: {e}")

return None

def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
sd, arch, metadata = gguf_sd_loader(path, return_arch=True, is_text_model=True)
if arch in {"t5", "t5encoder"}:
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape == (256384, 4096):
Expand All @@ -385,7 +450,7 @@ def gguf_clip_loader(path):
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, T5_SD_MAP)
elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl"}:
elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}:
# TODO: pass model_options["vocab_size"] to loader somehow
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
Expand All @@ -395,12 +460,29 @@ def gguf_clip_loader(path):
# See note above for T5.
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, LLAMA_SD_MAP)
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3 / Mistral
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)

# Apply appropriate key mapping
if arch == "gemma3":
sd = sd_map_replace(sd, GEMMA3_SD_MAP)
# Gemma-3 uses same head/kv_head counts as config shows
sd = llama_permute(sd, 16, 8) # From config: num_attention_heads=16, num_key_value_heads=8
else:
sd = sd_map_replace(sd, LLAMA_SD_MAP)
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3 / Mistral

# Load mmproj for vision models
if arch in {"qwen2vl", "gemma3"}:
vsd = gguf_mmproj_loader(path, arch)
sd.update(vsd)

# Check if spiece_model is needed but missing
if arch == "gemma3" and "spiece_model" not in sd:
spiece_tensor = load_spiece_from_safetensor(path)
if spiece_tensor is not None:
sd["spiece_model"] = spiece_tensor
else:
logging.warning("spiece_model not found in GGUF or safetensor files. Tokenizer may not work correctly.")
else:
pass
return sd
return sd