Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions core/pioreactor/calibrations/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import typing as t
from collections.abc import Callable

from pioreactor import structs
from pioreactor.utils.timing import current_utc_datestamp
from pioreactor.utils.timing import current_utc_datetime


def pool_od_calibrations(
calibrations: list[structs.ODCalibration | structs.OD600Calibration],
fit: t.Literal["spline", "poly", "akima"] = "spline",
) -> structs.OD600Calibration | structs.ODCalibration:
"""
Merge raw recorded_data from multiple OD calibrations and refit a new curve.
"""
if not calibrations:
raise ValueError("No calibrations provided for pooling.")

if len(calibrations) == 1:
# Just copy and rename
cal = calibrations[0]
base_class = type(cal)
new_name = f"pooled-od{cal.angle}-from-1-unit-{current_utc_datestamp()}"

kwargs = {f: getattr(cal, f) for f in cal.__struct_fields__}
kwargs["calibration_name"] = new_name
kwargs["calibrated_on_pioreactor_unit"] = "$cluster"
kwargs["created_at"] = current_utc_datetime()
return base_class(**kwargs)

# Validation: must share angle and pd_channel
first_cal = calibrations[0]
angle = first_cal.angle
pd_channel = first_cal.pd_channel
ir_led_intensity = first_cal.ir_led_intensity

# Check compatibility
for cal in calibrations[1:]:
if cal.angle != angle:
raise ValueError(f"Incompatible angles: {angle} != {cal.angle}")
if cal.pd_channel != pd_channel:
raise ValueError(f"Incompatible pd_channels: {pd_channel} != {cal.pd_channel}")

# ir_led_intensity must be within 5%
if ir_led_intensity == 0:
if cal.ir_led_intensity != 0:
raise ValueError("Incompatible ir_led_intensity: 0 vs non-zero")
elif abs(cal.ir_led_intensity - ir_led_intensity) / ir_led_intensity > 0.05:
raise ValueError(
f"Incompatible ir_led_intensity: {ir_led_intensity} and {cal.ir_led_intensity} differ by > 5%"
)

# Merging
merged_x: list[float] = []
merged_y: list[float] = []
weights: list[float] = []

for cal in calibrations:
x_data = cal.recorded_data["x"]
y_data = cal.recorded_data["y"]
count = len(x_data)
if count == 0:
continue

merged_x.extend(x_data)
merged_y.extend(y_data)

# Equal weight for each point
weights.extend([1.0] * count)

if not merged_x:
raise ValueError("No recorded data found in any provided calibrations.")

# Refit
if fit == "poly":
from pioreactor.calibrations.utils import calculate_poly_curve_of_best_fit

curve_data = calculate_poly_curve_of_best_fit(merged_x, merged_y, degree=2, weights=weights)
elif fit == "spline":
from pioreactor.utils.splines import spline_fit

knots_count = min(4, len(set(merged_x)))
curve_data = spline_fit(merged_x, merged_y, knots=max(2, knots_count), weights=weights) # type: ignore
elif fit == "akima":
from pioreactor.utils.akimas import akima_fit

curve_data = akima_fit(merged_x, merged_y) # type: ignore
else:
raise ValueError(f"Unsupported fit type: {fit}")

new_name = f"pooled-od{angle}-{current_utc_datestamp()}"

kwargs = {f: getattr(first_cal, f) for f in first_cal.__struct_fields__}
kwargs["calibration_name"] = new_name
kwargs["calibrated_on_pioreactor_unit"] = "$cluster"
kwargs["created_at"] = current_utc_datetime()
kwargs["curve_data_"] = curve_data
kwargs["recorded_data"] = {"x": merged_x, "y": merged_y}

base_class = type(first_cal)
return base_class(**kwargs)


