Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,26 @@
.neptune
*__pycache__
tools/*/build
.environment
*.csv
*.png
mlir-venv/
iql_results/
llvm-project/
*.json
*.log
checkpoints/
cache/
data/
results/
offline_iql_adv_norm_gradclip_cosine_scheduler/
offline_dataset/
tmp/*

iql_online_fine_tuning/
offline_iql_results_1/
online_finetuning_iql/
online_iql/
online_ppo/

viz/
149 changes: 59 additions & 90 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,90 +1,59 @@
## Getting Started
This is an example of how you may give instructions on setting up your project locally.
To get a local copy up and running follow these simple example steps.
### Prerequisites:
###### Required
1) [CMake](https://cmake.org/): version 3.20 or greater.
2) [Ninja](https://ninja-build.org/).
3) [Gcc](https://gcc.gnu.org/) : version 13.2.
4) [Gxx]: version 13.2.
5) [LLD](https://lld.llvm.org/).
6) [Python](https://www.python.org/downloads/): version 3.11 or greater.
### Setup
#### 1. Building MLIR :
```sh
git clone --depth 1 -b release/19.x https://github.com/llvm/llvm-project.git
mkdir llvm-project/build
cd llvm-project/build
cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
-DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON \
-DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DLLVM_ENABLE_LLD=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON

cmake --build . --target check-mlir
```
#### 2. Install python requirements :
```sh
pip install -r requirements.txt
```
#### 3. Setup environment variables :
Change llvm related variables according to your llvm-project folder path.
```env
NEPTUNE_PROJECT=<NEPTUNE_PROJECT_URL>
NEPTUNE_TOKEN=<NEPTUNE_API_TOKEN>
LLVM_BUILD_PATH=llvm-project/build
MLIR_SHARED_LIBS=llvm-project/build/lib/libomp.so,llvm-project/build/lib/libmlir_c_runner_utils.so,llvm-project/build/lib/libmlir_runner_utils.so
AST_DUMPER_BIN_PATH=tools/ast_dumper/build/bin/AstDumper
VECTORIZER_BIN_PATH=tools/vectorizer/build/bin/Vectorizer
```
### Documentation
#### 1. Jobs
For running jobs using slurm script examples are provided in the `scripts/` folder.
#### 2. Configuration
Configuring the model on a specific case can be done by setting a JSON config file containing all required settings. Configuration JSON file examples are provided in the `config/` folder.
The following JSON content is an example of a config file:
```json
{
"max_num_stores_loads": 7,
"max_num_loops": 7,
"max_num_load_store_dim": 7,
"num_tile_sizes": 7,
"num_transformations": 6,
"vect_size_limit": 2048,
"use_bindings": false,
"use_vectorizer": false,
"data_format": "json",
"optimization_mode": "last",
"benchmarks_folder_path": "",
"len_trajectory": 64,
"ppo_batch_size": 64,
"nb_iterations": 10000,
"ppo_epochs": 4,
"entropy_coef": 0.01,
"lr": 0.001,
"truncate": 5,
"json_file": "data/nn/train_operations.json",
"tags": ["nn"],
"logging": true
}
```
The following list describes every required setting in a configuration file.
- `max_num_stores_loads (int)`: The maximum number of loads in the nested loops.
- `max_num_loops (int)`: The max number of nested loops.
- `max_num_load_store_dim (int)`: The max number of dimensions in load/store buffers.
- `num_tile_sizes (int)`: The number of possible tile sizes for a loop.
- `num_transformations (int)`: The number of transformations.
- `vect_size_limit (int)`: Vectorization size limit to prevent large sizes vectorization.
- `use_bindings (bool)`: Flag to enable using python bindings for execution, if False, the execution will be done using the command line. Default is False.
- `use_vectorizer (bool)`: Flag to enable using the vectorizer C++ program for vectorization, if False, vectorization is done using transform dialect directly. Default is False.
- `data_format (Literal["json", "mlir"])`: The format of the data, can be either "json" or "mlir". "json" mode reads json files containing benchmark features, "mlir" mode reads mlir code files directly and extract features from it using AST dumper. Default is "json".
- `optimization_mode (Literal["last", "all"])`: The optimization mode to use, "last" will optimize only the last operation, "all" will optimize all operations in the code. Default is "last".
- `benchmarks_folder_path (str)`: Path to the benchmarks folder. Can be empty if data format is set to "json".
- `len_trajectory (int)`: Length of the trajectory used for PPO.
- `ppo_batch_size (int)`: Batch size for PPO.
- `nb_iterations (int)`: Number of training iterations.
- `ppo_epochs (int)`: Number of epochs for PPO.
- `entropy_coef (float)`: Entropy coefficient.
- `lr (float)`: Learning rate.
- `truncate (int)`: Maximum number of steps of a schedule for an operation.
- `json_file (str)`: Path to the JSON file containing the benchmarks code and features if data format is set to "json". Otherwise, it should contain original execution times for every benchmark in the benchmark folder.
- `tags (list[str])`: List of tags to add to the neptune experiment.
- `logging (bool)`: Flag to enable logging to neptune.
# IQL for MLIR-RL

Example for `config.json` :
```
{
"max_num_stores_loads": 7,
"max_num_loops": 7,
"max_num_load_store_dim": 7,
"num_tile_sizes": 7,
"vect_size_limit": 2048,
"order": [["I"],["TP"],["T"],["V","NT"]],
"interchange_mode": "pointers",
"exploration": ["entropy"],
"init_epsilon": 0.1,
"new_architecture": false,
"normalize_bounds": "max",
"normalize_adv": "standard",
"sparse_reward": true,
"split_ops": true,
"reuse_experience": "none",
"activation": "relu",
"benchmarks_folder_path": "data/matmul/code/",
"bench_count": 8,
"replay_count": 10,
"nb_iterations": 1200,
"ppo_epochs": 4,
"ppo_batch_size": 32,
"value_epochs": 4,
"value_batch_size": 32,
"value_coef": 0.5,
"value_clip": true,
"entropy_coef": 0.01,
"lr": 3e-4,
"truncate": 10,
"json_file": "data/matmul/train_operations.json",
"eval_json_file": "data/matmul/eval_operations.json",
"tags": ["matmul"],
"debug": false,
"main_exec_data_file": "cache/execution.json",
"results_dir": "offline_iql_adv_norm_gradclip_cosine_scheduler",
"run_name": "offline_iql_adv_norm_gradclip_cosine_scheduler",
"collect_offline_data": false,
"offline_data_save_dir": "offline_dataset",
"offline_data_file": "offline_dataset_online_ppo.npz",

"gamma": 0.99,
"tau": 0.9,
"inverse_temperature":3.0,
"alpha": 0.005,
"batch_size": 256,
"learning_rate": {
"value": 3e-4,
"q": 3e-4,
"policy": 1e-4
},
"max_steps": 1000000,
"target_update_freq": 1
}
```
8 changes: 4 additions & 4 deletions config/.gitignore
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Ignore everything in this directory
*
# Except these files
!.gitignore
# Ignore everything in this directory
*
# Except these files
!.gitignore
!example.json
37 changes: 0 additions & 37 deletions config/example.json

This file was deleted.

48 changes: 48 additions & 0 deletions create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from dotenv import load_dotenv
import json
import pathlib
from rl_autoschedular.execution import Execution
from utils.config import Config

load_dotenv(override=True)

config = Config()
cache_file = "cache/execution.json"
exec = Execution(exec_data_file=cache_file)

train_operations = {}
# eval_operations = {}

for benchmark in os.listdir(config.benchmarks_folder_path):
benchmark_name = benchmark.split('.')[0]
mlir_code_path = f"data/matmul/online_data/{benchmark}"
mlir_code = pathlib.Path(mlir_code_path).read_text()
time_ns, success, cache_miss = exec.execute_code(mlir_code, benchmark_name, seq=[])

train_operations[benchmark_name] = time_ns
# eval_operations[benchmark_name] = time_ns

print(f"Benchmark: {benchmark_name}")
print(f"Execution time: {time_ns} ns")
print(f"Success: {success}, Cache miss: {cache_miss}")
print("-" * 40)

# --- helper function to append safely ---
def append_json(file_path, new_data):
if os.path.exists(file_path):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
data = {}
else:
data = {}
# update old with new
data.update(new_data)
with open(file_path, 'w') as f:
json.dump(data, f, indent=4)

# append instead of overwrite
append_json(config.json_file, train_operations)
# append_json(config.eval_json_file, eval_operations)
2 changes: 0 additions & 2 deletions dask-logs/.gitignore

This file was deleted.

2 changes: 0 additions & 2 deletions data/all/.gitignore

This file was deleted.

2 changes: 0 additions & 2 deletions data/debug/.gitignore

This file was deleted.

2 changes: 0 additions & 2 deletions data/features/.gitignore

This file was deleted.

4 changes: 2 additions & 2 deletions data/lqcd/bench/.gitignore
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
*
!.gitignore
*
!.gitignore
4 changes: 2 additions & 2 deletions data/lqcd/control/execution_times.json
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"test_AB_1": 2620914018
{
"test_AB_1": 2620914018
}
90 changes: 45 additions & 45 deletions data/lqcd/control/test_AB_1.mlir
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,45 +1,45 @@
func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
func.func @main(%B_28: memref<1024x1024xf64>, %A_30: memref<1024x1024xf64>, %output_24: memref<1024x1024xf64>) -> i64 attributes { llvm.emit_c_interface } {
%t0 = func.call @nanoTime() : () -> i64
%7 = memref.alloc() : memref<1024x1024x1xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%7: memref<1024x1024x1xf64>) {
^bb0(%8: f64):
%1 = arith.constant 0.0 : f64
linalg.yield %1 : f64
}
%9 = memref.alloc() : memref<1024xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%9, %A_30, %B_28, %A_30, %B_28: memref<1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>) outs(%7: memref<1024x1024x1xf64>) {
^bb0(%10: f64, %13: f64, %37: f64, %30: f64, %39: f64, %11: f64):
%4 = linalg.index 0 : index
%5 = linalg.index 1 : index
%6 = linalg.index 2 : index
%12 = linalg.index 3 : index
%27 = arith.constant 1 : index
%32 = arith.minsi %6, %27 : index
%17 = arith.constant 0 : index
%33 = arith.maxsi %32, %17 : index
%26 = arith.mulf %13, %37 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%22 = arith.constant 0.0 : f64
%18 = arith.subf %26, %22 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%29 = arith.constant 0.0 : f64
%24 = arith.mulf %30, %29 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%15 = arith.constant 0.0 : f64
%14 = arith.mulf %15, %39 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%19 = arith.addf %24, %14 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%36 = arith.constant 0 : index
%21 = arith.cmpi eq, %33, %36 : index
%38 = arith.select %21, %18, %19 : f64
%25 = arith.addf %11, %38 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
linalg.yield %25 : f64
}
%41 = memref.collapse_shape %7 [[0], [1, 2]] : memref<1024x1024x1xf64> into memref<1024x1024xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>, affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>], iterator_types = ["parallel", "parallel"]} ins(%41: memref<1024x1024xf64>) outs(%output_24: memref<1024x1024xf64>) {
^bb0(%43: f64, %42: f64):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
linalg.yield %43 : f64
}
%t1 = func.call @nanoTime() : () -> (i64)
%t2 = arith.subi %t1, %t0 : i64
return %t2 : i64
}
func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
func.func @main(%B_28: memref<1024x1024xf64>, %A_30: memref<1024x1024xf64>, %output_24: memref<1024x1024xf64>) -> i64 attributes { llvm.emit_c_interface } {
%t0 = func.call @nanoTime() : () -> i64
%7 = memref.alloc() : memref<1024x1024x1xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%7: memref<1024x1024x1xf64>) {
^bb0(%8: f64):
%1 = arith.constant 0.0 : f64
linalg.yield %1 : f64
}
%9 = memref.alloc() : memref<1024xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%9, %A_30, %B_28, %A_30, %B_28: memref<1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>) outs(%7: memref<1024x1024x1xf64>) {
^bb0(%10: f64, %13: f64, %37: f64, %30: f64, %39: f64, %11: f64):
%4 = linalg.index 0 : index
%5 = linalg.index 1 : index
%6 = linalg.index 2 : index
%12 = linalg.index 3 : index
%27 = arith.constant 1 : index
%32 = arith.minsi %6, %27 : index
%17 = arith.constant 0 : index
%33 = arith.maxsi %32, %17 : index
%26 = arith.mulf %13, %37 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%22 = arith.constant 0.0 : f64
%18 = arith.subf %26, %22 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%29 = arith.constant 0.0 : f64
%24 = arith.mulf %30, %29 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%15 = arith.constant 0.0 : f64
%14 = arith.mulf %15, %39 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%19 = arith.addf %24, %14 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
%36 = arith.constant 0 : index
%21 = arith.cmpi eq, %33, %36 : index
%38 = arith.select %21, %18, %19 : f64
%25 = arith.addf %11, %38 fastmath<nnan, ninf, nsz, reassoc, contract, afn> : f64
linalg.yield %25 : f64
}
%41 = memref.collapse_shape %7 [[0], [1, 2]] : memref<1024x1024x1xf64> into memref<1024x1024xf64>
linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>, affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>], iterator_types = ["parallel", "parallel"]} ins(%41: memref<1024x1024xf64>) outs(%output_24: memref<1024x1024xf64>) {
^bb0(%43: f64, %42: f64):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
linalg.yield %43 : f64
}
%t1 = func.call @nanoTime() : () -> (i64)
%t2 = arith.subi %t1, %t0 : i64
return %t2 : i64
}
Empty file modified data/lqcd/control/test_AB_1.mlir.npy
100644 → 100755
Empty file.
Empty file modified data/lqcd/control/test_AB_1.mlir.npz
100644 → 100755
Empty file.
2 changes: 0 additions & 2 deletions data/multi/.gitignore

This file was deleted.

Loading
Loading