Skip to content

Commit

Permalink
feat(session): encrypt data and fix renewal (stackblitz#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrvidal authored Aug 19, 2024
1 parent b939a0a commit 44226db
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 51 deletions.
93 changes: 73 additions & 20 deletions packages/bolt/app/lib/.server/sessions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,59 @@ import { CLIENT_ID, CLIENT_ORIGIN } from '~/lib/constants';
import { request as doRequest } from '~/lib/fetch';
import { logger } from '~/utils/logger';
import type { Identity } from '~/lib/analytics';
import { decrypt, encrypt } from '~/lib/crypto';

const DEV_SESSION_SECRET = import.meta.env.DEV ? 'LZQMrERo3Ewn/AbpSYJ9aw==' : undefined;
const DEV_PAYLOAD_SECRET = import.meta.env.DEV ? '2zAyrhjcdFeXk0YEDzilMXbdrGAiR+8ACIUgFNfjLaI=' : undefined;

const TOKEN_KEY = 't';
const EXPIRES_KEY = 'e';
const USER_ID_KEY = 'u';
const SEGMENT_KEY = 's';

interface SessionData {
refresh: string;
expiresAt: number;
userId: string | null;
segmentWriteKey: string | null;
[TOKEN_KEY]: string;
[EXPIRES_KEY]: number;
[USER_ID_KEY]?: string;
[SEGMENT_KEY]?: string;
}

export async function isAuthenticated(request: Request, env: Env) {
const { session, sessionStorage } = await getSession(request, env);
const token = session.get('refresh');

const sessionData: SessionData | null = await decryptSessionData(env, session.get('d'));

const header = async (cookie: Promise<string>) => ({ headers: { 'Set-Cookie': await cookie } });
const destroy = () => header(sessionStorage.destroySession(session));

if (token == null) {
if (sessionData?.[TOKEN_KEY] == null) {
return { authenticated: false as const, response: await destroy() };
}

const expiresAt = session.get('expiresAt') ?? 0;
const expiresAt = sessionData[EXPIRES_KEY] ?? 0;

if (Date.now() < expiresAt) {
return { authenticated: true as const };
}

logger.debug('Renewing token');

let data: Awaited<ReturnType<typeof refreshToken>> | null = null;

try {
data = await refreshToken(token);
} catch {
data = await refreshToken(sessionData[TOKEN_KEY]);
} catch (error) {
// we can ignore the error here because it's handled below
logger.error(error);
}

if (data != null) {
const expiresAt = cookieExpiration(data.expires_in, data.created_at);
session.set('expiresAt', expiresAt);

const newSessionData = { ...sessionData, [EXPIRES_KEY]: expiresAt };
const encryptedData = await encryptSessionData(env, newSessionData);

session.set('d', encryptedData);

return { authenticated: true as const, response: await header(sessionStorage.commitSession(session)) };
} else {
Expand All @@ -59,13 +74,15 @@ export async function createUserSession(

const expiresAt = cookieExpiration(tokens.expires_in, tokens.created_at);

session.set('refresh', tokens.refresh);
session.set('expiresAt', expiresAt);
const sessionData: SessionData = {
[TOKEN_KEY]: tokens.refresh,
[EXPIRES_KEY]: expiresAt,
[USER_ID_KEY]: identity?.userId ?? undefined,
[SEGMENT_KEY]: identity?.segmentWriteKey ?? undefined,
};

if (identity) {
session.set('userId', identity.userId ?? null);
session.set('segmentWriteKey', identity.segmentWriteKey ?? null);
}
const encryptedData = await encryptSessionData(env, sessionData);
session.set('d', encryptedData);

return {
headers: {
Expand All @@ -77,7 +94,7 @@ export async function createUserSession(
}

function getSessionStorage(cloudflareEnv: Env) {
return createCookieSessionStorage<SessionData>({
return createCookieSessionStorage<{ d: string }>({
cookie: {
name: '__session',
httpOnly: true,
Expand All @@ -91,7 +108,11 @@ function getSessionStorage(cloudflareEnv: Env) {
export async function logout(request: Request, env: Env) {
const { session, sessionStorage } = await getSession(request, env);

revokeToken(session.get('refresh'));
const sessionData = await decryptSessionData(env, session.get('d'));

if (sessionData) {
revokeToken(sessionData[TOKEN_KEY]);
}

return redirect('/login', {
headers: {
Expand All @@ -106,7 +127,18 @@ export function validateAccessToken(access: string) {
return jwtPayload.bolt === true;
}

export async function getSession(request: Request, env: Env) {
export async function getSessionData(request: Request, env: Env) {
const { session } = await getSession(request, env);

const decrypted = await decryptSessionData(env, session.get('d'));

return {
userId: decrypted?.[USER_ID_KEY],
segmentWriteKey: decrypted?.[SEGMENT_KEY],
};
}

async function getSession(request: Request, env: Env) {
const sessionStorage = getSessionStorage(env);
const cookie = request.headers.get('Cookie');

Expand All @@ -117,12 +149,15 @@ async function refreshToken(refresh: string): Promise<{ expires_in: number; crea
const response = await doRequest(`${CLIENT_ORIGIN}/oauth/token`, {
method: 'POST',
body: urlParams({ grant_type: 'refresh_token', client_id: CLIENT_ID, refresh_token: refresh }),
headers: {
'content-type': 'application/x-www-form-urlencoded',
},
});

const body = await response.json();

if (!response.ok) {
throw new Error(`Unable to refresh token\n${JSON.stringify(body)}`);
throw new Error(`Unable to refresh token\n${response.status} ${JSON.stringify(body)}`);
}

const { access_token: access } = body;
Expand Down Expand Up @@ -151,6 +186,9 @@ async function revokeToken(refresh?: string) {
token_type_hint: 'refresh_token',
client_id: CLIENT_ID,
}),
headers: {
'content-type': 'application/x-www-form-urlencoded',
},
});

if (!response.ok) {
Expand All @@ -171,3 +209,18 @@ function urlParams(data: Record<string, string>) {

return encoded;
}

async function decryptSessionData(env: Env, encryptedData?: string) {
const decryptedData = encryptedData ? await decrypt(payloadSecret(env), encryptedData) : undefined;
const sessionData: SessionData | null = JSON.parse(decryptedData ?? 'null');

return sessionData;
}

async function encryptSessionData(env: Env, sessionData: SessionData) {
return await encrypt(payloadSecret(env), JSON.stringify(sessionData));
}

function payloadSecret(env: Env) {
return DEV_PAYLOAD_SECRET || env.PAYLOAD_SECRET;
}
58 changes: 58 additions & 0 deletions packages/bolt/app/lib/crypto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const IV_LENGTH = 16;

export async function encrypt(key: string, data: string) {
const iv = crypto.getRandomValues(new Uint8Array(IV_LENGTH));
const cryptoKey = await getKey(key);

const ciphertext = await crypto.subtle.encrypt(
{
name: 'AES-CBC',
iv,
},
cryptoKey,
encoder.encode(data),
);

const bundle = new Uint8Array(IV_LENGTH + ciphertext.byteLength);

bundle.set(new Uint8Array(ciphertext));
bundle.set(iv, ciphertext.byteLength);

return decodeBase64(bundle);
}

export async function decrypt(key: string, payload: string) {
const bundle = encodeBase64(payload);

const iv = new Uint8Array(bundle.buffer, bundle.byteLength - IV_LENGTH);
const ciphertext = new Uint8Array(bundle.buffer, 0, bundle.byteLength - IV_LENGTH);

const cryptoKey = await getKey(key);

const plaintext = await crypto.subtle.decrypt(
{
name: 'AES-CBC',
iv,
},
cryptoKey,
ciphertext,
);

return decoder.decode(plaintext);
}

async function getKey(key: string) {
return await crypto.subtle.importKey('raw', encodeBase64(key), { name: 'AES-CBC' }, false, ['encrypt', 'decrypt']);
}

function decodeBase64(encoded: Uint8Array) {
const byteChars = Array.from(encoded, (byte) => String.fromCodePoint(byte));

return btoa(byteChars.join(''));
}

function encodeBase64(data: string) {
return Uint8Array.from(atob(data), (ch) => ch.codePointAt(0)!);
}
6 changes: 3 additions & 3 deletions packages/bolt/app/routes/api.analytics.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { json, type ActionFunctionArgs } from '@remix-run/cloudflare';
import { handleWithAuth } from '~/lib/.server/login';
import { getSession } from '~/lib/.server/sessions';
import { getSessionData } from '~/lib/.server/sessions';
import { sendEventInternal, type AnalyticsEvent } from '~/lib/analytics';

async function analyticsAction({ request, context }: ActionFunctionArgs) {
const event: AnalyticsEvent = await request.json();
const { session } = await getSession(request, context.cloudflare.env);
const { success, error } = await sendEventInternal(session.data, event);
const sessionData = await getSessionData(request, context.cloudflare.env);
const { success, error } = await sendEventInternal(sessionData, event);

if (!success) {
return json({ error }, { status: 500 });
Expand Down
6 changes: 3 additions & 3 deletions packages/bolt/app/routes/api.chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
import { handleWithAuth } from '~/lib/.server/login';
import { getSession } from '~/lib/.server/sessions';
import { getSessionData } from '~/lib/.server/sessions';
import { AnalyticsAction, AnalyticsTrackEvent, sendEventInternal } from '~/lib/analytics';

export async function action(args: ActionFunctionArgs) {
Expand All @@ -21,9 +21,9 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
toolChoice: 'none',
onFinish: async ({ text: content, finishReason, usage }) => {
if (finishReason !== 'length') {
const { session } = await getSession(request, context.cloudflare.env);
const sessionData = await getSessionData(request, context.cloudflare.env);

await sendEventInternal(session.data, {
await sendEventInternal(sessionData, {
action: AnalyticsAction.Track,
payload: {
event: AnalyticsTrackEvent.MessageComplete,
Expand Down
57 changes: 33 additions & 24 deletions packages/bolt/app/utils/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ interface Logger {

let currentLevel: DebugLevel = import.meta.env.VITE_LOG_LEVEL ?? import.meta.env.DEV ? 'debug' : 'info';

const isWorker = 'HTMLRewriter' in globalThis;
const supportsColor = !isWorker;

export const logger: Logger = {
trace: (...messages: any[]) => log('trace', undefined, messages),
debug: (...messages: any[]) => log('debug', undefined, messages),
Expand Down Expand Up @@ -44,35 +47,41 @@ function setLevel(level: DebugLevel) {
function log(level: DebugLevel, scope: string | undefined, messages: any[]) {
const levelOrder: DebugLevel[] = ['trace', 'debug', 'info', 'warn', 'error'];

if (levelOrder.indexOf(level) >= levelOrder.indexOf(currentLevel)) {
const labelBackgroundColor = getColorForLevel(level);
const labelTextColor = level === 'warn' ? 'black' : 'white';

const labelStyles = getLabelStyles(labelBackgroundColor, labelTextColor);
const scopeStyles = getLabelStyles('#77828D', 'white');
if (levelOrder.indexOf(level) < levelOrder.indexOf(currentLevel)) {
return;
}

const styles = [labelStyles];
const allMessages = messages.reduce((acc, current) => {
if (acc.endsWith('\n')) {
return acc + current;
}

if (typeof scope === 'string') {
styles.push('', scopeStyles);
if (!acc) {
return current;
}

console.log(
`%c${level.toUpperCase()}${scope ? `%c %c${scope}` : ''}`,
...styles,
messages.reduce((acc, current) => {
if (acc.endsWith('\n')) {
return acc + current;
}

if (!acc) {
return current;
}

return `${acc} ${current}`;
}, ''),
);
return `${acc} ${current}`;
}, '');

if (!supportsColor) {
console.log(`[${level.toUpperCase()}]`, allMessages);

return;
}

const labelBackgroundColor = getColorForLevel(level);
const labelTextColor = level === 'warn' ? 'black' : 'white';

const labelStyles = getLabelStyles(labelBackgroundColor, labelTextColor);
const scopeStyles = getLabelStyles('#77828D', 'white');

const styles = [labelStyles];

if (typeof scope === 'string') {
styles.push('', scopeStyles);
}

console.log(`%c${level.toUpperCase()}${scope ? `%c %c${scope}` : ''}`, ...styles, allMessages);
}

function getLabelStyles(color: string, textColor: string) {
Expand Down
2 changes: 1 addition & 1 deletion packages/bolt/worker-configuration.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
interface Env {
ANTHROPIC_API_KEY: string;
SESSION_SECRET: string;
LOGIN_PASSWORD: string;
PAYLOAD_SECRET: string;
}

0 comments on commit 44226db

Please sign in to comment.