From 134920d872a2c6f0939db7bbf58d2e7cf30d03fb Mon Sep 17 00:00:00 2001 From: Daniel Holmgren Date: Thu, 2 Feb 2023 15:04:38 -0600 Subject: [PATCH] Database notify/listen (#514) * setup notify/listen for db * end pool to fix hanging test * small comment in test * switch to event emitter * reconnect on listener error * rename notifyClient * remove payload on channels * pr feedback * poor guy cant spell --- packages/common/src/async.ts | 24 +++++++ packages/common/src/index.ts | 1 + packages/pds/package.json | 1 + packages/pds/src/db/index.ts | 102 +++++++++++++++++++++++---- packages/pds/tests/db-notify.test.ts | 71 +++++++++++++++++++ yarn.lock | 2 +- 6 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 packages/common/src/async.ts create mode 100644 packages/pds/tests/db-notify.test.ts diff --git a/packages/common/src/async.ts b/packages/common/src/async.ts new file mode 100644 index 00000000000..0da04247f98 --- /dev/null +++ b/packages/common/src/async.ts @@ -0,0 +1,24 @@ +export type Deferrable = { + resolve: () => void + complete: Promise +} + +export const createDeferrable = (): Deferrable => { + let resolve + const promise: Promise = new Promise((res) => { + resolve = () => res() + }) + return { resolve, complete: promise } +} + +export const createDeferrables = (count: number): Deferrable[] => { + const list: Deferrable[] = [] + for (let i = 0; i < count; i++) { + list.push(createDeferrable()) + } + return list +} + +export const allComplete = async (deferrables: Deferrable[]): Promise => { + await Promise.all(deferrables.map((d) => d.complete)) +} diff --git a/packages/common/src/index.ts b/packages/common/src/index.ts index 0a16fa1d9d5..8a8230914eb 100644 --- a/packages/common/src/index.ts +++ b/packages/common/src/index.ts @@ -1,6 +1,7 @@ export * as check from './check' export * as util from './util' +export * from './async' export * from './util' export * from './tid' export * from './ipld' diff --git a/packages/pds/package.json b/packages/pds/package.json index 374405d0910..a28d1db414b 100644 --- a/packages/pds/package.json +++ b/packages/pds/package.json @@ -54,6 +54,7 @@ "pino": "^8.6.1", "pino-http": "^8.2.1", "sharp": "^0.31.2", + "typed-emitter": "^2.1.0", "uint8arrays": "3.0.0" }, "devDependencies": { diff --git a/packages/pds/src/db/index.ts b/packages/pds/src/db/index.ts index 02c51aaadd7..528846d13a0 100644 --- a/packages/pds/src/db/index.ts +++ b/packages/pds/src/db/index.ts @@ -1,23 +1,31 @@ import assert from 'assert' import { Kysely, SqliteDialect, PostgresDialect, Migrator } from 'kysely' import SqliteDB from 'better-sqlite3' -import { Pool as PgPool, types as pgTypes } from 'pg' +import { + Pool as PgPool, + PoolClient as PgPoolClient, + types as pgTypes, +} from 'pg' +import EventEmitter from 'events' +import TypedEmitter from 'typed-emitter' import DatabaseSchema, { DatabaseSchemaType } from './database-schema' import { dummyDialect } from './util' import * as migrations from './migrations' import { CtxMigrationProvider } from './migrations/provider' +import { dbLogger as log } from '../logger' export class Database { + channels: Channels = { + repo_seq: new EventEmitter() as ChannelEmitter, + } migrator: Migrator - constructor( - public db: DatabaseSchema, - public dialect: Dialect, - public schema?: string, - ) { + private channelClient: PgPoolClient | null = null + + constructor(public db: DatabaseSchema, public cfg: DialectConfig) { this.migrator = new Migrator({ db, - migrationTableSchema: schema, - provider: new CtxMigrationProvider(migrations, dialect), + migrationTableSchema: cfg.dialect === 'pg' ? cfg.schema : undefined, + provider: new CtxMigrationProvider(migrations, cfg.dialect), }) } @@ -27,7 +35,7 @@ export class Database { database: new SqliteDB(location), }), }) - return new Database(db, 'sqlite') + return new Database(db, { dialect: 'sqlite' }) } static postgres(opts: PgOptions): Database { @@ -45,30 +53,69 @@ export class Database { `Postgres schema must only contain [A-Za-z_]: ${schema}`, ) } - pool.on('connect', (client) => + pool.on('connect', (client) => { // Shared objects such as extensions will go in the public schema - client.query(`SET search_path TO "${schema}",public`), - ) + client.query(`SET search_path TO "${schema}",public`) + }) } const db = new Kysely({ dialect: new PostgresDialect({ pool }), }) - return new Database(db, 'pg', schema) + return new Database(db, { dialect: 'pg', pool, schema }) } static memory(): Database { return Database.sqlite(':memory:') } + async startListeningToChannels() { + if (this.cfg.dialect !== 'pg') return + this.channelClient = await this.cfg.pool.connect() + await this.channelClient.query(`LISTEN repo_seq`) + this.channelClient.on('notification', (msg) => { + const channel = this.channels[msg.channel] + if (channel) { + channel.emit('message') + } + }) + this.channelClient.on('error', (err) => { + log.error({ err }, 'postgres listener errored, reconnecting') + this.channelClient?.removeAllListeners() + this.startListeningToChannels() + }) + } + + notify(channel: keyof Channels) { + if (channel !== 'repo_seq') { + throw new Error(`attempted sending on unavailable channel: ${channel}`) + } + if (this.cfg.dialect === 'pg') { + this.cfg.pool.query(`NOTIFY ${channel}`) + } else { + const emitter = this.channels[channel] + if (emitter) { + emitter.emit('message') + } + } + } + async transaction(fn: (db: Database) => Promise): Promise { return await this.db.transaction().execute((txn) => { - const dbTxn = new Database(txn, this.dialect, this.schema) + const dbTxn = new Database(txn, this.cfg) return fn(dbTxn) }) } + get schema(): string | undefined { + return this.cfg.dialect === 'pg' ? this.cfg.schema : undefined + } + + get dialect(): Dialect { + return this.cfg.dialect + } + get isTransaction() { return this.db.isTransaction } @@ -78,6 +125,11 @@ export class Database { } async close(): Promise { + this.channelClient?.removeAllListeners() + this.channelClient?.release() + if (this.cfg.dialect === 'pg') { + await this.cfg.pool.end() + } await this.db.destroy() } @@ -114,9 +166,31 @@ export default Database export type Dialect = 'pg' | 'sqlite' +export type DialectConfig = PgConfig | SqliteConfig + +export type PgConfig = { + dialect: 'pg' + pool: PgPool + schema?: string +} + +export type SqliteConfig = { + dialect: 'sqlite' +} + // Can use with typeof to get types for partial queries export const dbType = new Kysely({ dialect: dummyDialect }) type PgOptions = | { url: string; schema?: string } | { pool: PgPool; schema?: string } + +type ChannelEvents = { + message: () => void +} + +type ChannelEmitter = TypedEmitter + +type Channels = { + repo_seq: ChannelEmitter +} diff --git a/packages/pds/tests/db-notify.test.ts b/packages/pds/tests/db-notify.test.ts new file mode 100644 index 00000000000..56e4756b709 --- /dev/null +++ b/packages/pds/tests/db-notify.test.ts @@ -0,0 +1,71 @@ +import { allComplete, createDeferrables } from '@atproto/common' +import { Database } from '../src' + +describe('db', () => { + let dbOne: Database + let dbTwo: Database + + beforeAll(async () => { + if (process.env.DB_POSTGRES_URL) { + dbOne = Database.postgres({ + url: process.env.DB_POSTGRES_URL, + schema: 'db_notify', + }) + dbTwo = Database.postgres({ + url: process.env.DB_POSTGRES_URL, + schema: 'db_notify', + }) + await dbOne.startListeningToChannels() + await dbTwo.startListeningToChannels() + } else { + // in the sqlite case, we just use two references to the same db + dbOne = Database.memory() + dbTwo = dbOne + } + }) + + afterAll(async () => { + await dbOne.close() + await dbTwo.close() + }) + + it('notifies', async () => { + const sendCount = 5 + const deferrables = createDeferrables(sendCount) + let receivedCount = 0 + dbOne.channels.repo_seq.addListener('message', () => { + deferrables[receivedCount]?.resolve() + receivedCount++ + }) + + for (let i = 0; i < sendCount; i++) { + dbTwo.notify('repo_seq') + } + + await allComplete(deferrables) + expect(receivedCount).toBe(sendCount) + }) + + it('can notifies multiple listeners', async () => { + const sendCount = 5 + const deferrables = createDeferrables(sendCount * 2) + let receivedOne = 0 + let receivedTwo = 0 + dbOne.channels.repo_seq.addListener('message', () => { + deferrables[receivedOne]?.resolve() + receivedOne++ + }) + dbOne.channels.repo_seq.addListener('message', () => { + deferrables[receivedTwo + sendCount]?.resolve() + receivedTwo++ + }) + + for (let i = 0; i < sendCount; i++) { + dbTwo.notify('repo_seq') + } + + await allComplete(deferrables) + expect(receivedOne).toBe(sendCount) + expect(receivedTwo).toBe(sendCount) + }) +}) diff --git a/yarn.lock b/yarn.lock index 19106623ce2..a60aeb2c196 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10519,7 +10519,7 @@ type-is@~1.6.18: typed-emitter@^2.1.0: version "2.1.0" - resolved "https://registry.npmjs.org/typed-emitter/-/typed-emitter-2.1.0.tgz" + resolved "https://registry.yarnpkg.com/typed-emitter/-/typed-emitter-2.1.0.tgz#ca78e3d8ef1476f228f548d62e04e3d4d3fd77fb" integrity sha512-g/KzbYKbH5C2vPkaXGu8DJlHrGKHLsM25Zg9WuC9pMGfuvT+X25tZQWo5fK1BjBm8+UrVE9LDCvaY0CQk+fXDA== optionalDependencies: rxjs "^7.5.2"