Skip to content

Commit

Permalink
feat: binary messages through ws + camera demo
Browse files Browse the repository at this point in the history
  • Loading branch information
drochetti committed Dec 12, 2023
1 parent 6f95c65 commit 7452001
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 7 deletions.
205 changes: 205 additions & 0 deletions apps/demo-nextjs-app-router/app/camera-turbo/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/* eslint-disable @next/next/no-img-element */
'use client';

import * as fal from '@fal-ai/serverless-client';
import { MutableRefObject, useEffect, useRef, useState } from 'react';

fal.config({
proxyUrl: '/api/fal/proxy',
});

const EMPTY_IMG =
'';

type WebcamOptions = {
videoRef: MutableRefObject<HTMLVideoElement | null>;
previewRef: MutableRefObject<HTMLCanvasElement | null>;
onFrameUpdate?: (data: Uint8Array) => void;
width?: number;
height?: number;
};
const useWebcam = ({
videoRef,
previewRef,
onFrameUpdate,
width = 512,
height = 512,
}: WebcamOptions) => {
useEffect(() => {
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
navigator.mediaDevices.getUserMedia({ video: true }).then((stream) => {
if (videoRef.current !== null) {
videoRef.current.srcObject = stream;
videoRef.current.play();
}
});
}
}, [videoRef]);

const captureFrame = () => {
const canvas = previewRef.current;
const video = videoRef.current;
if (canvas === null || video === null) {
return;
}

// Calculate the aspect ratio and crop dimensions
const aspectRatio = video.videoWidth / video.videoHeight;
let sourceX, sourceY, sourceWidth, sourceHeight;

if (aspectRatio > 1) {
// If width is greater than height
sourceWidth = video.videoHeight;
sourceHeight = video.videoHeight;
sourceX = (video.videoWidth - video.videoHeight) / 2;
sourceY = 0;
} else {
// If height is greater than or equal to width
sourceWidth = video.videoWidth;
sourceHeight = video.videoWidth;
sourceX = 0;
sourceY = (video.videoHeight - video.videoWidth) / 2;
}

// Resize the canvas to the target dimensions
canvas.width = width;
canvas.height = height;

const context = canvas.getContext('2d');
if (context === null) {
return;
}

// Draw the image on the canvas (cropped and resized)
context.drawImage(
video,
sourceX,
sourceY,
sourceWidth,
sourceHeight,
0,
0,
width,
height
);

// Callback with frame data
if (onFrameUpdate) {
canvas.toBlob(
(blob) => {
blob?.arrayBuffer().then((buffer) => {
const frameData = new Uint8Array(buffer);
onFrameUpdate(frameData);
});
},
'image/jpeg',
0.7
);
}
};

useEffect(() => {
const interval = setInterval(() => {
captureFrame();
}, 16); // Adjust interval as needed

return () => clearInterval(interval);
});
};

type LCMInput = {
prompt: string;
image: Uint8Array;
strength?: number;
negative_prompt?: string;
seed?: number | null;
guidance_scale?: number;
num_inference_steps?: number;
enable_safety_checks?: boolean;
request_id?: string;
height?: number;
width?: number;
};

type LCMOutput = {
image: Uint8Array;
timings: Record<string, number>;
seed: number;
num_inference_steps: number;
request_id: string;
nsfw_content_detected: boolean[];
};

