Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/selfish-bananas-clean.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"partyserver": patch
---

Check for hibernated websocket connections
57 changes: 53 additions & 4 deletions packages/partyserver/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ type ConnectionAttachments = {
__user?: unknown;
};

function tryGetPartyServerMeta(
ws: WebSocket
): ConnectionAttachments["__pk"] | null {
try {
// Avoid AttachmentCache.get() here: hibernated sockets accepted outside
// PartyServer can have an attachment without a __pk namespace.
const attachment = WebSocket.prototype.deserializeAttachment.call(
ws
) as unknown;
if (!attachment || typeof attachment !== "object") {
return null;
}
if (!("__pk" in attachment)) {
return null;
}
const pk = (attachment as ConnectionAttachments).__pk as unknown;
if (!pk || typeof pk !== "object") {
return null;
}
const { id, server } = pk as { id?: unknown; server?: unknown };
if (typeof id !== "string" || typeof server !== "string") {
return null;
}
return pk as ConnectionAttachments["__pk"];
} catch {
return null;
}
}

export function isPartyServerWebSocket(ws: WebSocket): boolean {
return tryGetPartyServerMeta(ws) !== null;
}

/**
* Cache websocket attachments to avoid having to rehydrate them on every property access.
*/
Expand Down Expand Up @@ -180,6 +213,12 @@ class HibernatingConnectionIterator<T> implements IterableIterator<
while ((socket = sockets[this.index++])) {
// only yield open sockets to match non-hibernating behaviour
if (socket.readyState === WebSocket.READY_STATE_OPEN) {
// Durable Objects hibernation APIs allow storing arbitrary sockets via
// `state.acceptWebSocket()`. Those sockets won't have PartyServer's
// `__pk` attachment namespace and must be ignored.
if (!isPartyServerWebSocket(socket)) {
continue;
}
const value = createLazyConnection(socket) as Connection<T>;
return { done: false, value };
}
Expand Down Expand Up @@ -263,15 +302,25 @@ export class HibernatingConnectionManager<TState> implements ConnectionManager {
constructor(private controller: DurableObjectState) {}

getCount() {
return Number(this.controller.getWebSockets().length);
// Only count sockets managed by PartyServer. Other hibernated sockets may
// exist on the same Durable Object via `state.acceptWebSocket()`.
let count = 0;
for (const ws of this.controller.getWebSockets()) {
if (isPartyServerWebSocket(ws)) count++;
}
return count;
}

getConnection<T = TState>(id: string) {
// TODO: Should we cache the connections?
const sockets = this.controller.getWebSockets(id);
if (sockets.length === 0) return undefined;
if (sockets.length === 1)
return createLazyConnection(sockets[0]) as Connection<T>;
const matching = sockets.filter((ws) => {
return tryGetPartyServerMeta(ws)?.id === id;
});

if (matching.length === 0) return undefined;
if (matching.length === 1)
return createLazyConnection(matching[0]) as Connection<T>;

throw new Error(
`More than one connection found for id ${id}. Did you mean to use getConnections(tag) instead?`
Expand Down
18 changes: 17 additions & 1 deletion packages/partyserver/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import { nanoid } from "nanoid";
import {
createLazyConnection,
HibernatingConnectionManager,
InMemoryConnectionManager
InMemoryConnectionManager,
isPartyServerWebSocket
} from "./connection";

import type { ConnectionManager } from "./connection";
Expand Down Expand Up @@ -422,6 +423,13 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam
return;
}

// Ignore websockets accepted outside PartyServer (e.g. via
// `state.acceptWebSocket()` in user code). These sockets won't have the
// `__pk` attachment namespace required to rehydrate a Connection.
if (!isPartyServerWebSocket(ws)) {
return;
}

const connection = createLazyConnection(ws);

// rehydrate the server name if it's woken up
Expand Down Expand Up @@ -449,6 +457,10 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam
return;
}

if (!isPartyServerWebSocket(ws)) {
return;
}

const connection = createLazyConnection(ws);

// rehydrate the server name if it's woken up
Expand All @@ -470,6 +482,10 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam
return;
}

if (!isPartyServerWebSocket(ws)) {
return;
}

const connection = createLazyConnection(ws);

// rehydrate the server name if it's woken up
Expand Down
41 changes: 41 additions & 0 deletions packages/partyserver/src/tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,47 @@ describe("Server", () => {
expect(response.headers.get("Location")).toBe("https://example3.com");
});

it("ignores foreign hibernated websockets when broadcasting", async () => {
const ctx = createExecutionContext();

// Create a websocket that is accepted via the DO hibernation API directly
// (no PartyServer `__pk` attachment).
const foreignReq = new Request(
"http://example.com/parties/mixed/room/foreign",
{
headers: { Upgrade: "websocket" }
}
);
const foreignRes = await worker.fetch(foreignReq, env, ctx);
const foreignWs = foreignRes.webSocket!;
foreignWs.accept();

// Now connect via PartyServer. onConnect() will call broadcast(), which must
// not crash due to the foreign socket.
const req = new Request("http://example.com/parties/mixed/room", {
headers: { Upgrade: "websocket" }
});
const res = await worker.fetch(req, env, ctx);
const ws = res.webSocket!;
ws.accept();

const { promise, resolve, reject } = Promise.withResolvers<void>();
ws.addEventListener("message", (message) => {
try {
// We should receive at least one message from the server.
expect(["hello", "connected"]).toContain(message.data);
resolve();
} catch (e) {
reject(e);
} finally {
ws.close();
foreignWs.close();
}
});

return promise;
});

// it("can be connected with a query parameter");
// it("can be connected with a header");

Expand Down
32 changes: 32 additions & 0 deletions packages/partyserver/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function assert(condition: unknown, message: string): asserts condition {
export type Env = {
Stateful: DurableObjectNamespace<Stateful>;
OnStartServer: DurableObjectNamespace<OnStartServer>;
Mixed: DurableObjectNamespace<Mixed>;
};

export class Stateful extends Server {
Expand Down Expand Up @@ -61,6 +62,37 @@ export class OnStartServer extends Server {
}
}

export class Mixed extends Server {
static options = {
hibernate: true
};

async fetch(request: Request): Promise<Response> {
const url = new URL(request.url);
if (url.pathname.endsWith("/foreign")) {
const room = request.headers.get("x-partykit-room");
if (room) {
await this.setName(room);
}

const pair = new WebSocketPair();
const [client, server] = Object.values(pair);
// Accept a hibernated websocket that PartyServer does not manage. This is
// equivalent to user code calling `this.ctx.acceptWebSocket()` directly.
this.ctx.acceptWebSocket(server, ["foreign"]);
return new Response(null, { status: 101, webSocket: client });
}

return super.fetch(request);
}

onConnect(connection: Connection): void {
// Trigger a broadcast while a foreign hibernated socket exists.
this.broadcast("hello");
connection.send("connected");
}
}

export default {
async fetch(request: Request, env: Env, _ctx: ExecutionContext) {
return (
Expand Down
6 changes: 5 additions & 1 deletion packages/partyserver/src/tests/wrangler.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
{
"name": "OnStartServer",
"class_name": "OnStartServer"
},
{
"name": "Mixed",
"class_name": "Mixed"
}
]
},
"migrations": [
{
"tag": "v1", // Should be unique for each entry
"new_classes": ["Stateful", "OnStartServer"]
"new_classes": ["Stateful", "OnStartServer", "Mixed"]
}
]
}
2 changes: 1 addition & 1 deletion packages/partysocket/src/tests/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
for (let i = 0; i < messageCount; i++) {
expect(receivedMessages[i]).toBe(`message-${i}`);
}
} catch (e) {
} catch (_e) {
// If we still have Blobs, messages aren't fully processed yet
return;
}
Expand Down Expand Up @@ -582,7 +582,7 @@
let wss: WebSocketServer;

beforeAll(() => {
wss = new WebSocketServer({ port: PORT + 4 });

Check failure on line 585 in packages/partysocket/src/tests/integration.test.ts

View workflow job for this annotation

GitHub Actions / check (ubuntu-24.04)

Unhandled error

Error: listen EADDRINUSE: address already in use :::50136 ❯ Server.setupListenHandle [as _listen2] node:net:1908:16 ❯ listenInCluster node:net:1965:12 ❯ Server.listen node:net:2067:7 ❯ new WebSocketServer ../../node_modules/ws/lib/websocket-server.js:102:20 ❯ src/tests/integration.test.ts:585:11 ❯ ../../node_modules/@vitest/runner/dist/chunk-hooks.js:1897:20 ❯ runWithTimeout ../../node_modules/@vitest/runner/dist/chunk-hooks.js:1863:10 ❯ runHook ../../node_modules/@vitest/runner/dist/chunk-hooks.js:1436:51 ❯ callSuiteHook ../../node_modules/@vitest/runner/dist/chunk-hooks.js:1442:25 ⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯ Serialized Error: { code: 'EADDRINUSE', errno: -98, syscall: 'listen', address: '::', port: 50136 } This error originated in "src/tests/integration.test.ts" test file. It doesn't mean the error was thrown inside the file itself, but while it was running.
});

afterAll(() => {
Expand Down
Loading
Loading