Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions engine/packages/guard/src/routing/actor_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ pub enum ParsedActorPath {
Query(QueryActorPathInfo),
}

pub fn is_actor_gateway_path(path: &str) -> bool {
let (base_path, _) = split_path_and_query(path);

if base_path.contains("//") {
return false;
}

base_path
.split('/')
.filter(|segment| !segment.is_empty())
.next()
== Some("gateway")
}

/// Parsed rvt-* query parameters.
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
Expand Down
14 changes: 9 additions & 5 deletions engine/packages/guard/src/routing/pegboard_gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::{
use crate::{
errors,
routing::{
actor_path::parse_actor_path,
actor_path::{is_actor_gateway_path, parse_actor_path},
pegboard_gateway::resolve_actor_query::ResolveQueryActorResult,
},
shared_state::SharedState,
Expand Down Expand Up @@ -56,14 +56,18 @@ pub async fn route_request_path_based_inner(
shared_state: &SharedState,
req_ctx: &mut RequestContext,
) -> Result<Option<RoutingOutput>> {
if req_ctx.method() == hyper::Method::OPTIONS {
if is_actor_gateway_path(req_ctx.path()) {
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
}

return Ok(None);
}

let Some(actor_path) = parse_actor_path(req_ctx.path())? else {
return Ok(None);
};

if req_ctx.method() == hyper::Method::OPTIONS {
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
}

tracing::debug!(?actor_path, "routing using path-based actor routing");

let (actor_id, token, stripped_path, bypass_connectable) = match actor_path {
Expand Down
67 changes: 57 additions & 10 deletions engine/packages/guard/tests/parse_actor_path.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Keep this test suite in sync with the TypeScript equivalent at
// rivetkit-typescript/packages/rivetkit/tests/parse-actor-path.test.ts
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rivet_guard::routing::actor_path::{ParsedActorPath, QueryActorQuery, parse_actor_path};
use rivet_guard::routing::actor_path::{
ParsedActorPath, QueryActorQuery, is_actor_gateway_path, parse_actor_path,
};

#[test]
fn parses_direct_actor_paths_with_existing_behavior() {
Expand Down Expand Up @@ -37,6 +39,7 @@ fn parses_query_actor_get_paths() {
"shard-2".to_string(),
"alpha@beta".to_string(),
],
bypass_connectable: false,
}
);
}
Expand Down Expand Up @@ -66,11 +69,12 @@ fn parses_query_actor_get_or_create_paths_with_input_and_region() {
QueryActorQuery::GetOrCreate {
namespace: "default".to_string(),
name: "worker".to_string(),
runner_name: "default".to_string(),
pool_name: "default".to_string(),
key: vec!["shard-1".to_string()],
input: Some(input_bytes),
region: Some("us-west-2".to_string()),
crash_policy: None,
bypass_connectable: false,
}
);
}
Expand All @@ -95,11 +99,12 @@ fn parses_query_actor_get_or_create_paths_with_multi_component_key() {
QueryActorQuery::GetOrCreate {
namespace: "default".to_string(),
name: "worker".to_string(),
runner_name: "default".to_string(),
pool_name: "default".to_string(),
key: vec!["tenant".to_string(), "job".to_string()],
input: Some(input_bytes),
region: None,
crash_policy: None,
bypass_connectable: false,
}
);
assert_eq!(path.stripped_path, "/socket");
Expand All @@ -121,6 +126,7 @@ fn parses_query_actor_get_paths_with_empty_key() {
namespace: "default".to_string(),
name: "lobby".to_string(),
key: Vec::new(),
bypass_connectable: false,
}
);
assert_eq!(path.stripped_path, "/");
Expand All @@ -141,11 +147,12 @@ fn omits_key_when_not_present() {
QueryActorQuery::GetOrCreate {
namespace: "default".to_string(),
name: "builder".to_string(),
runner_name: "default".to_string(),
pool_name: "default".to_string(),
key: Vec::new(),
input: None,
region: None,
crash_policy: None,
bypass_connectable: false,
}
);
assert_eq!(path.stripped_path, "/");
Expand All @@ -167,6 +174,7 @@ fn parses_simple_multi_component_keys() {
namespace: "default".to_string(),
name: "lobby".to_string(),
key: vec!["a".to_string(), "b".to_string(), "c".to_string()],
bypass_connectable: false,
}
);
}
Expand All @@ -186,18 +194,55 @@ fn parses_crash_policy_param() {
QueryActorQuery::GetOrCreate {
namespace: "default".to_string(),
name: "worker".to_string(),
runner_name: "default".to_string(),
pool_name: "default".to_string(),
key: Vec::new(),
input: None,
region: None,
crash_policy: Some(rivet_types::actors::CrashPolicy::Restart),
bypass_connectable: false,
}
);
}
ParsedActorPath::Direct(_) => panic!("expected query actor path"),
}
}

