Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions experimental/builder/include/ck_tile/builder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
};
```

**Properties:**
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
- **`direction`**: Operation type (optional, defaults to FORWARD)
- **`direction`**: Operation type (Optional, defaults to FORWARD)
- `FORWARD`: Standard forward convolution
- `BACKWARD_DATA`: Gradient computation w.r.t. input
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)
- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)
- **`accumulation_data_type`**: Type used for internal accumulation

#### 2. Tensor Level
Expand All @@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) {

A tensor descriptor encapsulates:
- **Configuration**: Layout and data type information
- **Operation** (optional): Fused elementwise operations on this tensor
- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)

#### 3. Tensor Configuration

Expand All @@ -126,7 +128,7 @@ Describes the memory layout and data types:
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<ConvLayout>;
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
};
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ concept ConvOutputLayout3D =
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);

namespace detail {
template <typename T>
concept HasDataType = requires(T t) {
{ t.data_type };
Expand All @@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) {
};
};

} // namespace detail
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
};

template <typename T>
Expand All @@ -116,7 +118,6 @@ template <typename T, std::size_t N>
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
{
};
} // namespace detail

template <typename T>
concept ConvertibleToArrayOfTensorConfigs =
Expand All @@ -128,18 +129,21 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
};
};
} // namespace detail

template <typename T>
concept TensorOperatorDescriptor = requires(T t) {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
};

template <typename T>
concept HasTensorOp = requires(T t) {
{ t.operation };
};

namespace detail {

template <typename T>
concept HasConvolutionDirection = requires(T t) {
{ t.direction };
Expand All @@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
};
};

} // namespace detail

// Concept for the convolution tensor
template <typename T>
concept ConvTensorDescriptor = requires(T t) {
{ t.config } -> TensorConfigDescriptor;
requires ElementwiseOpWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};

template <typename T>
Expand All @@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) {
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::ConvolutionDirectionWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};

// Concept to validate a convolution signature's values.
Expand Down