Skip to content

Commit dda5ca4

Browse files
committed
New matlab engine rework
1 parent c83b5eb commit dda5ca4

File tree

6 files changed

+299
-53
lines changed

6 files changed

+299
-53
lines changed

RATapi/wrappers.py

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Wrappers for the interface between RATapi and MATLAB custom files."""
2-
2+
import os
33
import pathlib
4-
from contextlib import suppress
54
from typing import Callable
65

76
import numpy as np
@@ -10,22 +9,29 @@
109
import RATapi.rat_core
1110

1211

12+
def find_matlab():
13+
pass
14+
15+
16+
17+
1318
def start_matlab():
1419
"""Start MATLAB asynchronously and returns a future to retrieve the engine later.
1520
1621
Returns
1722
-------
18-
future : matlab.engine.futureresult.FutureResult
19-
A future used to get the actual matlab engine.
23+
future : RATapi.rat_core.MatlabEngine
24+
A custom matlab engine wrapper.
2025
2126
"""
22-
future = None
23-
with suppress(ImportError):
24-
import matlab.engine
2527

26-
future = matlab.engine.start_matlab(background=True)
28+
29+
os.environ["MATLAB_INSTALL_DIR"] += os.pathsep + "C:\\Program Files\\MATLAB\\R2023a\\bin\\win64"
30+
engine = RATapi.rat_core.MatlabEngine()
31+
engine.start()
32+
33+
return engine
2734

28-
return future
2935

3036

3137
class MatlabWrapper:
@@ -35,57 +41,31 @@ class MatlabWrapper:
3541
----------
3642
filename : string
3743
The path of the file containing MATLAB function
38-
3944
"""
40-
41-
loader = start_matlab()
42-
43-
def __init__(self, filename: str) -> None:
44-
if self.loader is None:
45-
raise ImportError("matlabengine is required to use MatlabWrapper") from None
46-
47-
self.engine = self.loader.result()
45+
engine = start_matlab()
46+
47+
def __init__(self, filename) -> None:
4848
path = pathlib.Path(filename)
49-
self.engine.cd(str(path.parent), nargout=0)
50-
self.function_name = path.stem
49+
self.engine.cd(str(path.parent))
50+
self.engine.setFunction(path.stem)
5151

5252
def getHandle(self) -> Callable[[ArrayLike, ArrayLike, ArrayLike, int, int], tuple[ArrayLike, float]]:
53-
"""Return a wrapper for the custom MATLAB function.
53+
"""Return a wrapper for the custom dynamic library function.
5454
5555
Returns
5656
-------
5757
wrapper : Callable[[ArrayLike, ArrayLike, ArrayLike, int, int], tuple[ArrayLike, float]]
58-
The wrapper function for the MATLAB callback
58+
The wrapper function for the dynamic library callback
5959
6060
"""
6161

6262
def handle(*args):
63-
if len(args) == 2:
64-
output = getattr(self.engine, self.function_name)(
65-
np.array(args[0], "float"), # xdata
66-
np.array(args[1], "float"), # params
67-
nargout=1,
68-
)
69-
return np.array(output, "float").tolist()
70-
else:
71-
matlab_args = [
72-
np.array(args[0], "float"), # params
73-
np.array(args[1], "float"), # bulk in
74-
np.array(args[2], "float"), # bulk out
75-
float(args[3] + 1), # contrast
76-
]
77-
if len(args) > 4:
78-
matlab_args.append(float(args[4] + 1)) # domain number
79-
80-
output, sub_rough = getattr(self.engine, self.function_name)(
81-
*matlab_args,
82-
nargout=2,
83-
)
84-
return np.array(output, "float").tolist(), float(sub_rough)
63+
return self.engine.invoke(*args)
8564

8665
return handle
8766

8867

68+
8969
class DylibWrapper:
9070
"""Creates a python callback for a function in dynamic library.
9171

cpp/matlab/matlabCaller.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "matlabCaller.h"
2+
3+
LIB_EXPORT void startMatlab()
4+
{
5+
MatlabCaller::get_instance()->startMatlab();;
6+
}
7+
8+
LIB_EXPORT void cd(std::string path)
9+
{
10+
MatlabCaller::get_instance()->cd(path);
11+
}
12+
13+
LIB_EXPORT void callFunction(std::string functionName, std::vector<double>& params, std::vector<double>& bulkIn,
14+
std::vector<double>& bulkOut, int contrast, int domain, std::vector<double>& output, double* outputSize, double* rough)
15+
{
16+
MatlabCaller::get_instance()->call(functionName, params, bulkIn, bulkOut, contrast, domain, output, outputSize, rough);
17+
}

