Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions examples/end-to-end/KernelBench/test_kernel_bench.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# RUN: python %s --ci | FileCheck %s
# RUN: python %s --ci --torch-compile | FileCheck %s

# REQUIRES: torch
# REQUIRES: kernel_bench
Expand Down Expand Up @@ -115,6 +116,11 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
action=argparse.BooleanOptionalAction,
help="Enable CI mode (faster run, fewer kernels). Incompatible with --smoke-test.",
)
Parser.add_argument(
"--torch-compile",
action=argparse.BooleanOptionalAction,
help="Enable TorchScript compilation. Default is False.",
)
Parser.add_argument(
"--test",
type=str,
Expand Down Expand Up @@ -151,23 +157,40 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
command_line = [
str(kb_program),
str(kb_kernel),
"--input-shapes",
test["input_shapes"],
"--output-shape",
test["output_shape"],
"--pipeline",
test["pipeline"],
"--print-tensor=1",
"--seed=42",
]
# Benchmarks only if there's data to calculate FLOPS.
benchmark = args.benchmark and test.get("gflops") is not None
if benchmark:
command_line += ["--benchmark"]

# We allow torch.compile to pick its own shapes (unless it's CI).
if args.torch_compile:
command_line += ["--torch-compile"]

# TODO: Implement auto-shapes for non-compile mode as well.
if args.ci or not args.torch_compile:
command_line += [
"--input-shapes",
test["input_shapes"],
"--output-shape",
test["output_shape"],
]

# Smoke tests / CI don't print outputs.
if not args.smoke_test and not args.ci:
command_line += ["--print-output"]

# For debugging, prefer not to capture output.
if args.print_mlir_after_all:
command_line += ["--print-mlir-after-all"]

# Print out before we run the test.
if test.get("warning"):
print(f"WARNING: {test['warning']}")
print(f"Running command: {' '.join(command_line)}")
print(f"Running command: {' '.join(command_line)}", flush=True)

# While debugging kernels, it's useful to see the output as it comes.
# Note: GFLOPS can't be shown if the output is not captured.
Expand All @@ -193,29 +216,24 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
print("STDERR:")
print(result.stderr)

print(f"Return code: {result.returncode}")
print(f"Return code: {result.returncode}", flush=True)

# Only stop on failure on normal runs.
# Smoke tests try to run as much as possible.
if not args.smoke_test:
assert result.returncode == 0, "Execution failed"

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK: 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}}
# CHECK: 0.7201{{.*}} 0.9926{{.*}} 0.1208{{.*}} ... 0.1742{{.*}} 0.3485{{.*}} 0.6436{{.*}}
# CHECK: 1_Square_matrix_multiplication_.py
# CHECK: Success: The output of the compiled model matches the reference output.

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 249.78{{.*}} 260.13{{.*}} 249.36{{.*}} ... 261.10{{.*}} 260.49{{.*}} 257.09{{.*}}
# CHECK: 243.56{{.*}} 250.91{{.*}} 252.38{{.*}} ... 260.40{{.*}} 261.56{{.*}} 256.24{{.*}}
# CHECK: 2_Standard_matrix_multiplication_.py
# CHECK: Success: The output of the compiled model matches the reference output.

# CHECK: 3_Batched_matrix_multiplication.mlir
# CHECK: 5.2403{{.*}} 7.7905{{.*}} 6.0769{{.*}} ... 7.8579{{.*}} 6.8890{{.*}} 6.6193{{.*}}
# CHECK: 9.0407{{.*}} 6.3299{{.*}} 5.2003{{.*}} ... 6.2594{{.*}} 6.2980{{.*}} 5.9807{{.*}}
# CHECK: 3_Batched_matrix_multiplication.py
# CHECK: Success: The output of the compiled model matches the reference output.

# CHECK: 4_Matrix_vector_multiplication_.mlir
# CHECK: 264.86{{.*}}
# CHECK: 265.12{{.*}}
# CHECK: 4_Matrix_vector_multiplication_.py
# CHECK: Success: The output of the compiled model matches the reference output.

