Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ export interface ActorGatewayOptions {
bypassConnectable?: boolean;
}

export type ResolvedActorGatewayOptions = Required<ActorGatewayOptions>;

export function resolveActorGatewayOptions(
defaults: ActorGatewayOptions = {},
overrides?: ActorGatewayOptions,
): ResolvedActorGatewayOptions {
return {
bypassConnectable:
overrides?.bypassConnectable ?? defaults.bypassConnectable ?? false,
};
}

export interface ActorActionOptions {
gateway?: ActorGatewayOptions;
signal?: AbortSignal;
}

export interface ActorConnectOptions {
gateway?: ActorGatewayOptions;
}

export interface ActorFetchInit extends RequestInit {
gateway?: ActorGatewayOptions;
}
Expand Down
32 changes: 31 additions & 1 deletion rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ import type {
ActorDefinitionActions,
ActorDefinitionEventSubscriptions,
ActorDefinitionQueueSend,
ActorGatewayOptions,
ResolvedActorGatewayOptions,
} from "./actor-common";
import { resolveActorGatewayOptions } from "./actor-common";
import {
type ActorResolutionState,
checkForSchedulingError,
Expand All @@ -53,6 +56,7 @@ import {
type QueueSendResult,
type QueueSendWaitOptions,
} from "./queue";
import { resolveGatewayTarget } from "./resolve-gateway-target";
import {
type WebSocketMessage as ConnMessage,
messageLength,
Expand Down Expand Up @@ -186,6 +190,7 @@ export class ActorConnRaw {
#getParams?: () => Promise<unknown>;
#encoding: Encoding;
#actorResolutionState: ActorResolutionState;
#gatewayOptions: ResolvedActorGatewayOptions;

// TODO: ws message queue

Expand All @@ -203,13 +208,15 @@ export class ActorConnRaw {
getParams: (() => Promise<unknown>) | undefined,
encoding: Encoding,
actorResolutionState: ActorResolutionState,
gatewayOptions: ActorGatewayOptions = {},
) {
this.#client = client;
this.#driver = driver;
this.#params = params;
this.#getParams = getParams;
this.#encoding = encoding;
this.#actorResolutionState = actorResolutionState;
this.#gatewayOptions = resolveActorGatewayOptions(gatewayOptions);
this.#readyPromise = promiseWithResolvers((reason) =>
logger().warn({
msg: "unhandled ready promise rejection",
Expand All @@ -225,6 +232,7 @@ export class ActorConnRaw {
return await this.#driver.sendRequest(
getGatewayTarget(this.#actorResolutionState),
request,
this.#gatewayOptions,
);
},
});
Expand Down Expand Up @@ -570,12 +578,15 @@ export class ActorConnRaw {

async #connectWebSocket() {
const params = await this.#resolveConnectionParams();
const target = getGatewayTarget(this.#actorResolutionState);
const target = this.#gatewayOptions.bypassConnectable
? await this.#resolveGatewayTargetForBypass()
: getGatewayTarget(this.#actorResolutionState);
const ws = await this.#driver.openWebSocket(
PATH_CONNECT,
target,
this.#encoding,
params,
this.#gatewayOptions,
);
invariant(ws, "websocket should have been created");
logger().debug({
Expand Down Expand Up @@ -623,6 +634,25 @@ export class ActorConnRaw {
});
}

async #resolveGatewayTargetForBypass() {
if ("getForId" in this.#actorResolutionState) {
return {
directId: this.#actorResolutionState.getForId.actorId,
} as const;
}

if (this.#actorId) {
return { directId: this.#actorId } as const;
}

return {
directId: await resolveGatewayTarget(
this.#driver,
this.#actorResolutionState,
),
} as const;
}

/** Called by the onopen event from drivers. */
#handleOnOpen() {
// Connection was disposed before Init message arrived - close the websocket to avoid leak
Expand Down
58 changes: 44 additions & 14 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ import type { EngineControlClient } from "@/engine-client/driver";
import { decodeCborCompat, deserializeWithEncoding, encodeCborCompat } from "@/serde";
import { bufferToArrayBuffer } from "@/utils";
import type {
ActorActionOptions,
ActorConnectOptions,
ActorDefinitionActions,
ActorFetchInit,
ActorDefinitionQueueSend,
ActorGatewayOptions,
ActorWebSocketOptions,
} from "./actor-common";
import { resolveActorGatewayOptions } from "./actor-common";
import { type ActorConn, ActorConnRaw } from "./actor-conn";
import {
type ActorResolutionState,
Expand Down Expand Up @@ -65,6 +69,7 @@ export class ActorHandleRaw {
#driver: EngineControlClient;
#encoding: Encoding;
#actorResolutionState: ActorResolutionState;
#gatewayOptions: ActorGatewayOptions;
#params: unknown;
#getParams?: () => Promise<unknown>;
#resolvedActorId?: string;
Expand All @@ -85,11 +90,13 @@ export class ActorHandleRaw {
getParams: (() => Promise<unknown>) | undefined,
encoding: Encoding,
actorResolutionState: ActorResolutionState,
gatewayOptions: ActorGatewayOptions = {},
) {
this.#client = client;
this.#driver = driver;
this.#encoding = encoding;
this.#actorResolutionState = actorResolutionState;
this.#gatewayOptions = gatewayOptions;
this.#params = params;
this.#getParams = getParams;
}
Expand Down Expand Up @@ -139,7 +146,13 @@ export class ActorHandleRaw {
encoding: this.#encoding,
params: this.#params,
customFetch: async (request: Request) => {
return await this.#driver.sendRequest(target, request);
return await this.#driver.sendRequest(
target,
request,
resolveActorGatewayOptions(
this.#gatewayOptions,
),
);
},
}).send(name, body, options as any);
} catch (err) {
Expand Down Expand Up @@ -224,8 +237,7 @@ export class ActorHandleRaw {
>(opts: {
name: string;
args: Args;
signal?: AbortSignal;
}): Promise<Response> {
} & ActorActionOptions): Promise<Response> {
if (
typeof opts === "string" ||
typeof opts !== "object" ||
Expand All @@ -247,10 +259,13 @@ export class ActorHandleRaw {
async #sendActionNow(opts: {
name: string;
args: unknown[];
signal?: AbortSignal;
}): Promise<unknown> {
} & ActorActionOptions): Promise<unknown> {
const maxAttempts = this.#getDynamicQueryMaxAttempts();
let useQueryTarget = false;
const gatewayOptions = resolveActorGatewayOptions(
this.#gatewayOptions,
opts.gateway,
);

for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
Expand Down Expand Up @@ -294,10 +309,12 @@ export class ActorHandleRaw {
},
body: opts.args,
encoding: this.#encoding,
customFetch: this.#driver.sendRequest.bind(
this.#driver,
target,
),
customFetch: async (request) =>
await this.#driver.sendRequest(
target,
request,
gatewayOptions,
),
signal: opts?.signal,
requestVersion: CLIENT_PROTOCOL_CURRENT_VERSION,
requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED,
Expand Down Expand Up @@ -550,7 +567,10 @@ export class ActorHandleRaw {
* @template AD The actor class that this connection is for.
* @returns {ActorConn<AD>} A connection to the actor.
*/
connect(params?: unknown): ActorConn<AnyActorDefinition> {
connect(
params?: unknown,
options: ActorConnectOptions = {},
): ActorConn<AnyActorDefinition> {
logger().debug({
msg: "establishing connection from handle",
query: this.#actorResolutionState,
Expand All @@ -566,6 +586,7 @@ export class ActorHandleRaw {
getParams,
this.#encoding,
this.#actorResolutionState,
resolveActorGatewayOptions(this.#gatewayOptions, options.gateway),
);

return this.#client[CREATE_ACTOR_CONN_PROXY](
Expand All @@ -588,6 +609,10 @@ export class ActorHandleRaw {
const maxAttempts = this.#getDynamicQueryMaxAttempts();
let useQueryTarget = false;
const { gateway, ...requestInit } = init ?? {};
const gatewayOptions = resolveActorGatewayOptions(
this.#gatewayOptions,
gateway,
);

for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
Expand All @@ -600,7 +625,7 @@ export class ActorHandleRaw {
this.#params,
input,
requestInit,
gateway,
gatewayOptions,
);
const retry = await this.#shouldRetryRawFetchResponse(
response,
Expand Down Expand Up @@ -793,7 +818,11 @@ export class ActorHandleRaw {
options: ActorWebSocketOptions = {},
) {
const params = await this.#resolveConnectionParams();
const target = options.gateway?.bypassConnectable
const gatewayOptions = resolveActorGatewayOptions(
this.#gatewayOptions,
options.gateway,
);
const target = gatewayOptions.bypassConnectable
? await this.#resolveActionTarget(false)
: getGatewayTarget(this.#actorResolutionState);
return await rawWebSocket(
Expand All @@ -802,7 +831,7 @@ export class ActorHandleRaw {
params,
path,
protocols,
options.gateway,
gatewayOptions,
);
}

Expand All @@ -828,6 +857,7 @@ export class ActorHandleRaw {
async getGatewayUrl(): Promise<string> {
return await this.#driver.buildGatewayUrl(
getGatewayTarget(this.#actorResolutionState),
this.#gatewayOptions,
);
}

Expand Down Expand Up @@ -870,7 +900,7 @@ export type ActorHandle<AD extends AnyActorDefinition> = Omit<
"connect" | "send"
> & {
// Add typed version of ActorConn (instead of using AnyActorDefinition)
connect(params?: unknown): ActorConn<AD>;
connect(params?: unknown, options?: ActorConnectOptions): ActorConn<AD>;
// Resolve method returns the actor ID
resolve(): Promise<string>;
} & ActorDefinitionQueueSend<AD> &
Expand Down
12 changes: 8 additions & 4 deletions rivetkit-typescript/packages/rivetkit/src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import type { AnyActorDefinition } from "@/actor/definition";
import type { ActorQuery } from "@/client/query";
import type { Encoding } from "@/common/encoding";
import type { EngineControlClient } from "@/engine-client/driver";
import type { ActorQuery } from "@/client/query";
import type { Registry } from "@/registry";
import type { ActorActionFunction } from "./actor-common";
import type { ActorActionFunction, ActorGatewayOptions } from "./actor-common";
import {
type ActorConn,
type ActorConnRaw,
Expand Down Expand Up @@ -178,17 +178,20 @@ export class ClientRaw {

#driver: EngineControlClient;
#encodingKind: Encoding;
#gatewayOptions: ActorGatewayOptions;

/**
* Creates an instance of Client.
*/
public constructor(
driver: EngineControlClient,
encoding: Encoding | undefined,
gatewayOptions: ActorGatewayOptions = {},
) {
this.#driver = driver;

this.#encodingKind = encoding ?? "bare";
this.#gatewayOptions = gatewayOptions;
}

/**
Expand Down Expand Up @@ -382,6 +385,7 @@ export class ClientRaw {
getParams,
this.#encodingKind,
actorQuery,
this.#gatewayOptions,
);
}

Expand Down Expand Up @@ -438,9 +442,9 @@ export type AnyClient = Client<Registry<any>>;

export function createClientWithDriver<A extends Registry<any>>(
driver: EngineControlClient,
config: { encoding?: Encoding } = {},
config: { encoding?: Encoding; gateway?: ActorGatewayOptions } = {},
): Client<A> {
const client = new ClientRaw(driver, config.encoding);
const client = new ClientRaw(driver, config.encoding, config.gateway);

// Create proxy for accessing actors by name
return new Proxy(client, {
Expand Down
8 changes: 8 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/client/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ export const ClientConfigSchemaBase = z.object({
.optional()
.default(() => ({})),

gateway: z
.object({
bypassConnectable: z.boolean().optional().default(false),
})
.optional()
.default(() => ({ bypassConnectable: false })),

// See RunConfig.getUpgradeWebSocket
//
// This is required in the client config in order to support
Expand Down Expand Up @@ -147,6 +154,7 @@ export function convertRegistryConfigToClientConfig(
namespace: config.namespace,
poolName: config.envoy.poolName,
headers: config.headers,
gateway: { bypassConnectable: false },
encoding: "bare",
getUpgradeWebSocket: undefined,
// We don't need health checks for internal clients
Expand Down
Loading
Loading