Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,17 @@ def __init__(self, kernel_group, reason=None):
self.header.writeline(" return p;")
self.header.writeline("}")
self.header.writeline("void __wrap_free(void *ptr) { return; }")
self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad")
self.apply_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="apply")
self.mask_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="mask")
self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter")
self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init")
self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec")
self.const_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="const")
self.alloc_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="alloc")
self.indexed_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="indexed_op")
self.map_cse = common.CSE("#", self.suffix, name_prefix="map")
self.reduction_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
self.spad_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="spad")
self.apply_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="apply")
self.mask_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="mask")
self.iterator_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="iter")
self.init_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="init")
self.init_vec_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="init_vec")
self.const_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="const")
self.alloc_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="alloc")
self.indexed_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="indexed_op")
self.map_cse = mlir_common.MLIRCSE("#", self.suffix, name_prefix="map")
self.global_vars_dict = dict()
self.reduction_vars = dict()
self.consts = dict()
Expand Down Expand Up @@ -549,7 +549,12 @@ def load(self, name: str, index: sympy.Expr):
else:
# FIXME. Any good idea?
out = sram_var
self.register_var_info(out, [compute_vec_size, mlir_dtype])
# `out` is the spad memref reference (an MLIRCSEVariable from
# spad_cse.generate). Annotate it with the load's compute-vec
# size and torch dtype so downstream attribute reads (vec_size,
# mlir_dtype) reflect the load shape.
out.vec_size = compute_vec_size
out.dtype = dtype
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
return out

Expand Down Expand Up @@ -593,11 +598,11 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs)
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index)
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
# Generate vector store instruction
_, operand_type = self.var_info[value]
_, operand_type = value.vec_size, value.mlir_dtype
if mlir_dtype != operand_type:
value = ops.to_dtype(value, mlir_dtype)

if compute_vec_size < self.var_info[value][0]:
if compute_vec_size < value.vec_size:
with self.override_buffer_cse(buffer=self.stores):
value = ops.extract_strided_slice(value, compute_vec_size)

Expand Down Expand Up @@ -644,14 +649,20 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name)
init_vec = init if vec_len == 1 else ops.broadcast(init, vec_len)

# The outermost acc (reduction_depth == 0) carries the final reduced
# shape; inner accumulators stay at default vec_size=1 until lowered.
outer_reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel()
acc_var_list = []
iter_var_list = []
for reduction_depth in range(self.get_nr_rdim()):
# Create reduction key
reduction_key = src_dtype, reduction_type, value, reduction_depth
acc_init_var = init_vec if reduction_depth == 0 else iter_var_list[-1]

acc = self.reduction_cse.generate(self.loads, f"reduction {reduction_key}", write=False)
acc = self.reduction_cse.generate(
self.loads, f"reduction {reduction_key}",
write=False, dtype=dtype, vec_size=outer_reduction_size,
)
iterator = self.iterator_cse.generate(self.loads, f"reduction {reduction_key}", write=False)
acc_var_list.append(acc)
iter_var_list.append(iterator)
Expand All @@ -664,8 +675,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
# Note: reduction body is inner most loop body. So it doesn't need reduction depth.
body_key = src_dtype, reduction_type, value
body_acc = self.reduction_cse.generate(self.compute, f"reduction {body_key}body_acc", write=False)
body_iter_arg = self.iterator_cse.generate(self.compute, f"reduction {body_key}body_iter_arg", write=False)
self.register_var_info(body_iter_arg, [vec_len, type_name])
body_iter_arg = self.iterator_cse.generate(
self.compute, f"reduction {body_key}body_iter_arg",
write=False, dtype=dtype, vec_size=vec_len,
)
acc_var_list.append(body_acc)

# Reduction body codegen
Expand All @@ -683,9 +696,8 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
self.affine_yield[acc] = reduced_shape, reduction_depth

# Final reduction
reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel()
acc = acc_var_list[0] # Set outermost acc var
self.register_var_info(acc, [reduction_size, type_name])
reduction_size = outer_reduction_size # already attached to acc_var_list[0]
acc = acc_var_list[0] # outermost acc var (already typed at creation)
assert(vec_len % reduction_size==0)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


# Prepare init value
Expand Down Expand Up @@ -794,7 +806,7 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index):

with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse):
vlane_offset = ops.vlane_offset(vlane_vec, vlane_vec, attributes={"vlane_offset": offset}, comment="vlane offset")
if compute_vec_size < self.var_info[vlane_offset][0]:
if compute_vec_size < vlane_offset.vec_size:
vlane_offset = ops.extract_strided_slice(vlane_offset, compute_vec_size)
vlane_offset = ops.index_cast(vlane_offset, "index")
dim = ops.add(dim, vlane_offset)
Expand Down Expand Up @@ -874,7 +886,7 @@ def index_expr(self, index, dtype):

# Initialize base vector
if not self.base_vector_initialized:
init_iter = self.register_var_cse("init_iter", 1, "index")
init_iter = self.cse.namedvar("init_iter", dtype=mlir_common.INDEX_DTYPE)
parallel_map = f"affine.parallel (%{init_iter}) = ({0}) to ({compute_vec_size}) {{ // Base vector initializer"
self.spad_buffer.writeline(parallel_map)
with self.spad_buffer.indent():
Expand Down Expand Up @@ -1479,8 +1491,12 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable:
value = int(value)
key = str(value)+dtype
if key not in self.consts:
self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}")
self.register_var_info(self.consts[key], [1, dtype])
# MLIR_TO_DTYPE maps "index" -> INDEX_DTYPE sentinel (not
# torch.int64, which would make mlir_dtype derive to "i64").
self.consts[key] = self.const_cse.generate(
self.const_buffer, f"arith.constant {value} : {dtype}",
dtype=mlir_common.MLIR_TO_DTYPE.get(dtype),
)
return self.consts[key]

def get_tag_cse(self, value=None, shape="memref<1xi32>"):
Expand Down Expand Up @@ -1531,15 +1547,17 @@ def convert_indirect_indexing(self, index :sympy.Expr):
if target_dim in self.spad_buffer_dict:
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
else:
# FIXME.
var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0]
dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]
# Issue #238: read torch dtype directly from the csevar's attribute
# rather than round-tripping the MLIR string through MLIR_TO_DTYPE
# (which silently downcasts bool/uint8 to int8).
csevar = self.cse.varname_map[target_dim]
dtype = csevar.dtype

local_tile_desc = self.kernel_group.tile_desc
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
tile_shape = local_tile_desc.get_mlir_shape(var_info[1])
tile_shape = local_tile_desc.get_mlir_shape(csevar.mlir_dtype)
tile_vec = local_tile_desc.get_compute_vec_size()
vshape = f"vector<{var_info[0]}x{var_info[1]}>"
vshape = f"vector<{csevar.vec_size}x{csevar.mlir_dtype}>"
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim)
self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]

Expand All @@ -1559,7 +1577,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
if "tmp" not in str(arg):
continue
if arg.is_Mul and arg.args[0].is_number:
coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1]
coeff_dtype = spad_vars[str(arg.args[1])].mlir_dtype
coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype)
spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff)
index = index.replace(arg, 0)
Expand All @@ -1577,7 +1595,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute...

# Conversion
mlir_dtype = self.var_info[spad_vars[first_dim]][1]
mlir_dtype = spad_vars[first_dim].mlir_dtype
with self.override_buffer_cse(buffer=target_dma_buffers):
out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape)
if mlir_dtype != "index":
Expand Down
Loading
Loading