Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 78 additions & 11 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import logging
import collections

import random
import nodes
import comfy.sd
import comfy.lora
Expand Down Expand Up @@ -116,21 +116,90 @@ def clone(self, *args, **kwargs):
return n

class UnetLoaderGGUF:
@classmethod
def IS_CHANGED(cls, **kwargs):
return float("NaN")

@classmethod
def INPUT_TYPES(s):
unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")]
return {
"required": {
"unet_name": (unet_names,),
}
"unet_name": (unet_names, {"default": unet_names[0] if unet_names else ""}),
"mode": (["fixed", "increment", "decrement", "random"], {"default": "fixed"}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}

RETURN_TYPES = ("MODEL",)
RETURN_TYPES = ("MODEL", "STRING")
FUNCTION = "load_unet"
CATEGORY = "bootleg"
TITLE = "Unet Loader (GGUF)"

def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None):
def __init__(self):
self.base_model = None
self.current_model = None

def load_unet(self, unet_name, mode, dequant_dtype=None, patch_dtype=None, patch_on_device=None, unique_id=None):
# Get current list of available models
unet_names = folder_paths.get_filename_list("unet_gguf")

# Initialize state if needed
if self.base_model is None:
self.base_model = unet_name if unet_name in unet_names else (unet_names[0] if unet_names else None)
if self.current_model is None:
self.current_model = self.base_model

selected_model = None

# Handle modes
if mode == "fixed":
selected_model = unet_name
self.base_model = selected_model
self.current_model = selected_model

elif mode == "random":
if unet_names:
selected_model = random.choice(unet_names)
self.base_model = selected_model
self.current_model = selected_model

elif mode in ["increment", "decrement"]:
# Validate current state
valid_current = self.current_model in unet_names
valid_base = self.base_model in unet_names

# Determine starting point
start_point = None
if valid_current:
start_point = self.current_model
elif valid_base:
start_point = self.base_model
elif unet_names:
start_point = unet_names[0]

# Process movement if valid start point
if start_point and unet_names:
idx = unet_names.index(start_point)
if mode == "increment":
new_idx = (idx + 1) % len(unet_names)
else: # decrement
new_idx = (idx - 1) % len(unet_names)

selected_model = unet_names[new_idx]
self.current_model = selected_model

# Fallback if no selection made
if selected_model is None:
selected_model = unet_names[0] if unet_names else unet_name

# Load the selected model - FIXED: use selected_model instead of unet_name
unet_path = folder_paths.get_full_path("unet", selected_model)
sd = gguf_sd_loader(unet_path)

# Configure operations
ops = GGMLOps()

if dequant_dtype in ("default", None):
Expand All @@ -147,18 +216,16 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de
else:
ops.Linear.patch_dtype = getattr(torch, patch_dtype)

# init model
unet_path = folder_paths.get_full_path("unet", unet_name)
sd = gguf_sd_loader(unet_path)
# Load model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
raise RuntimeError(f"Unsupported UNET: {unet_path}")

model = GGUFModelPatcher.clone(model)
model.patch_on_device = patch_on_device
return (model,)
return (model, selected_model) # Return selected_model instead of unet_name

class UnetLoaderGGUFAdvanced(UnetLoaderGGUF):
@classmethod
Expand Down