Skip to content

Commit

Permalink
Showing 5 changed files with 415 additions and 114 deletions.
5 changes: 5 additions & 0 deletions .changeset/tough-camels-clean.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@atproto/pds": patch
---

Prevent server crash when catchall proxy request are cancelled
1 change: 1 addition & 0 deletions packages/pds/src/index.ts
Original file line number Diff line number Diff line change
@@ -127,6 +127,7 @@ export class PDS {
await this.ctx.backgroundQueue.destroy()
await this.ctx.accountManager.close()
await this.ctx.redisScratch?.quit()
await this.ctx.proxyAgent.destroy()
clearInterval(this.dbStatsInterval)
clearInterval(this.sequencerStatsInterval)
}
234 changes: 133 additions & 101 deletions packages/pds/src/pipethrough.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import express from 'express'
import { IncomingHttpHeaders } from 'node:http'
import { Duplex, Readable } from 'node:stream'
import { IncomingHttpHeaders, ServerResponse } from 'node:http'
import { PassThrough, Readable } from 'node:stream'
import { Dispatcher } from 'undici'

import {
@@ -89,6 +89,10 @@ export const proxyHandler = (ctx: AppContext): CatchallHandler => {
res.setHeader(name, val)
}

// Note that we should not need to manually handle errors here (e.g. by
// destroying the response), as the http server will handle them for us.
res.on('error', logResponseError)

// Tell undici to write the upstream response directly to the response
return res
})
@@ -135,7 +139,11 @@ export async function pipethrough(
ctx: AppContext,
req: express.Request,
options?: PipethroughOptions,
): Promise<HandlerPipeThroughStream> {
): Promise<{
stream: Readable
headers: Record<string, string>
encoding: string
}> {
if (req.method !== 'GET' && req.method !== 'HEAD') {
// pipethrough() is used from within xrpcServer handlers, which means that
// the request body either has been parsed or is a readable stream that has
@@ -192,7 +200,7 @@ export async function pipethrough(
headers: Object.fromEntries(responseHeaders(upstream.headers)),
encoding:
safeString(upstream.headers['content-type']) ?? 'application/json',
}
} satisfies HandlerPipeThroughStream
}

// Request setup/formatting
@@ -267,39 +275,48 @@ async function pipethroughStream(
dispatchOptions: Dispatcher.RequestOptions,
successStreamFactory: Dispatcher.StreamFactory,
): Promise<void> {
await ctx.proxyAgent
.stream(dispatchOptions, (upstream) => {
// Upstream resulted in an error, create a writable stream for undici
// that will decode & parse the error message and construct an XRPCError
if (upstream.statusCode !== ResponseType.Success) {
return Duplex.from(async function (
res: AsyncGenerator<Buffer, void, unknown>,
): Promise<void> {
return handleUpstreamResponseError(dispatchOptions, upstream, res)
})
}
return new Promise<void>((resolve, reject) => {
void ctx.proxyAgent
.stream(dispatchOptions, (upstream) => {
if (upstream.statusCode >= 400) {
const passThrough = new PassThrough()

void tryParsingError(upstream.headers, passThrough).then((parsed) => {
const xrpcError = new XRPCClientError(
upstream.statusCode === 500
? ResponseType.UpstreamFailure
: upstream.statusCode,
parsed.error,
parsed.message,
Object.fromEntries(responseHeaders(upstream.headers, false)),
{ cause: dispatchOptions },
)

reject(xrpcError)
}, reject)

return passThrough
}

try {
return successStreamFactory(upstream)
} catch (err) {
// Assume any error thrown from successStreamFactory() is due to an
// unsupported or invalid value in "upstream" (statusCode or headers).
// This will allow to distinguish requests errors bellow.
return handleUpstreamRequestError(
err,
dispatchOptions,
'unable to process upstream response',
)
}
})
.catch((err) => {
if (err instanceof XRPCServerError) throw err
if (err instanceof XRPCClientError) throw err

// Any other error here was caused by undici, the network or the writable
// stream returned by the function above (e.g. decoding error).
return handleUpstreamRequestError(err, dispatchOptions)
})
const writable = successStreamFactory(upstream)

// As soon as the control was passed to the writable stream (i.e. by
// returning the writable hereafter), pipethroughStream() is considered
// to have succeeded. Any error occurring while writing upstream data to
// the writable stream should be handled through the stream's error
// state (i.e. successStreamFactory() must ensure that error events on
// the returned writable will be handled).
resolve()

return writable
})
// The following catch block will be triggered with either network errors
// or writable stream errors. In the latter case, the promise will already
// be resolved, and reject()ing it there after will have no effect. Those
// error would still be logged by the successStreamFactory() function.
.catch(handleUpstreamRequestError)
.catch(reject)
})
}

