Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions __tests__/unserializable.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import { beforeEach, describe, expect, test } from 'vitest';
import { Type } from '@sinclair/typebox';
import {
Procedure,
createServiceSchema,
Ok,
createClient,
createServer,
UNEXPECTED_DISCONNECT_CODE,
} from '../router';
import { testMatrix } from '../testUtil/fixtures/matrix';
import {
advanceFakeTimersBySessionGrace,
cleanupTransports,
createPostTestCleanups,
} from '../testUtil/fixtures/cleanup';
import { TestSetupHelpers } from '../testUtil/fixtures/transports';
import { readNextResult } from '../testUtil';

const ServiceSchema = createServiceSchema();

const UnserializableServiceSchema = ServiceSchema.define({
returnSymbol: Procedure.rpc({
requestInit: Type.Object({}),
responseData: Type.Object({ id: Type.String() }),
async handler() {
return Ok({ id: 'test', extra: Symbol('unserializable') });
},
}),
streamSymbol: Procedure.subscription({
requestInit: Type.Object({}),
responseData: Type.Object({ id: Type.String() }),
async handler({ resWritable }) {
resWritable.write(Ok({ id: 'test', extra: Symbol('unserializable') }));
resWritable.close();
},
}),
});

describe('unserializable values in procedure handlers', () => {
// binary codec (msgpack) throws on Symbol, causing encode failure
// which kills the session -- only test with ws transport since mock
// transport's setImmediate chains conflict with fake timer flushing
describe.each(testMatrix(['ws', 'binary']))(
'binary codec ($transport.name transport)',
({ transport, codec }) => {
const opts = { codec: codec.codec };
const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];

beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

test('rpc handler returning symbol causes client disconnect', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const resultPromise = client.svc.returnSymbol.rpc({});
await advanceFakeTimersBySessionGrace();

const result = await resultPromise;
expect(result).toMatchObject({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
},
});
});

test('client-side encode failure cleans up listeners', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const messageListenersBefore =
clientTransport.eventDispatcher.numberOfListeners('message');
const sessionStatusListenersBefore =
clientTransport.eventDispatcher.numberOfListeners('sessionStatus');

// sending a Symbol as init payload will fail encoding on the client side
expect(() =>
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any
client.svc.returnSymbol.rpc({ extra: Symbol('x') } as any),
).toThrow();

// listeners should not leak after the failed send
expect(
clientTransport.eventDispatcher.numberOfListeners('message'),
).toEqual(messageListenersBefore);
expect(
clientTransport.eventDispatcher.numberOfListeners('sessionStatus'),
).toEqual(sessionStatusListenersBefore);
});

test('subscription handler writing symbol causes client disconnect', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const { resReadable } = client.svc.streamSymbol.subscribe({});
await advanceFakeTimersBySessionGrace();

const result = await readNextResult(resReadable);
expect(result).toMatchObject({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
},
});
});
},
);

