Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
c9fec32
Port ROCm changes from multi-backend-refactor branch
pnunna93 May 15, 2025
d729c18
Update ops.py
MISHANMAURYA May 20, 2025
6459c2b
Update functional.py
MISHANMAURYA May 20, 2025
09249c8
Update ops.py
MISHANMAURYA May 21, 2025
4afa774
Update ops.py
MISHANMAURYA May 21, 2025
033d92c
Update ops.py
MISHANMAURYA May 21, 2025
4def959
Update ops.py
MISHANMAURYA May 22, 2025
0f31866
Update functional.py
MISHANMAURYA May 22, 2025
190faed
Update ops.py
MISHANMAURYA May 22, 2025
d7f413b
Update ops.py
MISHANMAURYA May 22, 2025
3b6e68a
Update ops.py
MISHANMAURYA May 22, 2025
06740b1
Update ops.py
MISHANMAURYA May 22, 2025
9fe67ef
Update functional.py
MISHANMAURYA May 22, 2025
d97fdce
Update functional.py
MISHANMAURYA May 22, 2025
f1fbe92
Update functional.py
MISHANMAURYA May 24, 2025
660c254
Update functional.py
MISHANMAURYA May 24, 2025
c692f4b
Update ops.py
MISHANMAURYA May 27, 2025
46f9800
Update ops.py
MISHANMAURYA May 27, 2025
7823bac
Update ops.py
MISHANMAURYA May 28, 2025
d0ed107
Update ops.py
MISHANMAURYA May 28, 2025
af3aaf6
Update ops.py
MISHANMAURYA May 28, 2025
d1e34a5
Update ops.py
MISHANMAURYA May 28, 2025
b2b4df6
Update ops.py
MISHANMAURYA May 28, 2025
8863d0e
Update ops.py
MISHANMAURYA May 28, 2025
d1a5e8d
Update ops.py
MISHANMAURYA May 28, 2025
843ea33
Update functional.py
MISHANMAURYA May 28, 2025
d6d2e5f
Update functional.py
MISHANMAURYA May 28, 2025
e3f9f21
Update functional.py
MISHANMAURYA May 28, 2025
bc0957d
Update test_ops.py
MISHANMAURYA May 28, 2025
b8247ab
Update test_functional.py
MISHANMAURYA May 28, 2025
531758a
Update test_ops.py
MISHANMAURYA May 28, 2025
6d7db8e
Update test_functional.py
MISHANMAURYA May 28, 2025
632e95b
Update test_functional.py
MISHANMAURYA May 28, 2025
90d9af2
Update functional.py
MISHANMAURYA May 28, 2025
80048d8
Update functional.py
MISHANMAURYA May 28, 2025
e448ebb
Update ops.py
MISHANMAURYA May 28, 2025
048faa8
Update ops.py
MISHANMAURYA May 28, 2025
c45e9d1
Update test_functional.py
MISHANMAURYA May 28, 2025
47a491f
Update test_functional.py
MISHANMAURYA May 28, 2025
86976bc
Update cextension.py
MISHANMAURYA May 28, 2025
98a142a
Update cuda_specs.py
MISHANMAURYA May 28, 2025
888fe46
Update cuda_specs.py
MISHANMAURYA May 28, 2025
c9c52b5
Update test_functional.py
MISHANMAURYA May 29, 2025
fc29586
Update test_linear4bit.py
MISHANMAURYA May 30, 2025
53b8b1c
Update test_cuda_setup_evaluator.py
MISHANMAURYA May 30, 2025
fe1fe7c
Update test_functional.py
MISHANMAURYA May 30, 2025
e198824
Update modules.py
MISHANMAURYA May 30, 2025
dd58310
Update modules.py
MISHANMAURYA May 30, 2025
931bd70
Update ops.py
MISHANMAURYA May 30, 2025
9e62d46
Update test_linear4bit.py
MISHANMAURYA May 30, 2025
1f71562
Update ops.py
MISHANMAURYA Jun 2, 2025
eac7632
Update ops.py
MISHANMAURYA Jun 2, 2025
66dcfc4
Update test_linear4bit.py
MISHANMAURYA Jun 2, 2025
b96905d
Update test_linear4bit.py
MISHANMAURYA Jun 2, 2025
ef31c36
Update python-package.yml
MISHANMAURYA Jun 2, 2025
e1435f0
Update python-package.yml
MISHANMAURYA Jun 2, 2025
da9a271
Update python-package.yml
MISHANMAURYA Jun 2, 2025
08848da
Update python-package.yml
MISHANMAURYA Jun 2, 2025
978cba3
Create build-rocm.sh
MISHANMAURYA Jun 2, 2025
79fc632
Merge pull request #65 from MISHANMAURYA/upstream_main_rocm_enabled
pnunna93 Jun 3, 2025
4e31305
Merge remote-tracking branch 'origin/upstream_main_rocm_enabled' into…
MISHANMAURYA Jun 3, 2025
af6561a
Update cuda_specs.py
MISHANMAURYA Jun 3, 2025
405b484
Fix trailing whitespace
MISHANMAURYA Jun 3, 2025
93768d0
Remove conflicts.diff
MISHANMAURYA Jun 3, 2025
47ac97d
Merge pull request #70 from MISHANMAURYA/upstream_main_mm
pnunna93 Jun 3, 2025
59ec4b9
Merge upstream/main into IFU-master-2025-06-04
MISHANMAURYA Jun 4, 2025
e119ff7
update for hipblasVersionMajor >=3
amcamd Jun 5, 2025
8dc297d
Update test_functional.py
MISHANMAURYA Jun 6, 2025
f7d8bf3
Update test_linear4bit.py
MISHANMAURYA Jun 6, 2025
fd0a4d0
Update test_ops.py
MISHANMAURYA Jun 6, 2025
75487d3
Update main.py
MISHANMAURYA Jun 6, 2025
539f01b
Merge pull request #76 from ROCm/upstream_fix
pnunna93 Jun 6, 2025
3551457
Update test_functional.py
MISHANMAURYA Jun 10, 2025
90437b9
Update test_linear4bit.py
MISHANMAURYA Jun 10, 2025
a0bdc94
Update test_ops.py
MISHANMAURYA Jun 10, 2025
8a27346
Update test_linear4bit.py
MISHANMAURYA Jun 10, 2025
c945dbb
Lint
MISHANMAURYA Jun 10, 2025
58e989e
Lint
MISHANMAURYA Jun 11, 2025
2cce336
Update helpers.py
MISHANMAURYA Jun 11, 2025
5eb0316
Update test_functional.py
MISHANMAURYA Jun 11, 2025
dcdf2c5
Update test_linear4bit.py
MISHANMAURYA Jun 11, 2025
6bba740
Update test_ops.py
MISHANMAURYA Jun 11, 2025
bdd6754
Lint
MISHANMAURYA Jun 11, 2025
c2cfa7a
Merge pull request #75 from MISHANMAURYA/skip_cpu_test_upstream_main_…
pnunna93 Jun 11, 2025
ad5794f
Merge branch 'origin/upstream_main_rocm_enabled' into IFU-master-2025…
MISHANMAURYA Jun 17, 2025
f9746dc
merge
MISHANMAURYA Jun 18, 2025
3db3196
Update pythonInterface.cpp
MISHANMAURYA Jun 18, 2025
75a654e
lint fix
MISHANMAURYA Jun 18, 2025
5624736
lint
MISHANMAURYA Jun 18, 2025
c75fdb7
Update pythonInterface.cpp
MISHANMAURYA Jun 18, 2025
3936ca4
revert permissions change
Jun 18, 2025
648ecd2
Merge pull request #73 from MISHANMAURYA/IFU-master-2025-06-04
pnunna93 Jun 18, 2025
b4fd594
Fix indentation
pnunna93 Jun 18, 2025
8934cb3
Merge branch 'main' into upstream_main_rocm_enabled
pnunna93 Jun 18, 2025
ca04bc5
Merge branch 'main' into upstream_main_rocm_enabled
pnunna93 Jun 19, 2025
3228ca8
Update kernels_hip.cuh
MISHANMAURYA Jun 20, 2025
94c1b77
Update kernels.hip
MISHANMAURYA Jun 20, 2025
cd3f0b7
Update ops.hip
MISHANMAURYA Jun 20, 2025
98bb06e
Update ops_hip.cuh
MISHANMAURYA Jun 20, 2025
3bad454
Update kernels_hip.cuh
MISHANMAURYA Jun 20, 2025
e0c766d
Update kernels.hip
MISHANMAURYA Jun 20, 2025
f35a063
Update kernels.hip
MISHANMAURYA Jun 20, 2025
fca01f3
Update ops.hip
MISHANMAURYA Jun 20, 2025
5569c2d
Update ops_hip.cuh
MISHANMAURYA Jun 20, 2025
7a17f2d
Update ops.hip
MISHANMAURYA Jun 20, 2025
6b8239e
Update CMakeLists.txt
MISHANMAURYA Jun 20, 2025
00ac146
Update functional.py
MISHANMAURYA Jun 20, 2025
77f4c77
Update cextension.py
MISHANMAURYA Jun 20, 2025
c9fe284
Update cextension.py
MISHANMAURYA Jun 20, 2025
e2ddda3
Merge pull request #78 from MISHANMAURYA/remove-estimate-quantiles-hi…
pnunna93 Jun 20, 2025
2f49a0b
Merge pull request #80 from MISHANMAURYA/update_doc_string
pnunna93 Jun 20, 2025
48a551f
Merge pull request #79 from MISHANMAURYA/remove_hip_version_check
pnunna93 Jun 20, 2025
7d4854e
warpSize is being made non constexpr in ROCm 7.0
sstamenk Jul 25, 2025
7524c09
Merge pull request #87 from sstamenk/rocm_enabled_warpsize_fix
pnunna93 Sep 23, 2025
2e65b38
Merge pull request #90 from ROCm/IFU-rocm_enabled-09-23-2025
pnunna93 Sep 24, 2025
a4080e1
Merge branch 'main' into rocm_enabled
pnunna93 Sep 24, 2025
fcbec79
Fix typo
pnunna93 Sep 24, 2025
4fa939b
unskip test_4bit_quant
pnunna93 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
message(FATAL_ERROR "XPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
else()
Expand Down
152 changes: 52 additions & 100 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,42 @@
#define NUM 4
#define NUM_BLOCK 4096

__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
__device__ static float fp4_dequantization_lut[8] = {
0.0f, // 0b000
0.005208333333f, // 0b001
0.66666667f, // 0b010
1.0f, // 0b011
0.33333333f, // 0b100
0.5f, // 0b101
0.16666667f, // 0b110
0.25f // 0b111
};

__device__ static float nf4_dequantization_lut[16] = {
-1.0f, // 0b0000
-0.6961928009986877f, // 0b0001
-0.5250730514526367f, // 0b0010
-0.39491748809814453f, // 0b0011
-0.28444138169288635f, // 0b0100
-0.18477343022823334f, // 0b0101
-0.09105003625154495f, // 0b0110
0.0f, // 0b0111
0.07958029955625534f, // 0b1000
0.16093020141124725f, // 0b1001
0.24611230194568634f, // 0b1010
0.33791524171829224f, // 0b1011
0.44070982933044434f, // 0b1100
0.5626170039176941f, // 0b1101
0.7229568362236023f, // 0b1110
1.0f // 0b1111
};

// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
// Luckily we have atomicmax and atomicmin in ROCm


__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else
return 0.33333333f*absmax*sign; // 1100
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
return fp4_dequantization_lut[val & 0b111] * sign;
}