export default function WebcamPage() {
const [enabled, setEnabled] = useState(false);
const processedImageRef = useRef<HTMLImageElement | null>(null);
const videoRef = useRef<HTMLVideoElement | null>(null);
const previewRef = useRef<HTMLCanvasElement | null>(null);

const { send } = fal.realtime.connect<LCMInput, LCMOutput>(
'110602490-sd-turbo-real-time-high-fps-msgpack',
{
connectionKey: 'camera-turbo-demo',
// not throttling the client, handling throttling of the camera itself
// and letting all requests through in real-time
throttleInterval: 0,
onResult(result) {
if (processedImageRef.current && result.image) {
const blob = new Blob([result.image], { type: 'image/jpeg' });
const url = URL.createObjectURL(blob);
processedImageRef.current.src = url;
}
},
}
);

const onFrameUpdate = (data: Uint8Array) => {
if (!enabled) {
return;
}
send({
prompt: 'a picture of leonardo di caprio, elegant, in a suit, 8k, uhd',
image: data,
num_inference_steps: 3,
strength: 0.44,
guidance_scale: 1,
seed: 6252023,
});
};

useWebcam({
videoRef,
previewRef,
onFrameUpdate,
});

return (
<main className="flex-col px-32 mx-auto my-20">
<h1 className="text-4xl font-mono mb-8 text-current text-center">
fal<code className="font-light text-pink-600">camera</code>
</h1>
<video ref={videoRef} style={{ display: 'none' }}></video>
<div className="py-12 flex items-center justify-center">
<button
className="py-3 px-4 bg-indigo-700 text-white text-lg rounded"
onClick={() => {
setEnabled(!enabled);
}}
>
{enabled ? 'Stop' : 'Start'}
</button>
</div>
<div className="flex flex-col lg:flex-row space-y-4 lg:space-y-0 lg:space-x-4 justify-between">
<canvas ref={previewRef} width="512" height="512"></canvas>
<img
ref={processedImageRef}
src={EMPTY_IMG}
width={512}
height={512}
className="min-w-[512px] min-h-[512px]"
alt="generated"
/>
</div>
</main>
);
}
3 changes: 2 additions & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.7.0",
"version": "0.7.1-alpha.0",
"license": "MIT",
"repository": {
"type": "git",
Expand All @@ -16,6 +16,7 @@
"ml"
],
"dependencies": {
"msgpackr": "^1.10.0",
"robot3": "^0.4.1"
},
"engines": {
Expand Down
61 changes: 55 additions & 6 deletions libs/client/src/realtime.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { pack, unpack } from 'msgpackr';
import {
ContextFunction,
createMachine,
Expand Down Expand Up @@ -75,7 +76,14 @@ function closeConnection(context: Context): Context {

function sendMessage(context: Context, event: SendEvent): Context {
if (context.websocket && context.websocket.readyState === WebSocket.OPEN) {
context.websocket.send(JSON.stringify(event.message));
if (event.message instanceof Uint8Array) {
context.websocket.send(event.message);
} else if (shouldSendBinary(event.message)) {
context.websocket.send(pack(event.message));
} else {
context.websocket.send(JSON.stringify(event.message));
}

return {
...context,
enqueuedMessage: undefined,
Expand Down Expand Up @@ -260,6 +268,16 @@ function buildRealtimeUrl(
const TOKEN_EXPIRATION_SECONDS = 120;
const DEFAULT_THROTTLE_INTERVAL = 128;

function shouldSendBinary(message: any): boolean {
return Object.values(message).some(
(value) =>
value instanceof Buffer ||
value instanceof Blob ||
value instanceof ArrayBuffer ||
value instanceof Uint8Array
);
}

/**
* Get a token to connect to the realtime endpoint.
*/
Expand Down Expand Up @@ -452,7 +470,33 @@ export const realtimeImpl: RealtimeClient = {
onError(new ApiError({ message: 'Unknown error', status: 500 }));
};
ws.onmessage = (event) => {
const { onResult } = getCallbacks();

// Handle binary messages as msgpack messages
if (event.data instanceof ArrayBuffer) {
const result = unpack(new Uint8Array(event.data));
onResult(result);
return;
}
if (
event.data instanceof Buffer ||
event.data instanceof Uint8Array
) {
const result = unpack(event.data);
onResult(result);
return;
}
if (event.data instanceof Blob) {
event.data.arrayBuffer().then((buffer) => {
const result = unpack(buffer as Buffer);
onResult(result);
});
return;
}

// Otherwise handle strings as plain JSON messages
const data = JSON.parse(event.data);

// Drop messages that are not related to the actual result.
// In the future, we might want to handle other types of messages.
// TODO: specify the fal ws protocol format
Expand All @@ -461,7 +505,6 @@ export const realtimeImpl: RealtimeClient = {
return;
}
if (isSuccessfulResult(data)) {
const { onResult } = getCallbacks();
onResult(data);
return;
}
Expand All @@ -485,12 +528,18 @@ export const realtimeImpl: RealtimeClient = {

const send = (input: Input & Partial<WithRequestId>) => {
// Use throttled send to avoid sending too many messages

const message =
input instanceof Uint8Array
? input
: {
...input,
request_id: input['request_id'] ?? crypto.randomUUID(),
};

stateMachine.throttledSend({
type: 'send',
message: {
...input,
request_id: input['request_id'] ?? crypto.randomUUID(),
},
message,
});
};

Expand Down
Loading

0 comments on commit 7452001

Please sign in to comment.