Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

logger = logging.getLogger(__name__)

_has_avx512 = torch.backends.cpu.get_cpu_capability() == "AVX512"

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
# This is fixed in torch 2.6+, so we set this as the minimum to be safe.
Expand Down Expand Up @@ -134,8 +136,14 @@ def _(
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)

# Fallback as AVX512 implementation has accuracy issues with fp16/fp32 and blocksize >= 2048
# Note: this is not a common use case.
avx512_fallback = _has_avx512 and blocksize >= 2048 and dtype != torch.bfloat16

# Odd shape is not supported by this kernel; fallback to generic implementation
if shape[-1] % 2 != 0:
shape_fallback = shape[-1] % 2 != 0

if avx512_fallback or shape_fallback:
from ..default.ops import _dequantize_4bit_impl

return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
Expand Down