cpp/matlab/matlabCaller.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "matlabCallerImpl.hpp"
2+
3+
#ifndef EVENT_MANAGER_H
4+
#define EVENT_MANAGER_H
5+
6+
#ifdef __cplusplus
7+
extern "C" {
8+
#endif
9+
10+
#if defined(_WIN32) || defined(_WIN64)
11+
#define LIB_EXPORT __declspec(dllexport)
12+
#else
13+
#define LIB_EXPORT
14+
#endif
15+
16+
17+
LIB_EXPORT void startMatlab();
18+
19+
LIB_EXPORT void cd(std::string path);
20+
21+
LIB_EXPORT void callFunction(std::string functionName, std::vector<double>& params, std::vector<double>& bulkIn,
22+
std::vector<double>& bulkOut, int contrast, int domain, std::vector<double>& output, double* outputSize, double* rough);
23+
24+
#ifdef __cplusplus
25+
}
26+
#endif
27+
28+
#endif // EVENT_MANAGER_H

cpp/matlab/matlabCallerImpl.hpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#ifndef MATLAB_CALLER_IMPL_HPP
2+
#define MATLAB_CALLER_IMPL_HPP
3+
4+
#include "engine.h"
5+
#include <vector>
6+
#include <iostream>
7+
#include <chrono>
8+
9+
using namespace std::chrono;
10+
11+
class MatlabCaller
12+
{
13+
14+
public:
15+
MatlabCaller(MatlabCaller const&) = delete;
16+
MatlabCaller& operator=(MatlabCaller const&) = delete;
17+
~MatlabCaller() {}
18+
19+
void setEngine(){
20+
if (!(matlabPtr = engOpen(""))) {
21+
throw("\nCan't start MATLAB engine\n");
22+
}
23+
};
24+
25+
void startMatlab(){
26+
// this->matlabFuture = startMATLABAsync();
27+
};
28+
29+
void cd(std::string path){
30+
this->currentDirectory = path;
31+
dirChanged = true;
32+
};
33+
34+
void call(std::string functionName, std::vector<double>& params, std::vector<double>& bulkIn,
35+
std::vector<double>& bulkOut, int contrast, int domain, std::vector<double>& output, double* outputSize, double* rough)
36+
{
37+
if (!this->matlabPtr)
38+
this->setEngine();
39+
if (dirChanged){
40+
std::string cdCmd = "cd('" + (this->currentDirectory + "')");
41+
engEvalString(this->matlabPtr, cdCmd.c_str());
42+
}
43+
//this->matlabPtr->feval(u"cd", factory.createCharArray(this->currentDirectory));
44+
dirChanged = false;
45+
mxArray *PARAMS = mxCreateDoubleMatrix(1,params.size(),mxREAL);
46+
memcpy(mxGetPr(PARAMS), &params[0], params.size()*sizeof(double));
47+
engPutVariable(this->matlabPtr, "params", PARAMS);
48+
mxArray *BULKIN = mxCreateDoubleMatrix(1,bulkIn.size(),mxREAL);
49+
memcpy((void *)mxGetPr(BULKIN), &bulkIn[0], bulkIn.size()*sizeof(double));
50+
engPutVariable(this->matlabPtr, "bulkIn", BULKIN);
51+
mxArray *BULKOUT = mxCreateDoubleMatrix(1,bulkOut.size(),mxREAL);
52+
memcpy((void *)mxGetPr(BULKOUT), &bulkOut[0], bulkOut.size()*sizeof(double));
53+
engPutVariable(this->matlabPtr, "bulkOut", BULKOUT);
54+
mxArray *CONTRAST = mxCreateDoubleScalar(contrast);
55+
// memcpy((void *)mxGetPr(CONTRAST), &contrast, 1*sizeof(double));
56+
engPutVariable(this->matlabPtr, "contrast", CONTRAST);
57+
// if (domain > 0)
58+
// args.push_back(factory.createScalar<int>(domain));
59+
std::string customCmd = "[output, subRough] = " + (functionName + "(params, bulkIn, bulkOut, contrast)");
60+
engPutVariable(this->matlabPtr, "myFunction", mxCreateString(customCmd.c_str()));
61+
engOutputBuffer(this->matlabPtr, NULL, 0);
62+
//auto start = high_resolution_clock::now();
63+
// std::vector<matlab::data::Array> results = this->matlabPtr->feval(functionName, 2, args);
64+
engEvalString(this->matlabPtr, "eval(myFunction)");
65+
//auto stop = high_resolution_clock::now();
66+
//auto duration = duration_cast<microseconds>(stop - start);
67+
//std::cout << duration.count() << "Usec" << std::endl;
68+
69+
mxArray *matOutput = engGetVariable(this->matlabPtr, "output");
70+
if (matOutput == NULL)
71+
{
72+
throw("FAILED!");
73+
}
74+
mxArray *subRough = engGetVariable(this->matlabPtr, "subRough");
75+
if (subRough == NULL)
76+
{
77+
throw("FAILED!");
78+
}
79+
*rough = (double)mxGetScalar(subRough);
80+
const mwSize* dims = mxGetDimensions(matOutput);
81+
outputSize[0] = (double) dims[0];
82+
outputSize[1] = (double) dims[1];
83+
// output.push_back((double) matOutput[i]);
84+
double* s = (double *)mxGetData(matOutput);
85+
for (int i=0; i < dims[0] * dims[1]; i++)
86+
output.push_back(s[i]);
87+
//std::memcpy(output, (double *)mxGetData(matOutput), mxGetNumberOfElements(matOutput)* mxGetElementSize(matOutput));
88+
};
89+
90+
static MatlabCaller* get_instance()
91+
{
92+
// Static local variable initialization is thread-safe
93+
// and will be initialized only once.
94+
static MatlabCaller instance{};
95+
return &instance;
96+
};
97+
98+
private:
99+
explicit MatlabCaller() {}
100+
Engine *matlabPtr;
101+
std::string currentDirectory;
102+
bool dirChanged = false;
103+
};
104+
105+
#endif // MATLAB_CALLER_IMPL_HPP

