Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions __tests__/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1314,5 +1314,61 @@ describe.each(testMatrix())(
server,
});
});

test('validate receives the connecting client id', async () => {
const requestSchema = Type.Object({});

interface ParsedMetadata {
seenFrom: string;
}

const clientTransport = getClientTransport(
'client',
createClientHandshakeOptions(requestSchema, () => ({})),
);
const serverTransport = getServerTransport(
'SERVER',
createServerHandshakeOptions(
requestSchema,
(_metadata, _prev, from) => ({
seenFrom: from ?? '<none>',
}),
),
);
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const ServiceSchema = createServiceSchema<
MaybeDisposable,
ParsedMetadata
>();
const services = {
test: ServiceSchema.define({
whoami: Procedure.rpc({
requestInit: Type.Object({}),
responseData: Type.Object({ seenFrom: Type.String() }),
handler: async ({ ctx }) => Ok({ seenFrom: ctx.metadata.seenFrom }),
}),
}),
};
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const result = await client.test.whoami.rpc({});
expect(result).toStrictEqual({
ok: true,
payload: { seenFrom: 'client' },
});

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});
},
);
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@replit/river",
"description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!",
"version": "0.217.1",
"version": "0.217.2",
"type": "module",
"exports": {
".": "./dist/router/index.js",
Expand Down
10 changes: 7 additions & 3 deletions protobuf/handshake.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import {
type ClientHandshakeOptions,
type ServerHandshakeOptions,
} from '../router/handshake';
import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message';
import {
HandshakeErrorCustomHandlerFatalResponseCodes,
type TransportClientId,
} from '../transport/message';
import { decodeMessageBytes, encodeMessageBytes } from './shared';
import { Uint8ArrayType } from '../customSchemas';

Expand All @@ -27,6 +30,7 @@ type ConstructHandshake<Schema extends DescMessage> = () =>
type ValidateHandshake<Schema extends DescMessage, ParsedMetadata> = (
metadata: MessageShape<Schema>,
previousParsedMetadata?: ParsedMetadata,
from?: TransportClientId,
) =>
| ParsedMetadata
| ProtobufHandshakeFailureCode
Expand Down Expand Up @@ -61,15 +65,15 @@ export function createServerHandshakeOptions<
): ServerHandshakeOptions<typeof HandshakeBytesSchema, ParsedMetadata> {
return createTransportServerHandshakeOptions(
HandshakeBytesSchema,
async (metadata, previousParsedMetadata) => {
async (metadata, previousParsedMetadata, from) => {
let decoded;
try {
decoded = decodeMessageBytes(schema, metadata);
} catch {
return 'REJECTED_BY_CUSTOM_HANDLER' as ProtobufHandshakeFailureCode;
}

return await validate(decoded, previousParsedMetadata);
return await validate(decoded, previousParsedMetadata, from);
},
);
}
21 changes: 13 additions & 8 deletions router/handshake.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import type { Static, TSchema } from 'typebox';
import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message';
import {
HandshakeErrorCustomHandlerFatalResponseCodes,
type TransportClientId,
} from '../transport/message';

type ConstructHandshake<T extends TSchema> = () =>
| Static<T>
Expand All @@ -8,6 +11,7 @@ type ConstructHandshake<T extends TSchema> = () =>
type ValidateHandshake<T extends TSchema, ParsedMetadata> = (
metadata: Static<T>,
previousParsedMetadata?: ParsedMetadata,
from?: TransportClientId,
) =>
| Static<typeof HandshakeErrorCustomHandlerFatalResponseCodes>
| ParsedMetadata
Expand Down Expand Up @@ -42,15 +46,16 @@ export interface ServerHandshakeOptions<
schema: MetadataSchema;

/**
* Parses the {@link HandshakeRequestMetadata} sent by the client, transforming
* it into {@link ParsedHandshakeMetadata}.
*
* May return `false` if the client should be rejected.
* Parses the metadata sent by the client during the handshake into the
* server-side {@link ParsedMetadata}, or returns a handshake failure code to
* reject the connection.
*
* @param metadata - The metadata sent by the client.
* @param session - The session that the client would be associated with.
* @param isReconnect - Whether the client is reconnecting to the session,
* or if this is a new session.
* @param previousParsedMetadata - The parsed metadata from the previous
* connection on this session, if any (e.g. on reconnect).
* @param from - The client id the peer presented in its handshake. Use it to
* confirm the presented id is the one the metadata authorizes before
* returning parsed metadata.
*/
validate: ValidateHandshake<MetadataSchema, ParsedMetadata>;
}
Expand Down
1 change: 1 addition & 0 deletions transport/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ export abstract class ServerTransport<
parsedMetadataOrFailureCode = await this.handshakeExtensions.validate(
msg.payload.metadata,
previousParsedMetadata,
msg.from,
);
} catch (err) {
this.rejectHandshakeRequest(
Expand Down
9 changes: 9 additions & 0 deletions transport/sessionStateMachine/SessionConnected.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ export class SessionConnected<

const parsedMsg = parsedMsgRes.value;

// messages must originate from this session's peer
if (parsedMsg.from !== this.to) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we even allow connections to specify the to/from? i think this is a vestigial feature that we never ended up using.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you're correct

this.listeners.onInvalidMessage(
`received message with 'from' (${parsedMsg.from}) that does not match the session peer (${this.to})`,
);

return;
}

// check message ordering here
if (parsedMsg.seq !== this.ack) {
if (parsedMsg.seq < this.ack) {
Expand Down
45 changes: 37 additions & 8 deletions transport/sessionStateMachine/stateMachine.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1967,11 +1967,17 @@ describe('session state machine', () => {
expect(onConnectionClosed).not.toHaveBeenCalled();
expect(onConnectionErrored).not.toHaveBeenCalled();

const encodeResult = session.encodeMsg(
payloadToTransportMessage('hello'),
// an incoming frame carries the peer's id in `from`
session.conn.emitData(
session.options.codec.toBuffer({
id: 'msgid',
from: session.to,
to: session.from,
seq: 0,
ack: 0,
...payloadToTransportMessage('hello'),
}),
);
assert(encodeResult.ok);
session.conn.emitData(encodeResult.value.data);

await waitFor(async () => {
expect(onMessage).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -2021,8 +2027,8 @@ describe('session state machine', () => {
conn.onData(
session.options.codec.toBuffer({
id: 'msgid',
to: 'SERVER',
from: 'client',
to: session.from,
from: session.to,
seq: 0,
ack: 0,
streamId: 'heartbeat',
Expand All @@ -2048,8 +2054,8 @@ describe('session state machine', () => {
conn.onData(
session.options.codec.toBuffer({
id: 'msgid',
to: 'SERVER',
from: 'client',
to: session.from,
from: session.to,
seq: 0,
ack: 0,
streamId: 'heartbeat',
Expand All @@ -2062,5 +2068,28 @@ describe('session state machine', () => {

expect(sessionHandle.onMessage).not.toHaveBeenCalled();
});

test('rejects a message whose from does not match the session peer', async () => {
const sessionHandle = await createSessionConnected();
const session = sessionHandle.session;
const conn = session.conn;

// a frame whose `from` isn't this session's peer is rejected
conn.onData(
session.options.codec.toBuffer({
id: 'msgid',
to: session.from,
from: 'someone-else',
seq: 0,
ack: 0,
streamId: 'stream',
controlFlags: 0,
payload: { type: 'ACK' },
}),
);

expect(sessionHandle.onInvalidMessage).toHaveBeenCalledTimes(1);
expect(sessionHandle.onMessage).not.toHaveBeenCalled();
});
});
});
2 changes: 2 additions & 0 deletions transport/transport.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@ describe.each(testMatrix())(
discarded: 'discarded',
},
undefined,
clientTransport.clientId,
);

const session = serverTransport.sessions.get(clientTransport.clientId);
Expand Down Expand Up @@ -1791,6 +1792,7 @@ describe.each(testMatrix())(
{
kept: 'kept',
},
clientTransport.clientId,
);

await testFinishesCleanly({
Expand Down
Loading