# CHECK: 5_Matrix_scalar_multiplication.mlir
# CHECK: 0.1750{{.*}} 0.4442{{.*}} 0.3420{{.*}} ... 0.1389{{.*}} 0.4319{{.*}} 0.4538{{.*}}
# CHECK: 0.3365{{.*}} 0.4638{{.*}} 0.0564{{.*}} ... 0.0814{{.*}} 0.1628{{.*}} 0.3007{{.*}}
# CHECK: 5_Matrix_scalar_multiplication.py
# CHECK: Success: The output of the compiled model matches the reference output.
3 changes: 2 additions & 1 deletion lighthouse/ingress/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Provides functions to convert PyTorch models to MLIR."""

from .importer import import_from_file, import_from_model
from .importer import import_from_file, import_from_model, import_model
from .compile import cpu_backend
from .compile import gpu_backend
from .compile import TargetDialect
Expand All @@ -11,4 +11,5 @@
"gpu_backend",
"import_from_file",
"import_from_model",
"import_model",
]
231 changes: 142 additions & 89 deletions lighthouse/ingress/torch/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,134 @@
from mlir import ir


def import_model(
filepath: str | Path,
model_class_name: str = "Model",
init_args_fn_name: str | None = "get_init_inputs",
init_kwargs_fn_name: str | None = None,
model_init_args: Iterable | None = None,
sample_args_fn_name: str = "get_inputs",
sample_kwargs_fn_name: str | None = None,
sample_args: Iterable | None = None,
state_path: str | Path | None = None,
**kwargs,
) -> str | ir.Module:
"""Load a PyTorch nn.Module from a file.

The function takes a `filepath` to a Python file containing the model definition,
along with the names of functions to get model init arguments and sample inputs.
The function imports the model class on its own and instantiates it.

Args:
filepath (str | Path): Path to the Python file containing the model definition.
model_class_name (str, optional): The name of the model class in the file.
Defaults to "Model".
init_args_fn_name (str | None, optional): The name of the function in the file
that returns the arguments for initializing the model. If None, the model
is initialized without arguments. Defaults to "get_init_inputs".
init_kwargs_fn_name (str | None, optional): The name of the function in the file
that returns the keyword arguments for initializing the model. If None, the model
is initialized without keyword arguments.
model_init_args (Iterable | None, optional): If provided, these are used directly as
initialization arguments instead of calling ``init_args_fn_name`` from the file.
Useful for overriding hard-coded sizes in the model file. Defaults to None.
sample_args_fn_name (str, optional): The name of the function in the file that
returns the sample input arguments for the model. Defaults to "get_inputs".
sample_kwargs_fn_name (str, optional): The name of the function in the file that
returns the sample keyword input arguments for the model. Defaults to None.
sample_args (Iterable | None, optional): If provided, these are used directly as
sample inputs instead of calling ``sample_args_fn_name`` from the file.
Useful for overriding hard-coded sizes in the model file. Defaults to None.
state_path (str | Path | None, optional): Optional path to a file containing
the model's ``state_dict``. Defaults to None.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.

Returns:
torch.nn.Module: The imported PyTorch model.
sample_args: The sample input arguments for the model.
sample_kwargs: The sample keyword input arguments for the model.

Examples:
Given a file `path/to/model_file.py` with the following content:
```python
import torch
import torch.nn as nn


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)

def forward(self, x):
return self.fc(x)


def get_inputs():
return (torch.randn(1, 10),)
```