__device__ unsigned char dQuantizeFP4(float x)
Expand Down Expand Up @@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x)
return 0b0000+sign;
}


__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
{

// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;

else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;

}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }

__device__ unsigned char dQuantizeNF4(float x)
{
Expand Down Expand Up @@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}

unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case General8bit:
Expand All @@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}
Expand Down Expand Up @@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
Expand Down Expand Up @@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define WARP_SIZE warpSize
#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
Expand Down Expand Up @@ -2503,7 +2455,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i

#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];

T local_A[2];
Expand Down Expand Up @@ -2708,13 +2660,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];

const int warp_idx = threadIdx.x / warpSize;
const int warp_lane = threadIdx.x % warpSize;
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int warp_lane = threadIdx.x % WARP_SIZE;
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;

Expand All @@ -2732,7 +2684,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc

// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;

Expand Down
8 changes: 7 additions & 1 deletion csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

#define ERR_NOT_IMPLEMENTED 100

#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

using namespace BinSearch;
using std::cout;
using std::endl;
Expand Down Expand Up @@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
//warpsize - 32
int num_blocks = (m+3)/4;
//warpsize - 64
if (warpSize == 64) {
if (WARP_SIZE == 64) {
num_blocks = (m+1)/2;
}

Expand Down
8 changes: 3 additions & 5 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
from bitsandbytes.cextension import HIP_ENVIRONMENT
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
Expand Down Expand Up @@ -463,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
Expand Down Expand Up @@ -1408,10 +1409,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.skipif(
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
reason="this test is not supported on ROCm with gfx90a architecture yet",
)
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_gemv_eye_4bit(self, device, storage_type, dtype):
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit):
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
Expand Down
1 change: 1 addition & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down
Loading