cpp/rat.cpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,87 @@ namespace py = pybind11;
2727
const int DEFAULT_DOMAIN = -1;
2828
const int DEFAULT_NREPEATS = 1;
2929

30+
31+
class MatlabEngine
32+
{
33+
public:
34+
std::unique_ptr<dylib> library;
35+
std::string functionName;
36+
37+
MatlabEngine()
38+
{
39+
this->functionName = "";
40+
std::string filename = "matlabCaller" + std::string(dylib::extension);
41+
this->library = std::unique_ptr<dylib>(new dylib(std::getenv("RAT_PATH"), filename.c_str()));
42+
if (!library)
43+
{
44+
std::cerr << "The matlab caller dynamic library failed to load" << std::endl;
45+
return;
46+
}
47+
};
48+
49+
~MatlabEngine(){};
50+
51+
void cd(std::string path)
52+
{
53+
auto cdFunc = library->get_function<void(std::string)>("cd");
54+
cdFunc(path);
55+
};
56+
57+
void start()
58+
{
59+
auto startFunc = library->get_function<void(void)>("startMatlab");
60+
startFunc();
61+
};
62+
63+
void setFunction(std::string functionName)
64+
{
65+
this->functionName = functionName;
66+
};
67+
68+
py::list invoke(std::vector<double>& xdata, std::vector<double>& params)
69+
{
70+
// try{
71+
std::vector<double> output;
72+
73+
// auto func = library->get_function<void(std::vector<double>&, std::vector<double>&, std::vector<double>&)>(functionName);
74+
// func(xdata, params, output);
75+
76+
return py::cast(output);
77+
78+
// }catch (const dylib::symbol_error &) {
79+
// throw std::runtime_error("failed to get dynamic library symbol for " + functionName);
80+
// }
81+
};
82+
83+
py::tuple invoke(std::vector<double>& params, std::vector<double>& bulkIn, std::vector<double>& bulkOut, int contrast, int domain=DEFAULT_DOMAIN)
84+
{
85+
try{
86+
std::vector<double> tempOutput;
87+
double *outputSize = new double[2];
88+
double roughness = 0.0;
89+
auto func = library->get_function<void(std::string, std::vector<double>&, std::vector<double>&, std::vector<double>&,
90+
int, int, std::vector<double>&, double*, double*)>("callFunction");
91+
func(functionName, params, bulkIn, bulkOut, contrast + 1, domain + 1, tempOutput, outputSize, &roughness);
92+
93+
py::list output;
94+
for (int32_T idx1{0}; idx1 < outputSize[0]; idx1++)
95+
{
96+
py::list rows;
97+
for (int32_T idx2{0}; idx2 < outputSize[1]; idx2++)
98+
{
99+
rows.append(tempOutput[(int32_T)outputSize[1] * idx1 + idx2]);
100+
}
101+
output.append(rows);
102+
}
103+
return py::make_tuple(output, roughness);
104+
105+
}catch (const dylib::symbol_error &) {
106+
throw std::runtime_error("failed to get dynamic library symbol for " + functionName);
107+
}
108+
};
109+
};
110+
30111
class DylibEngine
31112
{
32113
public:
@@ -674,7 +755,20 @@ PYBIND11_MODULE(rat_core, m) {
674755
py::arg("domain") = DEFAULT_DOMAIN)
675756
.def("invoke", overload_cast_<std::vector<double>&,
676757
std::vector<double>&>()(&DylibEngine::invoke), py::arg("xdata"), py::arg("param"));
677-
758+
759+
py::class_<MatlabEngine>(m, "MatlabEngine")
760+
.def(py::init<>())
761+
.def("start", &MatlabEngine::start)
762+
.def("cd", &MatlabEngine::cd)
763+
.def("setFunction", &MatlabEngine::setFunction)
764+
.def("invoke", overload_cast_<std::vector<double>&, std::vector<double>&,
765+
std::vector<double>&, int, int>()(&MatlabEngine::invoke),
766+
py::arg("params"), py::arg("bulkIn"),
767+
py::arg("bulkOut"), py::arg("contrast"),
768+
py::arg("domain") = DEFAULT_DOMAIN)
769+
.def("invoke", overload_cast_<std::vector<double>&,
770+
std::vector<double>&>()(&MatlabEngine::invoke), py::arg("xdata"), py::arg("param"));
771+
678772
py::class_<PredictionIntervals>(m, "PredictionIntervals", docsPredictionIntervals.c_str())
679773
.def(py::init<>())
680774
.def_readwrite("reflectivity", &PredictionIntervals::reflectivity)

0 commit comments

Comments
 (0)