_POOLING_HANDLERS: dict[str, Callable] = {
"od": pool_od_calibrations,
"od600": pool_od_calibrations,
}
127 changes: 126 additions & 1 deletion core/pioreactor/web/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,11 +1753,136 @@ def get_all_active_estimators(pioreactor_unit: str) -> DelayedResponseReturnValu
def get_all_estimators(pioreactor_unit: str) -> DelayedResponseReturnValue:
if pioreactor_unit == UNIVERSAL_IDENTIFIER:
task = cache.cached_multicast_get(cache.ESTIMATORS, get_all_workers())
return create_task_response(task)


@api_bp.route("/cluster/calibrations/<device>/pool", methods=["POST"])
def pool_calibrations(device: str) -> ResponseReturnValue:
payload = request.get_json(silent=True) or {}
donor_units = payload.get("donor_units")

if donor_units:
task = tasks.multicast_get(f"/unit_api/calibrations/{device}/active", donor_units, return_raw=True)
else:
task = cache.cached_multicast_get(cache.ESTIMATORS, [pioreactor_unit])
task = fanout.broadcast_get_across_workers(f"/unit_api/calibrations/{device}/active", return_raw=True)

try:
results = task.get(blocking=True, timeout=15)
except (HueyException, TaskException):
abort_with(500, "Timed out fetching active calibrations from workers")

donors = []
skipped = []
calibrations = []

from pioreactor.structs import AllCalibrations
from pioreactor.utils import yaml_decode, yaml_encode

for worker, result in results.items():
if result is None:
skipped.append(worker)
continue
try:
cal = yaml_decode(result, type=AllCalibrations)
calibrations.append(cal)
donors.append(worker)
except Exception:
skipped.append(worker)

from pioreactor.calibrations.pooling import _POOLING_HANDLERS

if device.startswith("od"):
handler = _POOLING_HANDLERS.get("od")
else:
handler = _POOLING_HANDLERS.get(device)

if not handler:
abort_with(400, f"Pooling not supported for device {device}")

try:
pooled = handler(calibrations)
except Exception as e:
abort_with(400, f"Failed to pool calibrations: {e}")

return jsonify({
"calibration_data": yaml_encode(pooled).decode("utf-8"),
"calibration_name": pooled.calibration_name,
"donors": donors,
"skipped": skipped
})


@api_bp.route("/workers/<pioreactor_unit>/calibrations/<device>/apply", methods=["POST"])
def apply_calibration(pioreactor_unit: str, device: str) -> DelayedResponseReturnValue:
payload = request.get_json(silent=True) or {}
if "calibration_data" not in payload:
abort_with(400, "Missing calibration_data in payload")

payload["set_as_active"] = True

if pioreactor_unit == UNIVERSAL_IDENTIFIER:
task = tasks.multicast_post(f"/unit_api/calibrations/{device}", get_all_workers(), json=payload)
else:
task = tasks.multicast_post(f"/unit_api/calibrations/{device}", [pioreactor_unit], json=payload)

# Invalidate cache
if pioreactor_unit == UNIVERSAL_IDENTIFIER:
cache.cache.delete_memoized(cache.get_all_calibrations)
else:
cache.cache.delete_memoized(cache.get_all_calibrations, pioreactor_unit)

return create_task_response(task)


@api_bp.route("/workers/<source_unit>/calibrations/<device>/copy_to/<target_unit>", methods=["POST"])
def copy_calibration(source_unit: str, device: str, target_unit: str) -> DelayedResponseReturnValue:
payload = request.get_json(silent=True) or {}
calibration_name = payload.get("calibration_name")

if calibration_name:
task = tasks.multicast_get(f"/unit_api/calibrations/{device}/{calibration_name}", [source_unit], return_raw=True)
else:
task = tasks.multicast_get(f"/unit_api/calibrations/{device}/active", [source_unit], return_raw=True)

try:
result = task.get(blocking=True, timeout=10)
except (HueyException, TaskException):
abort_with(500, "Timed out fetching calibration from source unit")

source_result = result.get(source_unit)
if source_result is None:
abort_with(404, "Source unit did not return a calibration")

