Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
actor_events: RwLock<Option<mpsc::UnboundedSender<ActorEvent>>>,
pub(super) lifecycle_events: RwLock<Option<mpsc::Sender<LifecycleEvent>>>,
hibernated_connection_liveness_override: RwLock<Option<BTreeSet<(Vec<u8>, Vec<u8>)>>>,
pub(super) lifecycle_event_inbox_capacity: usize,

Check warning on line 146 in rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

View workflow job for this annotation

GitHub Actions / Build rivetkit-wasm

field `lifecycle_event_inbox_capacity` is never read
pub(super) metrics: ActorMetrics,
diagnostics: ActorDiagnostics,
actor_id: String,
Expand Down Expand Up @@ -619,6 +619,12 @@
future.await
}

pub fn keep_awake_region(&self) -> KeepAwakeRegion {
KeepAwakeRegion {
guard: Some(self.keep_awake_guard()),
}
}

pub async fn internal_keep_awake<F>(&self, future: F) -> F::Output
where
F: Future,
Expand Down Expand Up @@ -1313,7 +1319,7 @@
self.reset_sleep_timer();
}

pub(crate) fn sleep_config(&self) -> ActorConfig {

Check warning on line 1322 in rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

View workflow job for this annotation

GitHub Actions / Build rivetkit-wasm

method `sleep_config` is never used
self.sleep_state_config()
}

Expand All @@ -1327,7 +1333,7 @@

fn keep_awake_guard(&self) -> KeepAwakeGuard {
let region = self
.keep_awake_region()
.keep_awake_region_state()
.with_log_fields("keep_awake", Some(self.actor_id().to_owned()));
let guard = KeepAwakeGuard::new(self.clone(), region);
self.reset_sleep_timer();
Expand Down Expand Up @@ -1639,6 +1645,10 @@
guard: Option<WebSocketCallbackGuard>,
}

pub struct KeepAwakeRegion {
guard: Option<KeepAwakeGuard>,
}

impl WebSocketCallbackGuard {
fn new(ctx: ActorContext, kind: UserTaskKind, region: RegionGuard) -> Self {
Self {
Expand All @@ -1665,6 +1675,13 @@
}
}

impl Drop for KeepAwakeRegion {
fn drop(&mut self) {
// Take the guard explicitly to mirror WebSocketCallbackRegion.
self.guard.take();
}
}

impl std::fmt::Debug for ActorContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActorContext")
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) mod work_registry;
pub use action::ActionDispatchError;
pub use config::{ActionDefinition, ActorConfig, ActorConfigOverrides, CanHibernateWebSocket};
pub use connection::ConnHandle;
pub use context::{ActorContext, WebSocketCallbackRegion};
pub use context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
pub use factory::{ActorEntryFn, ActorFactory};
pub use kv::Kv;
pub use lifecycle_hooks::{ActorEvents, ActorStart, Reply};
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@
/// wired, which in practice means test contexts. Production actors built
/// through the registry always have an `ActorTask` and never spawn this
/// detached timer.
pub(crate) fn reset_sleep_timer_state(&self) {

Check warning on line 280 in rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs

View workflow job for this annotation

GitHub Actions / Build rivetkit-wasm

method `reset_sleep_timer_state` is never used
self.cancel_sleep_timer();

#[cfg(not(feature = "wasm-runtime"))]
Expand Down Expand Up @@ -433,7 +433,7 @@
}
}

pub(crate) fn keep_awake_region(&self) -> RegionGuard {
pub(crate) fn keep_awake_region_state(&self) -> RegionGuard {
self.0.sleep.work.keep_awake_guard()
}

Expand Down
2 changes: 1 addition & 1 deletion rivetkit-rust/packages/rivetkit-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub use actor::config::{
ActionDefinition, ActorConfig, ActorConfigInput, ActorConfigOverrides, CanHibernateWebSocket,
};
pub use actor::connection::ConnHandle;
pub use actor::context::{ActorContext, WebSocketCallbackRegion};
pub use actor::context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
pub use actor::factory::{ActorEntryFn, ActorFactory};
pub use actor::kv::Kv;
pub use actor::lifecycle_hooks::{ActorEvents, ActorStart, Reply};
Expand Down
4 changes: 3 additions & 1 deletion rivetkit-typescript/packages/rivetkit-napi/index.d.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* tslint:disable */

Check failure on line 1 in rivetkit-typescript/packages/rivetkit-napi/index.d.ts

View workflow job for this annotation

GitHub Actions / RivetKit / Quality Check

format

Formatter would have printed the following content:
/* eslint-disable */

/* auto-generated by NAPI-RS */
Expand Down Expand Up @@ -228,6 +228,9 @@
aborted(): boolean
runHandlerActive(): boolean
restartRunHandler(): void
beginKeepAwake(): number
endKeepAwake(regionId: number): void
keepAwake(promise: Promise<any>): void
beginWebsocketCallback(): number
endWebsocketCallback(regionId: number): void
abortSignal(): AbortSignal
Expand All @@ -237,7 +240,6 @@
disconnectConns(predicate: (...args: any[]) => any): Promise<void>
broadcast(name: string, args: Buffer): void
waitUntil(promise: Promise<any>): void
keepAwake(promise: Promise<any>): Promise<any>
registerTask(promise: Promise<any>): void
runtimeState(): object
clearRuntimeState(): void
Expand Down
69 changes: 60 additions & 9 deletions rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
};
use napi::{Env, JsFunction, JsObject, Ref};
use napi_derive::napi;
use parking_lot::Mutex;