/**
@@ -315,60 +332,35 @@ async function pipethroughRequest(

const upstream = await ctx.proxyAgent
.request(dispatchOptions)
.catch((err) => handleUpstreamRequestError(err, dispatchOptions))

if (upstream.statusCode !== ResponseType.Success) {
return handleUpstreamResponseError(dispatchOptions, upstream)
.catch(handleUpstreamRequestError)

if (upstream.statusCode >= 400) {
const parsed = await tryParsingError(upstream.headers, upstream.body)

// Note "XRPCClientError" is used instead of "XRPCServerError" in order to
// allow users of this function to capture & handle these errors (namely in
// "app.bsky.feed.getPostThread").
throw new XRPCClientError(
upstream.statusCode === 500
? ResponseType.UpstreamFailure
: upstream.statusCode,
parsed.error,
parsed.message,
Object.fromEntries(responseHeaders(upstream.headers, false)),
{ cause: dispatchOptions },
)
}

return upstream
}

async function handleUpstreamResponseError(
dispatchOptions: Dispatcher.RequestOptions,
data: Dispatcher.ResponseData,
): Promise<never>
async function handleUpstreamResponseError(
dispatchOptions: Dispatcher.RequestOptions,
data: Dispatcher.StreamFactoryData,
body: Readable | AsyncGenerator<Buffer, void, unknown>,
): Promise<never>
async function handleUpstreamResponseError(
dispatchOptions: Dispatcher.RequestOptions,
data: Dispatcher.StreamFactoryData | Dispatcher.ResponseData,
body?: Readable | AsyncGenerator<Buffer, void, unknown>,
): Promise<never> {
const stream = body ?? ('body' in data ? data.body : undefined)

// Type-safety, should never happen
if (!stream) throw new TypeError('body is required')

const buffer = await bufferUpstreamResponse(
stream,
data.headers['content-encoding'],
)

const errInfo = safeParseJson(buffer.toString('utf8'))

// Throwing here will cause the promise returned by stream() to
// reject. This will cause the `.catch` block below to be triggered.
throw new XRPCClientError(
data.statusCode,
safeString(errInfo?.['error']),
safeString(errInfo?.['message']),
Object.fromEntries(responseHeaders(data.headers, false)),
{ cause: dispatchOptions },
)
}

function handleUpstreamRequestError(
err: unknown,
dispatchOptions: Dispatcher.RequestOptions,
message = 'pipethrough network error',
): never {
httpLogger.warn({ err }, message)
throw new XRPCServerError(ResponseType.UpstreamFailure, message, undefined, {
cause: [err, dispatchOptions],
cause: err,
})
}

@@ -447,21 +439,64 @@ function isNonNullable<T>(val: T): val is NonNullable<T> {
return val != null
}

export async function bufferUpstreamResponse(
stream: Readable | AsyncIterable<Uint8Array>,
contentEncoding?: string | string[],
): Promise<Buffer> {
// Needed for type-safety (should never happen irl)
if (Array.isArray(contentEncoding)) {
throw new XRPCServerError(
ResponseType.UpstreamFailure,
'upstream service returned multiple content-encoding headers',
export function isJsonContentType(contentType?: string): boolean | undefined {
if (contentType == null) return undefined
return /application\/(?:\w+\+)?json/i.test(contentType)
}

async function tryParsingError(
headers: IncomingHttpHeaders,
readable: Readable,
): Promise<{ error?: string; message?: string }> {
if (isJsonContentType(headers['content-type']) === false) {
// We don't known how to parse non JSON content types so we can discard the
// whole response.
//
// @NOTE we could also simply "drain" the stream here. This would prevent
// the upstream HTTP/1.1 connection from getting destroyed (closed). This
// would however imply to read the whole upstream response, which would be
// costly in terms of bandwidth and I/O processing. It is recommended to use
// HTTP/2 to avoid this issue (be able to destroy a single response stream
// without resetting the whole connection). This is not expected to happen
// too much as 4xx and 5xx responses are expected to be JSON.
readable.destroy()

return {}
}

try {
const buffer = await bufferUpstreamResponse(
readable,
headers['content-encoding'],
)

const errInfo: unknown = JSON.parse(buffer.toString('utf8'))
return {
error: safeString(errInfo?.['error']),
message: safeString(errInfo?.['message']),
}
} catch (err) {
// Failed to read, decode, buffer or parse. No big deal.
return {}
}
}

export async function bufferUpstreamResponse(
readable: Readable,
contentEncoding?: string | string[],
): Promise<Buffer> {
try {
return streamToNodeBuffer(decodeStream(stream, contentEncoding))
// Needed for type-safety (should never happen irl)
if (Array.isArray(contentEncoding)) {
throw new TypeError(
'upstream service returned multiple content-encoding headers',
)
}

return await streamToNodeBuffer(decodeStream(readable, contentEncoding))
} catch (err) {
if (!readable.destroyed) readable.destroy()

throw new XRPCServerError(
ResponseType.UpstreamFailure,
err instanceof TypeError ? err.message : 'unable to decode request body',
@@ -487,12 +522,7 @@ export async function asPipeThroughBuffer(
// Response parsing/forwarding
// -------------------

const RES_HEADERS_TO_FORWARD = [
'content-type',
'content-language',
'atproto-repo-rev',
'atproto-content-labelers',
]
const RES_HEADERS_TO_FORWARD = ['atproto-repo-rev', 'atproto-content-labelers']

function* responseHeaders(
headers: IncomingHttpHeaders,
@@ -504,6 +534,12 @@ function* responseHeaders(

const encoding = headers['content-encoding']
if (encoding) yield ['content-encoding', encoding]

const type = headers['content-type']
if (type) yield ['content-type', type]

const language = headers['content-language']
if (language) yield ['content-language', language]
}

for (let i = 0; i < RES_HEADERS_TO_FORWARD.length; i++) {
@@ -587,10 +623,6 @@ const safeString = (str: unknown): string | undefined => {
return typeof str === 'string' ? str : undefined
}

export const safeParseJson = (json: string): unknown => {
try {
return JSON.parse(json)
} catch {
return null
}
function logResponseError(this: ServerResponse, err: unknown): void {
httpLogger.warn({ err }, 'error forwarding upstream response')
}
34 changes: 21 additions & 13 deletions packages/pds/src/read-after-write/util.ts
Original file line number Diff line number Diff line change
@@ -10,7 +10,11 @@ import express from 'express'
import AppContext from '../context'
import { lexicons } from '../lexicon/lexicons'
import { readStickyLogger as log } from '../logger'
import { asPipeThroughBuffer, pipethrough, safeParseJson } from '../pipethrough'
import {
asPipeThroughBuffer,
isJsonContentType,
pipethrough,
} from '../pipethrough'
import { HandlerResponse, LocalRecords, MungeFn } from './types'
import { getRecordsSinceRev } from './viewer'

@@ -39,27 +43,31 @@ export const pipethroughReadAfterWrite = async <T>(
const { req, auth } = reqCtx
const requester = auth.credentials.did

const upstreamRes = await pipethrough(ctx, req, { iss: requester })
const streamRes = await pipethrough(ctx, req, { iss: requester })

const rev = getRepoRev(streamRes.headers)
if (!rev) return streamRes

const rev = upstreamRes.headers && getRepoRev(upstreamRes.headers)
if (!rev) return upstreamRes
if (isJsonContentType(streamRes.headers['content-type']) === false) {
// content-type is present but not JSON, we can't munge this
return streamRes
}

// if the munging fails, we can't return the original response because the
// stream will already have been read. If we end-up buffering the response,
// we'll return the buffered response in case of an error.
let bufferedRes: HandlerPipeThroughBuffer | undefined
let bufferRes: HandlerPipeThroughBuffer | undefined

try {
const lxm = parseReqNsid(req)

return await ctx.actorStore.read(requester, async (store) => {
const local = await getRecordsSinceRev(store, rev)
if (local.count === 0) return upstreamRes

const lxm = parseReqNsid(req)
if (local.count === 0) return streamRes

bufferedRes = await asPipeThroughBuffer(upstreamRes)
const { buffer } = (bufferRes = await asPipeThroughBuffer(streamRes))

const value = safeParseJson(bufferedRes!.buffer.toString('utf8'))
const lex = value && jsonToLex(value)
const lex = jsonToLex(JSON.parse(buffer.toString('utf8')))

const parsedRes = lexicons.assertValidXrpcOutput(lxm, lex) as T

@@ -70,10 +78,10 @@ export const pipethroughReadAfterWrite = async <T>(
})
} catch (err) {
// The error occurred while reading the stream, this is non-recoverable
if (!bufferedRes && !upstreamRes.stream.readable) throw err
if (!bufferRes && !streamRes.stream.readable) throw err

log.warn({ err, requester }, 'error in read after write munge')
return bufferedRes ?? upstreamRes
return bufferRes ?? streamRes
}
}

255 changes: 255 additions & 0 deletions packages/pds/tests/proxied/proxy-catchall.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import AtpAgent from '@atproto/api'
import { Keypair } from '@atproto/crypto'
import { TestNetworkNoAppView } from '@atproto/dev-env'
import { LexiconDoc } from '@atproto/lexicon'
import * as plc from '@did-plc/lib'
import express from 'express'
import getPort from 'get-port'
import { once } from 'node:events'
import http from 'node:http'
import { setTimeout as sleep } from 'node:timers/promises'

const lexicons = [
{
lexicon: 1,
id: 'com.example.ok',
defs: {
main: {
type: 'query',
output: {
encoding: 'application/json',
schema: { type: 'object', properties: { foo: { type: 'string' } } },
},
},
},
},
{
lexicon: 1,
id: 'com.example.slow',
defs: {
main: {
type: 'query',
output: {
encoding: 'application/json',
schema: { type: 'object', properties: { foo: { type: 'string' } } },
},
},
},
},
{
lexicon: 1,
id: 'com.example.abort',
defs: {
main: {
type: 'query',
output: {
encoding: 'application/json',
schema: { type: 'object', properties: { foo: { type: 'string' } } },
},
},
},
},
{
lexicon: 1,
id: 'com.example.error',
defs: {
main: {
type: 'query',
output: {
encoding: 'application/json',
schema: { type: 'object', properties: { foo: { type: 'string' } } },
},
},
},
},
] as const satisfies LexiconDoc[]

describe('proxy header', () => {
let network: TestNetworkNoAppView
let alice: AtpAgent

let proxyServer: ProxyServer

beforeAll(async () => {
network = await TestNetworkNoAppView.create({
dbPostgresSchema: 'proxy_catchall',
})

const serviceId = 'proxy_test'

proxyServer = await ProxyServer.create(
network.pds.ctx.plcClient,
network.pds.ctx.plcRotationKey,
serviceId,
)

alice = network.pds.getClient().withProxy(serviceId, proxyServer.did)

for (const lex of lexicons) alice.lex.add(lex)

await alice.createAccount({
email: 'alice@test.com',
handle: 'alice.test',
password: 'alice-pass',
})
await network.processAll()
})

afterAll(async () => {
await proxyServer?.close()
await network?.close()
})

it('rejects when upstream unavailable', async () => {
const serviceId = 'foo_bar'

const proxyServer = await ProxyServer.create(
network.pds.ctx.plcClient,
network.pds.ctx.plcRotationKey,
serviceId,
)

// Make sure the service is not available
await proxyServer.close()

const client = alice.withProxy(serviceId, proxyServer.did)
for (const lex of lexicons) client.lex.add(lex)

await expect(client.call('com.example.ok')).rejects.toThrow(
'pipethrough network error',
)
})

it('successfully proxies requests', async () => {
await expect(alice.call('com.example.ok')).resolves.toMatchObject({
data: { foo: 'ok' },
success: true,
})
})

it('handles cancelled upstream requests', async () => {
await expect(alice.call('com.example.abort')).rejects.toThrow('terminated')
})

it('handles failing upstream requests', async () => {
await expect(alice.call('com.example.error')).rejects.toThrowError(
expect.objectContaining({
status: 502,
error: 'FooBar',
message: 'My message',
}),
)
})

it('handles cancelled downstream requests', async () => {
const ac = new AbortController()

setTimeout(() => ac.abort(), 20)

await expect(
alice.call('com.example.slow', {}, undefined, { signal: ac.signal }),
).rejects.toThrow('This operation was aborted')

await expect(alice.call('com.example.slow')).resolves.toMatchObject({
data: { foo: 'slow' },
success: true,
})
})
})

class ProxyServer {
constructor(
private server: http.Server,
public did: string,
) {}

static async create(
plcClient: plc.Client,
keypair: Keypair,
serviceId: string,
): Promise<ProxyServer> {
const app = express()

app.get('/xrpc/com.example.ok', (req, res) => {
res.status(200)
res.setHeader('content-type', 'application/json')
res.send('{"foo":"ok"}')
})

app.get('/xrpc/com.example.slow', async (req, res) => {
const wait = async (ms: number) => {
if (res.destroyed) return
const ac = new AbortController()
const abort = () => ac.abort()
res.on('close', abort)
try {
await sleep(ms, undefined, { signal: ac.signal })
} finally {
res.off('close', abort)
}
}

await wait(50)

res.status(200)
res.setHeader('content-type', 'application/json')
res.flushHeaders()

await wait(50)

for (const char of '{"foo":"slow"}') {
res.write(char)
await wait(10)
}

res.end()
})

app.get('/xrpc/com.example.abort', async (req, res) => {
res.status(200)
res.setHeader('content-type', 'application/json')
res.write('{"foo"')
await sleep(50)
res.destroy(new Error('abort'))
})

app.get('/xrpc/com.example.error', async (req, res) => {
res.status(500).json({ error: 'FooBar', message: 'My message' })
})

const port = await getPort()
const server = app.listen(port)
server.keepAliveTimeout = 30 * 1000
server.headersTimeout = 35 * 1000
await once(server, 'listening')

const plcOp = await plc.signOperation(
{
type: 'plc_operation',
rotationKeys: [keypair.did()],
alsoKnownAs: [],
verificationMethods: {},
services: {
[serviceId]: {
type: 'TestAtprotoService',
endpoint: `http://localhost:${port}`,
},
},
prev: null,
},
keypair,
)
const did = await plc.didForCreateOp(plcOp)
await plcClient.sendOperation(did, plcOp)
return new ProxyServer(server, did)
}

async close() {
await new Promise<void>((resolve, reject) => {
this.server.close((err) => {
if (err) reject(err)
else resolve()
})
})
}
}

0 comments on commit b298bfd

Please sign in to comment.