from pioreactor.structs import AllCalibrations
from pioreactor.utils import yaml_decode, yaml_encode
from pioreactor.utils.timing import current_utc_datestamp, current_utc_datetime

try:
cal = yaml_decode(source_result, type=AllCalibrations)
except Exception as e:
abort_with(500, f"Failed to decode calibration from source unit: {e}")

kwargs = {f: getattr(cal, f) for f in cal.__struct_fields__}
kwargs["calibration_name"] = f"copy-from-{source_unit}-{current_utc_datestamp()}"
kwargs["calibrated_on_pioreactor_unit"] = "$cluster"
kwargs["created_at"] = current_utc_datetime()
new_cal = type(cal)(**kwargs) # type: ignore

apply_payload = {
"calibration_data": yaml_encode(new_cal).decode("utf-8"),
"set_as_active": True
}

if target_unit == UNIVERSAL_IDENTIFIER:
post_task = tasks.multicast_post(f"/unit_api/calibrations/{device}", get_all_workers(), json=apply_payload)
cache.cache.delete_memoized(cache.get_all_calibrations)
else:
post_task = tasks.multicast_post(f"/unit_api/calibrations/{device}", [target_unit], json=apply_payload)
cache.cache.delete_memoized(cache.get_all_calibrations, target_unit)

return create_task_response(post_task)


@api_bp.route("/workers/<pioreactor_unit>/zipped_calibrations", methods=["GET"])
def get_zipped_calibrations(pioreactor_unit: str) -> ResponseReturnValue:
if pioreactor_unit == UNIVERSAL_IDENTIFIER:
Expand Down
40 changes: 38 additions & 2 deletions core/pioreactor/web/unit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,15 @@ def create_calibration(device: str) -> ResponseReturnValue:
remediation="Check file permissions and server logs.",
)

activated = False
if set_as_active:
with local_persistent_storage("active_calibrations") as c:
c[device] = calibration_name
if device not in c:
c[device] = calibration_name
activated = True

# Respond with success and the created calibration details
response = jsonify({"msg": "Calibration created successfully.", "path": str(path)})
response = jsonify({"msg": "Calibration created successfully.", "path": str(path), "activated": activated})
response.status_code = 201
return response

Expand Down Expand Up @@ -1372,6 +1375,39 @@ def delete_calibration(device: str, calibration_name: str) -> ResponseReturnValu
)


@unit_api_bp.route("/calibrations/<device>/active", methods=["GET"])
def get_active_calibration(device: str) -> ResponseReturnValue:
with local_persistent_storage("active_calibrations") as c:
if device not in c:
abort_with(404, description=f"No active calibration for {device}.")
calibration_name = str(c[device])

calibration_path = CALIBRATION_PATH / device / f"{calibration_name}.yaml"
if not calibration_path.exists():
abort_with(404, description=f"Active calibration file for {device} missing.")

try:
raw_yaml = calibration_path.read_text()
return attach_cache_control(Response(response=raw_yaml, status=200, mimetype="application/yaml"))
except Exception as e:
publish_to_error_log(f"Error reading active calibration: {e}", "get_active_calibration")
abort_with(500, description="Failed to read active calibration.")


@unit_api_bp.route("/calibrations/<device>/<calibration_name>", methods=["GET"])
def get_calibration(device: str, calibration_name: str) -> ResponseReturnValue:
calibration_path = CALIBRATION_PATH / device / f"{calibration_name}.yaml"
if not calibration_path.exists():
abort_with(404, description=f"Calibration file for {device} missing.")

try:
raw_yaml = calibration_path.read_text()
return attach_cache_control(Response(response=raw_yaml, status=200, mimetype="application/yaml"))
except Exception as e:
publish_to_error_log(f"Error reading calibration: {e}", "get_calibration")
abort_with(500, description="Failed to read calibration.")


@unit_api_bp.route("/calibrations", methods=["GET"])
def get_all_calibrations() -> ResponseReturnValue:
calibration_dir = CALIBRATION_PATH
Expand Down
Loading