Check warning on line 18 in rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs
use rivetkit_core::types::ActorKeySegment;
use rivetkit_core::{
ActorContext as CoreActorContext, ConnHandle as CoreConnHandle, Request as CoreRequest,
RequestSaveOpts, StateDelta, WebSocketCallbackRegion,
KeepAwakeRegion, RequestSaveOpts, StateDelta, WebSocketCallbackRegion,
};
use scc::HashMap as SccHashMap;
use tokio::sync::mpsc::UnboundedSender;
Expand Down Expand Up @@ -59,7 +59,9 @@
task_sender: Mutex<Option<UnboundedSender<RegisteredTask>>>,
runtime_state: Mutex<Option<Ref<()>>>,
end_reason: Mutex<Option<EndReason>>,
keep_awake_regions: Mutex<BTreeMap<u32, KeepAwakeRegion>>,
websocket_callback_regions: Mutex<BTreeMap<u32, WebSocketCallbackRegion>>,
next_keep_awake_region_id: AtomicU32,
next_websocket_callback_region_id: AtomicU32,
}

Expand Down Expand Up @@ -464,6 +466,28 @@
self.shared.run_restart().map_err(napi_anyhow_error)
}

#[napi]
pub fn begin_keep_awake(&self) -> u32 {
self.shared.begin_keep_awake(self.inner.keep_awake_region())
}

#[napi]
pub fn end_keep_awake(&self, region_id: u32) {
self.shared.end_keep_awake(region_id);
}

#[napi]
pub fn keep_awake(&self, promise: Promise<serde_json::Value>) -> napi::Result<()> {
let region = self.inner.keep_awake_region();
self.inner.wait_until(async move {
let _region = region;
if let Err(error) = promise.await {
tracing::warn!(?error, "actor keep_awake promise rejected");
}
});
Ok(())
}

#[napi]
pub fn begin_websocket_callback(&self) -> u32 {
self.shared
Expand Down Expand Up @@ -583,14 +607,6 @@
Ok(())
}

#[napi]
pub async fn keep_awake(
&self,
promise: Promise<serde_json::Value>,
) -> napi::Result<serde_json::Value> {
self.inner.keep_awake(promise).await
}

#[napi]
pub fn register_task(&self, promise: Promise<serde_json::Value>) -> napi::Result<()> {
self.shared
Expand Down Expand Up @@ -708,6 +724,39 @@
id
}

fn begin_keep_awake(&self, region: KeepAwakeRegion) -> u32 {
let mut regions = self.keep_awake_regions.lock();
let Some(id) = self.allocate_keep_awake_region_id(&regions) else {
tracing::error!("failed to begin keep-awake region: no region ids available");
return 0;
};
regions.insert(id, region);
id
}

fn end_keep_awake(&self, region_id: u32) {
if region_id == 0 {
return;
}
self.keep_awake_regions.lock().remove(&region_id);
}

fn allocate_keep_awake_region_id(
&self,
regions: &BTreeMap<u32, KeepAwakeRegion>,
) -> Option<u32> {
for _ in 0..=u32::MAX {
let next = self
.next_keep_awake_region_id
.fetch_add(1, Ordering::SeqCst)
.wrapping_add(1);
if next != 0 && !regions.contains_key(&next) {
return Some(next);
}
}
None
}

