Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 71 additions & 30 deletions benchmarking/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
--iterations ITERATIONS
--warmup-runs WARMUP_RUNS
--output-length OUTPUT_LENGTH
"""

import argparse
Expand All @@ -30,6 +33,9 @@
from optimum_benchmark.logging_utils import setup_logging
import torch

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8

WEIGHTS_CONFIGS = {
Expand Down Expand Up @@ -73,9 +79,8 @@
},
}

if __name__ == "__main__":
setup_logging(level="INFO")

def parse_args():
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")

parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
Expand All @@ -98,37 +103,73 @@

parser.add_argument("--out-dir", type=str, default="reports")

args = parser.parse_args()
parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run")
parser.add_argument(
"--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement"
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.",
)

return parser.parse_args()


def run_benchmark(args, config, batch_size):
launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
iterations=args.iterations,
warmup_runs=args.warmup_runs,
# set duration to 0 to disable the duration-based stopping criterion
# this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks
duration=0,
# for consistent results, set a fixed min and max for output tokens
generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
)

backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)

test_name = (
f"benchmark-{config}"
f"-bsz-{batch_size}"
f"-isz-{args.input_length}"
f"-osz-{args.output_length}"
f"-iter-{args.iterations}"
f"-wrmup-{args.warmup_runs}"
)
benchmark_config = BenchmarkConfig(
name=test_name,
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)

out_path = out_dir / (test_name + ".json")
print(f"[{test_name}] Starting:")
benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.save_json(out_path)


if __name__ == "__main__":
setup_logging(level="INFO")
args = parse_args()

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)

out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"

benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
run_benchmark(args, config, batch_size)
121 changes: 35 additions & 86 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,34 @@
#define NUM 4
#define NUM_BLOCK 4096

__device__ static float nf4_data[16] = {
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0
__device__ static float fp4_dequantization_lut[8] = {
0.0f, // 0b000
0.005208333333f, // 0b001
0.66666667f, // 0b010
1.0f, // 0b011
0.33333333f, // 0b100
0.5f, // 0b101
0.16666667f, // 0b110
0.25f // 0b111
};

__device__ static float nf4_dequantization_lut[16] = {
-1.0f, // 0b0000
-0.6961928009986877f, // 0b0001
-0.5250730514526367f, // 0b0010
-0.39491748809814453f, // 0b0011
-0.28444138169288635f, // 0b0100
-0.18477343022823334f, // 0b0101
-0.09105003625154495f, // 0b0110
0.0f, // 0b0111
0.07958029955625534f, // 0b1000
0.16093020141124725f, // 0b1001
0.24611230194568634f, // 0b1010
0.33791524171829224f, // 0b1011
0.44070982933044434f, // 0b1100
0.5626170039176941f, // 0b1101
0.7229568362236023f, // 0b1110
1.0f // 0b1111
};

// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
Expand All @@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) {
return __int_as_float(old);
}

__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 111
return 0.25000000f * absmax * sign; // 1111
else
return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else
return 0.33333333f * absmax * sign; // 1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else
return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else
return 0.00000000f * absmax * sign; // 1000
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
return fp4_dequantization_lut[val & 0b111] * sign;
}

__device__ unsigned char dQuantizeFP4(float x) {
Expand Down Expand Up @@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) {
return 0b0000 + sign;
}

__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {

// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4) // 1
if ((val & 0b0010) == 2) // 11
if ((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else if ((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else if ((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;

else if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else if ((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else if ((val & 0b0010) == 2) // 00
if ((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else if ((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }

__device__ unsigned char dQuantizeNF4(float x) {

Expand Down Expand Up @@ -510,8 +459,8 @@ __global__ void
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
Expand Down Expand Up @@ -2352,7 +2301,7 @@ __global__ void kgemm_4bit_inference(

#pragma unroll 16
for (int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];

T local_A[2];
Expand Down
Loading