Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit dc7ef09

Browse files
committed
first contact
1 parent 9ec3005 commit dc7ef09

File tree

9 files changed

+164
-7
lines changed

9 files changed

+164
-7
lines changed

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,9 @@
1313
[submodule "third_party/xtl"]
1414
path = third_party/xtl
1515
url = https://github.com/xtensor-stack/xtl
16+
[submodule "third_party/mlir-extensions"]
17+
path = third_party/mlir-extensions
18+
url = https://github.com/intel/mlir-extensions
19+
[submodule "third_party/dpcomp"]
20+
path = third_party/dpcomp
21+
url = https://github.com/IntelPython/dpcomp.git

CMakeLists.txt

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
cmake_minimum_required(VERSION 3.18.2)
22
project(ddptensor VERSION 1.0)
33

4+
set(LLVM_PATH ${PROJECT_SOURCE_DIR}/mlir-llvm)
5+
if(DEFINED ENV{CONDA_PREFIX})
6+
set(CONDA_PREFIX $ENV{CONDA_PREFIX})
7+
else()
8+
set(CONDA_PREFIX UNSET)
9+
endif()
10+
11+
if(DEFINED ENV{MKLROOT})
12+
set(MKLROOT ENV{MKLROOT})
13+
else()
14+
set(MKLROOT ${CONDA_PREFIX})
15+
endif()
16+
if(DEFINED ENV{TBBROOT})
17+
set(TBBROOT ENV{TBBROOT})
18+
else()
19+
set(TBBROOT ${CONDA_PREFIX})
20+
endif()
21+
22+
if(MKLROOT STREQUAL UNSET OR TBBROOT STREQUAL UNSET)
23+
message(FATAL_ERROR "MKLROOT and TBBROOT not set (nor CONDA_PREFIX)")
24+
endif()
25+
426
# C++ standard
527
set(CMAKE_CXX_STANDARD 17)
628
set(CMAKE_C_EXTENSIONS OFF)
@@ -13,14 +35,23 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
1335
# ===============
1436
# Deps
1537
# ===============
38+
list(APPEND CMAKE_PREFIX_PATH ${PROJECT_SOURCE_DIR}/third_party/mlir-llvm)
1639

