Skip to content

Commit fcb9bb9

Browse files
author
miranov25
committed
add FormulaLinearModel.py used for the dEdx and distortion calibration
1 parent b8e241e commit fcb9bb9

File tree

1 file changed

+165
-0
lines changed

1 file changed

+165
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
2+
""" FormulaLinearModel.py
3+
import sys,os; sys.path.insert(1, os.environ[f"O2DPG"]+"/UTILS/dfextensions");
4+
from FormulaLinearModel import *
5+
Utility helpers extension for FormulaLinearModel.py
6+
"""
7+
8+
import ast
9+
import numpy as np
10+
from sklearn.linear_model import LinearRegression
11+
12+
import ast
13+
import numpy as np
14+
from sklearn.linear_model import LinearRegression
15+
16+
class FormulaLinearModel:
17+
def __init__(self, name, formulas, target, precision=4, weight_formula=None, var_list=None):
18+
"""
19+
Formula-based linear regression model supporting code export.
20+
21+
:param name: name of the model (used for function naming)
22+
:param formulas: dict of {name: formula_string}, e.g., {'x1': 'v0*var00', 'x2': 'w1*var10'}
23+
:param target: string expression for target variable, e.g., 'log(y)' or 'y'
24+
:param precision: number of significant digits in code export (default: 4)
25+
:param weight_formula: optional string formula for sample weights
26+
:param var_list: optional list of variable names to fix the argument order for C++/JS export
27+
28+
Example usage:
29+
30+
>>> formulas = {'x1': 'v0*var00', 'x2': 'w1*var10'}
31+
>>> model = FormulaLinearModel("myModel", formulas, target='y')
32+
>>> model.fit(df)
33+
>>> df['y_pred'] = model.predict(df)
34+
>>> print(model.to_cpp())
35+
>>> print(model.to_pandas())
36+
>>> print(model.to_javascript())
37+
"""
38+
self.name = name
39+
self.formulas = formulas
40+
self.target = target
41+
self.precision = precision
42+
self.weight_formula = weight_formula
43+
self.model = LinearRegression()
44+
self.feature_names = list(formulas.keys())
45+
46+
extracted_vars = self._extract_variables(from_formulas_only=True)
47+
if var_list:
48+
missing = set(extracted_vars) - set(var_list)
49+
if missing:
50+
raise ValueError(f"Provided var_list is missing variables: {missing}")
51+
self.variables = var_list
52+
else:
53+
self.variables = sorted(extracted_vars)
54+
55+
def _extract_variables(self, debug=False, from_formulas_only=False):
56+
class VarExtractor(ast.NodeVisitor):
57+
def __init__(self):
58+
self.vars = set()
59+
self.funcs = set()
60+
61+
def visit_Name(self, node):
62+
self.vars.add(node.id)
63+
64+
def visit_Call(self, node):
65+
if isinstance(node.func, ast.Name):
66+
self.funcs.add(node.func.id)
67+
self.generic_visit(node)
68+
69+
extractor = VarExtractor()
70+
if from_formulas_only:
71+
all_exprs = list(self.formulas.values())
72+
else:
73+
all_exprs = list(self.formulas.values())
74+
if self.weight_formula:
75+
all_exprs.append(self.weight_formula)
76+
if isinstance(self.target, str):
77+
all_exprs.append(self.target)
78+
79+
for expr in all_exprs:
80+
tree = ast.parse(expr, mode='eval')
81+
extractor.visit(tree)
82+
83+
if debug:
84+
print("Detected variables:", extractor.vars)
85+
print("Detected functions:", extractor.funcs)
86+
87+
return extractor.vars - extractor.funcs
88+
89+
def fit(self, df):
90+
X = np.column_stack([df.eval(expr) for expr in self.formulas.values()])
91+
y = df.eval(self.target) if isinstance(self.target, str) else df[self.target]
92+
if self.weight_formula:
93+
sample_weight = df.eval(self.weight_formula).values
94+
self.model.fit(X, y, sample_weight=sample_weight)
95+
else:
96+
self.model.fit(X, y)
97+
98+
def predict(self, df):
99+
X = np.column_stack([df.eval(expr) for expr in self.formulas.values()])
100+
mask_valid = ~np.isnan(X).any(axis=1)
101+
y_pred = np.full(len(df), np.nan)
102+
y_pred[mask_valid] = self.model.predict(X[mask_valid])
103+
return y_pred
104+
105+
def coef_dict(self):
106+
return dict(zip(self.feature_names, self.model.coef_)), self.model.intercept_
107+
108+
def to_cpp(self):
109+
fmt = f"{{0:.{self.precision}g}}"
110+
coefs, intercept = self.coef_dict()
111+
terms = [f"({fmt.format(coef)})*({self.formulas[name]})" for name, coef in coefs.items()]
112+
expr = " + ".join(terms) + f" + ({fmt.format(intercept)})"
113+
args = ", ".join([f"float {var}" for var in self.variables])
114+
return f"float {self.name}({args}) {{ return {expr}; }}"
115+
116+
def to_pandas(self):
117+
fmt = f"{{0:.{self.precision}g}}"
118+
coefs, intercept = self.coef_dict()
119+
terms = [f"({fmt.format(coef)})*({expr})" for expr, coef in zip(self.formulas.values(), coefs.values())]
120+
return " + ".join(terms) + f" + ({fmt.format(intercept)})"
121+
122+
def to_javascript(self):
123+
fmt = f"{{0:.{self.precision}g}}"
124+
coefs, intercept = self.coef_dict()
125+
terms = [f"({fmt.format(coef)})*({self.formulas[name]})" for name, coef in coefs.items()]
126+
expr = " + ".join(terms) + f" + ({fmt.format(intercept)})"
127+
args = ", ".join(self.variables)
128+
return f"function {self.name}({args}) {{ return {expr}; }}"
129+
130+
def to_cppstd(name, variables, expression, precision=6):
131+
args = ", ".join([f"const std::vector<float>& {v}" for v in variables])
132+
output = [f"std::vector<float> {name}(size_t n, {args}) {{"]
133+
output.append(f" std::vector<float> result(n);")
134+
output.append(f" for (size_t i = 0; i < n; ++i) {{")
135+
for v in variables:
136+
output.append(f" float {v}_i = {v}[i];")
137+
expr_cpp = expression
138+
for v in variables:
139+
expr_cpp = expr_cpp.replace(v, f"{v}_i")
140+
output.append(f" result[i] = {expr_cpp};")
141+
output.append(" }")
142+
output.append(" return result;")
143+
output.append("}")
144+
return "\n".join(output)
145+
146+
147+
def to_cpparrow(name, variables, expression, precision=6):
148+
args = ", ".join([f"const arrow::FloatArray& {v}" for v in variables])
149+
output = [f"std::shared_ptr<arrow::FloatArray> {name}(int64_t n, {args}, arrow::MemoryPool* pool) {{"]
150+
output.append(f" arrow::FloatBuilder builder(pool);")
151+
output.append(f" builder.Reserve(n);")
152+
output.append(f" for (int64_t i = 0; i < n; ++i) {{")
153+
expr_cpp = expression
154+
for v in variables:
155+
output.append(f" float {v}_i = {v}.Value(i);")
156+
expr_cpp = expr_cpp.replace(v, f"{v}_i")
157+
output.append(f" builder.UnsafeAppend({expr_cpp});")
158+
output.append(" }")
159+
output.append(" std::shared_ptr<arrow::FloatArray> result;")
160+
output.append(" builder.Finish(&result);")
161+
output.append(" return result;")
162+
output.append("}")
163+
return "\n".join(output)
164+
165+

0 commit comments

Comments
 (0)