Skip to content

Commit d98968e

Browse files
committed
feat: subscribe to events before connect
1 parent 18ea080 commit d98968e

File tree

13 files changed

+188
-11
lines changed

13 files changed

+188
-11
lines changed

packages/rivetkit/fixtures/driver-test-suite/counter.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ import { actor } from "rivetkit";
22

33
export const counter = actor({
44
state: { count: 0 },
5+
onConnect: (c, conn) => {
6+
c.broadcast("onconnect:broadcast", "Hello!");
7+
conn.send("onconnect:msg", "Welcome to the counter actor!");
8+
},
59
actions: {
610
increment: (c, x: number) => {
711
c.state.count += x;

packages/rivetkit/src/actor/instance.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
917917
state: CS,
918918
driverId: ConnectionDriver,
919919
driverState: unknown,
920+
subscriptions: string[],
920921
authData: unknown,
921922
): Promise<Conn<S, CP, CS, V, I, DB>> {
922923
this.#assertReady();
@@ -950,6 +951,11 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
950951
//
951952
// Do this immediately after adding connection & before any async logic in order to avoid race conditions with sleep timeouts
952953
this.#resetSleepTimer();
954+
if (subscriptions) {
955+
for (const sub of subscriptions) {
956+
this.#addSubscription(sub, conn, true);
957+
}
958+
}
953959

954960
// Add to persistence & save immediately
955961
this.#persist.connections.push(persist);
@@ -1017,6 +1023,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
10171023
return await this.executeAction(ctx, name, args);
10181024
},
10191025
onSubscribe: async (eventName, conn) => {
1026+
console.log("subscribing to event", { eventName, connId: conn.id });
10201027
this.inspector.emitter.emit("eventFired", {
10211028
type: "subscribe",
10221029
eventName,
@@ -1489,6 +1496,13 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
14891496
_broadcast<Args extends Array<unknown>>(name: string, ...args: Args) {
14901497
this.#assertReady();
14911498

1499+
console.log("broadcasting event", {
1500+
name,
1501+
args,
1502+
actorId: this.id,
1503+
subscriptions: this.#subscriptionIndex.size,
1504+
connections: this.conns.size,
1505+
});
14921506
this.inspector.emitter.emit("eventFired", {
14931507
type: "broadcast",
14941508
eventName: name,

packages/rivetkit/src/actor/protocol/serde.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export type InputData = string | Buffer | Blob | ArrayBufferLike | Uint8Array;
1313
export type OutputData = string | Uint8Array;
1414

1515
export const EncodingSchema = z.enum(["json", "cbor", "bare"]);
16+
export const SubscriptionsListSchema = z.array(z.string());
1617

1718
/**
1819
* Encoding used to communicate between the client & actor.

packages/rivetkit/src/actor/router-endpoints.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ export async function handleWebSocketConnect(
113113
encoding: Encoding,
114114
parameters: unknown,
115115
authData: unknown,
116+
subscriptions: string[],
116117
): Promise<UpgradeWebSocketArgs> {
117118
const exposeInternalError = req ? getRequestExposeInternalError(req) : false;
118119

@@ -182,6 +183,7 @@ export async function handleWebSocketConnect(
182183
connState,
183184
CONNECTION_DRIVER_WEBSOCKET,
184185
{ encoding } satisfies GenericWebSocketDriverState,
186+
subscriptions,
185187
authData,
186188
);
187189

@@ -332,6 +334,7 @@ export async function handleSseConnect(
332334
_runConfig: RunConfig,
333335
actorDriver: ActorDriver,
334336
actorId: string,
337+
subscriptions: string[],
335338
authData: unknown,
336339
) {
337340
const encoding = getRequestEncoding(c.req);
@@ -367,6 +370,7 @@ export async function handleSseConnect(
367370
connState,
368371
CONNECTION_DRIVER_SSE,
369372
{ encoding } satisfies GenericSseDriverState,
373+
subscriptions,
370374
authData,
371375
);
372376

@@ -463,6 +467,7 @@ export async function handleAction(
463467
connState,
464468
CONNECTION_DRIVER_HTTP,
465469
{} satisfies GenericHttpDriverState,
470+
[],
466471
authData,
467472
);
468473

@@ -655,6 +660,8 @@ export const HEADER_CONN_ID = "X-RivetKit-Conn";
655660

656661
export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token";
657662

663+
export const HEADER_CONN_SUBS = "X-RivetKit-Conn-Subs";
664+
658665
/**
659666
* Headers that publics can send from public clients.
660667
*
@@ -669,6 +676,7 @@ export const ALLOWED_PUBLIC_HEADERS = [
669676
HEADER_ACTOR_ID,
670677
HEADER_CONN_ID,
671678
HEADER_CONN_TOKEN,
679+
HEADER_CONN_SUBS,
672680
];
673681

674682
// Helper to get connection parameters for the request

packages/rivetkit/src/actor/router.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import { Hono, type Context as HonoContext } from "hono";
22
import { cors } from "hono/cors";
33
import invariant from "invariant";
4-
import { EncodingSchema } from "@/actor/protocol/serde";
4+
import {
5+
EncodingSchema,
6+
SubscriptionsListSchema,
7+
} from "@/actor/protocol/serde";
58
import {
69
type ActionOpts,
710
type ActionOutput,
@@ -13,6 +16,7 @@ import {
1316
HEADER_AUTH_DATA,
1417
HEADER_CONN_ID,
1518
HEADER_CONN_PARAMS,
19+
HEADER_CONN_SUBS,
1620
HEADER_CONN_TOKEN,
1721
HEADER_ENCODING,
1822
handleAction,
@@ -84,12 +88,16 @@ export function createActorRouter(
8488
const encodingRaw = c.req.header(HEADER_ENCODING);
8589
const connParamsRaw = c.req.header(HEADER_CONN_PARAMS);
8690
const authDataRaw = c.req.header(HEADER_AUTH_DATA);
91+
const subsRaw = c.req.header(HEADER_CONN_SUBS);
8792

8893
const encoding = EncodingSchema.parse(encodingRaw);
8994
const connParams = connParamsRaw
9095
? JSON.parse(connParamsRaw)
9196
: undefined;
9297
const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined;
98+
const subs = subsRaw
99+
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
100+
: [];
93101

94102
return await handleWebSocketConnect(
95103
c.req.raw,
@@ -98,6 +106,7 @@ export function createActorRouter(
98106
c.env.actorId,
99107
encoding,
100108
connParams,
109+
subs,
101110
authData,
102111
);
103112
})(c, noopNext());
@@ -115,8 +124,20 @@ export function createActorRouter(
115124
if (authDataRaw) {
116125
authData = JSON.parse(authDataRaw);
117126
}
127+
const subsRaw = c.req.header(HEADER_CONN_SUBS);
128+
129+
const subscriptions = subsRaw
130+
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
131+
: [];
118132

119-
return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData);
133+
return handleSseConnect(
134+
c,
135+
runConfig,
136+
actorDriver,
137+
c.env.actorId,
138+
subscriptions,
139+
authData,
140+
);
120141
});
121142

122143
router.post("/action/:action", async (c) => {

packages/rivetkit/src/client/actor-conn.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ export class ActorConnRaw {
113113
/**
114114
* Interval that keeps the NodeJS process alive if this is the only thing running.
115115
*
116-
* See ttps://github.com/nodejs/node/issues/22088
116+
* @see https://github.com/nodejs/node/issues/22088
117117
*/
118118
#keepNodeAliveInterval: NodeJS.Timeout;
119119

@@ -126,8 +126,6 @@ export class ActorConnRaw {
126126
#encoding: Encoding;
127127
#actorQuery: ActorQuery;
128128

129-
// TODO: ws message queue
130-
131129
/**
132130
* Do not call this directly.
133131
*
@@ -203,7 +201,6 @@ export class ActorConnRaw {
203201

204202
/**
205203
* Do not call this directly.
206-
enc
207204
* Establishes a connection to the server using the specified endpoint & encoding & driver.
208205
*
209206
* @protected
@@ -281,6 +278,7 @@ enc
281278
actorId,
282279
this.#encoding,
283280
this.#params,
281+
Array.from(this.#eventSubscriptions.keys()),
284282
);
285283
this.#transport = { websocket: ws };
286284
ws.addEventListener("open", () => {
@@ -863,3 +861,15 @@ enc
863861
*/
864862
export type ActorConn<AD extends AnyActorDefinition> = ActorConnRaw &
865863
ActorDefinitionActions<AD>;
864+
865+
/**
866+
* Connection to a actor. Allows calling actor's remote procedure calls with inferred types. See {@link ActorConnRaw} for underlying methods.
867+
* Needs to be established manually using #connect.
868+
*
869+
* @template AD The actor class that this connection is for.
870+
* @see {@link ActorConnRaw}
871+
* @see {@link ActorConn}
872+
*/
873+
export type ActorManualConn<AD extends AnyActorDefinition> = ActorConnRaw & {
874+
connect: () => void;
875+
} & ActorDefinitionActions<AD>;

packages/rivetkit/src/client/actor-handle.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ import {
1818
} from "@/schemas/client-protocol/versioned";
1919
import { bufferToArrayBuffer } from "@/utils";
2020
import type { ActorDefinitionActions } from "./actor-common";
21-
import { type ActorConn, ActorConnRaw } from "./actor-conn";
21+
import {
22+
type ActorConn,
23+
ActorConnRaw,
24+
type ActorManualConn,
25+
} from "./actor-conn";
2226
import { queryActor } from "./actor-query";
2327
import { type ClientRaw, CREATE_ACTOR_CONN_PROXY } from "./client";
2428
import { ActorError } from "./errors";
@@ -160,6 +164,33 @@ export class ActorHandleRaw {
160164
) as ActorConn<AnyActorDefinition>;
161165
}
162166

167+
/**
168+
* Creates a new connection to the actor, that should be manually connected.
169+
* This is useful for creating connections that are not immediately connected,
170+
* such as when you want to set up event listeners before connecting.
171+
*
172+
* @param AD - The actor definition for the connection.
173+
* @returns {ActorConn<AD>} A connection to the actor.
174+
*/
175+
create(): ActorManualConn<AnyActorDefinition> {
176+
logger().debug({
177+
msg: "creating a connection from handle",
178+
query: this.#actorQuery,
179+
});
180+
181+
const conn = new ActorConnRaw(
182+
this.#client,
183+
this.#driver,
184+
this.#params,
185+
this.#encoding,
186+
this.#actorQuery,
187+
);
188+
189+
return this.#client[CREATE_ACTOR_CONN_PROXY](
190+
conn,
191+
) as ActorManualConn<AnyActorDefinition>;
192+
}
193+
163194
/**
164195
* Makes a raw HTTP request to the actor.
165196
*
@@ -259,10 +290,12 @@ export class ActorHandleRaw {
259290
*/
260291
export type ActorHandle<AD extends AnyActorDefinition> = Omit<
261292
ActorHandleRaw,
262-
"connect"
293+
"connect" | "create"
263294
> & {
264295
// Add typed version of ActorConn (instead of using AnyActorDefinition)
265296
connect(): ActorConn<AD>;
266297
// Resolve method returns the actor ID
267298
resolve(): Promise<string>;
299+
// Add typed version of create
300+
create(): ActorManualConn<AD>;
268301
} & ActorDefinitionActions<AD>;

packages/rivetkit/src/client/client.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import type { ActorActionFunction } from "./actor-common";
88
import {
99
type ActorConn,
1010
type ActorConnRaw,
11+
type ActorManualConn,
1112
CONNECT_SYMBOL,
1213
} from "./actor-conn";
1314
import { type ActorHandle, ActorHandleRaw } from "./actor-handle";
@@ -149,6 +150,7 @@ export interface Region {
149150

150151
export const ACTOR_CONNS_SYMBOL = Symbol("actorConns");
151152
export const CREATE_ACTOR_CONN_PROXY = Symbol("createActorConnProxy");
153+
export const CREATE_ACTOR_PROXY = Symbol("createActorProxy");
152154
export const TRANSPORT_SYMBOL = Symbol("transport");
153155

154156
/**
@@ -359,12 +361,34 @@ export class ClientRaw {
359361
// Save to connection list
360362
this[ACTOR_CONNS_SYMBOL].add(conn);
361363

364+
logger().debug({
365+
msg: "creating actor proxy for connection and connecting",
366+
conn,
367+
});
368+
362369
// Start connection
363370
conn[CONNECT_SYMBOL]();
364371

365372
return createActorProxy(conn) as ActorConn<AD>;
366373
}
367374

375+
[CREATE_ACTOR_PROXY]<AD extends AnyActorDefinition>(
376+
conn: ActorConnRaw,
377+
): ActorConn<AD> {
378+
// Save to connection list
379+
this[ACTOR_CONNS_SYMBOL].add(conn);
380+
381+
logger().debug({ msg: "creating actor proxy for connection", conn });
382+
383+
Object.assign(conn, {
384+
connect: () => {
385+
conn[CONNECT_SYMBOL]();
386+
},
387+
});
388+
389+
return createActorProxy(conn) as ActorManualConn<AD>;
390+
}
391+
368392
/**
369393
* Disconnects from all actors.
370394
*

packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { describe, expect, test } from "vitest";
1+
import { describe, expect, test, vi } from "vitest";
22
import type { DriverTestConfig } from "../mod";
33
import { FAKE_TIME, setupDriverTest, waitFor } from "../utils";
44

@@ -190,6 +190,33 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) {
190190
// Clean up
191191
await connection.dispose();
192192
});
193+
194+
test("should handle events sent during onConnect", async (c) => {
195+
const { client } = await setupDriverTest(c, driverTestConfig);
196+
197+
// Create actor with onConnect event
198+
const connection = client.counter
199+
.getOrCreate(["test-onconnect"])
200+
.create();
201+
202+
// Set up event listener for onConnect
203+
const onBroadcastFn = vi.fn();
204+
connection.on("onconnect:broadcast", onBroadcastFn);
205+
206+
// Set up event listener for onConnect message
207+
const onMsgFn = vi.fn();
208+
connection.on("onconnect:msg", onMsgFn);
209+
210+
connection.connect();
211+
212+
// Verify the onConnect event was received
213+
await vi.waitFor(() => {
214+
expect(onBroadcastFn).toHaveBeenCalled();
215+
expect(onMsgFn).toHaveBeenCalled();
216+
});
217+
// Clean up
218+
await connection.dispose();
219+
});
193220
});
194221

195222
describe("Connection Parameters", () => {

0 commit comments

Comments
 (0)