// json codec silently drops Symbol values via JSON.stringify
describe.each(testMatrix(['all', 'naive']))(
'json codec ($transport.name transport)',
({ transport, codec }) => {
const opts = { codec: codec.codec };
const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];

beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

test('rpc handler returning symbol silently drops the value', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const result = await client.svc.returnSymbol.rpc({});
// JSON.stringify silently drops Symbol values, so the
// response arrives with the extra symbol field missing
expect(result).toStrictEqual({
ok: true,
payload: { id: 'test' },
});

await server.close();
});
},
);
});
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@replit/river",
"description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!",
"version": "0.214.0",
"version": "0.215.0",
"type": "module",
"exports": {
".": {
Expand Down
25 changes: 15 additions & 10 deletions router/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,21 @@ function handleProc(
transport.addEventListener('message', onMessage);
transport.addEventListener('sessionStatus', onSessionStatus);

sessionScopedSend({
streamId,
serviceName,
procedureName,
tracing: getPropagationContext(ctx),
payload: init,
controlFlags: procClosesWithInit
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
: ControlFlags.StreamOpenBit,
});
try {
sessionScopedSend({
streamId,
serviceName,
procedureName,
tracing: getPropagationContext(ctx),
payload: init,
controlFlags: procClosesWithInit
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
: ControlFlags.StreamOpenBit,
});
} catch (e) {
cleanup();
throw e;
}

if (procClosesWithInit) {
reqWritable.close();
Expand Down
18 changes: 10 additions & 8 deletions testUtil/fixtures/cleanup.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { expect, vi } from 'vitest';
import { assert, expect, vi } from 'vitest';
import {
ClientTransport,
Connection,
OpaqueTransportMessage,
ServerTransport,
Transport,
} from '../../transport';
Expand Down Expand Up @@ -68,14 +67,17 @@ export async function ensureTransportBuffersAreEventuallyEmpty(
[...t.sessions]
.map(([client, sess]) => {
// get all messages that are not heartbeats
const buff = sess.sendBuffer.filter((msg) => {
return !Value.Check(ControlMessageAckSchema, msg.payload);
const buff = sess.sendBuffer.filter((encodedMsg) => {
const decoded = sess.codec.fromBuffer(encodedMsg.data);
assert(decoded.ok);

return !Value.Check(
ControlMessageAckSchema,
decoded.value.payload,
);
});

return [client, buff] as [
string,
ReadonlyArray<OpaqueTransportMessage>,
];
return [client, buff] as const;
})
.filter((entry) => entry[1].length > 0),
),
Expand Down
3 changes: 3 additions & 0 deletions testUtil/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ export function dummySession() {
onSessionGracePeriodElapsed: () => {
/* noop */
},
onMessageSendFailure: () => {
/* noop */
},
},
testingSessionOptions,
currentProtocolVersion,
Expand Down
48 changes: 48 additions & 0 deletions transport/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ export abstract class ClientTransport<
onSessionGracePeriodElapsed: () => {
this.onSessionGracePeriodElapsed(session);
},
onMessageSendFailure: (msg, reason) => {
this.log?.error(`failed to send message: ${reason}`, {
...session.loggingMetadata,
transportMessage: msg,
});

this.protocolError({
type: ProtocolError.MessageSendFailure,
message: reason,
});
this.deleteSession(session, { unhealthy: true });
},
},
this.options,
currentProtocolVersion,
Expand Down Expand Up @@ -186,6 +198,18 @@ export abstract class ClientTransport<
onSessionGracePeriodElapsed: () => {
this.onSessionGracePeriodElapsed(handshakingSession);
},
onMessageSendFailure: (msg, reason) => {
this.log?.error(`failed to send message: ${reason}`, {
...handshakingSession.loggingMetadata,
transportMessage: msg,
});

this.protocolError({
type: ProtocolError.MessageSendFailure,
message: reason,
});
this.deleteSession(handshakingSession, { unhealthy: true });
},
},
);

Expand Down Expand Up @@ -395,6 +419,18 @@ export abstract class ClientTransport<
onSessionGracePeriodElapsed: () => {
this.onSessionGracePeriodElapsed(backingOffSession);
},
onMessageSendFailure: (msg, reason) => {
this.log?.error(`failed to send message: ${reason}`, {
...backingOffSession.loggingMetadata,
transportMessage: msg,
});

this.protocolError({
type: ProtocolError.MessageSendFailure,
message: reason,
});
this.deleteSession(backingOffSession, { unhealthy: true });
},
},
);

Expand Down Expand Up @@ -470,6 +506,18 @@ export abstract class ClientTransport<
onSessionGracePeriodElapsed: () => {
this.onSessionGracePeriodElapsed(connectingSession);
},
onMessageSendFailure: (msg, reason) => {
this.log?.error(`failed to send message: ${reason}`, {
...connectingSession.loggingMetadata,
transportMessage: msg,
});

this.protocolError({
type: ProtocolError.MessageSendFailure,
message: reason,
});
this.deleteSession(connectingSession, { unhealthy: true });
},
},
);

Expand Down
12 changes: 12 additions & 0 deletions transport/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,18 @@ export function cancelMessage(
export type OpaqueTransportMessage = TransportMessage;
export type TransportClientId = string;

/**
* An encoded message that is ready to be sent over the transport.
* The seq number is kept to track which messages have been
* acked by the peer and can be dropped from the send buffer.
*/
export interface EncodedTransportMessage {
id: string;
seq: number;
msg: PartialTransportMessage;
data: Uint8Array;
}

/**
* Checks if the given control flag (usually found in msg.controlFlag) is an ack message.
* @param controlFlag - The control flag to check.
Expand Down
Loading
Loading