Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
28565f6
Experimental Triton support for Q8_0 and Q4_K
blepping Sep 7, 2025
ab5f83e
Add Q6_K Triton kernel
blepping Sep 8, 2025
4a96da2
Add Q5_K Triton kernel
blepping Sep 8, 2025
82be145
Actually enable the Q5_K kernel
blepping Sep 8, 2025
7f83f52
Triton dequant code cleanups
blepping Sep 8, 2025
a249da9
Use static_range in Triton kernels
blepping Sep 9, 2025
0735b87
Refactor Q4_K Triton kernel a bit to reduce code duplication
blepping Sep 9, 2025
8d88d53
Refactor/cleanup Triton support
blepping Sep 11, 2025
4c16086
Fix dequant_dtype handling
blepping Sep 13, 2025
03e0e54
Add an optimize parameter to the advanced loader
blepping Sep 13, 2025
c9922f6
Implement Q3_K Triton kernel
blepping Sep 17, 2025
58a758b
Remove unnecessary to_uint32 helper function in dequant.py
blepping Sep 18, 2025
dedf338
Implement Q2_K, Q4_0, Q4_1, Q5_0 and Q5_1 Triton kernels
blepping Sep 18, 2025
9e6e5f7
Triton kernel cleanups and refactoring
blepping Sep 22, 2025
eab1b52
Recent PyTorch versions have native support for bfloat16
blepping Sep 23, 2025
24408e5
Fix setting ggufconfig for CLIP loaders
blepping Sep 23, 2025
36af572
Do internal Triton dequant math in float32 by default
blepping Sep 23, 2025
ede8fff
Fix broken bitcasting in Q3_K and Q5_0
blepping Sep 24, 2025
8e69d3e
Compatibility with Triton 3.3.1 (presumably 3.3.0 also)
blepping Sep 25, 2025
c903119
Fix compiling when Triton is enabled
blepping Sep 27, 2025
6c7a5bb
sync
blepping Nov 27, 2025
2893b0b
Update dequant_type handling
blepping Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions dequant.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,72 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
from typing import Callable, Literal, NamedTuple, Optional, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repo generally doesn't have typing anywhere, so I'd say just remove it unless we plan to add in everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, typing stuff is generally considered good and if you'd be interested in adding it everywhere I can certainly do that. Also having type annotations in some places means it would be less work if you wanted it (everywhere) later on. I can just remove all the type annotations if that's what you really want though.


import gguf
import torch
from tqdm import tqdm

HAVE_BFLOAT16=hasattr(torch, "bfloat16")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we need to check if the current device actually supports it? i.e. we'd want to test this on RTX 20XX, Volta and Pascal. I can test on volta+pascal sometime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the bfloat16 changes a separate PR.

try:
from . import dequant_triton
triton_dequantize_functions=dequant_triton.dequantize_functions
HAVE_TRITON=True
except Exception as exc:
HAVE_TRITON=False
print(f"\nGGUF: Failed to enable Triton: {exc}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging.warning instead of print, that's what comfy uses. I should probably add my linter config to the repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These debug print statements will be removed before I mark this PR as ready. If you actually want some Triton status information logged to the console, let me know and I can add it.

triton_dequantize_functions={}


TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frozenset? we can probably leave it as a normal set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it if you want. Usually it's better to use non-mutable types for stuff that isn't supposed to be mutable. It's a little easier to reason about (since you know stuff won't randomly be changing) and it's usually a little more efficient also.


DequantizeHandlersType = dict[gguf.GGMLQuantizationType, Callable]
DequantizeDtype = Optional[Union[torch.dtype, Literal["target"]]]

class GGUFConfig(NamedTuple):
dequant_dtype: DequantizeDtype = None
patch_dtype: DequantizeDtype = None
patch_on_device: Optional[bool] = None
optimize: str = "none"
dequantize_function: Optional[Callable] = None
dequantize_handlers: Optional[DequantizeHandlersType] = None

TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
DEFAULT_CONFIG = GGUFConfig()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just set the default values to being with, and then they're already the "default" when you create the instance with GGUFConfig()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean have it preloaded with the default PT functions and then maybe overwrite (some of) them later on?


def is_torch_compatible(tensor):
return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES

def is_quantized(tensor):
return not is_torch_compatible(tensor)

def dequantize_tensor(tensor, dtype=None, dequant_dtype=None):
def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None):
config = config or DEFAULT_CONFIG
qtype = getattr(tensor, "tensor_type", None)
oshape = getattr(tensor, "tensor_shape", tensor.shape)

if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in dequantize_functions:
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
else:
# this is incredibly slow
tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}")
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)

def dequantize(data, qtype, oshape, dtype=None):
if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16:
return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16

(also, I don't super like the if/elif/else being flattened, the diff is harder to read with all the small changes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove these changes and make it a separate PR. I did test though and it seemed to work fine with a model that had a bunch of BF16 tensors so I am pretty sure it couldn't be a different layout.

also, I don't super like the if/elif/else being flattened

You mean you don't like:

def blah():
  if condition:
    return 1
  elif other_condition:
    return 2
  else:
    return 3

Compared to:

def blah():
  if condition:
    return 1
  if other_condition:
    return 2
  return 3

Linters will complain about the former version because the elif and else are redundant and it's usually considered a "code smell" but I can do it that way if you want. (This particular part isn't going to be relevant for the reviewable pull but I can make sure I follow your style preference in other places.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16

>>> import numpy  as np
>>> import torch
>>> def quantize_blocks(blocks: np.ndarray) -> np.ndarray: # From gguf-py
...     n = blocks.view(np.uint32)
...     # force nan to quiet
...     n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
...     # round to nearest even
...     n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
...     return n.astype(np.uint16).view(np.uint8)
...
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x7f841bbf9650>
>>> x = torch.randn(1000, dtype=torch.bfloat16)
>>> xnp = x.to(dtype=torch.float32).numpy()
>>> xqnp = quantize_blocks(xnp)
>>> xq = torch.tensor(xqnp)
>>> xq.dtype
torch.uint8
>>> xdq_manual = (xq.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
>>> xdq_view = xq.view(dtype=torch.bfloat16).to(torch.float32)
>>> torch.equal(xdq_manual, xdq_view)
True

TL;DR: It's the same layout and just viewing is safe as long as Torch has bf16 support.

I'm not sure the GPU even needs to have BF16 support here since it's just a storage type and there's no math involved. Of course if the user has the compute dtype set to bf16 and their GPU doesn't support it then they're going to run into issues.

If you have access to a GPU without bf16 support, a simple test would just be to temporarily manifest a bf16 tensor and see if it causes problems:

>>> import torch
>>> torch.arange(100, dtype=torch.uint8).view(torch.bfloat16).view(torch.uint8)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype=torch.uint8)

if qtype in dequantize_functions:
dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype
dequantize_function = config.dequantize_function or dequantize
return dequantize_function(
tensor.data,
qtype,
oshape,
dtype=dequant_dtype,
dequantize_functions_override=config.dequantize_handlers,
).to(dtype)
# this is incredibly slow
tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}")
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)

def dequantize(data, qtype, oshape, dtype=None, dequantize_functions_override: Optional[DequantizeHandlersType]=None):
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
dequantize_blocks = (dequantize_functions_override or dequantize_functions)[qtype]

rows = data.reshape(
(-1, data.shape[-1])
Expand Down Expand Up @@ -74,7 +108,7 @@ def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.contiguous().view(torch.int32)

qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
Expand All @@ -89,7 +123,7 @@ def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):

d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.contiguous().view(torch.int32)

qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
Expand Down
Loading