1740
# Find Python3 and NumPy
1841
find_package(Python3 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
1942
find_package(pybind11 CONFIG)
2043
find_package(MPI REQUIRED)
44+
find_package(LLVM REQUIRED CONFIG)
45+
find_package(MLIR REQUIRED CONFIG)
46+
47+
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
48+
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
49+
include(AddLLVM)
50+
include(AddMLIR)
51+
2152
#find_package(OpenMP)
2253

23-
set(MKL_LIBRARIES -L$ENV{MKLROOT}/lib -lmkl_intel_lp64 -lmkl_tbb_thread -lmkl_core -ltbb -lpthread -lrt -ldl -lm)
54+
set(MKL_LIBRARIES -L${MKLROOT}/lib -lmkl_intel_lp64 -lmkl_tbb_thread -lmkl_core -ltbb -lpthread -lrt -ldl -lm)
2455
#set(CMAKE_INSTALL_RPATH $ENV{MKLROOT}/lib)
2556
# Use -fPIC even if statically compiled
2657
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
@@ -38,7 +69,7 @@ add_custom_command(
3869
# Target
3970
# ============
4071
FILE(GLOB MyCppSources ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/include/ddptensor/*.hpp)
41-
set(MyCppSources ${MyCppSources} ${P2C_HPP})
72+
set(MyCppSources ${MyCppSources} ${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp ${P2C_HPP})
4273

4374
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
4475

@@ -50,7 +81,27 @@ target_include_directories(_ddptensor PRIVATE
5081
${PROJECT_SOURCE_DIR}/third_party/xtensor-blas/include
5182
${PROJECT_SOURCE_DIR}/third_party/xtensor/include
5283
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
53-
${MPI_INCLUDE_PATH} $ENV{MKLROOT}/include
54-
${pybind11_INCLUDE_DIRS})
84+
${MPI_INCLUDE_PATH}
85+
$ENV{MKLROOT}/include
86+
$ENV{TBBROOT}/include
87+
${pybind11_INCLUDE_DIRS}
88+
${MLIR_INCLUDE_DIRS})
89+
5590
#target_compile_options(_ddptensor PRIVATE -fopenmp)
56-
target_link_libraries(_ddptensor PRIVATE ${MPI_C_LIBRARIES} ${MKL_LIBRARIES})
91+
target_link_libraries(_ddptensor PRIVATE
92+
${MPI_C_LIBRARIES}
93+
${MKL_LIBRARIES}
94+
LLVM${LLVM_NATIVE_ARCH}CodeGen
95+
LLVM${LLVM_NATIVE_ARCH}Desc
96+
LLVMTarget
97+
MLIRIR
98+
MLIRLLVMIR
99+
MLIRLLVMToLLVMIRTranslation
100+
MLIRTransforms
101+
MLIRFuncTransforms
102+
MLIRLinalgTransforms
103+
MLIRLinalgToLLVM
104+
MLIRMathToLLVM
105+
MLIRMathToLibm
106+
MLIRTensorTransforms
107+
MLIRReconcileUnrealizedCasts)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def build_cmake(self, ext):
3838
# example of build args
3939
build_args = [
4040
'--config', config,
41-
#'--', '-j8'
41+
'--', '-j8'
4242
]
4343

4444
os.chdir(str(build_temp))

src/Deferred.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
13
#include <oneapi/tbb/concurrent_queue.h>
24
#include "include/ddptensor/Deferred.hpp"
35
#include "include/ddptensor/Transceiver.hpp"

src/EWBinOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
13
#include <xtensor/xview.hpp>
24
using namespace xt::placeholders;
35
#include "ddptensor/EWBinOp.hpp"

src/ddptensor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ using namespace pybind11::literals; // to bring _a
3737
#include "ddptensor/Factory.hpp"
3838
#include "ddptensor/IO.hpp"
3939

40+
extern void ttt();
41+
4042
// #########################################################################
4143
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
4244

@@ -134,7 +136,8 @@ PYBIND11_MODULE(_ddptensor, m) {
134136
.def("_get_slice", &GetItem::get_slice)
135137
.def("_get_local", &GetItem::get_local)
136138
.def("_gather", &GetItem::gather)
137-
.def("to_numpy", &IO::to_numpy);
139+
.def("to_numpy", &IO::to_numpy)
140+
.def("ttt", &ttt);
138141

139142
py::class_<Creator>(m, "Creator")
140143
.def("create_from_shape", &Creator::create_from_shape)

src/jit/mlir.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
3+
#include "mlir/IR/MLIRContext.h"
4+
#include "mlir/InitAllDialects.h"
5+
6+
static mlir::Type makeSignlessType(mlir::Type type)
7+
{
8+
if (auto shaped = type.dyn_cast<mlir::ShapedType>()) {
9+
auto origElemType = shaped.getElementType();
10+
auto signlessElemType = makeSignlessType(origElemType);
11+
return shaped.clone(signlessElemType);
12+
} else if (auto intType = type.dyn_cast<mlir::IntegerType>()) {
13+
if (!intType.isSignless())
14+
return mlir::IntegerType::get(intType.getContext(), intType.getWidth());
15+
}
16+
return type;
17+
}
18+
19+
auto getInt(const mlir::Location & loc, mlir::OpBuilder & builder, int64_t val)
20+
{
21+
auto attr = builder.getI64IntegerAttr(val);
22+
return builder.create<mlir::arith::ConstantOp>(loc, attr);
23+
// auto intType = builder.getIntegerType(64, true);
24+
// return builder.create<plier::SignCastOp>(loc, intType, res);
25+
}
26+
27+
void ttt()
28+
{
29+
std::vector<int> shape = {16, 16};
30+
std::string fname("ttt_mlir");
31+
32+
mlir::MLIRContext context;
33+
context.getOrLoadDialect<mlir::arith::ArithmeticDialect>();
34+
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
35+
mlir::OpBuilder builder(&context);
36+
auto theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
37+
auto loc = builder.getUnknownLoc();
38+
39+
// Create a func prototype
40+
llvm::SmallVector<mlir::Type, 4> argTypes(0);
41+
auto funcType = builder.getFunctionType(argTypes, llvm::None);
42+
auto fproto = mlir::FuncOp::create(loc, fname, funcType);
43+
44+
// Create an MLIR function for the given prototype.
45+
mlir::FuncOp function(fproto);
46+
assert(function);
47+
48+
// Let's start the body of the function now!
49+
// In MLIR the entry block of the function is special: it must have the same
50+
// argument list as the function itself.
51+
auto &entryBlock = *function.addEntryBlock();
52+
53+
// Set the insertion point in the builder to the beginning of the function
54+
// body, it will be used throughout the codegen to create operations in this
55+
// function.
56+
builder.setInsertionPointToStart(&entryBlock);
57+
58+
auto elemType = builder.getF64Type();
59+
auto signlessElemType = makeSignlessType(elemType);
60+
auto indexType = builder.getIndexType();
61+
auto count = shape.size();
62+
llvm::SmallVector<mlir::Value> shapeVal(count);
63+
llvm::SmallVector<int64_t> staticShape(count); // mlir::ShapedType::kDynamicSize);
64+
65+
for(auto it : llvm::enumerate(shape)) {
66+
auto i = it.index();
67+
auto elem = it.value();
68+
auto elemVal = getInt(loc, builder, elem);
69+
staticShape[i] = elem;
70+
shapeVal[i] = elemVal;
71+
}
72+
73+
mlir::Value init;
74+
if(true) { //initVal.is_none()) {
75+
init = builder.create<mlir::linalg::InitTensorOp>(loc, shapeVal, signlessElemType);
76+
}// else {
77+
// auto val = doCast(builder, loc, ctx.context.unwrapVal(loc, builder, initVal), signlessElemType);
78+
// llvm::SmallVector<int64_t> shape(count, mlir::ShapedType::kDynamicSize);
79+
// auto type = mlir::RankedTensorType::get(shape, signlessElemType);
80+
// auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) {
81+
// builder.create<mlir::tensor::YieldOp>(loc, val);
82+
// };
83+
// init = builder.create<mlir::tensor::GenerateOp>(loc, type, shapeVal, body);
84+
// }
85+
if (llvm::any_of(staticShape, [](auto val) { return val >= 0; })) {
86+
auto newType = mlir::RankedTensorType::get(staticShape, signlessElemType);
87+
init = builder.create<mlir::tensor::CastOp>(loc, newType, init);
88+
}
89+
auto resTensorTypeSigness = init.getType().cast<mlir::RankedTensorType>();
90+
auto resTensorType = mlir::RankedTensorType::get(resTensorTypeSigness.getShape(), elemType, resTensorTypeSigness.getEncoding());
91+
}

third_party/dpcomp

Submodule dpcomp added at d3e9b99

third_party/mlir-extensions

Submodule mlir-extensions added at 1637cca

0 commit comments

Comments
 (0)