#[test]
fn parses_bypass_connectable_query_bool_strings() {
let path = "/gateway/worker/request/bypass?rvt-namespace=default&rvt-method=getOrCreate&rvt-runner=default&rvt-bypass_connectable=true";
let result = parse_actor_path(path).unwrap().unwrap();

match result {
ParsedActorPath::Query(path) => {
assert_eq!(
path.query,
QueryActorQuery::GetOrCreate {
namespace: "default".to_string(),
name: "worker".to_string(),
pool_name: "default".to_string(),
key: Vec::new(),
input: None,
region: None,
crash_policy: None,
bypass_connectable: true,
}
);
assert_eq!(path.stripped_path, "/request/bypass");
}
ParsedActorPath::Direct(_) => panic!("expected query actor path"),
}
}

#[test]
fn identifies_gateway_paths_without_parsing_query_params() {
assert!(is_actor_gateway_path(
"/gateway/worker/request/bypass?rvt-bypass_connectable=true"
));
assert!(is_actor_gateway_path("/gateway/actor-id"));
assert!(!is_actor_gateway_path("/request/bypass"));
assert!(!is_actor_gateway_path("/gateway//worker"));
}

#[test]
fn strips_rvt_params_from_remaining_path() {
let path = "/gateway/lobby/api/v1?rvt-namespace=prod&rvt-method=get&foo=bar&baz=qux";
Expand Down Expand Up @@ -272,6 +317,7 @@ fn handles_interleaved_rvt_and_actor_params() {
namespace: "default".to_string(),
name: "lobby".to_string(),
key: Vec::new(),
bypass_connectable: false,
}
);
}
Expand All @@ -295,6 +341,7 @@ fn decodes_plus_as_space_in_rvt_values() {
namespace: "my ns".to_string(),
name: "lobby".to_string(),
key: vec!["hello world".to_string()],
bypass_connectable: false,
}
);
// Actor param + is preserved literally.
Expand Down Expand Up @@ -421,7 +468,7 @@ fn rejects_input_for_get_queries() {
.unwrap_err()
.to_string();
assert!(err.contains(
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
));
}

Expand All @@ -433,7 +480,7 @@ fn rejects_region_for_get_queries() {
.unwrap_err()
.to_string();
assert!(err.contains(
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
));
}

Expand All @@ -445,7 +492,7 @@ fn rejects_crash_policy_for_get_queries() {
.unwrap_err()
.to_string();
assert!(err.contains(
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
));
}

Expand All @@ -456,7 +503,7 @@ fn rejects_runner_for_get_queries() {
.unwrap_err()
.to_string();
assert!(err.contains(
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
));
}

Expand All @@ -465,7 +512,7 @@ fn rejects_missing_runner_for_get_or_create_queries() {
let err = parse_actor_path("/gateway/lobby?rvt-namespace=default&rvt-method=getOrCreate")
.unwrap_err()
.to_string();
assert!(err.contains("query gateway method=getOrCreate requires rvt-runner param"));
assert!(err.contains("query gateway method=getOrCreate requires rvt-pool param"));
}

#[test]
Expand Down
Loading