-
Notifications
You must be signed in to change notification settings - Fork 254
Triton dequantization/config framework #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28565f6
ab5f83e
4a96da2
82be145
7f83f52
a249da9
0735b87
8d88d53
4c16086
03e0e54
c9922f6
58a758b
dedf338
9e6e5f7
eab1b52
24408e5
36af572
ede8fff
8e69d3e
c903119
6c7a5bb
2893b0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,38 +1,72 @@ | ||
| # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | ||
| from typing import Callable, Literal, NamedTuple, Optional, Union | ||
|
|
||
| import gguf | ||
| import torch | ||
| from tqdm import tqdm | ||
|
|
||
| HAVE_BFLOAT16=hasattr(torch, "bfloat16") | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't we need to check if the current device actually supports it? i.e. we'd want to test this on RTX 20XX, Volta and Pascal. I can test on volta+pascal sometime.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make the bfloat16 changes a separate PR. |
||
| try: | ||
| from . import dequant_triton | ||
| triton_dequantize_functions=dequant_triton.dequantize_functions | ||
| HAVE_TRITON=True | ||
| except Exception as exc: | ||
| HAVE_TRITON=False | ||
| print(f"\nGGUF: Failed to enable Triton: {exc}") | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These debug print statements will be removed before I mark this PR as ready. If you actually want some Triton status information logged to the console, let me know and I can add it. |
||
| triton_dequantize_functions={} | ||
|
|
||
|
|
||
| TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can change it if you want. Usually it's better to use non-mutable types for stuff that isn't supposed to be mutable. It's a little easier to reason about (since you know stuff won't randomly be changing) and it's usually a little more efficient also. |
||
|
|
||
| DequantizeHandlersType = dict[gguf.GGMLQuantizationType, Callable] | ||
| DequantizeDtype = Optional[Union[torch.dtype, Literal["target"]]] | ||
|
|
||
| class GGUFConfig(NamedTuple): | ||
| dequant_dtype: DequantizeDtype = None | ||
| patch_dtype: DequantizeDtype = None | ||
| patch_on_device: Optional[bool] = None | ||
| optimize: str = "none" | ||
| dequantize_function: Optional[Callable] = None | ||
| dequantize_handlers: Optional[DequantizeHandlersType] = None | ||
|
|
||
| TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) | ||
| DEFAULT_CONFIG = GGUFConfig() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just set the default values to being with, and then they're already the "default" when you create the instance with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean have it preloaded with the default PT functions and then maybe overwrite (some of) them later on? |
||
|
|
||
| def is_torch_compatible(tensor): | ||
| return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES | ||
|
|
||
| def is_quantized(tensor): | ||
| return not is_torch_compatible(tensor) | ||
|
|
||
| def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): | ||
| def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None): | ||
| config = config or DEFAULT_CONFIG | ||
| qtype = getattr(tensor, "tensor_type", None) | ||
| oshape = getattr(tensor, "tensor_shape", tensor.shape) | ||
|
|
||
| if qtype in TORCH_COMPATIBLE_QTYPES: | ||
| return tensor.to(dtype) | ||
| elif qtype in dequantize_functions: | ||
| dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype | ||
| return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) | ||
| else: | ||
| # this is incredibly slow | ||
| tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") | ||
| new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) | ||
| return torch.from_numpy(new).to(tensor.device, dtype=dtype) | ||
|
|
||
| def dequantize(data, qtype, oshape, dtype=None): | ||
| if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16: | ||
| return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16 (also, I don't super like the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove these changes and make it a separate PR. I did test though and it seemed to work fine with a model that had a bunch of BF16 tensors so I am pretty sure it couldn't be a different layout.
You mean you don't like: def blah():
if condition:
return 1
elif other_condition:
return 2
else:
return 3Compared to: def blah():
if condition:
return 1
if other_condition:
return 2
return 3Linters will complain about the former version because the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
TL;DR: It's the same layout and just viewing is safe as long as Torch has bf16 support. I'm not sure the GPU even needs to have BF16 support here since it's just a storage type and there's no math involved. Of course if the user has the compute dtype set to bf16 and their GPU doesn't support it then they're going to run into issues. If you have access to a GPU without bf16 support, a simple test would just be to temporarily manifest a bf16 tensor and see if it causes problems: |
||
| if qtype in dequantize_functions: | ||
| dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype | ||
| dequantize_function = config.dequantize_function or dequantize | ||
| return dequantize_function( | ||
| tensor.data, | ||
| qtype, | ||
| oshape, | ||
| dtype=dequant_dtype, | ||
| dequantize_functions_override=config.dequantize_handlers, | ||
| ).to(dtype) | ||
| # this is incredibly slow | ||
| tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") | ||
| new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) | ||
| return torch.from_numpy(new).to(tensor.device, dtype=dtype) | ||
|
|
||
| def dequantize(data, qtype, oshape, dtype=None, dequantize_functions_override: Optional[DequantizeHandlersType]=None): | ||
| """ | ||
| Dequantize tensor back to usable shape/dtype | ||
| """ | ||
| block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] | ||
| dequantize_blocks = dequantize_functions[qtype] | ||
| dequantize_blocks = (dequantize_functions_override or dequantize_functions)[qtype] | ||
|
|
||
| rows = data.reshape( | ||
| (-1, data.shape[-1]) | ||
|
|
@@ -74,7 +108,7 @@ def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): | |
| d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) | ||
| d = d.view(torch.float16).to(dtype) | ||
| m = m.view(torch.float16).to(dtype) | ||
| qh = to_uint32(qh) | ||
| qh = qh.contiguous().view(torch.int32) | ||
|
|
||
| qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) | ||
| ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) | ||
|
|
@@ -89,7 +123,7 @@ def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): | |
|
|
||
| d, qh, qs = split_block_dims(blocks, 2, 4) | ||
| d = d.view(torch.float16).to(dtype) | ||
| qh = to_uint32(qh) | ||
| qh = qh.contiguous().view(torch.int32) | ||
|
|
||
| qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) | ||
| ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repo generally doesn't have typing anywhere, so I'd say just remove it unless we plan to add in everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, typing stuff is generally considered good and if you'd be interested in adding it everywhere I can certainly do that. Also having type annotations in some places means it would be less work if you wanted it (everywhere) later on. I can just remove all the type annotations if that's what you really want though.