Skip to content

Commit fe0a45e

Browse files
authored
Makes the arguments from RAT_main and events data pickleable (#74)
1 parent b3452f3 commit fe0a45e

File tree

3 files changed

+314
-10
lines changed

3 files changed

+314
-10
lines changed

cpp/rat.cpp

Lines changed: 273 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,12 +1295,54 @@ PYBIND11_MODULE(rat_core, m) {
12951295
.def_readwrite("resample", &PlotEventData::resample)
12961296
.def_readwrite("dataPresent", &PlotEventData::dataPresent)
12971297
.def_readwrite("modelType", &PlotEventData::modelType)
1298-
.def_readwrite("contrastNames", &PlotEventData::contrastNames);
1298+
.def_readwrite("contrastNames", &PlotEventData::contrastNames)
1299+
.def(py::pickle(
1300+
[](const PlotEventData &evt) { // __getstate__
1301+
/* Return a tuple that fully encodes the state of the object */
1302+
return py::make_tuple(evt.reflectivity, evt.shiftedData, evt.sldProfiles, evt.resampledLayers, evt.subRoughs, evt.resample,
1303+
evt.dataPresent, evt.modelType, evt.contrastNames);
1304+
},
1305+
[](py::tuple t) { // __setstate__
1306+
if (t.size() != 9)
1307+
throw std::runtime_error("Encountered invalid state unpickling PlotEventData object!");
1308+
1309+
/* Create a new C++ instance */
1310+
PlotEventData evt;
1311+
1312+
evt.reflectivity = t[0].cast<py::list>();
1313+
evt.shiftedData = t[1].cast<py::list>();
1314+
evt.sldProfiles = t[2].cast<py::list>();
1315+
evt.resampledLayers = t[3].cast<py::list>();
1316+
evt.subRoughs = t[4].cast<py::array_t<double>>();
1317+
evt.resample = t[5].cast<py::array_t<double>>();
1318+
evt.dataPresent = t[6].cast<py::array_t<double>>();
1319+
evt.modelType = t[7].cast<std::string>();
1320+
evt.contrastNames = t[8].cast<py::list>();
1321+
1322+
return evt;
1323+
}));
12991324

13001325
py::class_<ProgressEventData>(m, "ProgressEventData")
13011326
.def(py::init<>())
13021327
.def_readwrite("message", &ProgressEventData::message)
1303-
.def_readwrite("percent", &ProgressEventData::percent);
1328+
.def_readwrite("percent", &ProgressEventData::percent)
1329+
.def(py::pickle(
1330+
[](const ProgressEventData &evt) { // __getstate__
1331+
/* Return a tuple that fully encodes the state of the object */
1332+
return py::make_tuple(evt.message, evt.percent);
1333+
},
1334+
[](py::tuple t) { // __setstate__
1335+
if (t.size() != 2)
1336+
throw std::runtime_error("Encountered invalid state unpickling ProgressEventData object!");
1337+
1338+
/* Create a new C++ instance */
1339+
ProgressEventData evt;
1340+
1341+
evt.message = t[0].cast<std::string>();
1342+
evt.percent = t[1].cast<double>();
1343+
1344+
return evt;
1345+
}));
13041346

13051347
py::class_<ConfidenceIntervals>(m, "ConfidenceIntervals")
13061348
.def(py::init<>())
@@ -1393,7 +1435,31 @@ PYBIND11_MODULE(rat_core, m) {
13931435
.def_readwrite("fitBulkIn", &Checks::fitBulkIn)
13941436
.def_readwrite("fitBulkOut", &Checks::fitBulkOut)
13951437
.def_readwrite("fitResolutionParam", &Checks::fitResolutionParam)
1396-
.def_readwrite("fitDomainRatio", &Checks::fitDomainRatio);
1438+
.def_readwrite("fitDomainRatio", &Checks::fitDomainRatio)
1439+
.def(py::pickle(
1440+
[](const Checks &chk) { // __getstate__
1441+
/* Return a tuple that fully encodes the state of the object */
1442+
return py::make_tuple(chk.fitParam, chk.fitBackgroundParam, chk.fitQzshift, chk.fitScalefactor, chk.fitBulkIn, chk.fitBulkOut,
1443+
chk.fitResolutionParam, chk.fitDomainRatio);
1444+
},
1445+
[](py::tuple t) { // __setstate__
1446+
if (t.size() != 8)
1447+
throw std::runtime_error("Encountered invalid state unpickling Checks object!");
1448+
1449+
/* Create a new C++ instance */
1450+
Checks chk;
1451+
1452+
chk.fitParam = t[0].cast<py::array_t<real_T>>();
1453+
chk.fitBackgroundParam = t[1].cast<py::array_t<real_T>>();
1454+
chk.fitQzshift = t[2].cast<py::array_t<real_T>>();
1455+
chk.fitScalefactor = t[3].cast<py::array_t<real_T>>();
1456+
chk.fitBulkIn = t[4].cast<py::array_t<real_T>>();
1457+
chk.fitBulkOut = t[5].cast<py::array_t<real_T>>();
1458+
chk.fitResolutionParam = t[6].cast<py::array_t<real_T>>();
1459+
chk.fitDomainRatio = t[7].cast<py::array_t<real_T>>();
1460+
1461+
return chk;
1462+
}));
13971463

13981464
py::class_<Limits>(m, "Limits")
13991465
.def(py::init<>())
@@ -1404,7 +1470,31 @@ PYBIND11_MODULE(rat_core, m) {
14041470
.def_readwrite("bulkIn", &Limits::bulkIn)
14051471
.def_readwrite("bulkOut", &Limits::bulkOut)
14061472
.def_readwrite("resolutionParam", &Limits::resolutionParam)
1407-
.def_readwrite("domainRatio", &Limits::domainRatio);
1473+
.def_readwrite("domainRatio", &Limits::domainRatio)
1474+
.def(py::pickle(
1475+
[](const Limits &lim) { // __getstate__
1476+
/* Return a tuple that fully encodes the state of the object */
1477+
return py::make_tuple(lim.param, lim.backgroundParam, lim.qzshift, lim.scalefactor, lim.bulkIn, lim.bulkOut,
1478+
lim.resolutionParam, lim.domainRatio);
1479+
},
1480+
[](py::tuple t) { // __setstate__
1481+
if (t.size() != 8)
1482+
throw std::runtime_error("Encountered invalid state unpickling Limits object!");
1483+
1484+
/* Create a new C++ instance */
1485+
Limits lim;
1486+
1487+
lim.param = t[0].cast<py::array_t<real_T>>();
1488+
lim.backgroundParam = t[1].cast<py::array_t<real_T>>();
1489+
lim.qzshift = t[2].cast<py::array_t<real_T>>();
1490+
lim.scalefactor = t[3].cast<py::array_t<real_T>>();
1491+
lim.bulkIn = t[4].cast<py::array_t<real_T>>();
1492+
lim.bulkOut = t[5].cast<py::array_t<real_T>>();
1493+
lim.resolutionParam = t[6].cast<py::array_t<real_T>>();
1494+
lim.domainRatio = t[7].cast<py::array_t<real_T>>();
1495+
1496+
return lim;
1497+
}));
14081498

14091499
py::class_<Priors>(m, "Priors")
14101500
.def(py::init<>())
@@ -1417,8 +1507,34 @@ PYBIND11_MODULE(rat_core, m) {
14171507
.def_readwrite("resolutionParam", &Priors::resolutionParam)
14181508
.def_readwrite("domainRatio", &Priors::domainRatio)
14191509
.def_readwrite("priorNames", &Priors::priorNames)
1420-
.def_readwrite("priorValues", &Priors::priorValues);
1421-
1510+
.def_readwrite("priorValues", &Priors::priorValues)
1511+
.def(py::pickle(
1512+
[](const Priors &prior) { // __getstate__
1513+
/* Return a tuple that fully encodes the state of the object */
1514+
return py::make_tuple(prior.param, prior.backgroundParam, prior.qzshift, prior.scalefactor, prior.bulkIn,
1515+
prior.bulkOut, prior.resolutionParam, prior.domainRatio, prior.priorNames, prior.priorValues);
1516+
},
1517+
[](py::tuple t) { // __setstate__
1518+
if (t.size() != 10)
1519+
throw std::runtime_error("Encountered invalid state unpickling Limits object!");
1520+
1521+
/* Create a new C++ instance */
1522+
Priors prior;
1523+
1524+
prior.param = t[0].cast<py::list>();
1525+
prior.backgroundParam = t[1].cast<py::list>();
1526+
prior.qzshift = t[2].cast<py::list>();
1527+
prior.scalefactor = t[3].cast<py::list>();
1528+
prior.bulkIn = t[4].cast<py::list>();
1529+
prior.bulkOut = t[5].cast<py::list>();
1530+
prior.resolutionParam = t[6].cast<py::list>();
1531+
prior.domainRatio = t[7].cast<py::list>();
1532+
prior.priorNames = t[8].cast<py::list>();
1533+
prior.priorValues = t[9].cast<py::array_t<real_T>>();
1534+
1535+
return prior;
1536+
}));
1537+
14221538
py::class_<Cells>(m, "Cells")
14231539
.def(py::init<>())
14241540
.def_readwrite("f1", &Cells::f1)
@@ -1441,7 +1557,44 @@ PYBIND11_MODULE(rat_core, m) {
14411557
.def_readwrite("f18", &Cells::f18)
14421558
.def_readwrite("f19", &Cells::f19)
14431559
.def_readwrite("f20", &Cells::f20)
1444-
.def_readwrite("f21", &Cells::f21);
1560+
.def_readwrite("f21", &Cells::f21)
1561+
.def(py::pickle(
1562+
[](const Cells &cell) { // __getstate__
1563+
/* Return a tuple that fully encodes the state of the object */
1564+
return py::make_tuple(cell.f1, cell.f2, cell.f3, cell.f4, cell.f5, cell.f6, cell.f7, cell.f8, cell.f9, cell.f10, cell.f11,
1565+
cell.f12, cell.f13, cell.f14, cell.f15, cell.f16, cell.f17, cell.f18, cell.f19, cell.f20, cell.f21);
1566+
},
1567+
[](py::tuple t) { // __setstate__
1568+
if (t.size() != 21)
1569+
throw std::runtime_error("Encountered invalid state unpickling Cells object!");
1570+
1571+
/* Create a new C++ instance */
1572+
Cells cell;
1573+
1574+
cell.f1 = t[0].cast<py::list>();
1575+
cell.f2 = t[1].cast<py::list>();
1576+
cell.f3 = t[2].cast<py::list>();
1577+
cell.f4 = t[3].cast<py::list>();
1578+
cell.f5 = t[4].cast<py::list>();
1579+
cell.f6 = t[5].cast<py::list>();
1580+
cell.f7 = t[6].cast<py::list>();
1581+
cell.f8 = t[7].cast<py::list>();
1582+
cell.f9 = t[8].cast<py::list>();
1583+
cell.f10 = t[9].cast<py::list>();
1584+
cell.f11 = t[10].cast<py::list>();
1585+
cell.f12 = t[11].cast<py::list>();
1586+
cell.f13 = t[12].cast<py::list>();
1587+
cell.f14 = t[13].cast<py::list>();
1588+
cell.f15 = t[14].cast<py::list>();
1589+
cell.f16 = t[15].cast<py::list>();
1590+
cell.f17 = t[16].cast<py::list>();
1591+
cell.f18 = t[17].cast<py::list>();
1592+
cell.f19 = t[18].cast<py::list>();
1593+
cell.f20 = t[19].cast<py::list>();
1594+
cell.f21 = t[20].cast<py::list>();
1595+
1596+
return cell;
1597+
}));
14451598

14461599
py::class_<Control>(m, "Control")
14471600
.def(py::init<>())
@@ -1473,8 +1626,67 @@ PYBIND11_MODULE(rat_core, m) {
14731626
.def_readwrite("boundHandling", &Control::boundHandling)
14741627
.def_readwrite("adaptPCR", &Control::adaptPCR)
14751628
.def_readwrite("checks", &Control::checks)
1476-
.def_readwrite("IPCFilePath", &Control::IPCFilePath);
1477-
1629+
.def_readwrite("IPCFilePath", &Control::IPCFilePath)
1630+
.def(py::pickle(
1631+
[](const Control &ctrl) { // __getstate__
1632+
/* Return a tuple that fully encodes the state of the object */
1633+
return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance,
1634+
ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability,
1635+
ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale,
1636+
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq,
1637+
ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR,
1638+
ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift,
1639+
ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
1640+
ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio);
1641+
},
1642+
[](py::tuple t) { // __setstate__
1643+
if (t.size() != 36)
1644+
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");
1645+
1646+
/* Create a new C++ instance */
1647+
Control ctrl;
1648+
1649+
ctrl.parallel = t[0].cast<std::string>();
1650+
ctrl.procedure = t[1].cast<std::string>();
1651+
ctrl.display = t[2].cast<std::string>();
1652+
ctrl.xTolerance = t[3].cast<real_T>();
1653+
ctrl.funcTolerance = t[4].cast<real_T>();
1654+
ctrl.maxFuncEvals = t[5].cast<real_T>();
1655+
ctrl.maxIterations = t[6].cast<real_T>();
1656+
ctrl.populationSize = t[7].cast<real_T>();
1657+
ctrl.fWeight = t[8].cast<real_T>();
1658+
ctrl.crossoverProbability = t[9].cast<real_T>();
1659+
ctrl.targetValue = t[10].cast<real_T>();
1660+
ctrl.numGenerations = t[11].cast<real_T>();
1661+
ctrl.strategy = t[12].cast<real_T>();
1662+
ctrl.nLive = t[13].cast<real_T>();
1663+
ctrl.nMCMC = t[14].cast<real_T>();
1664+
ctrl.propScale = t[15].cast<real_T>();
1665+
ctrl.nsTolerance = t[16].cast<real_T>();
1666+
ctrl.calcSldDuringFit = t[17].cast<boolean_T>();
1667+
ctrl.resampleParams = t[18].cast<py::array_t<real_T>>();
1668+
ctrl.updateFreq = t[19].cast<real_T>();
1669+
ctrl.updatePlotFreq = t[20].cast<real_T>();
1670+
ctrl.nSamples = t[21].cast<real_T>();
1671+
ctrl.nChains = t[22].cast<real_T>();
1672+
ctrl.jumpProbability = t[23].cast<real_T>();
1673+
ctrl.pUnitGamma = t[24].cast<real_T>();
1674+
ctrl.boundHandling = t[25].cast<std::string>();
1675+
ctrl.adaptPCR = t[26].cast<boolean_T>();
1676+
ctrl.IPCFilePath = t[27].cast<std::string>();
1677+
1678+
ctrl.checks.fitParam = t[28].cast<py::array_t<real_T>>();
1679+
ctrl.checks.fitBackgroundParam = t[29].cast<py::array_t<real_T>>();
1680+
ctrl.checks.fitQzshift = t[30].cast<py::array_t<real_T>>();
1681+
ctrl.checks.fitScalefactor = t[31].cast<py::array_t<real_T>>();
1682+
ctrl.checks.fitBulkIn = t[32].cast<py::array_t<real_T>>();
1683+
ctrl.checks.fitBulkOut = t[33].cast<py::array_t<real_T>>();
1684+
ctrl.checks.fitResolutionParam = t[34].cast<py::array_t<real_T>>();
1685+
ctrl.checks.fitDomainRatio = t[35].cast<py::array_t<real_T>>();
1686+
1687+
return ctrl;
1688+
}));
1689+
14781690
py::class_<ProblemDefinition>(m, "ProblemDefinition")
14791691
.def(py::init<>())
14801692
.def_readwrite("contrastBackgroundParams", &ProblemDefinition::contrastBackgroundParams)
@@ -1507,7 +1719,58 @@ PYBIND11_MODULE(rat_core, m) {
15071719
.def_readwrite("fitParams", &ProblemDefinition::fitParams)
15081720
.def_readwrite("otherParams", &ProblemDefinition::otherParams)
15091721
.def_readwrite("fitLimits", &ProblemDefinition::fitLimits)
1510-
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits);
1722+
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits)
1723+
.def(py::pickle(
1724+
[](const ProblemDefinition &p) { // __getstate__
1725+
/* Return a tuple that fully encodes the state of the object */
1726+
return py::make_tuple(p.contrastBackgroundParams, p.contrastBackgroundActions, p.TF, p.resample, p.dataPresent, p.oilChiDataPresent,
1727+
p.numberOfContrasts, p.geometry, p.useImaginary, p.contrastQzshifts, p.contrastScalefactors,
1728+
p.contrastBulkIns, p.contrastBulkOuts, p.contrastResolutionParams, p.backgroundParams,
1729+
p.qzshifts, p.scalefactors, p.bulkIn, p.bulkOut, p.resolutionParams, p.params,
1730+
p.numberOfLayers, p.modelType, p.contrastCustomFiles, p.contrastDomainRatios,
1731+
p.domainRatio, p.numberOfDomainContrasts, p.fitParams, p.otherParams, p.fitLimits, p.otherLimits);
1732+
},
1733+
[](py::tuple t) { // __setstate__
1734+
if (t.size() != 31)
1735+
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");
1736+
1737+
/* Create a new C++ instance */
1738+
ProblemDefinition p;
1739+
1740+
p.contrastBackgroundParams = t[0].cast<py::array_t<real_T>>();
1741+
p.contrastBackgroundActions = t[1].cast<py::array_t<real_T>>();
1742+
p.TF = t[2].cast<std::string>();
1743+
p.resample = t[3].cast<py::array_t<real_T>>();
1744+
p.dataPresent = t[4].cast<py::array_t<real_T>>();
1745+
p.oilChiDataPresent = t[5].cast<py::array_t<real_T>>();
1746+
p.numberOfContrasts = t[6].cast<real_T>();
1747+
p.geometry = t[7].cast<std::string>();
1748+
p.useImaginary = t[8].cast<bool>();
1749+
p.contrastQzshifts = t[9].cast<py::array_t<real_T>>();
1750+
p.contrastScalefactors = t[10].cast<py::array_t<real_T>>();
1751+
p.contrastBulkIns = t[11].cast<py::array_t<real_T>>();
1752+
p.contrastBulkOuts = t[12].cast<py::array_t<real_T>>();
1753+
p.contrastResolutionParams = t[13].cast<py::array_t<real_T>>();
1754+
p.backgroundParams = t[14].cast<py::array_t<real_T>>();
1755+
p.qzshifts = t[15].cast<py::array_t<real_T>>();
1756+
p.scalefactors = t[16].cast<py::array_t<real_T>>();
1757+
p.bulkIn= t[17].cast<py::array_t<real_T>>();
1758+
p.bulkOut= t[18].cast<py::array_t<real_T>>();
1759+
p.resolutionParams= t[19].cast<py::array_t<real_T>>();
1760+
p.params = t[20].cast<py::array_t<real_T>>(),
1761+
p.numberOfLayers = t[21].cast<real_T>();
1762+
p.modelType = t[22].cast<std::string>();
1763+
p.contrastCustomFiles = t[23].cast<py::array_t<real_T>>();
1764+
p.contrastDomainRatios = t[24].cast<py::array_t<real_T>>(),
1765+
p.domainRatio = t[25].cast<py::array_t<real_T>>();
1766+
p.numberOfDomainContrasts = t[26].cast<real_T>();
1767+
p.fitParams = t[27].cast<py::array_t<real_T>>();
1768+
p.otherParams = t[28].cast<py::array_t<real_T>>();
1769+
p.fitLimits = t[29].cast<py::array_t<real_T>>();
1770+
p.otherLimits = t[30].cast<py::array_t<real_T>>();
1771+
1772+
return p;
1773+
}));
15111774

15121775
m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");
15131776

tests/test_events.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import pickle
12
from unittest import mock
23

4+
import numpy as np
35
import pytest
46

57
import RATapi.events
@@ -64,3 +66,36 @@ def test_event_notify() -> None:
6466
assert first_callback.call_count == 1
6567
assert second_callback.call_count == 1
6668
assert third_callback.call_count == 1
69+
70+
71+
def test_event_data_pickle():
72+
data = RATapi.events.ProgressEventData()
73+
data.message = "Hello"
74+
data.percent = 0.5
75+
pickled_data = pickle.loads(pickle.dumps(data))
76+
assert pickled_data.message == data.message
77+
assert pickled_data.percent == data.percent
78+
79+
data = RATapi.events.PlotEventData()
80+
data.modelType = "custom layers"
81+
data.dataPresent = np.ones(2)
82+
data.subRoughs = np.ones((20, 2))
83+
data.resample = np.ones(2)
84+
data.resampledLayers = [np.ones((20, 2)), np.ones((20, 2))]
85+
data.reflectivity = [np.ones((20, 2)), np.ones((20, 2))]
86+
data.shiftedData = [np.ones((20, 2)), np.ones((20, 2))]
87+
data.sldProfiles = [np.ones((20, 2)), np.ones((20, 2))]
88+
data.contrastNames = ["D2O", "SMW"]
89+
90+
pickled_data = pickle.loads(pickle.dumps(data))
91+
92+
assert pickled_data.modelType == data.modelType
93+
assert (pickled_data.dataPresent == data.dataPresent).all()
94+
assert (pickled_data.subRoughs == data.subRoughs).all()
95+
assert (pickled_data.resample == data.resample).all()
96+
for i in range(2):
97+
assert (pickled_data.resampledLayers[i] == data.resampledLayers[i]).all()
98+
assert (pickled_data.reflectivity[i] == data.reflectivity[i]).all()
99+
assert (pickled_data.shiftedData[i] == data.shiftedData[i]).all()
100+
assert (pickled_data.sldProfiles[i] == data.sldProfiles[i]).all()
101+
assert pickled_data.contrastNames == data.contrastNames

0 commit comments

Comments
 (0)