From b65537f91aa805583ee00b8e910e083f5fd762c6 Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Tue, 28 Nov 2023 19:43:35 -0800 Subject: [PATCH 1/8] fix: connection state handling --- .../app/realtime/page.tsx | 17 +++++++++- libs/client/package.json | 2 +- libs/client/src/realtime.ts | 33 +++++++++++++------ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/apps/demo-nextjs-app-router/app/realtime/page.tsx b/apps/demo-nextjs-app-router/app/realtime/page.tsx index 152f094..5621401 100644 --- a/apps/demo-nextjs-app-router/app/realtime/page.tsx +++ b/apps/demo-nextjs-app-router/app/realtime/page.tsx @@ -12,17 +12,32 @@ fal.config({ const PROMPT = 'a moon in a starry night sky'; export default function RealtimePage() { + // const [prompt, setPrompt] = useState(PROMPT); + // const [rerender, setRerender] = useState(0); const [image, setImage] = useState(null); - const { send } = fal.realtime.connect('110602490-shared-lcm-test', { + const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', { connectionKey: 'realtime-demo', onResult(result) { + console.log('onResult!!!', result); if (result.images && result.images[0]) { setImage(result.images[0].url); } }, }); + // useEffect(() => { + // setTimeout(() => { + // setRerender((v) => v + 1); + // }, 10); + // }, []); + + // useEffect(() => { + // setTimeout(() => { + // setRerender((v) => v + 1); + // }, 50); + // }, []); + return (
diff --git a/libs/client/package.json b/libs/client/package.json index 4c50a30..6e5a851 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.6.0", + "version": "0.6.1-alpha.0", "license": "MIT", "repository": { "type": "git", diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index fd19d3d..f849d03 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -98,7 +98,7 @@ async function getToken(app: string): Promise { `https://${getRestApiUrl()}/tokens/`, { allowed_apps: [appAlias.join('-')], - token_expiration: 120, + token_expiration: 40, } ); // keep this in case the response was wrapped (old versions of the proxy do that) @@ -109,6 +109,11 @@ async function getToken(app: string): Promise { return token; } +function isUnauthorizedError(message: any): boolean { + // TODO we need better protocol definition with error codes + return message['status'] === 'error' && message['error'] === 'Unauthorized'; +} + /** * See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 */ @@ -219,7 +224,6 @@ export const realtimeImpl: RealtimeClient = { } else { enqueueMessages.push(input); if (!reconnecting) { - reconnecting = true; reconnect(); } } @@ -229,8 +233,13 @@ export const realtimeImpl: RealtimeClient = { const reconnect = () => { if (ws && ws.readyState === WebSocket.OPEN) { + reconnecting = false; + return; + } + if (reconnecting) { return; } + reconnecting = true; getConnection(app, connectionKey) .then((connection) => { ws = connection; @@ -254,13 +263,7 @@ export const realtimeImpl: RealtimeClient = { ws = null; }; ws.onerror = (event) => { - // TODO handle errors once server specify them - // if error 401, refresh token and retry - // if error 403, refresh token and retry - connectionManager.expireToken(app); - connectionManager.remove(connectionKey); - ws = null; - // if any of those are failed again, call onError + // TODO specify error protocol for identified errors onError(new ApiError({ message: 'Unknown error', status: 500 })); }; ws.onmessage = (event) => { @@ -268,6 +271,13 @@ export const realtimeImpl: RealtimeClient = { // Drop messages that are not related to the actual result. // In the future, we might want to handle other types of messages. // TODO: specify the fal ws protocol format + if (isUnauthorizedError(data)) { + connectionManager.expireToken(app); + connectionManager.remove(connectionKey); + connectionManager.expireToken(app); + ws = null; + return; + } if (data.status !== 'error' && data.type !== 'x-fal-message') { onResult(data); } @@ -275,7 +285,10 @@ export const realtimeImpl: RealtimeClient = { }) .catch((error) => { onError( - new ApiError({ message: 'Error opening connection', status: 500 }) + new ApiError({ + message: `Error opening connection: ${error.message}`, + status: 500, + }) ); }); }; From cec7ce13cda6e05c77726630b8e1c186daaff12b Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Tue, 28 Nov 2023 19:56:35 -0800 Subject: [PATCH 2/8] chore: reset token expiration --- libs/client/src/realtime.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index f849d03..e2daf63 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -98,7 +98,7 @@ async function getToken(app: string): Promise { `https://${getRestApiUrl()}/tokens/`, { allowed_apps: [appAlias.join('-')], - token_expiration: 40, + token_expiration: TOKEN_EXPIRATION_SECONDS, } ); // keep this in case the response was wrapped (old versions of the proxy do that) @@ -140,7 +140,7 @@ const connectionManager = (() => { // We should make it more robust in the future. setTimeout(() => { tokens.delete(app); - }, TOKEN_EXPIRATION_SECONDS * 0.9 * 1000); + }, Math.round(TOKEN_EXPIRATION_SECONDS * 0.9 * 1000)); return token; }, has(connectionKey: string): boolean { From ffaecf3c3d743d321e8a047ab243f03b7b660262 Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Wed, 29 Nov 2023 05:18:06 -0800 Subject: [PATCH 3/8] feat: state machine experiment --- .../app/realtime/page.tsx | 15 +- libs/client/package.json | 7 +- libs/client/src/realtime.ts | 312 ++++++++++++------ package-lock.json | 6 + package.json | 1 + 5 files changed, 216 insertions(+), 125 deletions(-) diff --git a/apps/demo-nextjs-app-router/app/realtime/page.tsx b/apps/demo-nextjs-app-router/app/realtime/page.tsx index 5621401..ab46068 100644 --- a/apps/demo-nextjs-app-router/app/realtime/page.tsx +++ b/apps/demo-nextjs-app-router/app/realtime/page.tsx @@ -12,12 +12,11 @@ fal.config({ const PROMPT = 'a moon in a starry night sky'; export default function RealtimePage() { - // const [prompt, setPrompt] = useState(PROMPT); - // const [rerender, setRerender] = useState(0); const [image, setImage] = useState(null); const { send } = fal.realtime.connect('110602490-lcm-sd15-i2i', { connectionKey: 'realtime-demo', + throttleInterval: 128, onResult(result) { console.log('onResult!!!', result); if (result.images && result.images[0]) { @@ -26,18 +25,6 @@ export default function RealtimePage() { }, }); - // useEffect(() => { - // setTimeout(() => { - // setRerender((v) => v + 1); - // }, 10); - // }, []); - - // useEffect(() => { - // setTimeout(() => { - // setRerender((v) => v + 1); - // }, 50); - // }, []); - return (
diff --git a/libs/client/package.json b/libs/client/package.json index 6e5a851..4b8725d 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.6.1-alpha.0", + "version": "0.6.1-alpha.4", "license": "MIT", "repository": { "type": "git", @@ -14,5 +14,8 @@ "client", "ai", "ml" - ] + ], + "dependencies": { + "robot3": "^0.4.1" + } } diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index e2daf63..c8f8470 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -1,9 +1,152 @@ +import { + createMachine, + state, + transition, + interpret, + reduce, + ContextFunction, + guard, + immediate, + Service, + InterpretOnChangeFunction, +} from 'robot3'; import { getConfig, getRestApiUrl } from './config'; import { dispatchRequest } from './request'; import { ApiError } from './response'; import { isBrowser } from './runtime'; import { isReact, throttle } from './utils'; +// Define the context +interface Context { + token?: string; + enqueuedMessage?: any; + websocket?: WebSocket; + error?: Error; +} + +const initialState: ContextFunction = () => ({ + enqueuedMessage: undefined, +}); + +type SendEvent = { type: 'send'; message: any }; +type AuthenticatedEvent = { type: 'authenticated'; token: string }; +type InitiateAuthEvent = { type: 'initiateAuth' }; +type UnauthorizedEvent = { type: 'unauthorized'; error: Error }; +type ConnectedEvent = { type: 'connected'; websocket: WebSocket }; +type ConnectionClosedEvent = { + type: 'connectionClosed'; + code: number; + reason: string; +}; + +type Event = + | SendEvent + | AuthenticatedEvent + | InitiateAuthEvent + | UnauthorizedEvent + | ConnectedEvent + | ConnectionClosedEvent; + +function hasToken(context: Context): boolean { + return context.token !== undefined; +} + +function noToken(context: Context): boolean { + return !hasToken(context); +} + +function enqueueMessage(context: Context, event: SendEvent): Context { + return { + ...context, + enqueuedMessage: event.message, + }; +} + +function closeConnection(context: Context): Context { + if (context.websocket && context.websocket.readyState === WebSocket.OPEN) { + context.websocket.close(); + } + return { + ...context, + websocket: undefined, + }; +} + +function sendMessage(context: Context, event: SendEvent): Context { + if (context.websocket && context.websocket.readyState === WebSocket.OPEN) { + context.websocket.send(JSON.stringify(event.message)); + return { + ...context, + enqueuedMessage: undefined, + }; + } + return enqueueMessage(context, event); +} + +function expireToken(context: Context): Context { + return { + ...context, + token: undefined, + }; +} + +function setToken(context: Context, event: AuthenticatedEvent): Context { + return { + ...context, + token: event.token, + }; +} + +function connectionEstablished( + context: Context, + event: ConnectedEvent +): Context { + return { + ...context, + websocket: event.websocket, + }; +} + +// State machine +const connectionStateMachine = createMachine( + 'idle', + { + idle: state( + transition('send', 'connecting', reduce(enqueueMessage)), + transition('expireToken', 'idle', reduce(expireToken)) + ), + connecting: state( + transition('connecting', 'connecting'), + transition('connected', 'active', reduce(connectionEstablished)), + transition('connectionClosed', 'idle', reduce(closeConnection)), + transition('send', 'connecting', reduce(enqueueMessage)), + + immediate('authRequired', guard(noToken)) + ), + authRequired: state( + transition('initiateAuth', 'authInProgress'), + transition('send', 'authRequired', reduce(enqueueMessage)) + ), + authInProgress: state( + transition('authenticated', 'connecting', reduce(setToken)), + transition( + 'unauthorized', + 'failed', + reduce(expireToken), + reduce(closeConnection) + ), + transition('send', 'authInProgress', reduce(enqueueMessage)) + ), + active: state( + transition('send', 'active', reduce(sendMessage)), + transition('unauthorized', 'idle', reduce(expireToken)), + transition('connectionClosed', 'idle', reduce(closeConnection)) + ), + failed: state(transition('send', 'failed')), + }, + initialState +); + /** * A connection object that allows you to `send` request payloads to a * realtime endpoint. @@ -122,55 +265,18 @@ const WebSocketErrorCodes = { GOING_AWAY: 1001, }; -const connectionManager = (() => { - const connections = new Map(); - const tokens = new Map(); - - return { - token(app: string) { - return tokens.get(app); - }, - expireToken(app: string) { - tokens.delete(app); - }, - async refreshToken(app: string) { - const token = await getToken(app); - tokens.set(app, token); - // Very simple token expiration mechanism. - // We should make it more robust in the future. - setTimeout(() => { - tokens.delete(app); - }, Math.round(TOKEN_EXPIRATION_SECONDS * 0.9 * 1000)); - return token; - }, - has(connectionKey: string): boolean { - return connections.has(connectionKey); - }, - get(connectionKey: string): WebSocket | undefined { - return connections.get(connectionKey); - }, - set(connectionKey: string, ws: WebSocket) { - connections.set(connectionKey, ws); - }, - remove(connectionKey: string) { - connections.delete(connectionKey); - }, - }; -})(); +type ConnectionStateMachine = Service; -async function getConnection(app: string, key: string): Promise { - const url = buildRealtimeUrl(app); +type ConnectionOnChange = InterpretOnChangeFunction< + typeof connectionStateMachine +>; - if (connectionManager.has(key)) { - return connectionManager.get(key) as WebSocket; +const connections = new Map(); +function reuseInterpreter(key: string, onChange: ConnectionOnChange) { + if (!connections.has(key)) { + connections.set(key, interpret(connectionStateMachine, onChange)); } - let token = connectionManager.token(app); - if (!token) { - token = await connectionManager.refreshToken(app); - } - const ws = new WebSocket(`${url}?fal_jwt_token=${token}`); - connectionManager.set(key, ws); - return ws; + return connections.get(key) as ConnectionStateMachine; } const noop = () => { @@ -204,54 +310,46 @@ export const realtimeImpl: RealtimeClient = { onError = noop, onResult, } = handler; - if (clientOnly && typeof window === 'undefined') { + if (clientOnly && !isBrowser()) { return NoOpConnection; } - const enqueueMessages: Input[] = []; - - let reconnecting = false; - let ws: WebSocket | null = null; - const _send = (input: Input) => { - const requestId = crypto.randomUUID(); - if (ws && ws.readyState === WebSocket.OPEN) { - ws.send( - JSON.stringify({ - request_id: requestId, - ...input, - }) - ); - } else { - enqueueMessages.push(input); - if (!reconnecting) { - reconnect(); + let previousState: string | undefined; + const stateMachine = reuseInterpreter( + connectionKey, + ({ context, machine, send }) => { + const { enqueuedMessage, token } = context; + if (machine.current === 'active' && enqueuedMessage) { + send({ type: 'send', message: enqueuedMessage }); } - } - }; - const send = - throttleInterval > 0 ? throttle(_send, throttleInterval) : _send; - - const reconnect = () => { - if (ws && ws.readyState === WebSocket.OPEN) { - reconnecting = false; - return; - } - if (reconnecting) { - return; - } - reconnecting = true; - getConnection(app, connectionKey) - .then((connection) => { - ws = connection; + if ( + machine.current === 'authRequired' && + token === undefined && + previousState !== machine.current + ) { + send({ type: 'initiateAuth' }); + getToken(app) + .then((token) => { + send({ type: 'authenticated', token }); + const tokenExpirationTimeout = Math.round( + TOKEN_EXPIRATION_SECONDS * 0.9 * 1000 + ); + setTimeout(() => { + send({ type: 'expireToken' }); + }, tokenExpirationTimeout); + }) + .catch((error) => { + send({ type: 'unauthorized', error }); + }); + } + if (machine.current === 'connecting' && token !== undefined) { + const ws = new WebSocket( + `${buildRealtimeUrl(app)}?fal_jwt_token=${token}` + ); ws.onopen = () => { - reconnecting = false; - if (enqueueMessages.length > 0) { - enqueueMessages.forEach((input) => send(input)); - enqueueMessages.length = 0; - } + send({ type: 'connected', websocket: ws }); }; ws.onclose = (event) => { - connectionManager.remove(connectionKey); if (event.code !== WebSocketErrorCodes.NORMAL_CLOSURE) { onError( new ApiError({ @@ -260,7 +358,7 @@ export const realtimeImpl: RealtimeClient = { }) ); } - ws = null; + send({ type: 'connectionClosed', code: event.code }); }; ws.onerror = (event) => { // TODO specify error protocol for identified errors @@ -272,37 +370,33 @@ export const realtimeImpl: RealtimeClient = { // In the future, we might want to handle other types of messages. // TODO: specify the fal ws protocol format if (isUnauthorizedError(data)) { - connectionManager.expireToken(app); - connectionManager.remove(connectionKey); - connectionManager.expireToken(app); - ws = null; + send({ type: 'unauthorized', error: new Error('Unauthorized') }); return; } if (data.status !== 'error' && data.type !== 'x-fal-message') { onResult(data); } }; - }) - .catch((error) => { - onError( - new ApiError({ - message: `Error opening connection: ${error.message}`, - status: 500, - }) - ); - }); + } + previousState = machine.current; + } + ); + + const sendMessage = (input: Input) => { + stateMachine.send({ type: 'send', message: input }); + }; + const send = + throttleInterval > 0 + ? throttle(sendMessage, throttleInterval) + : sendMessage; + + const close = () => { + stateMachine.send({ type: 'close' }); }; return { send, - close() { - if (ws && ws.readyState === WebSocket.CLOSED) { - ws.close( - WebSocketErrorCodes.GOING_AWAY, - 'Client manually closed the connection.' - ); - } - }, + close, }; }, }; diff --git a/package-lock.json b/package-lock.json index 337cd7b..b4837b8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -28,6 +28,7 @@ "react": "^18.2.0", "react-dom": "^18.2.0", "regenerator-runtime": "0.13.7", + "robot3": "^0.4.1", "ts-morph": "^17.0.1", "tslib": "^2.3.0" }, @@ -24588,6 +24589,11 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/robot3": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/robot3/-/robot3-0.4.1.tgz", + "integrity": "sha512-hzjy826lrxzx8eRgv80idkf8ua1JAepRc9Efdtj03N3KNJuznQCPlyCJ7gnUmDFwZCLQjxy567mQVKmdv2BsXQ==" + }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", diff --git a/package.json b/package.json index 286ddbd..f6af7af 100644 --- a/package.json +++ b/package.json @@ -44,6 +44,7 @@ "react": "^18.2.0", "react-dom": "^18.2.0", "regenerator-runtime": "0.13.7", + "robot3": "^0.4.1", "ts-morph": "^17.0.1", "tslib": "^2.3.0" }, From c66dfea5c097b401a503f4abd9f7ecfe953cd8fb Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Fri, 1 Dec 2023 09:08:32 -0800 Subject: [PATCH 4/8] feat: new realtime state machine impl --- .../app/realtime/page.tsx | 3 +- libs/client/package.json | 2 +- libs/client/src/config.ts | 2 +- libs/client/src/function.ts | 2 +- libs/client/src/realtime.ts | 112 ++++++++++++------ libs/client/src/utils.ts | 5 +- libs/proxy/src/nextjs.ts | 2 +- 7 files changed, 86 insertions(+), 42 deletions(-) diff --git a/apps/demo-nextjs-app-router/app/realtime/page.tsx b/apps/demo-nextjs-app-router/app/realtime/page.tsx index ab46068..a9d9229 100644 --- a/apps/demo-nextjs-app-router/app/realtime/page.tsx +++ b/apps/demo-nextjs-app-router/app/realtime/page.tsx @@ -2,8 +2,8 @@ /* eslint-disable @next/next/no-img-element */ import * as fal from '@fal-ai/serverless-client'; -import { DrawingCanvas } from '../../components/drawing'; import { useState } from 'react'; +import { DrawingCanvas } from '../../components/drawing'; fal.config({ proxyUrl: '/api/fal/proxy', @@ -18,7 +18,6 @@ export default function RealtimePage() { connectionKey: 'realtime-demo', throttleInterval: 128, onResult(result) { - console.log('onResult!!!', result); if (result.images && result.images[0]) { setImage(result.images[0].url); } diff --git a/libs/client/package.json b/libs/client/package.json index 8aaa7be..f258c33 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.7.0-alpha.0", + "version": "0.7.0-alpha.5", "license": "MIT", "repository": { "type": "git", diff --git a/libs/client/src/config.ts b/libs/client/src/config.ts index a9ba9b8..9bfb6e8 100644 --- a/libs/client/src/config.ts +++ b/libs/client/src/config.ts @@ -1,7 +1,7 @@ import { + withMiddleware, withProxy, type RequestMiddleware, - withMiddleware, } from './middleware'; import type { ResponseHandler } from './response'; import { defaultResponseHandler } from './response'; diff --git a/libs/client/src/function.ts b/libs/client/src/function.ts index c10a810..9c42f6e 100644 --- a/libs/client/src/function.ts +++ b/libs/client/src/function.ts @@ -1,6 +1,6 @@ import { getConfig } from './config'; -import { storageImpl } from './storage'; import { dispatchRequest } from './request'; +import { storageImpl } from './storage'; import { EnqueueResult, QueueStatus } from './types'; import { isUUIDv4, isValidUrl } from './utils'; diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index c8f8470..161ab05 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -1,14 +1,15 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { - createMachine, - state, - transition, - interpret, - reduce, ContextFunction, + createMachine, guard, immediate, - Service, + interpret, InterpretOnChangeFunction, + reduce, + Service, + state, + transition, } from 'robot3'; import { getConfig, getRestApiUrl } from './config'; import { dispatchRequest } from './request'; @@ -80,7 +81,10 @@ function sendMessage(context: Context, event: SendEvent): Context { enqueuedMessage: undefined, }; } - return enqueueMessage(context, event); + return { + ...context, + enqueuedMessage: event.message, + }; } function expireToken(context: Context): Context { @@ -131,7 +135,7 @@ const connectionStateMachine = createMachine( transition('authenticated', 'connecting', reduce(setToken)), transition( 'unauthorized', - 'failed', + 'idle', reduce(expireToken), reduce(closeConnection) ), @@ -147,20 +151,20 @@ const connectionStateMachine = createMachine( initialState ); +type WithRequestId = { + request_id: string; +}; + /** * A connection object that allows you to `send` request payloads to a * realtime endpoint. */ export interface RealtimeConnection { - send(input: Input): void; + send(input: Input & Partial): void; close(): void; } -type ResultWithRequestId = { - request_id: string; -}; - /** * Options for connecting to the realtime endpoint. */ @@ -189,9 +193,9 @@ export interface RealtimeConnectionHandler { /** * The throtle duration in milliseconds. This is used to throtle the * calls to the `send` function. Realtime apps usually react to user - * input, which can be very frequesnt (e.g. fast typing or mouse/drag movements). + * input, which can be very frequent (e.g. fast typing or mouse/drag movements). * - * The default value is `64` milliseconds. + * The default value is `128` milliseconds. */ throttleInterval?: number; @@ -199,13 +203,12 @@ export interface RealtimeConnectionHandler { * Callback function that is called when a result is received. * @param result - The result of the request. */ - onResult(result: Output & ResultWithRequestId): void; + onResult(result: Output & WithRequestId): void; /** * Callback function that is called when an error occurs. * @param error - The error that occurred. */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any onError?(error: ApiError): void; } @@ -217,7 +220,6 @@ export interface RealtimeClient { * @param app the app alias or identifier. * @param handler the connection handler. */ - // eslint-disable-next-line @typescript-eslint/no-explicit-any connect( app: string, handler: RealtimeConnectionHandler @@ -230,6 +232,7 @@ function buildRealtimeUrl(app: string): string { } const TOKEN_EXPIRATION_SECONDS = 120; +const DEFAULT_THROTTLE_INTERVAL = 128; /** * Get a token to connect to the realtime endpoint. @@ -265,18 +268,40 @@ const WebSocketErrorCodes = { GOING_AWAY: 1001, }; -type ConnectionStateMachine = Service; +type ConnectionStateMachine = Service & { + throttledSend: ( + event: Event, + payload?: any + ) => void | Promise | undefined; +}; type ConnectionOnChange = InterpretOnChangeFunction< typeof connectionStateMachine >; -const connections = new Map(); -function reuseInterpreter(key: string, onChange: ConnectionOnChange) { - if (!connections.has(key)) { - connections.set(key, interpret(connectionStateMachine, onChange)); +type RealtimeConnectionCallback = Pick< + RealtimeConnectionHandler, + 'onResult' | 'onError' +>; + +const connectionCache = new Map(); +const connectionCallbacks = new Map(); +function reuseInterpreter( + key: string, + throttleInterval: number, + onChange: ConnectionOnChange +) { + if (!connectionCache.has(key)) { + const machine = interpret(connectionStateMachine, onChange); + connectionCache.set(key, { + ...machine, + throttledSend: + throttleInterval > 0 + ? throttle(machine.send, throttleInterval, true) + : machine.send, + }); } - return connections.get(key) as ConnectionStateMachine; + return connectionCache.get(key) as ConnectionStateMachine; } const noop = () => { @@ -306,17 +331,26 @@ export const realtimeImpl: RealtimeClient = { // if running on React in the server, set clientOnly to true by default clientOnly = isReact() && !isBrowser(), connectionKey = crypto.randomUUID(), - throttleInterval = 64, - onError = noop, - onResult, + throttleInterval = DEFAULT_THROTTLE_INTERVAL, } = handler; if (clientOnly && !isBrowser()) { return NoOpConnection; } let previousState: string | undefined; + + // Although the state machine is cached so we don't open multiple connections, + // we still need to update the callbacks so we can call the correct references + // when the state machine is reused. This is needed because the callbacks + // are passed as part of the handler object, which can be different across + // different calls to `connect`. + connectionCallbacks.set(connectionKey, { + onError: handler.onError, + onResult: handler.onResult, + }); const stateMachine = reuseInterpreter( connectionKey, + throttleInterval, ({ context, machine, send }) => { const { enqueuedMessage, token } = context; if (machine.current === 'active' && enqueuedMessage) { @@ -342,7 +376,11 @@ export const realtimeImpl: RealtimeClient = { send({ type: 'unauthorized', error }); }); } - if (machine.current === 'connecting' && token !== undefined) { + if ( + machine.current === 'connecting' && + previousState !== machine.current && + token !== undefined + ) { const ws = new WebSocket( `${buildRealtimeUrl(app)}?fal_jwt_token=${token}` ); @@ -351,6 +389,7 @@ export const realtimeImpl: RealtimeClient = { }; ws.onclose = (event) => { if (event.code !== WebSocketErrorCodes.NORMAL_CLOSURE) { + const { onError = noop } = connectionCallbacks.get(connectionKey); onError( new ApiError({ message: `Error closing the connection: ${event.reason}`, @@ -362,6 +401,7 @@ export const realtimeImpl: RealtimeClient = { }; ws.onerror = (event) => { // TODO specify error protocol for identified errors + const { onError = noop } = connectionCallbacks.get(connectionKey); onError(new ApiError({ message: 'Unknown error', status: 500 })); }; ws.onmessage = (event) => { @@ -374,6 +414,7 @@ export const realtimeImpl: RealtimeClient = { return; } if (data.status !== 'error' && data.type !== 'x-fal-message') { + const { onResult } = connectionCallbacks.get(connectionKey); onResult(data); } }; @@ -382,13 +423,16 @@ export const realtimeImpl: RealtimeClient = { } ); - const sendMessage = (input: Input) => { - stateMachine.send({ type: 'send', message: input }); + const send = (input: Input) => { + // Use throttled send to avoid sending too many messages + stateMachine.throttledSend({ + type: 'send', + message: { + ...input, + request_id: input['request_id'] ?? crypto.randomUUID(), + }, + }); }; - const send = - throttleInterval > 0 - ? throttle(sendMessage, throttleInterval) - : sendMessage; const close = () => { stateMachine.send({ type: 'close' }); diff --git a/libs/client/src/utils.ts b/libs/client/src/utils.ts index a06c785..cb73b89 100644 --- a/libs/client/src/utils.ts +++ b/libs/client/src/utils.ts @@ -19,13 +19,14 @@ export function isValidUrl(url: string) { // eslint-disable-next-line @typescript-eslint/no-explicit-any export function throttle any>( func: T, - limit: number + limit: number, + leading = false ): (...funcArgs: Parameters) => ReturnType | void { let lastFunc: NodeJS.Timeout | null; let lastRan: number; return (...args: Parameters): ReturnType | void => { - if (!lastRan) { + if (!lastRan && leading) { func(...args); lastRan = Date.now(); } else { diff --git a/libs/proxy/src/nextjs.ts b/libs/proxy/src/nextjs.ts index 74680ae..be6a6d8 100644 --- a/libs/proxy/src/nextjs.ts +++ b/libs/proxy/src/nextjs.ts @@ -1,5 +1,5 @@ +import { NextResponse, type NextRequest } from 'next/server'; import type { NextApiHandler } from 'next/types'; -import { type NextRequest, NextResponse } from 'next/server'; import { DEFAULT_PROXY_ROUTE, handleRequest } from './index'; /** From 76c94461747762e1a371927fc3beed99b71458da Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Fri, 1 Dec 2023 09:09:19 -0800 Subject: [PATCH 5/8] chore: update client to 0.7.0 before release --- libs/client/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/client/package.json b/libs/client/package.json index f258c33..580d3d9 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -1,7 +1,7 @@ { "name": "@fal-ai/serverless-client", "description": "The fal serverless JS/TS client", - "version": "0.7.0-alpha.5", + "version": "0.7.0", "license": "MIT", "repository": { "type": "git", From 731ae5cbad08737eeec073213892a5e53e88bb98 Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Fri, 1 Dec 2023 11:41:18 -0800 Subject: [PATCH 6/8] fix: error handling x-fal-error --- libs/client/src/realtime.ts | 43 ++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index 161ab05..8046e76 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -319,6 +319,24 @@ const NoOpConnection: RealtimeConnection = { close: noop, }; +function isSuccessfulResult(data: any): boolean { + return ( + data.status !== 'error' && + data.type !== 'x-fal-message' && + !isFalErrorResult(data) + ); +} + +type FalErrorResult = { + type: 'x-fal-error'; + error: string; + reason: string; +}; + +function isFalErrorResult(data: any): data is FalErrorResult { + return data.type === 'x-fal-error'; +} + /** * The default implementation of the realtime client. */ @@ -348,6 +366,8 @@ export const realtimeImpl: RealtimeClient = { onError: handler.onError, onResult: handler.onResult, }); + const getCallbacks = () => + connectionCallbacks.get(connectionKey) as RealtimeConnectionCallback; const stateMachine = reuseInterpreter( connectionKey, throttleInterval, @@ -389,7 +409,7 @@ export const realtimeImpl: RealtimeClient = { }; ws.onclose = (event) => { if (event.code !== WebSocketErrorCodes.NORMAL_CLOSURE) { - const { onError = noop } = connectionCallbacks.get(connectionKey); + const { onError = noop } = getCallbacks(); onError( new ApiError({ message: `Error closing the connection: ${event.reason}`, @@ -401,7 +421,7 @@ export const realtimeImpl: RealtimeClient = { }; ws.onerror = (event) => { // TODO specify error protocol for identified errors - const { onError = noop } = connectionCallbacks.get(connectionKey); + const { onError = noop } = getCallbacks(); onError(new ApiError({ message: 'Unknown error', status: 500 })); }; ws.onmessage = (event) => { @@ -413,9 +433,22 @@ export const realtimeImpl: RealtimeClient = { send({ type: 'unauthorized', error: new Error('Unauthorized') }); return; } - if (data.status !== 'error' && data.type !== 'x-fal-message') { - const { onResult } = connectionCallbacks.get(connectionKey); + if (isSuccessfulResult(data)) { + const { onResult } = getCallbacks(); onResult(data); + return; + } + if (isFalErrorResult(data)) { + const { onError = noop } = getCallbacks(); + onError( + new ApiError({ + message: `${data.error}: ${data.reason}`, + // TODO better error status code + status: 400, + body: data, + }) + ); + return; } }; } @@ -423,7 +456,7 @@ export const realtimeImpl: RealtimeClient = { } ); - const send = (input: Input) => { + const send = (input: Input & Partial) => { // Use throttled send to avoid sending too many messages stateMachine.throttledSend({ type: 'send', From 18704f60a4ab6ca3d01c836b2bbb45627d75084b Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Mon, 4 Dec 2023 14:33:33 -0800 Subject: [PATCH 7/8] chore(client): release v0.7.0 --- apps/demo-express-app/src/main.ts | 11 ++++++++++ libs/client/package.json | 3 +++ libs/client/src/realtime.ts | 35 ++++++++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/apps/demo-express-app/src/main.ts b/apps/demo-express-app/src/main.ts index f875e7d..7f11332 100644 --- a/apps/demo-express-app/src/main.ts +++ b/apps/demo-express-app/src/main.ts @@ -3,6 +3,7 @@ * This is only a minimal backend to get started. */ +import * as fal from '@fal-ai/serverless-client'; import * as falProxy from '@fal-ai/serverless-proxy/express'; import cors from 'cors'; import { configDotenv } from 'dotenv'; @@ -25,6 +26,16 @@ app.get('/api', (req, res) => { res.send({ message: 'Welcome to demo-express-app!' }); }); +app.get('/fal-on-server', async (req, res) => { + const result = await fal.run('110602490-lcm', { + input: { + prompt: + 'a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k', + }, + }); + res.send(result); +}); + const port = process.env.PORT || 3333; const server = app.listen(port, () => { console.log(`Listening at http://localhost:${port}/api`); diff --git a/libs/client/package.json b/libs/client/package.json index 580d3d9..b7e983b 100644 --- a/libs/client/package.json +++ b/libs/client/package.json @@ -17,5 +17,8 @@ ], "dependencies": { "robot3": "^0.4.1" + }, + "engines": { + "node": ">=18.0.0" } } diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index 8046e76..433836a 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -199,6 +199,15 @@ export interface RealtimeConnectionHandler { */ throttleInterval?: number; + /** + * Configures the maximum amount of frames to store in memory before starting to drop + * old ones for in favor of the newer ones. It must be between `1` and `60`. + * + * The recommended is `2`. The default is `undefined` so it can be determined + * by the app (normally is set to the recommended setting). + */ + maxBuffering?: number; + /** * Callback function that is called when a result is received. * @param result - The result of the request. @@ -226,9 +235,28 @@ export interface RealtimeClient { ): RealtimeConnection; } -function buildRealtimeUrl(app: string): string { +type RealtimeUrlParams = { + token: string; + maxBuffering?: number; +}; + +function buildRealtimeUrl( + app: string, + { token, maxBuffering }: RealtimeUrlParams +): string { const { host } = getConfig(); - return `wss://${app}.${host}/ws`; + if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) { + throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)'); + } + const maxBufferingParam = + maxBuffering !== undefined + ? { max_buffering: maxBuffering.toFixed(0) } + : {}; + const queryParams = new URLSearchParams({ + fal_jwt_token: token, + ...maxBufferingParam, + }); + return `wss://${app}.${host}/ws?${queryParams.toString()}`; } const TOKEN_EXPIRATION_SECONDS = 120; @@ -349,6 +377,7 @@ export const realtimeImpl: RealtimeClient = { // if running on React in the server, set clientOnly to true by default clientOnly = isReact() && !isBrowser(), connectionKey = crypto.randomUUID(), + maxBuffering, throttleInterval = DEFAULT_THROTTLE_INTERVAL, } = handler; if (clientOnly && !isBrowser()) { @@ -402,7 +431,7 @@ export const realtimeImpl: RealtimeClient = { token !== undefined ) { const ws = new WebSocket( - `${buildRealtimeUrl(app)}?fal_jwt_token=${token}` + buildRealtimeUrl(app, { token, maxBuffering }) ); ws.onopen = () => { send({ type: 'connected', websocket: ws }); From 668e2aa4cc6eb3d014c59d2d1ce4d64837ce5ace Mon Sep 17 00:00:00 2001 From: Daniel Rochetti Date: Mon, 4 Dec 2023 14:40:23 -0800 Subject: [PATCH 8/8] fix(client): strict type check error --- libs/client/src/realtime.ts | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/libs/client/src/realtime.ts b/libs/client/src/realtime.ts index 433836a..32c12fb 100644 --- a/libs/client/src/realtime.ts +++ b/libs/client/src/realtime.ts @@ -248,14 +248,12 @@ function buildRealtimeUrl( if (maxBuffering !== undefined && (maxBuffering < 1 || maxBuffering > 60)) { throw new Error('The `maxBuffering` must be between 1 and 60 (inclusive)'); } - const maxBufferingParam = - maxBuffering !== undefined - ? { max_buffering: maxBuffering.toFixed(0) } - : {}; const queryParams = new URLSearchParams({ fal_jwt_token: token, - ...maxBufferingParam, }); + if (maxBuffering !== undefined) { + queryParams.set('max_buffering', maxBuffering.toFixed(0)); + } return `wss://${app}.${host}/ws?${queryParams.toString()}`; }