Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 89 additions & 53 deletions tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,27 @@
class ModelTemplate:
arch = "invalid" # string describing architecture
shape_fix = False # whether to reshape tensors
ndims_fix = False # whether to save fix file for tensors exceeding max dims
keys_detect = [] # list of lists to match in state dict
keys_banned = [] # list of keys that should mark model as invalid for conversion
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
keys_ignore = [] # list of strings to ignore keys by when found

def handle_nd_tensor(self, key, data):
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")

class ModelFlux(ModelTemplate):
arch = "flux"
keys_detect = [
("transformer_blocks.0.attn.norm_added_k.weight",),
("single_transformer_blocks.0.attn.norm_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
]
keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
keys_banned = ["single_transformer_blocks.0.attn.norm_k.weight",]

class ModelSD3(ModelTemplate):
arch = "sd3"
keys_detect = [
("transformer_blocks.0.attn.add_q_proj.weight",),
("transformer_blocks.0.ff_context.net.0.proj.weight",),
("joint_blocks.0.x_block.attn.qkv.weight",),
]
keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
keys_banned = ["transformer_blocks.0.ff_context.net.0.proj.weight",]

class ModelAura(ModelTemplate):
arch = "aura"
Expand All @@ -61,7 +59,7 @@ class ModelHiDream(ModelTemplate):
"img_emb.emb_pos"
]

