Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,32 @@
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl

# Int8 mixed precision matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_mixed_scaled_mm",
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_mixed_scaled_mm")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
shapeC = (*CA.shape[:-1], CB.shape[0])

out = torch.empty(shapeC, device=A.device, dtype=A.dtype)

outlier_cols = torch.library.get_ctx().new_dynamic_size()
subA = A.new_empty(outlier_cols, dtype=torch.int64)

return out, subA


# Higher level op: int8 matmul + dequant + bias
torch.library.define(
Expand Down
41 changes: 16 additions & 25 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,37 +210,28 @@ def forward(
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

# Handle sparse decomposition. In some instances, we may have not found any
# outlier columns at all. In that case, we'll skip this part completely.
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
# Handle sparse decomposition
if state.threshold > 0.0:
state.idx = outlier_cols

# Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
CAt[:, state.idx] = 0
Comment on lines -218 to -220
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step is missing right now, but note that we''re planning to deprecate/remove support for full int8 training.


# Extract the input outliers in original precision
subA = A[:, state.idx].contiguous()
# Mixed Int8 Matmul + Dequant + Bias
output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
A,
CA,
state.CB,
SCA,
state.SCB,
outlier_cols,
bias,
)

# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t()
else:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else:
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
)
subA = None

# 3. Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)

# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
output = output.addmm(subA, state.subB)

# 5. Save state
ctx.state = state

Expand Down
42 changes: 42 additions & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out)


@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None

if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()

# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)

# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()

else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)

# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)

if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)

return output, subA


def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A

Expand Down Expand Up @@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0):

if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
else:
# Needed for torch.compile support.
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)

with _cuda_device_of(A):
lib.cint8_vector_quant(
Expand Down