The import script would look like:
>>> from lighthouse.ingress.torch_import import import_model
>>> # option 1: get MLIR module as a string
>>> model: nn.Module = import_model(
... "path/to/model_file.py",
... model_class_name="MyModel",
... init_args_fn_name=None,
... )
"""
if isinstance(filepath, str):
filepath = Path(filepath)
module_name = filepath.stem

spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

model = getattr(module, model_class_name, None)
if model is None:
raise ValueError(f"Model class '{model_class_name}' not found in {filepath}")

model_init_args = (
maybe_load_and_run_callable(
module,
init_args_fn_name,
default=tuple(),
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
)
if model_init_args is None
else model_init_args
)
model_init_kwargs = maybe_load_and_run_callable(
module,
init_kwargs_fn_name,
default={},
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}",
)
sample_args = (
load_and_run_callable(
module,
sample_args_fn_name,
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
)
if sample_args is None
else sample_args
)
sample_kwargs = maybe_load_and_run_callable(
module,
sample_kwargs_fn_name,
default={},
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}",
)

nn_model: nn.Module = model(*model_init_args, **model_init_kwargs)
if state_path is not None:
state_dict = torch.load(state_path)
nn_model.load_state_dict(state_dict)

return nn_model, sample_args, sample_kwargs


def import_from_model(
model: nn.Module,
sample_args: Iterable,
Expand Down Expand Up @@ -118,10 +246,8 @@ def import_from_file(
) -> str | ir.Module:
"""Load a PyTorch nn.Module from a file and import it into MLIR.

The function takes a `filepath` to a Python file containing the model definition,
along with the names of functions to get model init arguments and sample inputs.
The function imports the model class on its own, instantiates it, and passes
it to ``torch_mlir`` to get a MLIR module in the specified `dialect`.
The function calls ``import_model`` to load the model from the given file
and then calls ``import_from_model`` to convert it into an MLIR module.

Args:
filepath (str | Path): Path to the Python file containing the model definition.
Expand Down Expand Up @@ -156,94 +282,21 @@ def import_from_file(
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.

Examples:
Given a file `path/to/model_file.py` with the following content:
```python
import torch
import torch.nn as nn


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)

def forward(self, x):
return self.fc(x)


def get_inputs():
return (torch.randn(1, 10),)
```

The import script would look like:
>>> from lighthouse.ingress.torch_import import import_from_file
>>> # option 1: get MLIR module as a string
>>> mlir_module: str = import_from_file(
... "path/to/model_file.py",
... model_class_name="MyModel",
... init_args_fn_name=None,
... dialect="linalg-on-tensors",
... )
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
>>> # option 2: get MLIR module as an ir.Module
>>> ir_context = ir.Context()
>>> mlir_module_ir: ir.Module = import_from_file(
... "path/to/model_file.py",
... model_class_name="MyModel",
... init_args_fn_name=None,
... dialect="linalg-on-tensors",
... ir_context=ir_context,
... )
See ``import_model`` and ``import_from_model`` for examples
of the expected content of the model file and how to call this function.
"""
if isinstance(filepath, str):
filepath = Path(filepath)
module_name = filepath.stem

spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

model = getattr(module, model_class_name, None)
if model is None:
raise ValueError(f"Model class '{model_class_name}' not found in {filepath}")

model_init_args = (
maybe_load_and_run_callable(
module,
init_args_fn_name,
default=tuple(),
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
)
if model_init_args is None
else model_init_args
)
model_init_kwargs = maybe_load_and_run_callable(
module,
init_kwargs_fn_name,
default={},
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}",
)
sample_args = (
load_and_run_callable(
module,
sample_args_fn_name,
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
)
if sample_args is None
else sample_args
)
sample_kwargs = maybe_load_and_run_callable(
module,
sample_kwargs_fn_name,
default={},
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}",
nn_model, sample_args, sample_kwargs = import_model(
filepath=filepath,
model_class_name=model_class_name,
init_args_fn_name=init_args_fn_name,
init_kwargs_fn_name=init_kwargs_fn_name,
model_init_args=model_init_args,
sample_args_fn_name=sample_args_fn_name,
sample_kwargs_fn_name=sample_kwargs_fn_name,
sample_args=sample_args,
state_path=state_path,
)

nn_model: nn.Module = model(*model_init_args, **model_init_kwargs)
if state_path is not None:
state_dict = torch.load(state_path)
nn_model.load_state_dict(state_dict)

return import_from_model(
nn_model,
sample_args=sample_args,
Expand Down
Loading
Loading