Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import comfy.model_management
from .dequant import dequantize_tensor, is_quantized

CUBLAS_IS_AVAILABLE = False
try:
import cublas_ops
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass

def chained_hasattr(obj, chained_attr):
probe = obj
for attr in chained_attr.split('.'):
Expand Down Expand Up @@ -238,10 +245,30 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None
self.out_features = out_features
self.weight = None
self.bias = None
self._cublas_linear = None

def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.linear(input, weight, bias)

# Create float16 copies for cublas operation
input_half = input.half() if input.dtype != torch.float16 else input
weight_half = weight.half() if weight.dtype != torch.float16 else weight
bias_half = bias.half() if bias is not None and bias.dtype != torch.float16 else bias

if CUBLAS_IS_AVAILABLE and weight.is_cuda:
# Use cublas_half_matmul directly with dequantized weights
result_half = cublas_ops.cublas_half_matmul(
input_half,
weight_half,
bias_half,
epilogue_str="NONE",
has_bias=bias is not None
)

# Convert result back to original input dtype
return result_half.to(input.dtype) if result_half.dtype != input.dtype else result_half
else:
return torch.nn.functional.linear(input, weight, bias)

class Conv2d(GGMLLayer, comfy.ops.manual_cast.Conv2d):
def forward_ggml_cast_weights(self, input):
Expand Down