fn end_websocket_callback(&self, region_id: u32) {
self.websocket_callback_regions.lock().remove(&region_id);
}
Expand All @@ -734,7 +783,9 @@
std::mem::forget(old);
}
*self.end_reason.lock() = None;
*self.keep_awake_regions.lock() = BTreeMap::new();
*self.websocket_callback_regions.lock() = BTreeMap::new();
self.next_keep_awake_region_id.store(0, Ordering::SeqCst);
self.next_websocket_callback_region_id
.store(0, Ordering::SeqCst);
}
Expand Down
57 changes: 51 additions & 6 deletions rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
use rivet_error::{MacroMarker, RivetError as RivetTransportError, RivetErrorSchema};
use rivetkit_core::error::public_error_status_code;
use rivetkit_core::inspector::InspectorAuth;
use rivetkit_core::{

Check warning on line 13 in rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs
ActorConfig, ActorConfigInput, ActorEvent, ActorFactory as CoreActorFactory, ActorStart,
BindParam, ColumnValue, CoreRegistry as NativeCoreRegistry, CoreServerlessRuntime,
EnqueueAndWaitOpts, ListOpts, QueueMessage, QueueNextBatchOpts, QueueSendResult,
QueueSendStatus, QueueTryNextBatchOpts, QueueWaitOpts, Request, RequestSaveOpts, Response,
RuntimeSpawner, SerializeStateReason, ServeConfig, ServerlessRequest, StateDelta, WebSocket,
WebSocketCallbackRegion, WsMessage,
KeepAwakeRegion, WebSocketCallbackRegion, WsMessage,
};
use scc::HashMap as SccHashMap;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -1058,7 +1058,9 @@
inner: rivetkit_core::ActorContext,
callbacks: WasmCallbacks,
runtime_state: JsValue,
keep_awake_regions: Rc<RefCell<HashMap<u32, KeepAwakeRegion>>>,
websocket_callback_regions: Rc<RefCell<HashMap<u32, WebSocketCallbackRegion>>>,
next_keep_awake_region_id: Rc<Cell<u32>>,
next_websocket_callback_region_id: Rc<Cell<u32>>,
}

Expand All @@ -1068,7 +1070,9 @@
inner,
callbacks,
runtime_state: Object::new().into(),
keep_awake_regions: Rc::new(RefCell::new(HashMap::new())),
websocket_callback_regions: Rc::new(RefCell::new(HashMap::new())),
next_keep_awake_region_id: Rc::new(Cell::new(0)),
next_websocket_callback_region_id: Rc::new(Cell::new(0)),
}
}
Expand Down Expand Up @@ -1121,6 +1125,20 @@
}
}

fn allocate_keep_awake_region_id(
&self,
regions: &HashMap<u32, KeepAwakeRegion>,
) -> Option<u32> {
for _ in 0..=u32::MAX {
let next = self.next_keep_awake_region_id.get().wrapping_add(1);
self.next_keep_awake_region_id.set(next);
if next != 0 && !regions.contains_key(&next) {
return Some(next);
}
}
None
}

#[wasm_bindgen]
pub fn kv(&self) -> WasmKv {
WasmKv {
Expand Down Expand Up @@ -1359,11 +1377,19 @@
}

#[wasm_bindgen(js_name = keepAwake)]
pub async fn keep_awake(&self, promise: Promise) -> Result<JsValue, JsValue> {
self.inner
.keep_awake(JsFuture::from(promise))
.await
.map_err(|error| error)
pub fn keep_awake(&self, promise: Promise) {
console_error("keepAwake binding is deprecated; use beginKeepAwake/endKeepAwake");
let region = self.inner.keep_awake_region();
let actor_id = self.inner.actor_id().to_owned();
self.inner.register_task(async move {
let _region = region;
if let Err(error) = JsFuture::from(promise).await {
console_error(&format!(
"actor keepAwake promise rejected for actor {actor_id}: {}",
js_value_to_anyhow(error)
));
}
});
}

#[wasm_bindgen(js_name = registerTask)]
Expand All @@ -1384,6 +1410,25 @@
start_run_handler(&self.callbacks, self);
}

#[wasm_bindgen(js_name = beginKeepAwake)]
pub fn begin_keep_awake(&self) -> u32 {
let mut regions = self.keep_awake_regions.borrow_mut();
let Some(region_id) = self.allocate_keep_awake_region_id(&regions) else {
console_error("failed to begin keep-awake region: no region ids available");
return 0;
};
regions.insert(region_id, self.inner.keep_awake_region());
region_id
}

#[wasm_bindgen(js_name = endKeepAwake)]
pub fn end_keep_awake(&self, region_id: u32) {
if region_id == 0 {
return;
}
self.keep_awake_regions.borrow_mut().remove(&region_id);
}

#[wasm_bindgen(js_name = beginWebsocketCallback)]
pub fn begin_websocket_callback(&self) -> u32 {
let mut regions = self.websocket_callback_regions.borrow_mut();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ import {
sleepWaitUntilState,
sleepWithRawWs,
sleepWsActiveDbExceedsGrace,
sleepKeepAwakeUntilIdle,
} from "./sleep-db";
import { saveStateActor, saveStateObserver } from "./save-state";
import { lifecycleObserver, startStopRaceActor } from "./start-stop-race";
Expand Down Expand Up @@ -221,6 +222,7 @@ export const registry = setup({
sleepWsMessageExceedsGrace,
sleepWsConcurrentDbExceedsGrace,
sleepWsActiveDbExceedsGrace,
sleepKeepAwakeUntilIdle,
saveStateActor,
saveStateObserver,
// From error-handling.ts
Expand Down
Loading
Loading