Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --relu | FileCheck %s
Expand Down Expand Up @@ -29,7 +30,6 @@
from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary
from lighthouse.utils.numpy import mlir_to_numpy_dtype
from lighthouse.ingress.mlir_gen import generate_gpu_matmul_payload, get_mlir_elem_type
from lighthouse.schedule.xegpu import xegpu_parameter_selector


def matmul_complexity(
Expand Down Expand Up @@ -345,6 +345,11 @@ def parse_cli_args(description):
"--json",
help="Read problem sizes and tile parameters from a JSON file.",
)
parser.add_argument(
"--target",
choices=["B70", "B50"],
help="Target GPU device, e.g., B70.",
)
parser.add_argument(
"--verbose",
"-v",
Expand All @@ -370,31 +375,15 @@ def parse_cli_args(description):

# Problem size
m, n, k = args.sizes if args.sizes else (4096, 4096, 4096)
# Get default parameters from the database
try:
params = xegpu_parameter_selector.get_matmul_parameters(m, n, k)
except ValueError:
# Initialize with a stub and assume the rest will be populated
params = {
"m": m,
"n": n,
"k": k,
"wg_m": None,
"wg_n": None,
"sg_m": None,
"sg_n": None,
"k_tile": None,
"load_a_m": None,
"load_a_k": None,
"load_b_k": None,
"load_b_n": None,
"prefetch_a_m": None,
"prefetch_a_k": None,
"prefetch_b_k": None,
"prefetch_b_n": None,
"prefetch_a_nb": None,
"prefetch_b_nb": None,
}
# Set required parameters
params = {
"m": m,
"n": n,
"k": k,
}
if args.target:
params["device"] = args.target

if args.json:
# Override parameters with values from JSON file if provided
with open(args.json, "r") as f:
Expand Down
12 changes: 10 additions & 2 deletions examples/xegpu/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
generate_gpu_mlp_payload,
get_mlir_elem_type,
)
from lighthouse.schedule.xegpu import xegpu_parameter_selector

from matmul import matmul_complexity

Expand Down Expand Up @@ -332,6 +331,11 @@ def parse_cli():
action="store_true",
help="Dump transform schedule.",
)
parser.add_argument(
"--target",
choices=["B70", "B50"],
help="Target GPU device, e.g., B70.",
)
parser.add_argument(
"--verbose",
"-v",
Expand Down Expand Up @@ -371,7 +375,11 @@ def parse_cli():
ab_type = wload.ab_type
acc_type = wload.acc_type

params = xegpu_parameter_selector.get_parameters_for_layers(matmuls)
# Initialize layer parameters
params = [{"m": M, "n": N, "k": K} for M, N, K in matmuls]
if args.target:
for layer_params in params:
layer_params["device"] = args.target

if args.dump_kernel or args.dump_schedule:
pipeline = TransformDriver(
Expand Down
59 changes: 28 additions & 31 deletions examples/xegpu/torch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from lighthouse import schedule as lh_schedule
from lighthouse.pipeline.driver import TransformDriver
from lighthouse.utils.mlir import get_mlir_library_path
from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary
from lighthouse.schedule.xegpu import (
mlp_schedule,
xegpu_to_binary,
)
from lighthouse.ingress.torch import gpu_backend, TargetDialect

import parameter_selector


class Model(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -139,9 +140,14 @@ def parse_cli_args(description):
help="Tile size for cooperative prefetching of subgroup B matrix",
)
parser.add_argument(
"--prefetch-nb",
"--prefetch-a-nb",
type=int,
help="Number of initial prefetches for A matrix.",
)
parser.add_argument(
"--prefetch-b-nb",
type=int,
help="Number of initial prefetches.",
help="Number of initial prefetches for B matrix.",
)
parser.add_argument(
"--check-result",
Expand All @@ -164,6 +170,11 @@ def parse_cli_args(description):
"--json",
help="Read problem sizes and tile parameters from a JSON file.",
)
parser.add_argument(
"--target",
choices=["B70", "B50"],
help="Target GPU device, e.g., B70.",
)
args = parser.parse_args()

return args
Expand All @@ -182,30 +193,14 @@ def parse_cli_args(description):

# Problem size
m, n, k = args.sizes if args.sizes else (4096, 4096, 4096)
# Get default parameters from the database
try:
params = parameter_selector.get_matmul_parameters(m, n, k)
except ValueError:
# Initialize with a stub and assume the rest will be populated
params = {
"m": m,
"n": n,
"k": k,
"wg_m": None,
"wg_n": None,
"sg_m": None,
"sg_n": None,
"k_tile": None,
"load_a_m": None,
"load_a_k": None,
"load_b_k": None,
"load_b_n": None,
"prefetch_a_m": None,
"prefetch_a_k": None,
"prefetch_b_k": None,
"prefetch_b_n": None,
"prefetch_nb": None,
}
# Set required parameters
params = {
"m": m,
"n": n,
"k": k,
}
Comment thread
tkarna marked this conversation as resolved.
if args.target:
params["device"] = args.target
if args.json:
# Override parameters with values from JSON file if provided
with open(args.json, "r") as f:
Expand All @@ -227,8 +222,10 @@ def parse_cli_args(description):
params["prefetch_a_m"], params["prefetch_a_k"] = args.prefetch_tile_a
if args.prefetch_tile_b:
params["prefetch_b_k"], params["prefetch_b_n"] = args.prefetch_tile_b
if args.prefetch_nb is not None:
params["prefetch_nb"] = args.prefetch_nb
if args.prefetch_a_nb is not None:
params["prefetch_a_nb"] = args.prefetch_a_nb
if args.prefetch_b_nb is not None:
params["prefetch_b_nb"] = args.prefetch_b_nb

for param_key, v in params.items():
if v is None:
Expand Down
Loading
Loading