Skip to content

Commit 58625e1

Browse files
Added support to load model config from Metadata. (#399)
* Implement GGUF metadata extraction function Added function to extract metadata from GGUF files. * Updated the GGUF model loading and patching classes to include metadata handling. * Clean up return logic for extra metadata This should be more future proof in case we need to return other attributes in the future. Possible breaking change for anyone using `gguf_sd_loader` directly either way, though. --------- Co-authored-by: City <125218114+city96@users.noreply.github.com>
1 parent 795e451 commit 58625e1

2 files changed

Lines changed: 32 additions & 8 deletions

File tree

loader.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,26 @@ def get_list_field(reader, field_name, field_type):
4848
else:
4949
raise TypeError(f"Unknown field type {field_type}")
5050

51-
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False):
51+
def get_gguf_metadata(reader):
52+
"""Extract all simple metadata fields like safetensors"""
53+
metadata = {}
54+
for field_name in reader.fields:
55+
try:
56+
field = reader.get_field(field_name)
57+
if len(field.types) == 1: # Simple scalar fields only
58+
if field.types[0] == gguf.GGUFValueType.STRING:
59+
metadata[field_name] = str(field.parts[field.data[-1]], "utf-8")
60+
elif field.types[0] == gguf.GGUFValueType.INT32:
61+
metadata[field_name] = int(field.parts[field.data[-1]])
62+
elif field.types[0] == gguf.GGUFValueType.F32:
63+
metadata[field_name] = float(field.parts[field.data[-1]])
64+
elif field.types[0] == gguf.GGUFValueType.BOOL:
65+
metadata[field_name] = bool(field.parts[field.data[-1]])
66+
except:
67+
continue
68+
return metadata
69+
70+
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False):
5271
"""
5372
Read state dict as fake tensors
5473
"""
@@ -136,9 +155,12 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
136155
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
137156
state_dict[max_key].is_largest_weight = True
138157

139-
if return_arch:
140-
return (state_dict, arch_str)
141-
return state_dict
158+
# extra info to return
159+
extra = {
160+
"arch_str": arch_str,
161+
"metadata": get_gguf_metadata(reader)
162+
}
163+
return (state_dict, extra)
142164

143165
# for remapping llama.cpp -> original key names
144166
T5_SD_MAP = {
@@ -246,7 +268,7 @@ def gguf_mmproj_loader(path):
246268

247269
logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
248270
target = os.path.join(root, target[0])
249-
vsd = gguf_sd_loader(target, is_text_model=True)
271+
vsd, _ = gguf_sd_loader(target, is_text_model=True)
250272

251273
# concat 4D to 5D
252274
if "v.patch_embd.weight.1" in vsd:
@@ -375,7 +397,8 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
375397
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
376398

377399
def gguf_clip_loader(path):
378-
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
400+
sd, extra = gguf_sd_loader(path, is_text_model=True)
401+
arch = extra.get("arch_str", None)
379402
if arch in {"t5", "t5encoder"}:
380403
temb_key = "token_embd.weight"
381404
if temb_key in sd and sd[temb_key].shape == (256384, 4096):

nodes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de
165165

166166
# init model
167167
unet_path = folder_paths.get_full_path("unet", unet_name)
168-
sd = gguf_sd_loader(unet_path)
168+
sd, extra = gguf_sd_loader(unet_path)
169169
model = comfy.sd.load_diffusion_model_state_dict(
170-
sd, model_options={"custom_operations": ops}
170+
sd, model_options={"custom_operations": ops}, metadata=extra.get("metadata", {})
171171
)
172172
if model is None:
173173
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
@@ -319,3 +319,4 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable
319319
"QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF,
320320
"UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced,
321321
}
322+

0 commit comments

Comments
 (0)