class CosmosPredict2(ModelTemplate):
class ModelCosmosPredict2(ModelTemplate):
arch = "cosmos"
keys_detect = [
(
Expand All @@ -72,26 +70,29 @@ class CosmosPredict2(ModelTemplate):
keys_hiprec = ["pos_embedder"]
keys_ignore = ["_extra_state", "accum_"]

class ModelQwenImage(ModelTemplate):
arch = "qwen_image"
keys_detect = [
(
"time_text_embed.timestep_embedder.linear_2.weight",
"transformer_blocks.0.attn.norm_added_q.weight",
"transformer_blocks.0.img_mlp.net.0.proj.weight",
)
]

class ModelHyVid(ModelTemplate):
arch = "hyvid"
ndims_fix = True
keys_detect = [
(
"double_blocks.0.img_attn_proj.weight",
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
)
]

def handle_nd_tensor(self, key, data):
# hacky but don't have any better ideas
path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here??
if os.path.isfile(path):
raise RuntimeError(f"5D tensor fix file already exists! {path}")
fsd = {key: torch.from_numpy(data)}
tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
save_file(fsd, path)

class ModelWan(ModelHyVid):
class ModelWan(ModelTemplate):
arch = "wan"
ndims_fix = True
keys_detect = [
(
"blocks.0.self_attn.norm_q.weight",
Expand All @@ -100,7 +101,11 @@ class ModelWan(ModelHyVid):
)
]
keys_hiprec = [
".modulation" # nn.parameter, can't load from BF16 ver
".modulation", # nn.parameter, can't load from BF16 ver
".encoder.padding_tokens", # nn.parameter, specific to S2V
"trainable_cond_mask", # used directly w/ .weight
"casual_audio_encoder.weights", # nn.parameter, specific to S2V
"casual_audio_encoder.encoder.conv", # CausalConv1d doesn't use ops.py for now
]

class ModelLTXV(ModelTemplate):
Expand Down Expand Up @@ -144,9 +149,17 @@ class ModelLumina2(ModelTemplate):
keys_detect = [
("cap_embedder.1.weight", "context_refiner.0.attention.qkv.weight")
]
keys_hiprec = [
# Z-Image specific
"x_pad_token",
"cap_pad_token",
]

arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2,
ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1, ModelLumina2]
# The architectures are checked in order and the first successful match terminates the search.
arch_list = [
ModelFlux, ModelSD3, ModelAura, ModelHiDream, ModelCosmosPredict2, ModelQwenImage,
ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1, ModelLumina2
]

def is_model_arch(model, state_dict):
# check if model is correct
Expand All @@ -157,7 +170,7 @@ def is_model_arch(model, state_dict):
matched = True
invalid = any(key in state_dict for key in model.keys_banned)
break
assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
assert not invalid, f"Model architecture not allowed for conversion! (i.e. reference VS diffusers format) [arch:{model.arch}]"
return matched

def detect_arch(state_dict):
Expand Down Expand Up @@ -210,6 +223,24 @@ def strip_prefix(state_dict):

return sd

def find_main_dtype(state_dict, allow_fp32=False):
# detect most common dtype in input
dtypes = [x.dtype for x in state_dict.values()]
dtypes = {x:dtypes.count(x) for x in set(dtypes)}
main_dtype = max(dtypes, key=dtypes.get)

if main_dtype == torch.bfloat16:
ftype_name = "BF16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
elif main_dtype == torch.float32 and allow_fp32:
ftype_name = "F32"
ftype_gguf = gguf.LlamaFileType.ALL_F32
else:
ftype_name = "F16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16

return ftype_name, ftype_gguf

def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
Expand All @@ -224,7 +255,7 @@ def load_state_dict(path):

return strip_prefix(state_dict)

def handle_tensors(writer, state_dict, model_arch):
def handle_tensors(writer, state_dict, model_arch, allow_fp32=False):
name_lengths = tuple(sorted(
((key, len(key)) for key in state_dict.keys()),
key=lambda item: item[1],
Expand All @@ -233,9 +264,13 @@ def handle_tensors(writer, state_dict, model_arch):
if not name_lengths:
return
max_name_len = name_lengths[0][1]

if max_name_len > MAX_TENSOR_NAME_LENGTH:
bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")

invalid_tensors = {}
quantized_tensors = {}
for key, data in tqdm(state_dict.items()):
old_dtype = data.dtype

Expand All @@ -255,14 +290,14 @@ def handle_tensors(writer, state_dict, model_arch):
data_shape = data.shape
if old_dtype == torch.bfloat16:
data_qtype = gguf.GGMLQuantizationType.BF16
# elif old_dtype == torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
elif old_dtype == torch.float32 and allow_fp32:
data_qtype = gguf.GGMLQuantizationType.F32
else:
data_qtype = gguf.GGMLQuantizationType.F16

# The max no. of dimensions that can be handled by the quantization code is 4
if len(data.shape) > MAX_TENSOR_DIMS:
model_arch.handle_nd_tensor(key, data)
invalid_tensors[key] = data
continue # needs to be added back later

# get number of parameters (AKA elements) in this tensor
Expand Down Expand Up @@ -296,38 +331,27 @@ def handle_tensors(writer, state_dict, model_arch):

try:
data = gguf.quants.quantize(data, data_qtype)
quantized_tensors[key] = data_qtype
except (AttributeError, gguf.QuantError) as e:
tqdm.write(f"falling back to F16: {e}")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)

new_name = key # do we need to rename?
quantized_tensors[key] = data_qtype

shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{key}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

writer.add_tensor(key, data, raw_dtype=data_qtype)

writer.add_tensor(new_name, data, raw_dtype=data_qtype)
return quantized_tensors, invalid_tensors

def convert_file(path, dst_path=None, interact=True, overwrite=False):
def convert_file(path, dst_path=None, interact=True, overwrite=False, allow_fp32=False):
# load & run model detection logic
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
logging.info(f"* Architecture detected from input: {model_arch.arch}")

# detect & set dtype for output file
dtypes = [x.dtype for x in state_dict.values()]
dtypes = {x:dtypes.count(x) for x in set(dtypes)}
main_dtype = max(dtypes, key=dtypes.get)

if main_dtype == torch.bfloat16:
ftype_name = "BF16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
# elif main_dtype == torch.float32:
# ftype_name = "F32"
# ftype_gguf = None
else:
ftype_name = "F16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16
ftype_name, ftype_gguf = find_main_dtype(state_dict, allow_fp32=allow_fp32)

if dst_path is None:
dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf"
Expand All @@ -346,20 +370,32 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False):
if ftype_gguf is not None:
writer.add_file_type(ftype_gguf)

handle_tensors(writer, state_dict, model_arch)
quantized_tensors, invalid_tensors = handle_tensors(writer, state_dict, model_arch, allow_fp32=allow_fp32)
if len(invalid_tensors) > 0:
if not model_arch.ndims_fix: # only applies to 5D fix for now, possibly expand to cover more cases?
raise ValueError(f"Tensor(s) detected that exceeds dims supported by C++ code! ({invalid_tensors.keys()})")

fix_path = os.path.join(
os.path.dirname(dst_path),
f"fix_5d_tensors_{model_arch.arch}.safetensors"
)
if os.path.isfile(fix_path):
raise RuntimeError(f"Tensor fix file already exists! {path}")

invalid_tensors = {k:torch.from_numpy(v.copy()) for k,v in invalid_tensors.items()}
save_file(invalid_tensors, fix_path)
logging.warning(f"\n### Warning! Fix file found at '{fix_path}'")
logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")
else:
fix_path = None

writer.write_header_to_file(path=dst_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()

fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors"
if os.path.isfile(fix):
logging.warning(f"\n### Warning! Fix file found at '{fix}'")
logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")

return dst_path, model_arch
return dst_path, model_arch, fix_path

if __name__ == "__main__":
args = parse_args()
convert_file(args.src, args.dst)

23 changes: 13 additions & 10 deletions tools/fix_5d_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,21 @@ def get_file_type(reader):
ft = int(field.parts[field.data[-1]])
return gguf.LlamaFileType(ft)

if __name__ == "__main__":
args = get_args()

def apply_5d_fix(src, dst, fix=None, overwrite=False):
# read existing
reader = gguf.GGUFReader(args.src)
reader = gguf.GGUFReader(src)
arch = get_arch_str(reader)
file_type = get_file_type(reader)
print(f"Detected arch: '{arch}' (ftype: {str(file_type)})")

# prep fix
if args.fix is None:
args.fix = f"./fix_5d_tensors_{arch}.safetensors"
if fix is None:
fix = f"./fix_5d_tensors_{arch}.safetensors"

if not os.path.isfile(args.fix):
raise OSError(f"No 5D tensor fix file: {args.fix}")
if not os.path.isfile(fix):
raise OSError(f"No 5D tensor fix file: {fix}")

sd5d = load_file(args.fix)
sd5d = load_file(fix)
sd5d = {k:v.numpy() for k,v in sd5d.items()}
print("5D tensors:", sd5d.keys())

Expand All @@ -55,6 +53,7 @@ def get_file_type(reader):
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
writer.add_file_type(file_type)

global added
added = []
def add_extra_key(writer, key, data):
global added
Expand All @@ -76,7 +75,11 @@ def add_extra_key(writer, key, data):
if key not in added:
add_extra_key(writer, key, data)

writer.write_header_to_file(path=args.dst)
writer.write_header_to_file(path=dst)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()

if __name__ == "__main__":
args = get_args()
apply_5d_fix(args.src, args.dst, fix=args.fix, overwrite=args.overwrite)
Loading