Skip to content

Commit b5f2664

Browse files
authored
Refactor TaskError to hold Control and NonControl (#85)
* Refactor TaskError to hold Control and NonControl This will allow us to reduce error type duplication in tensorzero/tensorzero by removing the NonControlToolError type * Run fmt
1 parent 03f2c5f commit b5f2664

5 files changed

Lines changed: 142 additions & 93 deletions

File tree

src/context.rs

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use uuid::Uuid;
77

88
use crate::Durable;
99
use crate::error::suspend_handle::SuspendMarker;
10-
use crate::error::{ControlFlow, TaskError, TaskResult};
10+
use crate::error::{ControlFlow, NonControlTaskError, TaskError, TaskResult};
1111
use std::sync::Arc;
1212

1313
use crate::heartbeat::{HeartbeatHandle, Heartbeater, StepState};
@@ -93,9 +93,9 @@ where
9393
/// Validate that a user-provided step name doesn't use reserved prefix.
9494
fn validate_user_name(name: &str) -> TaskResult<()> {
9595
if name.starts_with('$') {
96-
return Err(TaskError::Validation {
96+
return Err(TaskError::NonControl(NonControlTaskError::Validation {
9797
message: "Step names cannot start with '$' (reserved for internal use)".to_string(),
98-
});
98+
}));
9999
}
100100
Ok(())
101101
}
@@ -106,9 +106,9 @@ where
106106
{
107107
pub(crate) fn mark_suspended(&mut self) -> TaskResult<()> {
108108
if self.has_suspended {
109-
return Err(TaskError::Validation {
109+
return Err(TaskError::NonControl(NonControlTaskError::Validation {
110110
message: "Task has already been suspended during this execution".to_string(),
111-
});
111+
}));
112112
}
113113
self.has_suspended = true;
114114
Ok(())
@@ -242,9 +242,11 @@ where
242242
state: self.durable.state().clone(),
243243
heartbeater: Arc::new(self.heartbeat_handle.clone()),
244244
};
245-
let result = f(params, step_state).await.map_err(|e| TaskError::Step {
246-
base_name: base_name.to_string(),
247-
error: e,
245+
let result = f(params, step_state).await.map_err(|e| {
246+
TaskError::NonControl(NonControlTaskError::Step {
247+
base_name: base_name.to_string(),
248+
error: e,
249+
})
248250
})?;
249251

250252
// Persist checkpoint (also extends claim lease)
@@ -423,9 +425,9 @@ where
423425
// Check if we were woken by this event but it timed out (null payload)
424426
if self.task.wake_event.as_deref() == Some(event_name) && self.task.event_payload.is_none()
425427
{
426-
return Err(TaskError::Timeout {
428+
return Err(TaskError::NonControl(NonControlTaskError::Timeout {
427429
step_name: event_name.to_string(),
428-
});
430+
}));
429431
}
430432

431433
// Call await_event stored procedure
@@ -505,9 +507,11 @@ where
505507
self.durable
506508
.emit_event(event_name, payload, None)
507509
.await
508-
.map_err(|e| TaskError::EmitEventFailed {
509-
event_name: event_name.to_string(),
510-
error: e,
510+
.map_err(|e| {
511+
TaskError::NonControl(NonControlTaskError::EmitEventFailed {
512+
event_name: event_name.to_string(),
513+
error: e,
514+
})
511515
})
512516
}
513517

@@ -566,8 +570,10 @@ where
566570
if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) {
567571
let stored: String = serde_json::from_value(cached.clone())?;
568572
return Ok(DateTime::parse_from_rfc3339(&stored)
569-
.map_err(|e| TaskError::Validation {
570-
message: format!("Invalid stored time: {e}"),
573+
.map_err(|e| {
574+
TaskError::NonControl(NonControlTaskError::Validation {
575+
message: format!("Invalid stored time: {e}"),
576+
})
571577
})?
572578
.with_timezone(&Utc));
573579
}
@@ -716,9 +722,11 @@ where
716722
},
717723
)
718724
.await
719-
.map_err(|e| TaskError::SubtaskSpawnFailed {
720-
name: task_name.to_string(),
721-
error: e,
725+
.map_err(|e| {
726+
TaskError::NonControl(NonControlTaskError::SubtaskSpawnFailed {
727+
name: task_name.to_string(),
728+
error: e,
729+
})
722730
})?;
723731
// Checkpoint the spawn
724732
self.persist_checkpoint(&checkpoint_name, &spawned_task.task_id)
@@ -781,9 +789,9 @@ where
781789
// Check if we were woken by this event but it timed out (null payload)
782790
if self.task.wake_event.as_deref() == Some(&event_name) && self.task.event_payload.is_none()
783791
{
784-
return Err(TaskError::Timeout {
792+
return Err(TaskError::NonControl(NonControlTaskError::Timeout {
785793
step_name: step_name.to_string(),
786-
});
794+
}));
787795
}
788796

789797
// Call await_event stored procedure (no timeout for join - we wait indefinitely)
@@ -829,8 +837,10 @@ where
829837
) -> TaskResult<T> {
830838
match payload.status {
831839
ChildStatus::Completed => {
832-
let result = payload.result.ok_or_else(|| TaskError::Validation {
833-
message: "Child completed but no result available".to_string(),
840+
let result = payload.result.ok_or_else(|| {
841+
TaskError::NonControl(NonControlTaskError::Validation {
842+
message: "Child completed but no result available".to_string(),
843+
})
834844
})?;
835845
Ok(serde_json::from_value(result)?)
836846
}
@@ -839,14 +849,16 @@ where
839849
.error
840850
.and_then(|e| e.get("message").and_then(|m| m.as_str()).map(String::from))
841851
.unwrap_or_else(|| "Unknown error".to_string());
842-
Err(TaskError::ChildFailed {
852+
Err(TaskError::NonControl(NonControlTaskError::ChildFailed {
843853
step_name: step_name.to_string(),
844854
message,
845-
})
855+
}))
856+
}
857+
ChildStatus::Cancelled => {
858+
Err(TaskError::NonControl(NonControlTaskError::ChildCancelled {
859+
step_name: step_name.to_string(),
860+
}))
846861
}
847-
ChildStatus::Cancelled => Err(TaskError::ChildCancelled {
848-
step_name: step_name.to_string(),
849-
}),
850862
}
851863
}
852864
}

0 commit comments

Comments
 (0)