From a8d6c112359f5c4c0cfbe2df63443ed275f2a646 Mon Sep 17 00:00:00 2001 From: Matthieu Sieben Date: Tue, 18 Jun 2024 21:11:37 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20OAuth2=20-=20Authorization=20Ser?= =?UTF-8?q?ver=20(#2482)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(deps): update zod * chore(deps): update pino to match entryway version * chore(tsconfig): remove truncation of types through noErrorTruncation * add support for DPoP token type when logging * fix(bsky): JSON.parse does not return value of type JSON * fix(pds): add res property to ReqCtx * fix(pds): properly type getPreferences return value * chore(tsconfig): disable noFallthroughCasesInSwitch * refactor(pds): move tracer config in own file * feat(dev-env): start with "pnpm dev" * feat(oauth): add oauth provider & client libs * feat(pds): add oauth provider * chore: changeset * feat: various fixes and improvements * chore(deps): update better-sqlite3 to version 10.0.0 for node 22 compatibility * chore(deps): drop unused tslib * fix(did): normalize service IDs before looking for duplicates * fix(did): avoid minor type casting * fix(did): improve argument validation * fix(fetch): explicit use of negation around number comparison * fix(oauth-provider): improve argument validation * feat(did): add ATPROTO specific "isAtprotoDidWeb" method * feat(rollup-plugin-bundle-manifest): add readme * feat(lint): add eqeqeq rule (only allow == and != with null) * fix(oauth-client-browser): typo in gitignore * fix(oauth-provider): properly name error class file * fix(oauth-provider): remove un-necessary useMemo * fix(did-resolver): properly build did:web document url * fix(did-resolver): remove unused types * fix(fetch): remove unused utils * fix(pds): remove unused script and dependency * fix(oauth-provider): simplify isSubPath util * fix(oauth-provider): add InvalidRedirectUriError static constructor * fix(jwk): improve JWT validation to provide better error messages and distinguish between signed and unsigned tokens * fix(pds): use "debug" log level for fetch method * fix(pds): allow access tokens to contain an unknown "typ" claim (with the exception of "dpop+jwt") * fix(jwk): remove un-necessary code * fix(pds): account for whitespace chars when checking JSON * fix(pds): remove oauth specific config * fix(pds): run all write queries through transaction or executeWithRetry fix(pds): remove outdated comments fix(pds): rename used_refresh_token columns & added primary key fix(pds): run cleanup task through backgroundQueue fix(pds): add device.id foreign key to device_account fix(pds): add comment on cleanup of used_refresh_token fix(pds): add primary key on device_account * fix(oauth-provider:time): simplify constantTime util * fix(pds): rename disableSsrf into disableSsrfProtection * fix(oauth-client-react-native): remove incomplete package * refactor(pds): remove status & active from ActorAccount * fix(pds): invalidate all oauth tokens on takedown * fix(oauth-provider): enforce token expiry * fix(pds): properly support deactivated accounts * perf(pds:db): allow transaction function to be sync * refactor(psq:account-manager): expose only query builders & data transformations utils from helpers * fix(oauth-provider): imports from self * fix(ci): add nested packages to build artifacts * style(fetch): rename TODO into @TODO * style(rollup-plugin-bundle-manifest): remove "TODO" from comment * style(oauth-client): rename TODO into @TODO * style(oauth-provider): rename TODO into @TODO * refactor(oauth-client): remove "OAuth" prefix from types * fix(oauth-client-browser): better type SessionListener * style(oauth): rename TODO into @TODO * fix(oauth-provider): enforce provider max session age * fix(oauth-provider): check authentication parameters against all client metadata * fix(api): tests * fix(pds): remove .js from imports for tests * fix(pds): change account status to match tests * chore(deps): make all packages depend on the same zod version * fix(common-web): remove un-necessary binding of Checkable to "zod" * refactor(jwk): infer jwt schema from refinement definition * fix(handle-resolver): allow resolution errors to propagate docs(handle-resolver): better handling of DNS resolution errors fix(handle-resolver): properly handle DOH responses * fix(did): service endpoint arrays must contain "one or more" element * refactor(pipe): simplify implementation * fix(pds): add missing DB indexes * feat(oauth): Resolve Authorization Server URI through Protected Resource Metadata * style:(oauth-client): import order * docs(oauth-provider:redirect-uri): add reference url * feat(oauth): implement "OAuth Client ID Metadata Document" from draft-parecki-oauth-client-id-metadata-document-latest internet draft * feat(oauth-client): backport changes from feat-oauth-client * docs(simple-store): improve comments * feat(lexicons): add iterable capabilities * fix(pds): type error in dev mode * feat(oauth-provider): improved error reporting * fix(oauth-types): allow insecure issuer during tests * fix(xrpc-server): allow upload of empty files * fix: lint * feat(fetch): keep request reference in errors feat(fetch): utilities improvements * fix(pds): allow more than one session token per user * feat(ozone): improve env validation error messages * fix(oauth-client): account for DPoP when checking for invalid_token errors * fixup! feat(fetch): keep request reference in errors feat(fetch): utilities improvements * fixup! feat(fetch): keep request reference in errors feat(fetch): utilities improvements * fix(oauth): various validation fixes feat(oauth): share client_id validation and parsing utilities between client & provider * feat(dev-env): fix ozone port number * fix(fetch-node): prevent fetch against invalid domain names * fix(oauth-provider): add typings for psl dep * feat(jwk): make type def compatible with TS 4.x * fix(oauth): fixed various spec compliance fix(oauth): return "sub" in refresh token response fix(oauth): limit token validity for third party clients fix(oauth): hide client image when not trusted * fix(oauth): lint * pds: switch changeset to patch, no breaking changes * changeset and config for new oauth deps --------- Co-authored-by: Devin Ivy --- .changeset/clever-monkeys-sparkle.md | 23 + .eslintrc | 1 + .github/workflows/repo.yaml | 4 +- .gitignore | 4 +- package.json | 6 +- packages/README.md | 1 + packages/api/tests/agent.test.ts | 2 +- packages/bsky/package.json | 2 +- packages/bsky/src/hydration/util.ts | 4 +- packages/bsky/src/logger.ts | 119 +- packages/common-web/package.json | 2 +- packages/common-web/src/check.ts | 5 +- packages/common/package.json | 2 +- packages/dev-env/package.json | 3 +- packages/dev-env/src/bin.ts | 1 + packages/dev-env/src/pds.ts | 10 + packages/did/package.json | 36 + packages/did/src/did-document.ts | 151 ++ packages/did/src/did-error.ts | 49 + packages/did/src/did.ts | 258 ++ packages/did/src/index.ts | 4 + packages/did/src/methods.ts | 2 + packages/did/src/methods/plc.ts | 40 + packages/did/src/methods/web.ts | 78 + packages/did/tsconfig.build.json | 8 + packages/did/tsconfig.json | 4 + packages/internal/did-resolver/package.json | 40 + .../did-resolver/src/did-cache-memory.ts | 25 + .../internal/did-resolver/src/did-cache.ts | 31 + .../internal/did-resolver/src/did-method.ts | 17 + .../did-resolver/src/did-resolver-base.ts | 66 + .../did-resolver/src/did-resolver-common.ts | 20 + .../internal/did-resolver/src/did-resolver.ts | 15 + packages/internal/did-resolver/src/index.ts | 9 + packages/internal/did-resolver/src/methods.ts | 2 + .../internal/did-resolver/src/methods/plc.ts | 56 + .../internal/did-resolver/src/methods/web.ts | 58 + packages/internal/did-resolver/src/util.ts | 1 + .../internal/did-resolver/tsconfig.build.json | 8 + packages/internal/did-resolver/tsconfig.json | 4 + packages/internal/fetch-node/package.json | 40 + packages/internal/fetch-node/src/index.ts | 4 + packages/internal/fetch-node/src/safe.ts | 78 + packages/internal/fetch-node/src/ssrf.ts | 214 ++ .../internal/fetch-node/tsconfig.build.json | 8 + packages/internal/fetch-node/tsconfig.json | 4 + packages/internal/fetch/package.json | 37 + packages/internal/fetch/src/fetch-error.ts | 59 + packages/internal/fetch/src/fetch-request.ts | 118 + packages/internal/fetch/src/fetch-response.ts | 278 ++ packages/internal/fetch/src/fetch-wrap.ts | 122 + packages/internal/fetch/src/fetch.ts | 39 + packages/internal/fetch/src/index.ts | 6 + .../fetch/src/transformed-response.ts | 36 + packages/internal/fetch/src/util.ts | 169 ++ packages/internal/fetch/tsconfig.build.json | 8 + packages/internal/fetch/tsconfig.json | 4 + .../handle-resolver-node/package.json | 39 + .../src/atproto-handle-resolver-node.ts | 51 + .../handle-resolver-node/src/index.ts | 4 + .../src/node-resolve-txt-factory.ts | 100 + .../handle-resolver-node/tsconfig.build.json | 8 + .../handle-resolver-node/tsconfig.json | 4 + packages/internal/handle-resolver/README.md | 148 + .../internal/handle-resolver/package.json | 42 + .../src/app-view-handle-resolver.ts | 95 + .../src/atproto-doh-handle-resolver.ts | 123 + .../src/atproto-handle-resolver.ts | 75 + .../src/cached-handle-resolver.ts | 32 + .../internal/handle-resolver/src/index.ts | 9 + .../internal-resolvers/dns-handle-resolver.ts | 38 + .../well-known-handler-resolver.ts | 56 + .../internal/handle-resolver/src/types.ts | 33 + .../handle-resolver/tsconfig.build.json | 8 + .../internal/handle-resolver/tsconfig.json | 4 + .../internal/identity-resolver/package.json | 38 + .../src/identity-resolver.ts | 78 + .../internal/identity-resolver/src/index.ts | 1 + .../identity-resolver/tsconfig.build.json | 8 + .../internal/identity-resolver/tsconfig.json | 4 + packages/internal/pipe/package.json | 32 + packages/internal/pipe/src/index.ts | 2 + packages/internal/pipe/src/pipe.ts | 63 + packages/internal/pipe/src/transformer.ts | 1 + packages/internal/pipe/tsconfig.build.json | 8 + packages/internal/pipe/tsconfig.json | 4 + .../rollup-plugin-bundle-manifest/README.md | 99 + .../package.json | 39 + .../src/index.ts | 76 + .../tsconfig.build.json | 8 + .../tsconfig.json | 4 + .../internal/simple-store-memory/package.json | 36 + .../internal/simple-store-memory/src/index.ts | 99 + .../internal/simple-store-memory/src/util.ts | 77 + .../simple-store-memory/tsconfig.build.json | 8 + .../simple-store-memory/tsconfig.json | 4 + packages/internal/simple-store/package.json | 32 + .../simple-store/src/cached-getter.ts | 160 ++ packages/internal/simple-store/src/index.ts | 2 + .../internal/simple-store/src/simple-store.ts | 16 + .../internal/simple-store/tsconfig.build.json | 8 + packages/internal/simple-store/tsconfig.json | 4 + packages/lex-cli/package.json | 2 +- packages/lexicon/package.json | 2 +- packages/lexicon/src/lexicons.ts | 21 +- packages/oauth/jwk-jose/package.json | 36 + packages/oauth/jwk-jose/src/index.ts | 1 + packages/oauth/jwk-jose/src/jose-key.ts | 194 ++ packages/oauth/jwk-jose/src/util.ts | 9 + packages/oauth/jwk-jose/tsconfig.build.json | 8 + packages/oauth/jwk-jose/tsconfig.json | 4 + packages/oauth/jwk-webcrypto/package.json | 36 + packages/oauth/jwk-webcrypto/src/index.ts | 1 + packages/oauth/jwk-webcrypto/src/util.ts | 122 + .../oauth/jwk-webcrypto/src/webcrypto-key.ts | 69 + .../oauth/jwk-webcrypto/tsconfig.build.json | 8 + packages/oauth/jwk-webcrypto/tsconfig.json | 4 + packages/oauth/jwk/package.json | 38 + packages/oauth/jwk/src/alg.ts | 98 + packages/oauth/jwk/src/errors.ts | 56 + packages/oauth/jwk/src/index.ts | 10 + packages/oauth/jwk/src/jwk.ts | 141 + packages/oauth/jwk/src/jwks.ts | 15 + packages/oauth/jwk/src/jwt-decode.ts | 27 + packages/oauth/jwk/src/jwt-verify.ts | 22 + packages/oauth/jwk/src/jwt.ts | 173 ++ packages/oauth/jwk/src/key.ts | 93 + packages/oauth/jwk/src/keyset.ts | 240 ++ packages/oauth/jwk/src/util.ts | 181 ++ packages/oauth/jwk/tsconfig.build.json | 8 + packages/oauth/jwk/tsconfig.json | 4 + .../oauth-client-browser/example/.gitignore | 1 + .../example/.postcssrc.yml | 3 + .../oauth-client-browser/example/package.json | 3 + .../example/rollup.config.js | 91 + .../oauth-client-browser/example/src/app.tsx | 65 + .../example/src/index.css | 3 + .../example/src/login-form.tsx | 103 + .../oauth-client-browser/example/src/main.tsx | 12 + .../oauth-client-browser/example/src/oauth.ts | 119 + .../example/tailwind.config.js | 8 + .../example/tsconfig.build.json | 11 + .../example/tsconfig.json | 7 + .../example/tsconfig.tools.json | 8 + .../oauth/oauth-client-browser/package.json | 72 + .../src/browser-oauth-client.ts | 514 ++++ .../src/browser-oauth-database.ts | 267 ++ .../src/browser-runtime-implementation.ts | 65 + .../src/disposable-polyfill/README.md | 9 + .../src/disposable-polyfill/index.ts | 10 + .../oauth/oauth-client-browser/src/errors.ts | 10 + .../oauth/oauth-client-browser/src/index.ts | 16 + .../src/indexed-db-store.ts | 79 + .../src/indexed-db/README.md | 8 + .../src/indexed-db/db-index.ts | 44 + .../src/indexed-db/db-object-store.ts | 47 + .../src/indexed-db/db-transaction.ts | 52 + .../oauth-client-browser/src/indexed-db/db.ts | 114 + .../src/indexed-db/index.ts | 6 + .../src/indexed-db/schema.ts | 2 + .../src/indexed-db/util.ts | 20 + .../oauth/oauth-client-browser/src/util.ts | 33 + .../oauth-client-browser/tsconfig.build.json | 8 + .../oauth/oauth-client-browser/tsconfig.json | 4 + packages/oauth/oauth-client/README.md | 124 + packages/oauth/oauth-client/package.json | 46 + packages/oauth/oauth-client/src/constants.ts | 4 + packages/oauth/oauth-client/src/fetch-dpop.ts | 235 ++ packages/oauth/oauth-client/src/index.ts | 18 + packages/oauth/oauth-client/src/lock.ts | 34 + .../oauth/oauth-client/src/oauth-agent.ts | 150 + ...-authorization-server-metadata-resolver.ts | 98 + .../oauth-client/src/oauth-callback-error.ts | 16 + .../oauth/oauth-client/src/oauth-client.ts | 440 +++ ...th-protected-resource-metadata-resolver.ts | 102 + .../oauth-client/src/oauth-resolver-error.ts | 12 + .../oauth/oauth-client/src/oauth-resolver.ts | 111 + .../oauth-client/src/oauth-response-error.ts | 31 + .../oauth-client/src/oauth-server-agent.ts | 275 ++ .../oauth-client/src/oauth-server-factory.ts | 41 + .../oauth/oauth-client/src/refresh-error.ts | 9 + .../src/runtime-implementation.ts | 17 + packages/oauth/oauth-client/src/runtime.ts | 211 ++ .../oauth/oauth-client/src/session-getter.ts | 182 ++ packages/oauth/oauth-client/src/types.ts | 26 + packages/oauth/oauth-client/src/util.ts | 51 + .../src/validate-client-metadata.ts | 61 + .../oauth/oauth-client/tsconfig.build.json | 8 + packages/oauth/oauth-client/tsconfig.json | 4 + packages/oauth/oauth-provider/.postcssrc.yml | 3 + packages/oauth/oauth-provider/package.json | 83 + .../oauth/oauth-provider/rollup.config.js | 55 + .../src/access-token/access-token-type.ts | 5 + .../src/account/account-manager.ts | 55 + .../src/account/account-store.ts | 74 + .../oauth-provider/src/account/account.ts | 10 + .../oauth-provider/src/assets/app/app.tsx | 28 + .../src/assets/app/backend-data.ts | 65 + .../src/assets/app/components/accept-form.tsx | 112 + .../app/components/account-identifier.tsx | 18 + .../assets/app/components/account-picker.tsx | 108 + .../app/components/client-identifier.tsx | 32 + .../src/assets/app/components/client-name.tsx | 30 + .../src/assets/app/components/error-card.tsx | 41 + .../src/assets/app/components/help-card.tsx | 42 + .../app/components/layout-title-page.tsx | 43 + .../assets/app/components/layout-welcome.tsx | 58 + .../assets/app/components/sign-in-form.tsx | 290 ++ .../app/components/sign-up-account-form.tsx | 210 ++ .../app/components/sign-up-disclaimer.tsx | 44 + .../src/assets/app/components/url-viewer.tsx | 70 + .../oauth-provider/src/assets/app/cookies.ts | 11 + .../src/assets/app/hooks/use-api.ts | 104 + .../assets/app/hooks/use-bound-dispatch.ts | 5 + .../src/assets/app/hooks/use-csrf-token.ts | 5 + .../oauth-provider/src/assets/app/lib/api.ts | 64 + .../oauth-provider/src/assets/app/lib/clsx.ts | 4 + .../oauth-provider/src/assets/app/lib/util.ts | 10 + .../oauth-provider/src/assets/app/main.css | 11 + .../oauth-provider/src/assets/app/main.tsx | 28 + .../src/assets/app/views/accept-view.tsx | 51 + .../src/assets/app/views/authorize-view.tsx | 101 + .../src/assets/app/views/error-view.tsx | 27 + .../src/assets/app/views/sign-in-view.tsx | 121 + .../src/assets/app/views/sign-up-view.tsx | 93 + .../src/assets/app/views/welcome-view.tsx | 61 + .../oauth/oauth-provider/src/assets/asset.ts | 8 + .../src/assets/assets-middleware.ts | 32 + .../oauth/oauth-provider/src/assets/index.ts | 74 + .../oauth-provider/src/client/client-auth.ts | 45 + .../oauth-provider/src/client/client-data.ts | 9 + .../oauth-provider/src/client/client-id.ts | 4 + .../oauth-provider/src/client/client-info.ts | 13 + .../src/client/client-manager.ts | 818 ++++++ .../oauth-provider/src/client/client-store.ts | 38 + .../oauth-provider/src/client/client-utils.ts | 43 + .../oauth/oauth-provider/src/client/client.ts | 231 ++ .../oauth/oauth-provider/src/constants.ts | 69 + .../oauth-provider/src/device/device-data.ts | 11 + .../src/device/device-details.ts | 43 + .../oauth-provider/src/device/device-id.ts | 23 + .../src/device/device-manager.ts | 287 ++ .../oauth-provider/src/device/device-store.ts | 35 + .../oauth-provider/src/device/session-id.ts | 22 + .../oauth-provider/src/dpop/dpop-manager.ts | 147 + .../oauth-provider/src/dpop/dpop-nonce.ts | 104 + .../src/errors/access-denied-error.ts | 26 + .../account-selection-required-error.ts | 12 + .../src/errors/consent-required-error.ts | 12 + .../invalid-authorization-details-error.ts | 22 + .../src/errors/invalid-client-error.ts | 20 + .../src/errors/invalid-client-id-error.ts | 20 + .../errors/invalid-client-metadata-error.ts | 19 + .../errors/invalid-dpop-key-binding-error.ts | 21 + .../src/errors/invalid-dpop-proof-error.ts | 13 + .../src/errors/invalid-grant-error.ts | 16 + .../src/errors/invalid-parameters-error.ts | 12 + .../src/errors/invalid-redirect-uri-error.ts | 17 + .../src/errors/invalid-request-error.ts | 30 + .../src/errors/invalid-token-error.ts | 59 + .../src/errors/login-required-error.ts | 12 + .../oauth-provider/src/errors/oauth-error.ts | 28 + .../src/errors/unauthorized-client-error.ts | 20 + .../src/errors/use-dpop-nonce-error.ts | 32 + .../src/errors/www-authenticate-error.ts | 65 + packages/oauth/oauth-provider/src/index.ts | 15 + .../oauth-provider/src/lib/html/README.md | 9 + .../src/lib/html/build-document.ts | 98 + .../oauth-provider/src/lib/html/escapers.ts | 66 + .../oauth/oauth-provider/src/lib/html/html.ts | 61 + .../oauth-provider/src/lib/html/index.ts | 5 + .../oauth/oauth-provider/src/lib/html/tags.ts | 58 + .../oauth/oauth-provider/src/lib/html/util.ts | 21 + .../oauth-provider/src/lib/http/README.md | 11 + .../oauth-provider/src/lib/http/accept.ts | 91 + .../oauth-provider/src/lib/http/context.ts | 11 + .../oauth-provider/src/lib/http/index.ts | 9 + .../oauth-provider/src/lib/http/method.ts | 18 + .../oauth-provider/src/lib/http/middleware.ts | 183 ++ .../oauth-provider/src/lib/http/parser.ts | 64 + .../oauth/oauth-provider/src/lib/http/path.ts | 82 + .../oauth-provider/src/lib/http/request.ts | 141 + .../oauth-provider/src/lib/http/response.ts | 133 + .../oauth-provider/src/lib/http/route.ts | 56 + .../oauth-provider/src/lib/http/router.ts | 118 + .../oauth-provider/src/lib/http/stream.ts | 78 + .../oauth-provider/src/lib/http/types.ts | 22 + .../oauth/oauth-provider/src/lib/http/url.ts | 23 + .../oauth/oauth-provider/src/lib/redis.ts | 23 + .../src/lib/util/authorization-header.ts | 26 + .../oauth/oauth-provider/src/lib/util/cast.ts | 4 + .../oauth-provider/src/lib/util/crypto.ts | 27 + .../oauth/oauth-provider/src/lib/util/date.ts | 7 + .../oauth-provider/src/lib/util/hostname.ts | 19 + .../src/lib/util/redirect-uri.ts | 46 + .../oauth/oauth-provider/src/lib/util/time.ts | 33 + .../oauth/oauth-provider/src/lib/util/type.ts | 4 + .../oauth-provider/src/lib/util/well-known.ts | 8 + .../src/metadata/build-metadata.ts | 165 ++ .../oauth/oauth-provider/src/oauth-client.ts | 3 + .../oauth/oauth-provider/src/oauth-dpop.ts | 2 + .../oauth/oauth-provider/src/oauth-errors.ts | 21 + .../oauth/oauth-provider/src/oauth-hooks.ts | 66 + .../oauth-provider/src/oauth-provider.ts | 1409 ++++++++++ .../oauth/oauth-provider/src/oauth-store.ts | 11 + .../oauth-provider/src/oauth-verifier.ts | 219 ++ .../oauth/oauth-provider/src/oidc/claims.ts | 35 + packages/oauth/oauth-provider/src/oidc/sub.ts | 4 + .../oauth/oauth-provider/src/oidc/userinfo.ts | 11 + .../src/output/build-error-payload.ts | 143 + .../src/output/customization.ts | 96 + .../src/output/send-authorize-page.ts | 111 + .../src/output/send-authorize-redirect.ts | 130 + .../src/output/send-error-page.ts | 41 + .../src/output/send-web-page.ts | 66 + .../src/parameters/claims-requested.ts | 106 + .../src/parameters/oidc-payload.ts | 28 + .../src/replay/replay-manager.ts | 38 + .../src/replay/replay-store-memory.ts | 36 + .../src/replay/replay-store-redis.ts | 31 + .../oauth-provider/src/replay/replay-store.ts | 44 + .../oauth/oauth-provider/src/request/code.ts | 24 + .../src/request/request-data.ts | 26 + .../oauth-provider/src/request/request-id.ts | 23 + .../src/request/request-info.ts | 12 + .../src/request/request-manager.ts | 479 ++++ .../src/request/request-store-memory.ts | 39 + .../src/request/request-store-redis.ts | 71 + .../src/request/request-store.ts | 54 + .../oauth-provider/src/request/request-uri.ts | 29 + .../oauth/oauth-provider/src/request/types.ts | 48 + .../src/signer/signed-token-payload.ts | 35 + .../oauth/oauth-provider/src/signer/signer.ts | 165 ++ .../oauth-provider/src/token/refresh-token.ts | 31 + .../oauth-provider/src/token/token-claims.ts | 31 + .../oauth-provider/src/token/token-data.ts | 33 + .../oauth-provider/src/token/token-id.ts | 26 + .../oauth-provider/src/token/token-manager.ts | 591 ++++ .../oauth-provider/src/token/token-store.ts | 78 + .../oauth/oauth-provider/src/token/types.ts | 86 + .../src/token/verify-token-claims.ts | 65 + .../oauth/oauth-provider/tailwind.config.js | 13 + .../oauth-provider/tsconfig.backend.json | 9 + .../oauth-provider/tsconfig.frontend.json | 11 + packages/oauth/oauth-provider/tsconfig.json | 8 + .../oauth/oauth-provider/tsconfig.tools.json | 8 + packages/oauth/oauth-types/README.md | 3 + packages/oauth/oauth-types/package.json | 37 + .../oauth/oauth-types/src/access-token.ts | 4 + .../src/atproto-loopback-client-metadata.ts | 30 + packages/oauth/oauth-types/src/constants.ts | 9 + packages/oauth/oauth-types/src/index.ts | 27 + ...oauth-authentication-request-parameters.ts | 104 + .../src/oauth-authorization-details.ts | 28 + .../oauth-authorization-server-metadata.ts | 106 + .../src/oauth-client-credentials.ts | 34 + .../src/oauth-client-id-discoverable.ts | 66 + .../src/oauth-client-id-loopback.ts | 58 + .../oauth-types/src/oauth-client-id-url.ts | 25 + .../oauth/oauth-types/src/oauth-client-id.ts | 4 + .../src/oauth-client-identification.ts | 14 + .../oauth-types/src/oauth-client-metadata.ts | 75 + .../src/oauth-endpoint-auth-method.ts | 13 + .../oauth-types/src/oauth-endpoint-name.ts | 5 + .../oauth/oauth-types/src/oauth-grant-type.ts | 13 + .../src/oauth-issuer-identifier.ts | 61 + .../oauth-types/src/oauth-par-response.ts | 7 + .../src/oauth-protected-resource-metadata.ts | 85 + .../oauth-types/src/oauth-response-mode.ts | 9 + .../oauth-types/src/oauth-response-type.ts | 17 + .../oauth-types/src/oauth-token-response.ts | 29 + .../oauth/oauth-types/src/oauth-token-type.ts | 15 + .../oauth-types/src/oidc-claims-parameter.ts | 40 + .../oauth-types/src/oidc-claims-properties.ts | 11 + .../oauth/oauth-types/src/oidc-entity-type.ts | 5 + packages/oauth/oauth-types/src/util.ts | 20 + .../oauth/oauth-types/tsconfig.build.json | 8 + packages/oauth/oauth-types/tsconfig.json | 4 + packages/ozone/src/config/config.ts | 17 +- packages/pds/example.env | 24 +- packages/pds/package.json | 8 +- .../db/migrations/004-oauth.ts | 122 + .../account-manager/db/migrations/index.ts | 2 + .../db/schema/authorization-request.ts | 26 + .../db/schema/device-account.ts | 15 + .../src/account-manager/db/schema/device.ts | 18 + .../src/account-manager/db/schema/index.ts | 15 + .../src/account-manager/db/schema/token.ts | 34 + .../db/schema/used-refresh-token.ts | 13 + .../src/account-manager/helpers/account.ts | 37 +- .../helpers/authorization-request.ts | 82 + .../account-manager/helpers/device-account.ts | 135 + .../pds/src/account-manager/helpers/device.ts | 45 + .../pds/src/account-manager/helpers/token.ts | 185 ++ .../helpers/used-refresh-token.ts | 30 + packages/pds/src/account-manager/index.ts | 345 ++- .../pds/src/actor-store/preference/reader.ts | 2 +- .../api/com/atproto/server/createSession.ts | 52 +- .../api/com/atproto/server/deleteSession.ts | 34 +- .../src/api/com/atproto/server/getSession.ts | 9 +- .../api/com/atproto/server/refreshSession.ts | 8 +- .../src/api/com/atproto/sync/getRepoStatus.ts | 4 +- .../pds/src/api/com/atproto/sync/listRepos.ts | 2 +- packages/pds/src/api/proxy.ts | 20 +- packages/pds/src/auth-routes.ts | 27 + packages/pds/src/auth-verifier.ts | 404 ++- packages/pds/src/config/config.ts | 66 + packages/pds/src/config/env.ts | 24 + packages/pds/src/config/secrets.ts | 2 + packages/pds/src/context.ts | 94 +- packages/pds/src/db/cast.ts | 59 + packages/pds/src/db/db.ts | 27 +- packages/pds/src/db/index.ts | 1 + packages/pds/src/error.ts | 7 + packages/pds/src/index.ts | 2 + packages/pds/src/logger.ts | 121 +- .../pds/src/oauth/detailed-account-store.ts | 96 + packages/pds/src/oauth/provider.ts | 77 + packages/pds/src/pipethrough.ts | 5 +- packages/repo/package.json | 2 +- packages/syntax/src/aturi_validation.ts | 8 +- packages/syntax/src/did.ts | 2 +- packages/syntax/src/handle.ts | 2 +- packages/syntax/src/nsid.ts | 4 +- packages/syntax/src/recordkey.ts | 2 +- packages/syntax/src/tid.ts | 2 +- packages/syntax/tests/aturi.test.ts | 2 +- packages/syntax/tests/datetime.test.ts | 6 +- packages/syntax/tests/did.test.ts | 4 +- packages/syntax/tests/handle.test.ts | 4 +- packages/syntax/tests/nsid.test.ts | 4 +- packages/syntax/tests/recordkey.test.ts | 4 +- packages/syntax/tests/tid.test.ts | 4 +- packages/xrpc-server/package.json | 2 +- packages/xrpc-server/src/util.ts | 21 +- packages/xrpc-server/tests/bodies.test.ts | 33 +- packages/xrpc/package.json | 2 +- pnpm-lock.yaml | 2412 +++++++++++++++-- pnpm-workspace.yaml | 2 + services/pds/.gitignore | 3 +- services/pds/Dockerfile | 2 +- services/pds/index.js | 41 - services/pds/tracer.js | 40 + tsconfig.json | 17 + tsconfig/base.json | 3 +- tsconfig/browser.json | 8 + tsconfig/bundler.json | 10 + tsconfig/nodenext.json | 11 + 448 files changed, 26273 insertions(+), 629 deletions(-) create mode 100644 .changeset/clever-monkeys-sparkle.md create mode 100644 packages/did/package.json create mode 100644 packages/did/src/did-document.ts create mode 100644 packages/did/src/did-error.ts create mode 100644 packages/did/src/did.ts create mode 100644 packages/did/src/index.ts create mode 100644 packages/did/src/methods.ts create mode 100644 packages/did/src/methods/plc.ts create mode 100644 packages/did/src/methods/web.ts create mode 100644 packages/did/tsconfig.build.json create mode 100644 packages/did/tsconfig.json create mode 100644 packages/internal/did-resolver/package.json create mode 100644 packages/internal/did-resolver/src/did-cache-memory.ts create mode 100644 packages/internal/did-resolver/src/did-cache.ts create mode 100644 packages/internal/did-resolver/src/did-method.ts create mode 100644 packages/internal/did-resolver/src/did-resolver-base.ts create mode 100644 packages/internal/did-resolver/src/did-resolver-common.ts create mode 100644 packages/internal/did-resolver/src/did-resolver.ts create mode 100644 packages/internal/did-resolver/src/index.ts create mode 100644 packages/internal/did-resolver/src/methods.ts create mode 100644 packages/internal/did-resolver/src/methods/plc.ts create mode 100644 packages/internal/did-resolver/src/methods/web.ts create mode 100644 packages/internal/did-resolver/src/util.ts create mode 100644 packages/internal/did-resolver/tsconfig.build.json create mode 100644 packages/internal/did-resolver/tsconfig.json create mode 100644 packages/internal/fetch-node/package.json create mode 100644 packages/internal/fetch-node/src/index.ts create mode 100644 packages/internal/fetch-node/src/safe.ts create mode 100644 packages/internal/fetch-node/src/ssrf.ts create mode 100644 packages/internal/fetch-node/tsconfig.build.json create mode 100644 packages/internal/fetch-node/tsconfig.json create mode 100644 packages/internal/fetch/package.json create mode 100644 packages/internal/fetch/src/fetch-error.ts create mode 100644 packages/internal/fetch/src/fetch-request.ts create mode 100644 packages/internal/fetch/src/fetch-response.ts create mode 100644 packages/internal/fetch/src/fetch-wrap.ts create mode 100644 packages/internal/fetch/src/fetch.ts create mode 100644 packages/internal/fetch/src/index.ts create mode 100644 packages/internal/fetch/src/transformed-response.ts create mode 100644 packages/internal/fetch/src/util.ts create mode 100644 packages/internal/fetch/tsconfig.build.json create mode 100644 packages/internal/fetch/tsconfig.json create mode 100644 packages/internal/handle-resolver-node/package.json create mode 100644 packages/internal/handle-resolver-node/src/atproto-handle-resolver-node.ts create mode 100644 packages/internal/handle-resolver-node/src/index.ts create mode 100644 packages/internal/handle-resolver-node/src/node-resolve-txt-factory.ts create mode 100644 packages/internal/handle-resolver-node/tsconfig.build.json create mode 100644 packages/internal/handle-resolver-node/tsconfig.json create mode 100644 packages/internal/handle-resolver/README.md create mode 100644 packages/internal/handle-resolver/package.json create mode 100644 packages/internal/handle-resolver/src/app-view-handle-resolver.ts create mode 100644 packages/internal/handle-resolver/src/atproto-doh-handle-resolver.ts create mode 100644 packages/internal/handle-resolver/src/atproto-handle-resolver.ts create mode 100644 packages/internal/handle-resolver/src/cached-handle-resolver.ts create mode 100644 packages/internal/handle-resolver/src/index.ts create mode 100644 packages/internal/handle-resolver/src/internal-resolvers/dns-handle-resolver.ts create mode 100644 packages/internal/handle-resolver/src/internal-resolvers/well-known-handler-resolver.ts create mode 100644 packages/internal/handle-resolver/src/types.ts create mode 100644 packages/internal/handle-resolver/tsconfig.build.json create mode 100644 packages/internal/handle-resolver/tsconfig.json create mode 100644 packages/internal/identity-resolver/package.json create mode 100644 packages/internal/identity-resolver/src/identity-resolver.ts create mode 100644 packages/internal/identity-resolver/src/index.ts create mode 100644 packages/internal/identity-resolver/tsconfig.build.json create mode 100644 packages/internal/identity-resolver/tsconfig.json create mode 100644 packages/internal/pipe/package.json create mode 100644 packages/internal/pipe/src/index.ts create mode 100644 packages/internal/pipe/src/pipe.ts create mode 100644 packages/internal/pipe/src/transformer.ts create mode 100644 packages/internal/pipe/tsconfig.build.json create mode 100644 packages/internal/pipe/tsconfig.json create mode 100644 packages/internal/rollup-plugin-bundle-manifest/README.md create mode 100644 packages/internal/rollup-plugin-bundle-manifest/package.json create mode 100644 packages/internal/rollup-plugin-bundle-manifest/src/index.ts create mode 100644 packages/internal/rollup-plugin-bundle-manifest/tsconfig.build.json create mode 100644 packages/internal/rollup-plugin-bundle-manifest/tsconfig.json create mode 100644 packages/internal/simple-store-memory/package.json create mode 100644 packages/internal/simple-store-memory/src/index.ts create mode 100644 packages/internal/simple-store-memory/src/util.ts create mode 100644 packages/internal/simple-store-memory/tsconfig.build.json create mode 100644 packages/internal/simple-store-memory/tsconfig.json create mode 100644 packages/internal/simple-store/package.json create mode 100644 packages/internal/simple-store/src/cached-getter.ts create mode 100644 packages/internal/simple-store/src/index.ts create mode 100644 packages/internal/simple-store/src/simple-store.ts create mode 100644 packages/internal/simple-store/tsconfig.build.json create mode 100644 packages/internal/simple-store/tsconfig.json create mode 100644 packages/oauth/jwk-jose/package.json create mode 100644 packages/oauth/jwk-jose/src/index.ts create mode 100644 packages/oauth/jwk-jose/src/jose-key.ts create mode 100644 packages/oauth/jwk-jose/src/util.ts create mode 100644 packages/oauth/jwk-jose/tsconfig.build.json create mode 100644 packages/oauth/jwk-jose/tsconfig.json create mode 100644 packages/oauth/jwk-webcrypto/package.json create mode 100644 packages/oauth/jwk-webcrypto/src/index.ts create mode 100644 packages/oauth/jwk-webcrypto/src/util.ts create mode 100644 packages/oauth/jwk-webcrypto/src/webcrypto-key.ts create mode 100644 packages/oauth/jwk-webcrypto/tsconfig.build.json create mode 100644 packages/oauth/jwk-webcrypto/tsconfig.json create mode 100644 packages/oauth/jwk/package.json create mode 100644 packages/oauth/jwk/src/alg.ts create mode 100644 packages/oauth/jwk/src/errors.ts create mode 100644 packages/oauth/jwk/src/index.ts create mode 100644 packages/oauth/jwk/src/jwk.ts create mode 100644 packages/oauth/jwk/src/jwks.ts create mode 100644 packages/oauth/jwk/src/jwt-decode.ts create mode 100644 packages/oauth/jwk/src/jwt-verify.ts create mode 100644 packages/oauth/jwk/src/jwt.ts create mode 100644 packages/oauth/jwk/src/key.ts create mode 100644 packages/oauth/jwk/src/keyset.ts create mode 100644 packages/oauth/jwk/src/util.ts create mode 100644 packages/oauth/jwk/tsconfig.build.json create mode 100644 packages/oauth/jwk/tsconfig.json create mode 100644 packages/oauth/oauth-client-browser/example/.gitignore create mode 100644 packages/oauth/oauth-client-browser/example/.postcssrc.yml create mode 100644 packages/oauth/oauth-client-browser/example/package.json create mode 100644 packages/oauth/oauth-client-browser/example/rollup.config.js create mode 100644 packages/oauth/oauth-client-browser/example/src/app.tsx create mode 100644 packages/oauth/oauth-client-browser/example/src/index.css create mode 100644 packages/oauth/oauth-client-browser/example/src/login-form.tsx create mode 100644 packages/oauth/oauth-client-browser/example/src/main.tsx create mode 100644 packages/oauth/oauth-client-browser/example/src/oauth.ts create mode 100644 packages/oauth/oauth-client-browser/example/tailwind.config.js create mode 100644 packages/oauth/oauth-client-browser/example/tsconfig.build.json create mode 100644 packages/oauth/oauth-client-browser/example/tsconfig.json create mode 100644 packages/oauth/oauth-client-browser/example/tsconfig.tools.json create mode 100644 packages/oauth/oauth-client-browser/package.json create mode 100644 packages/oauth/oauth-client-browser/src/browser-oauth-client.ts create mode 100644 packages/oauth/oauth-client-browser/src/browser-oauth-database.ts create mode 100644 packages/oauth/oauth-client-browser/src/browser-runtime-implementation.ts create mode 100644 packages/oauth/oauth-client-browser/src/disposable-polyfill/README.md create mode 100644 packages/oauth/oauth-client-browser/src/disposable-polyfill/index.ts create mode 100644 packages/oauth/oauth-client-browser/src/errors.ts create mode 100644 packages/oauth/oauth-client-browser/src/index.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db-store.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/README.md create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/db-index.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/db-object-store.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/db-transaction.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/db.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/index.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/schema.ts create mode 100644 packages/oauth/oauth-client-browser/src/indexed-db/util.ts create mode 100644 packages/oauth/oauth-client-browser/src/util.ts create mode 100644 packages/oauth/oauth-client-browser/tsconfig.build.json create mode 100644 packages/oauth/oauth-client-browser/tsconfig.json create mode 100644 packages/oauth/oauth-client/README.md create mode 100644 packages/oauth/oauth-client/package.json create mode 100644 packages/oauth/oauth-client/src/constants.ts create mode 100644 packages/oauth/oauth-client/src/fetch-dpop.ts create mode 100644 packages/oauth/oauth-client/src/index.ts create mode 100644 packages/oauth/oauth-client/src/lock.ts create mode 100644 packages/oauth/oauth-client/src/oauth-agent.ts create mode 100644 packages/oauth/oauth-client/src/oauth-authorization-server-metadata-resolver.ts create mode 100644 packages/oauth/oauth-client/src/oauth-callback-error.ts create mode 100644 packages/oauth/oauth-client/src/oauth-client.ts create mode 100644 packages/oauth/oauth-client/src/oauth-protected-resource-metadata-resolver.ts create mode 100644 packages/oauth/oauth-client/src/oauth-resolver-error.ts create mode 100644 packages/oauth/oauth-client/src/oauth-resolver.ts create mode 100644 packages/oauth/oauth-client/src/oauth-response-error.ts create mode 100644 packages/oauth/oauth-client/src/oauth-server-agent.ts create mode 100644 packages/oauth/oauth-client/src/oauth-server-factory.ts create mode 100644 packages/oauth/oauth-client/src/refresh-error.ts create mode 100644 packages/oauth/oauth-client/src/runtime-implementation.ts create mode 100644 packages/oauth/oauth-client/src/runtime.ts create mode 100644 packages/oauth/oauth-client/src/session-getter.ts create mode 100644 packages/oauth/oauth-client/src/types.ts create mode 100644 packages/oauth/oauth-client/src/util.ts create mode 100644 packages/oauth/oauth-client/src/validate-client-metadata.ts create mode 100644 packages/oauth/oauth-client/tsconfig.build.json create mode 100644 packages/oauth/oauth-client/tsconfig.json create mode 100644 packages/oauth/oauth-provider/.postcssrc.yml create mode 100644 packages/oauth/oauth-provider/package.json create mode 100644 packages/oauth/oauth-provider/rollup.config.js create mode 100644 packages/oauth/oauth-provider/src/access-token/access-token-type.ts create mode 100644 packages/oauth/oauth-provider/src/account/account-manager.ts create mode 100644 packages/oauth/oauth-provider/src/account/account-store.ts create mode 100644 packages/oauth/oauth-provider/src/account/account.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/app.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/backend-data.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/accept-form.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/account-identifier.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/account-picker.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/client-identifier.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/client-name.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/error-card.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/help-card.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/layout-title-page.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/layout-welcome.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/sign-in-form.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/sign-up-account-form.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/sign-up-disclaimer.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/components/url-viewer.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/cookies.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/hooks/use-api.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/hooks/use-bound-dispatch.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/hooks/use-csrf-token.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/lib/api.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/lib/clsx.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/lib/util.ts create mode 100644 packages/oauth/oauth-provider/src/assets/app/main.css create mode 100644 packages/oauth/oauth-provider/src/assets/app/main.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/accept-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/authorize-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/error-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/sign-in-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/sign-up-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/app/views/welcome-view.tsx create mode 100644 packages/oauth/oauth-provider/src/assets/asset.ts create mode 100644 packages/oauth/oauth-provider/src/assets/assets-middleware.ts create mode 100644 packages/oauth/oauth-provider/src/assets/index.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-auth.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-data.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-id.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-info.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-manager.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-store.ts create mode 100644 packages/oauth/oauth-provider/src/client/client-utils.ts create mode 100644 packages/oauth/oauth-provider/src/client/client.ts create mode 100644 packages/oauth/oauth-provider/src/constants.ts create mode 100644 packages/oauth/oauth-provider/src/device/device-data.ts create mode 100644 packages/oauth/oauth-provider/src/device/device-details.ts create mode 100644 packages/oauth/oauth-provider/src/device/device-id.ts create mode 100644 packages/oauth/oauth-provider/src/device/device-manager.ts create mode 100644 packages/oauth/oauth-provider/src/device/device-store.ts create mode 100644 packages/oauth/oauth-provider/src/device/session-id.ts create mode 100644 packages/oauth/oauth-provider/src/dpop/dpop-manager.ts create mode 100644 packages/oauth/oauth-provider/src/dpop/dpop-nonce.ts create mode 100644 packages/oauth/oauth-provider/src/errors/access-denied-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/account-selection-required-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/consent-required-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-authorization-details-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-client-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-client-id-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-client-metadata-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-dpop-key-binding-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-dpop-proof-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-grant-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-parameters-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-redirect-uri-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-request-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/invalid-token-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/login-required-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/oauth-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/unauthorized-client-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/use-dpop-nonce-error.ts create mode 100644 packages/oauth/oauth-provider/src/errors/www-authenticate-error.ts create mode 100644 packages/oauth/oauth-provider/src/index.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/README.md create mode 100644 packages/oauth/oauth-provider/src/lib/html/build-document.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/escapers.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/html.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/index.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/tags.ts create mode 100644 packages/oauth/oauth-provider/src/lib/html/util.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/README.md create mode 100644 packages/oauth/oauth-provider/src/lib/http/accept.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/context.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/index.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/method.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/middleware.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/parser.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/path.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/request.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/response.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/route.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/router.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/stream.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/types.ts create mode 100644 packages/oauth/oauth-provider/src/lib/http/url.ts create mode 100644 packages/oauth/oauth-provider/src/lib/redis.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/authorization-header.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/cast.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/crypto.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/date.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/hostname.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/redirect-uri.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/time.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/type.ts create mode 100644 packages/oauth/oauth-provider/src/lib/util/well-known.ts create mode 100644 packages/oauth/oauth-provider/src/metadata/build-metadata.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-client.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-dpop.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-errors.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-hooks.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-provider.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-store.ts create mode 100644 packages/oauth/oauth-provider/src/oauth-verifier.ts create mode 100644 packages/oauth/oauth-provider/src/oidc/claims.ts create mode 100644 packages/oauth/oauth-provider/src/oidc/sub.ts create mode 100644 packages/oauth/oauth-provider/src/oidc/userinfo.ts create mode 100644 packages/oauth/oauth-provider/src/output/build-error-payload.ts create mode 100644 packages/oauth/oauth-provider/src/output/customization.ts create mode 100644 packages/oauth/oauth-provider/src/output/send-authorize-page.ts create mode 100644 packages/oauth/oauth-provider/src/output/send-authorize-redirect.ts create mode 100644 packages/oauth/oauth-provider/src/output/send-error-page.ts create mode 100644 packages/oauth/oauth-provider/src/output/send-web-page.ts create mode 100644 packages/oauth/oauth-provider/src/parameters/claims-requested.ts create mode 100644 packages/oauth/oauth-provider/src/parameters/oidc-payload.ts create mode 100644 packages/oauth/oauth-provider/src/replay/replay-manager.ts create mode 100644 packages/oauth/oauth-provider/src/replay/replay-store-memory.ts create mode 100644 packages/oauth/oauth-provider/src/replay/replay-store-redis.ts create mode 100644 packages/oauth/oauth-provider/src/replay/replay-store.ts create mode 100644 packages/oauth/oauth-provider/src/request/code.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-data.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-id.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-info.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-manager.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-store-memory.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-store-redis.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-store.ts create mode 100644 packages/oauth/oauth-provider/src/request/request-uri.ts create mode 100644 packages/oauth/oauth-provider/src/request/types.ts create mode 100644 packages/oauth/oauth-provider/src/signer/signed-token-payload.ts create mode 100644 packages/oauth/oauth-provider/src/signer/signer.ts create mode 100644 packages/oauth/oauth-provider/src/token/refresh-token.ts create mode 100644 packages/oauth/oauth-provider/src/token/token-claims.ts create mode 100644 packages/oauth/oauth-provider/src/token/token-data.ts create mode 100644 packages/oauth/oauth-provider/src/token/token-id.ts create mode 100644 packages/oauth/oauth-provider/src/token/token-manager.ts create mode 100644 packages/oauth/oauth-provider/src/token/token-store.ts create mode 100644 packages/oauth/oauth-provider/src/token/types.ts create mode 100644 packages/oauth/oauth-provider/src/token/verify-token-claims.ts create mode 100644 packages/oauth/oauth-provider/tailwind.config.js create mode 100644 packages/oauth/oauth-provider/tsconfig.backend.json create mode 100644 packages/oauth/oauth-provider/tsconfig.frontend.json create mode 100644 packages/oauth/oauth-provider/tsconfig.json create mode 100644 packages/oauth/oauth-provider/tsconfig.tools.json create mode 100644 packages/oauth/oauth-types/README.md create mode 100644 packages/oauth/oauth-types/package.json create mode 100644 packages/oauth/oauth-types/src/access-token.ts create mode 100644 packages/oauth/oauth-types/src/atproto-loopback-client-metadata.ts create mode 100644 packages/oauth/oauth-types/src/constants.ts create mode 100644 packages/oauth/oauth-types/src/index.ts create mode 100644 packages/oauth/oauth-types/src/oauth-authentication-request-parameters.ts create mode 100644 packages/oauth/oauth-types/src/oauth-authorization-details.ts create mode 100644 packages/oauth/oauth-types/src/oauth-authorization-server-metadata.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-credentials.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-id-discoverable.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-id-loopback.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-id-url.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-id.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-identification.ts create mode 100644 packages/oauth/oauth-types/src/oauth-client-metadata.ts create mode 100644 packages/oauth/oauth-types/src/oauth-endpoint-auth-method.ts create mode 100644 packages/oauth/oauth-types/src/oauth-endpoint-name.ts create mode 100644 packages/oauth/oauth-types/src/oauth-grant-type.ts create mode 100644 packages/oauth/oauth-types/src/oauth-issuer-identifier.ts create mode 100644 packages/oauth/oauth-types/src/oauth-par-response.ts create mode 100644 packages/oauth/oauth-types/src/oauth-protected-resource-metadata.ts create mode 100644 packages/oauth/oauth-types/src/oauth-response-mode.ts create mode 100644 packages/oauth/oauth-types/src/oauth-response-type.ts create mode 100644 packages/oauth/oauth-types/src/oauth-token-response.ts create mode 100644 packages/oauth/oauth-types/src/oauth-token-type.ts create mode 100644 packages/oauth/oauth-types/src/oidc-claims-parameter.ts create mode 100644 packages/oauth/oauth-types/src/oidc-claims-properties.ts create mode 100644 packages/oauth/oauth-types/src/oidc-entity-type.ts create mode 100644 packages/oauth/oauth-types/src/util.ts create mode 100644 packages/oauth/oauth-types/tsconfig.build.json create mode 100644 packages/oauth/oauth-types/tsconfig.json create mode 100644 packages/pds/src/account-manager/db/migrations/004-oauth.ts create mode 100644 packages/pds/src/account-manager/db/schema/authorization-request.ts create mode 100644 packages/pds/src/account-manager/db/schema/device-account.ts create mode 100644 packages/pds/src/account-manager/db/schema/device.ts create mode 100644 packages/pds/src/account-manager/db/schema/token.ts create mode 100644 packages/pds/src/account-manager/db/schema/used-refresh-token.ts create mode 100644 packages/pds/src/account-manager/helpers/authorization-request.ts create mode 100644 packages/pds/src/account-manager/helpers/device-account.ts create mode 100644 packages/pds/src/account-manager/helpers/device.ts create mode 100644 packages/pds/src/account-manager/helpers/token.ts create mode 100644 packages/pds/src/account-manager/helpers/used-refresh-token.ts create mode 100644 packages/pds/src/auth-routes.ts create mode 100644 packages/pds/src/db/cast.ts create mode 100644 packages/pds/src/oauth/detailed-account-store.ts create mode 100644 packages/pds/src/oauth/provider.ts create mode 100644 services/pds/tracer.js create mode 100644 tsconfig/browser.json create mode 100644 tsconfig/bundler.json create mode 100644 tsconfig/nodenext.json diff --git a/.changeset/clever-monkeys-sparkle.md b/.changeset/clever-monkeys-sparkle.md new file mode 100644 index 00000000000..2d9e4422fde --- /dev/null +++ b/.changeset/clever-monkeys-sparkle.md @@ -0,0 +1,23 @@ +--- +"@atproto/pds": patch +"@atproto-labs/rollup-plugin-bundle-manifest": minor +"@atproto-labs/handle-resolver-node": minor +"@atproto-labs/simple-store-memory": minor +"@atproto-labs/identity-resolver": minor +"@atproto/oauth-client-browser": minor +"@atproto-labs/handle-resolver": minor +"@atproto-labs/did-resolver": minor +"@atproto-labs/simple-store": minor +"@atproto/oauth-provider": minor +"@atproto-labs/fetch-node": minor +"@atproto/jwk-webcrypto": minor +"@atproto/oauth-client": minor +"@atproto/oauth-types": minor +"@atproto-labs/fetch": minor +"@atproto/jwk-jose": minor +"@atproto-labs/pipe": minor +"@atproto/jwk": minor +"@atproto/did": minor +--- + +Add OAuth provider capability & support for DPoP signed tokens diff --git a/.eslintrc b/.eslintrc index 650df576b78..8529167b1fd 100644 --- a/.eslintrc +++ b/.eslintrc @@ -18,6 +18,7 @@ "no-var": "error", "prefer-const": "warn", "no-misleading-character-class": "warn", + "eqeqeq": ["error", "always", { "null": "ignore" }], "@typescript-eslint/no-unused-vars": [ "warn", { "argsIgnorePattern": "^_", "varsIgnorePattern": "^_" } diff --git a/.github/workflows/repo.yaml b/.github/workflows/repo.yaml index 3b46597f514..b9840f6162c 100644 --- a/.github/workflows/repo.yaml +++ b/.github/workflows/repo.yaml @@ -27,7 +27,9 @@ jobs: - uses: actions/upload-artifact@v4 with: name: dist - path: packages/*/dist + path: | + packages/*/dist + packages/*/*/dist retention-days: 1 test: name: Test diff --git a/.gitignore b/.gitignore index 3648e5a3f73..dccb907cf77 100644 --- a/.gitignore +++ b/.gitignore @@ -2,14 +2,14 @@ node_modules lerna-debug.log npm-debug.log yarn-error.log -packages/*/dist +packages/**/dist .idea packages/*/coverage .vscode/ test.sqlite .DS_Store *.log -tsconfig.build.tsbuildinfo +*.tsbuildinfo .*.env .env \#*\# diff --git a/package.json b/package.json index f2a0a1561b2..644fff5ef8c 100644 --- a/package.json +++ b/package.json @@ -19,7 +19,7 @@ "verify:types": "tsc --build tsconfig.json", "format": "pnpm lint:fix && pnpm style:fix", "build": "pnpm --recursive --stream build", - "dev": "pnpm --stream '/^dev:.+$/'", + "dev": "NODE_ENV=development pnpm --stream '/^dev:.+$/'", "dev:tsc": "tsc --build tsconfig.json --watch", "dev:pkg": "pnpm --recursive --parallel --stream dev", "test": "LOG_ENABLED=false ./packages/dev-infra/with-test-redis-and-db.sh pnpm --stream -r test", @@ -51,7 +51,9 @@ }, "workspaces": { "packages": [ - "packages/*" + "packages/*", + "packages/oauth/*", + "packages/internal/*" ] } } diff --git a/packages/README.md b/packages/README.md index 68835de45f7..7e433ff8000 100644 --- a/packages/README.md +++ b/packages/README.md @@ -13,6 +13,7 @@ - [Crypto](./crypto): Atproto's common cryptographic operations. - [Syntax](./syntax): A library for identifier syntax: NSID, AT URI, handles, etc. - [Lexicon](./lexicon): A library for validating data using atproto's schema system. +- [OAuth Provider](./oauth/oauth-provider): A library for supporting ATPROTO's OAuth. - [Repo](./repo): The "atproto repository" core implementation (a Merkle Search Tree). - [XRPC](./xrpc): An XRPC client implementation. - [XRPC Server](./xrpc-server): An XRPC server implementation. diff --git a/packages/api/tests/agent.test.ts b/packages/api/tests/agent.test.ts index fae74d84323..d80036f474f 100644 --- a/packages/api/tests/agent.test.ts +++ b/packages/api/tests/agent.test.ts @@ -355,7 +355,7 @@ describe('agent', () => { expect(events.length).toEqual(2) expect(events[0]).toEqual('create-failed') - expect(events[1]).toEqual('network-error') + expect(events[1]).toEqual('expired') expect(sessions.length).toEqual(2) expect(typeof sessions[0]).toEqual('undefined') expect(typeof sessions[1]).toEqual('undefined') diff --git a/packages/bsky/package.json b/packages/bsky/package.json index d0afdf5a61e..cf4a4063e5f 100644 --- a/packages/bsky/package.json +++ b/packages/bsky/package.json @@ -51,7 +51,7 @@ "multiformats": "^9.9.0", "p-queue": "^6.6.2", "pg": "^8.10.0", - "pino": "^8.15.0", + "pino": "^8.21.0", "pino-http": "^8.2.1", "sharp": "^0.32.6", "structured-headers": "^1.0.1", diff --git a/packages/bsky/src/hydration/util.ts b/packages/bsky/src/hydration/util.ts index 7654a7bbbe9..771dc2a2de6 100644 --- a/packages/bsky/src/hydration/util.ts +++ b/packages/bsky/src/hydration/util.ts @@ -67,9 +67,7 @@ export const parseRecordBytes = ( return parseJsonBytes(bytes) as T } -export const parseJsonBytes = ( - bytes: Uint8Array | undefined, -): JSON | undefined => { +export const parseJsonBytes = (bytes: Uint8Array | undefined): unknown => { if (!bytes || bytes.byteLength === 0) return const parsed = JSON.parse(ui8.toString(bytes, 'utf8')) return parsed ?? undefined diff --git a/packages/bsky/src/logger.ts b/packages/bsky/src/logger.ts index cc28490143f..699ba9e8175 100644 --- a/packages/bsky/src/logger.ts +++ b/packages/bsky/src/logger.ts @@ -1,8 +1,6 @@ -import pino from 'pino' +import { stdSerializers } from 'pino' import pinoHttp from 'pino-http' -import * as jose from 'jose' import { subsystemLogger } from '@atproto/common' -import { parseBasicAuth } from './auth-verifier' export const dbLogger: ReturnType = subsystemLogger('bsky:db') @@ -20,40 +18,85 @@ export const httpLogger: ReturnType = export const loggerMiddleware = pinoHttp({ logger: httpLogger, serializers: { - err: (err) => { - return { - code: err?.code, - message: err?.message, - } - }, - req: (req) => { - const serialized = pino.stdSerializers.req(req) - const authHeader = serialized.headers.authorization || '' - let auth: string | undefined = undefined - if (authHeader.startsWith('Bearer ')) { - const token = authHeader.slice('Bearer '.length) - const { iss } = jose.decodeJwt(token) - if (iss) { - auth = 'Bearer ' + iss - } else { - auth = 'Bearer Invalid' - } - } - if (authHeader.startsWith('Basic ')) { - const parsed = parseBasicAuth(authHeader) - if (!parsed) { - auth = 'Basic Invalid' - } else { - auth = 'Basic ' + parsed.username - } - } - return { - ...serialized, - headers: { - ...serialized.headers, - authorization: auth, - }, - } - }, + err: errSerializer, + req: reqSerializer, }, }) + +function errSerializer(err: any) { + return { + code: err?.code, + message: err?.message, + } +} + +function reqSerializer(req: any) { + const serialized = stdSerializers.req(req) + serialized.headers = obfuscateHeaders(serialized.headers) + return serialized +} + +function obfuscateHeaders(headers: Record) { + const obfuscatedHeaders: Record = {} + for (const key in headers) { + if (key.toLowerCase() === 'authorization') { + obfuscatedHeaders[key] = obfuscateAuthHeader(headers[key]) + } else if (key.toLowerCase() === 'dpop') { + obfuscatedHeaders[key] = obfuscateJws(headers[key]) || 'Invalid' + } else { + obfuscatedHeaders[key] = headers[key] + } + } + return obfuscatedHeaders +} + +function obfuscateAuthHeader(authHeader: string): string { + // This is a hot path (runs on every request). Avoid using split() or regex. + + const spaceIdx = authHeader.indexOf(' ') + if (spaceIdx === -1) return 'Invalid' + + const type = authHeader.slice(0, spaceIdx) + switch (type.toLowerCase()) { + case 'bearer': + return `${type} ${obfuscateBearer(authHeader.slice(spaceIdx + 1))}` + case 'dpop': + return `${type} ${obfuscateJws(authHeader.slice(spaceIdx + 1)) || 'Invalid'}` + case 'basic': + return `${type} ${obfuscateBasic(authHeader.slice(spaceIdx + 1)) || 'Invalid'}` + default: + return `Invalid` + } +} + +function obfuscateBasic(token: string): null | string { + if (!token) return null + const buffer = Buffer.from(token, 'base64') + if (!buffer.length) return null // Buffer.from will silently ignore invalid base64 chars + const authHeader = buffer.toString('utf8') + const colIdx = authHeader.indexOf(':') + if (colIdx === -1) return null + const username = authHeader.slice(0, colIdx) + return `${username}:***` +} + +function obfuscateBearer(token: string): string { + return obfuscateJws(token) || obfuscateToken(token) +} + +function obfuscateToken(token: string): string { + return token ? '***' : '' +} + +function obfuscateJws(token: string): null | string { + const firstDot = token.indexOf('.') + if (firstDot === -1) return null + + const secondDot = token.indexOf('.', firstDot + 1) + if (secondDot === -1) return null + + if (token.indexOf('.', secondDot + 1) !== -1) return null + + // Strip the signature + return token.slice(0, secondDot) + '.obfuscated' +} diff --git a/packages/common-web/package.json b/packages/common-web/package.json index 51ef70fa93c..fa60f5c2b86 100644 --- a/packages/common-web/package.json +++ b/packages/common-web/package.json @@ -22,7 +22,7 @@ "graphemer": "^1.4.0", "multiformats": "^9.9.0", "uint8arrays": "3.0.0", - "zod": "^3.21.4" + "zod": "^3.23.8" }, "devDependencies": { "jest": "^28.1.2" diff --git a/packages/common-web/src/check.ts b/packages/common-web/src/check.ts index 47bdce95958..f3ae32d5ea3 100644 --- a/packages/common-web/src/check.ts +++ b/packages/common-web/src/check.ts @@ -1,10 +1,11 @@ -import { ZodError } from 'zod' +// Explicitly not using "zod" types here to avoid mismatching types due to +// version differences. export interface Checkable { parse: (obj: unknown) => T safeParse: ( obj: unknown, - ) => { success: true; data: T } | { success: false; error: ZodError } + ) => { success: true; data: T } | { success: false; error: Error } } export interface Def { diff --git a/packages/common/package.json b/packages/common/package.json index 9921ab46ee9..2ac7f23dd65 100644 --- a/packages/common/package.json +++ b/packages/common/package.json @@ -24,7 +24,7 @@ "cbor-x": "^1.5.1", "iso-datestring-validator": "^2.2.2", "multiformats": "^9.9.0", - "pino": "^8.15.0" + "pino": "^8.21.0" }, "devDependencies": { "jest": "^28.1.2", diff --git a/packages/dev-env/package.json b/packages/dev-env/package.json index bde5c4bf2df..ab0d1aacfbf 100644 --- a/packages/dev-env/package.json +++ b/packages/dev-env/package.json @@ -17,7 +17,8 @@ "bin": "dist/bin.js", "scripts": { "build": "tsc --build tsconfig.build.json", - "start": "../dev-infra/with-test-redis-and-db.sh node dist/bin.js" + "start": "../dev-infra/with-test-redis-and-db.sh node dist/bin.js", + "dev": "../dev-infra/with-test-redis-and-db.sh node --watch dist/bin.js" }, "dependencies": { "@atproto/api": "workspace:^", diff --git a/packages/dev-env/src/bin.ts b/packages/dev-env/src/bin.ts index 57a09db5572..c0b8c5c7e2c 100644 --- a/packages/dev-env/src/bin.ts +++ b/packages/dev-env/src/bin.ts @@ -27,6 +27,7 @@ const run = async () => { }, plc: { port: 2582 }, ozone: { + port: 2587, chatUrl: 'http://localhost:2590', // must run separate chat service chatDid: 'did:example:chat', }, diff --git a/packages/dev-env/src/pds.ts b/packages/dev-env/src/pds.ts index 0828f2f3f03..f7d088c63fb 100644 --- a/packages/dev-env/src/pds.ts +++ b/packages/dev-env/src/pds.ts @@ -45,6 +45,16 @@ export class TestPds { modServiceDid: 'did:example:invalid', plcRotationKeyK256PrivateKeyHex: plcRotationPriv, inviteRequired: false, + fetchDisableSsrfProtection: true, + serviceName: 'Development PDS', + primaryColor: '#ffcb1e', + errorColor: undefined, + logoUrl: + 'https://uxwing.com/wp-content/themes/uxwing/download/animals-and-birds/bee-icon.png', + homeUrl: 'https://bsky.social/', + termsOfServiceUrl: 'https://bsky.social/about/support/tos', + privacyPolicyUrl: 'https://bsky.social/about/support/privacy-policy', + supportUrl: 'https://blueskyweb.zendesk.com/hc/en-us', ...config, } const cfg = pds.envToCfg(env) diff --git a/packages/did/package.json b/packages/did/package.json new file mode 100644 index 00000000000..398adb4ef91 --- /dev/null +++ b/packages/did/package.json @@ -0,0 +1,36 @@ +{ + "name": "@atproto/did", + "version": "0.0.1", + "license": "MIT", + "description": "DID resolution and verification library", + "keywords": [ + "atproto", + "did", + "validation", + "types" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/did" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/did/src/did-document.ts b/packages/did/src/did-document.ts new file mode 100644 index 00000000000..c22af71e25c --- /dev/null +++ b/packages/did/src/did-document.ts @@ -0,0 +1,151 @@ +import { z } from 'zod' + +import { Did, didSchema } from './did.js' + +/** + * RFC3968 compliant URI + * + * @see {@link https://www.rfc-editor.org/rfc/rfc3986} + */ +const rfc3968UriSchema = z.string().refine((data) => { + try { + new URL(data) + return true + } catch { + return false + } +}, 'RFC3968 compliant URI') + +const didControllerSchema = z.union([didSchema, z.array(didSchema)]) + +/** + * @note this schema might be too permissive + */ +const didRelativeUriSchema = z.union([ + rfc3968UriSchema, + z.string().regex(/^#[^#]+$/), +]) + +const didVerificationMethodSchema = z.object({ + id: didRelativeUriSchema, + type: z.string().min(1), + controller: didControllerSchema, + publicKeyJwk: z.record(z.string(), z.unknown()).optional(), + publicKeyMultibase: z.string().optional(), +}) + +/** + * The value of the id property MUST be a URI conforming to [RFC3986]. A + * conforming producer MUST NOT produce multiple service entries with the same + * id. A conforming consumer MUST produce an error if it detects multiple + * service entries with the same id. + * + * @note Normally, only rfc3968UriSchema should be allowed here. However, the + * did:plc uses relative URI. For this reason, we also allow relative URIs + * here. + */ +const didServiceIdSchema = didRelativeUriSchema + +/** + * The value of the type property MUST be a string or a set of strings. In order + * to maximize interoperability, the service type and its associated properties + * SHOULD be registered in the DID Specification Registries + * [DID-SPEC-REGISTRIES]. + */ +const didServiceTypeSchema = z.union([z.string(), z.array(z.string())]) + +/** + * The value of the serviceEndpoint property MUST be a string, a map, or a set + * composed of one or more strings and/or maps. All string values MUST be valid + * URIs conforming to [RFC3986] and normalized according to the Normalization + * and Comparison rules in RFC3986 and to any normalization rules in its + * applicable URI scheme specification. + */ +const didServiceEndpointSchema = z.union([ + rfc3968UriSchema, + z.record(z.string(), rfc3968UriSchema), + z + .array(z.union([rfc3968UriSchema, z.record(z.string(), rfc3968UriSchema)])) + .nonempty(), +]) + +/** + * Each service map MUST contain id, type, and serviceEndpoint properties. + * @see {@link https://www.w3.org/TR/did-core/#services} + */ +const didServiceSchema = z.object({ + id: didServiceIdSchema, + type: didServiceTypeSchema, + serviceEndpoint: didServiceEndpointSchema, +}) + +export type DidService = z.infer + +const didAuthenticationSchema = z.union([ + // + didRelativeUriSchema, + didVerificationMethodSchema, +]) + +/** + * @note This schema is incomplete + * @see {@link https://www.w3.org/TR/did-core/#production-0} + */ +export const didDocumentSchema = z.object({ + '@context': z.union([ + z.literal('https://www.w3.org/ns/did/v1'), + z + .array(z.string().url()) + .nonempty() + .refine((data) => data[0] === 'https://www.w3.org/ns/did/v1', { + message: 'First @context must be https://www.w3.org/ns/did/v1', + }), + ]), + id: didSchema, + controller: didControllerSchema.optional(), + alsoKnownAs: z.array(rfc3968UriSchema).optional(), + service: z.array(didServiceSchema).optional(), + authentication: z.array(didAuthenticationSchema).optional(), + verificationMethod: z + .array(z.union([didVerificationMethodSchema, didRelativeUriSchema])) + .optional(), +}) + +export type DidDocument = z.infer< + typeof didDocumentSchema +> & { id: Did } + +// @TODO: add other refinements ? +export const didDocumentValidator = didDocumentSchema + .superRefine((data, ctx) => { + if (data.service) { + for (let i = 0; i < data.service.length; i++) { + if (data.service[i].id === data.id) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: `Service id must be different from the document id`, + path: ['service', i, 'id'], + }) + } + } + } + }) + .superRefine((data, ctx) => { + if (data.service) { + const normalizedIds = data.service.map((s) => + s.id?.startsWith('#') ? `${data.id}${s.id}` : s.id, + ) + + for (let i = 0; i < normalizedIds.length; i++) { + for (let j = i + 1; j < normalizedIds.length; j++) { + if (normalizedIds[i] === normalizedIds[j]) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: `Duplicate service id (${normalizedIds[j]}) found in the document`, + path: ['service', j, 'id'], + }) + } + } + } + } + }) diff --git a/packages/did/src/did-error.ts b/packages/did/src/did-error.ts new file mode 100644 index 00000000000..17fc96c296d --- /dev/null +++ b/packages/did/src/did-error.ts @@ -0,0 +1,49 @@ +export class DidError extends Error { + constructor( + public readonly did: string, + message: string, + public readonly code: string, + public readonly status = 400, + cause?: unknown, + ) { + super(message, { cause }) + } + + /** + * For compatibility with error handlers in common HTTP frameworks. + */ + get statusCode() { + return this.status + } + + override toString() { + return `${this.constructor.name} ${this.code} (${this.did}): ${this.message}` + } + + static from(cause: unknown, did: string): DidError { + if (cause instanceof DidError) { + return cause + } + + const message = + cause instanceof Error + ? cause.message + : typeof cause === 'string' + ? cause + : 'An unknown error occurred' + + const status = + (typeof cause?.['statusCode'] === 'number' + ? cause['statusCode'] + : undefined) ?? + (typeof cause?.['status'] === 'number' ? cause['status'] : undefined) + + return new DidError(did, message, 'did-unknown-error', status, cause) + } +} + +export class InvalidDidError extends DidError { + constructor(did: string, message: string, cause?: unknown) { + super(did, message, 'did-invalid', 400, cause) + } +} diff --git a/packages/did/src/did.ts b/packages/did/src/did.ts new file mode 100644 index 00000000000..e12874e7995 --- /dev/null +++ b/packages/did/src/did.ts @@ -0,0 +1,258 @@ +import { z } from 'zod' +import { DidError, InvalidDidError } from './did-error.js' + +const DID_PREFIX = 'did:' +const DID_PREFIX_LENGTH = DID_PREFIX.length +export { DID_PREFIX } + +/** + * Type representation of a Did, with method. + * + * ```bnf + * did = "did:" method-name ":" method-specific-id + * method-name = 1*method-char + * method-char = %x61-7A / DIGIT + * method-specific-id = *( *idchar ":" ) 1*idchar + * idchar = ALPHA / DIGIT / "." / "-" / "_" / pct-encoded + * pct-encoded = "%" HEXDIG HEXDIG + * ``` + * + * @example + * ```ts + * type DidWeb = Did<'web'> // `did:web:${string}` + * type DidCustom = Did<'web' | 'plc'> // `did:${'web' | 'plc'}:${string}` + * type DidNever = Did<' invalid 🥴 '> // never + * type DidFoo = Did<'foo' | ' invalid 🥴 '> // `did:foo:${string}` + * ``` + * + * @see {@link https://www.w3.org/TR/did-core/#did-syntax} + */ +export type Did = `did:${AsDidMethod}:${string}` + +/** + * DID Method + */ +export type AsDidMethod = string extends M + ? string // can't know... + : AsDidMethodInternal + +type AlphanumericChar = DigitChar | LowerAlphaChar +type DigitChar = '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' +type LowerAlphaChar = + | 'a' + | 'b' + | 'c' + | 'd' + | 'e' + | 'f' + | 'g' + | 'h' + | 'i' + | 'j' + | 'k' + | 'l' + | 'm' + | 'n' + | 'o' + | 'p' + | 'q' + | 'r' + | 's' + | 't' + | 'u' + | 'v' + | 'w' + | 'x' + | 'y' + | 'z' + +type AsDidMethodInternal< + S, + Acc extends string, +> = S extends `${infer H}${infer T}` + ? H extends AlphanumericChar + ? AsDidMethodInternal + : never + : Acc extends '' + ? never + : Acc + +/** + * DID Method-name check function. + * + * Check if the input is a valid DID method name, at the position between + * `start` (inclusive) and `end` (exclusive). + */ +export function checkDidMethod( + input: string, + start = 0, + end = input.length, +): void { + if ( + !Number.isFinite(end) || + !Number.isFinite(start) || + end < start || + end > input.length + ) { + throw new TypeError('Invalid start or end position') + } + if (end === start) { + throw new InvalidDidError(input, `Empty method name`) + } + + let c: number + for (let i = start; i < end; i++) { + c = input.charCodeAt(i) + if ( + (c < 0x61 || c > 0x7a) && // a-z + (c < 0x30 || c > 0x39) // 0-9 + ) { + throw new InvalidDidError( + input, + `Invalid character at position ${i} in DID method name`, + ) + } + } +} + +/** + * This method assumes the input is a valid Did + */ +export function extractDidMethod(did: D) { + const msidSep = did.indexOf(':', DID_PREFIX_LENGTH) + const method = did.slice(DID_PREFIX_LENGTH, msidSep) + return method as D extends Did ? M : string +} + +/** + * DID Method-specific identifier check function. + * + * Check if the input is a valid DID method-specific identifier, at the position + * between `start` (inclusive) and `end` (exclusive). + */ +export function checkDidMsid( + input: string, + start = 0, + end = input.length, +): void { + if ( + !Number.isFinite(end) || + !Number.isFinite(start) || + end < start || + end > input.length + ) { + throw new TypeError('Invalid start or end position') + } + if (end === start) { + throw new InvalidDidError(input, `DID method-specific id must not be empty`) + } + + let c: number + for (let i = start; i < end; i++) { + c = input.charCodeAt(i) + + // Check for frequent chars first + if ( + (c < 0x61 || c > 0x7a) && // a-z + (c < 0x41 || c > 0x5a) && // A-Z + (c < 0x30 || c > 0x39) && // 0-9 + c !== 0x2e && // . + c !== 0x2d && // - + c !== 0x5f // _ + ) { + // Less frequent chars are checked here + + // ":" + if (c === 0x3a) { + if (i === end - 1) { + throw new InvalidDidError(input, `DID cannot end with ":"`) + } + continue + } + + // pct-encoded + if (c === 0x25) { + c = input.charCodeAt(++i) + if ((c < 0x30 || c > 0x39) && (c < 0x41 || c > 0x46)) { + throw new InvalidDidError( + input, + `Invalid pct-encoded character at position ${i}`, + ) + } + + c = input.charCodeAt(++i) + if ((c < 0x30 || c > 0x39) && (c < 0x41 || c > 0x46)) { + throw new InvalidDidError( + input, + `Invalid pct-encoded character at position ${i}`, + ) + } + + // There must always be 2 HEXDIG after a "%" + if (i >= end) { + throw new InvalidDidError( + input, + `Incomplete pct-encoded character at position ${i - 2}`, + ) + } + + continue + } + + throw new InvalidDidError( + input, + `Disallowed character in DID at position ${i}`, + ) + } + } +} + +export function checkDid(input: unknown): asserts input is Did { + if (typeof input !== 'string') { + throw new InvalidDidError(typeof input, `DID must be a string`) + } + + const { length } = input + if (length > 2048) { + throw new InvalidDidError(input, `DID is too long (2048 chars max)`) + } + + if (!input.startsWith(DID_PREFIX)) { + throw new InvalidDidError(input, `DID requires "${DID_PREFIX}" prefix`) + } + + const idSep = input.indexOf(':', DID_PREFIX_LENGTH) + if (idSep === -1) { + throw new InvalidDidError(input, `Missing colon after method name`) + } + + checkDidMethod(input, DID_PREFIX_LENGTH, idSep) + checkDidMsid(input, idSep + 1, length) +} + +export function isDid(input: unknown): input is Did { + try { + checkDid(input) + return true + } catch (err) { + if (err instanceof DidError) { + return false + } + throw err + } +} + +export const didSchema = z + .string() + .superRefine((value: string, ctx: z.RefinementCtx): value is Did => { + try { + checkDid(value) + return true + } catch (err) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: err instanceof Error ? err.message : 'Unexpected error', + }) + return false + } + }) diff --git a/packages/did/src/index.ts b/packages/did/src/index.ts new file mode 100644 index 00000000000..26c57de9942 --- /dev/null +++ b/packages/did/src/index.ts @@ -0,0 +1,4 @@ +export * from './did-document.js' +export * from './did-error.js' +export * from './did.js' +export * from './methods.js' diff --git a/packages/did/src/methods.ts b/packages/did/src/methods.ts new file mode 100644 index 00000000000..9fb8254af2d --- /dev/null +++ b/packages/did/src/methods.ts @@ -0,0 +1,2 @@ +export * from './methods/plc.js' +export * from './methods/web.js' diff --git a/packages/did/src/methods/plc.ts b/packages/did/src/methods/plc.ts new file mode 100644 index 00000000000..f1e7138b23d --- /dev/null +++ b/packages/did/src/methods/plc.ts @@ -0,0 +1,40 @@ +import { InvalidDidError } from '../did-error.js' +import { Did } from '../did.js' + +const DID_PLC_PREFIX = `did:plc:` +const DID_PLC_PREFIX_LENGTH = DID_PLC_PREFIX.length +const DID_PLC_LENGTH = 32 + +export { DID_PLC_PREFIX } + +export function isDidPlc(input: unknown): input is Did<'plc'> { + if (typeof input !== 'string') return false + try { + checkDidPlc(input) + return true + } catch { + return false + } +} + +export function checkDidPlc(input: string): asserts input is Did<'plc'> { + if (input.length !== DID_PLC_LENGTH) { + throw new InvalidDidError( + input, + `did:plc must be ${DID_PLC_LENGTH} characters long`, + ) + } + + if (!input.startsWith(DID_PLC_PREFIX)) { + throw new InvalidDidError(input, `Invalid did:plc prefix`) + } + + let c: number + for (let i = DID_PLC_PREFIX_LENGTH; i < DID_PLC_LENGTH; i++) { + c = input.charCodeAt(i) + // Base32 encoding ([a-z2-7]) + if ((c < 0x61 || c > 0x7a) && (c < 0x32 || c > 0x37)) { + throw new InvalidDidError(input, `Invalid character at position ${i}`) + } + } +} diff --git a/packages/did/src/methods/web.ts b/packages/did/src/methods/web.ts new file mode 100644 index 00000000000..05a187c53c0 --- /dev/null +++ b/packages/did/src/methods/web.ts @@ -0,0 +1,78 @@ +import { InvalidDidError } from '../did-error.js' +import { Did, checkDidMsid } from '../did.js' + +export const DID_WEB_PREFIX = `did:web:` + +/** + * This function checks if the input is a valid Web DID, as per DID spec. + * ATPROTO adds additional constraints to allowed DID values for the `did:web` + * method. Use {@link isAtprotoDidWeb} if that's what you need. + */ +export function isDidWeb(input: unknown): input is Did<'web'> { + if (typeof input !== 'string') return false + try { + didWebToUrl(input) + return true + } catch { + return false + } +} + +/** + * @see {@link https://atproto.com/specs/did#blessed-did-methods} + */ +export function isAtprotoDidWeb(input: unknown): input is Did<'web'> { + // Optimization: make cheap checks first + if (typeof input !== 'string') { + return false + } + + // Path are not allowed + if (input.includes(':', DID_WEB_PREFIX.length)) { + return false + } + + // Port numbers are not allowed, except for localhost + if ( + input.includes('%3A', DID_WEB_PREFIX.length) && + !input.startsWith('did:web:localhost%3A') + ) { + return false + } + + return isDidWeb(input) +} + +export function checkDidWeb(input: string): asserts input is Did<'web'> { + didWebToUrl(input) +} + +export function didWebToUrl(did: string): URL { + if (!did.startsWith(DID_WEB_PREFIX)) { + throw new InvalidDidError(did, `did:web must start with ${DID_WEB_PREFIX}`) + } + + if (did.charAt(DID_WEB_PREFIX.length) === ':') { + throw new InvalidDidError(did, 'did:web MSID must not start with a colon') + } + + // Make sure every char is valid (per DID spec) + checkDidMsid(did, DID_WEB_PREFIX.length) + + try { + const msid = did.slice(DID_WEB_PREFIX.length) + const parts = msid.split(':').map(decodeURIComponent) + return new URL(`https://${parts.join('/')}`) + } catch (cause) { + throw new InvalidDidError(did, 'Invalid Web DID', cause) + } +} + +export function urlToDidWeb(url: URL): Did<'web'> { + const path = + url.pathname === '/' + ? '' + : url.pathname.slice(1).split('/').map(encodeURIComponent).join(':') + + return `did:web:${encodeURIComponent(url.host)}${path ? `:${path}` : ''}` +} diff --git a/packages/did/tsconfig.build.json b/packages/did/tsconfig.build.json new file mode 100644 index 00000000000..436d8ecb628 --- /dev/null +++ b/packages/did/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": "../../tsconfig/isomorphic.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/did/tsconfig.json b/packages/did/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/did/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/did-resolver/package.json b/packages/internal/did-resolver/package.json new file mode 100644 index 00000000000..ac880e82ff8 --- /dev/null +++ b/packages/internal/did-resolver/package.json @@ -0,0 +1,40 @@ +{ + "name": "@atproto-labs/did-resolver", + "version": "0.0.1", + "license": "MIT", + "description": "DID resolution and verification library", + "keywords": [ + "atproto", + "did", + "resolver" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/did-resolver" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/fetch": "workspace:*", + "@atproto-labs/pipe": "workspace:*", + "@atproto-labs/simple-store": "workspace:*", + "@atproto-labs/simple-store-memory": "workspace:*", + "@atproto/did": "workspace:*", + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/internal/did-resolver/src/did-cache-memory.ts b/packages/internal/did-resolver/src/did-cache-memory.ts new file mode 100644 index 00000000000..1c05bb9739d --- /dev/null +++ b/packages/internal/did-resolver/src/did-cache-memory.ts @@ -0,0 +1,25 @@ +import { Did, DidDocument } from '@atproto/did' +import { + SimpleStoreMemory, + SimpleStoreMemoryOptions, +} from '@atproto-labs/simple-store-memory' + +import { DidCache } from './did-cache.js' + +const DEFAULT_TTL = 3600 * 1000 // 1 hour +const DEFAULT_MAX_SIZE = 50 * 1024 * 1024 // ~50MB + +export type DidCacheMemoryOptions = SimpleStoreMemoryOptions + +export class DidCacheMemory + extends SimpleStoreMemory + implements DidCache +{ + constructor(options?: DidCacheMemoryOptions) { + super( + options?.max == null + ? { ttl: DEFAULT_TTL, maxSize: DEFAULT_MAX_SIZE, ...options } + : { ttl: DEFAULT_TTL, ...options }, + ) + } +} diff --git a/packages/internal/did-resolver/src/did-cache.ts b/packages/internal/did-resolver/src/did-cache.ts new file mode 100644 index 00000000000..4a379e532dd --- /dev/null +++ b/packages/internal/did-resolver/src/did-cache.ts @@ -0,0 +1,31 @@ +import { CachedGetter, SimpleStore } from '@atproto-labs/simple-store' +import { Did, DidDocument } from '@atproto/did' + +import { DidCacheMemory } from './did-cache-memory.js' +import { DidMethod, ResolveOptions } from './did-method.js' +import { DidResolver, ResolvedDocument } from './did-resolver.js' + +export type { DidMethod, ResolveOptions, ResolvedDocument } + +export type DidCache = SimpleStore + +export type DidResolverCachedOptions = { cache?: DidCache } + +export class DidResolverCached + implements DidResolver +{ + protected readonly getter: CachedGetter + constructor( + resolver: DidResolver, + cache: DidCache = new DidCacheMemory(), + ) { + this.getter = new CachedGetter( + (did, options) => resolver.resolve(did, options), + cache, + ) + } + + public async resolve(did: D, options?: ResolveOptions) { + return this.getter.get(did, options) as Promise> + } +} diff --git a/packages/internal/did-resolver/src/did-method.ts b/packages/internal/did-resolver/src/did-method.ts new file mode 100644 index 00000000000..3b45ec1bd06 --- /dev/null +++ b/packages/internal/did-resolver/src/did-method.ts @@ -0,0 +1,17 @@ +import { Did, DidDocument } from '@atproto/did' + +export type ResolveOptions = { + signal?: AbortSignal + noCache?: boolean +} + +export interface DidMethod { + resolve: ( + did: Did, + options?: ResolveOptions, + ) => DidDocument | PromiseLike +} + +export type DidMethods = { + [K in M]: DidMethod +} diff --git a/packages/internal/did-resolver/src/did-resolver-base.ts b/packages/internal/did-resolver/src/did-resolver-base.ts new file mode 100644 index 00000000000..95b61e43693 --- /dev/null +++ b/packages/internal/did-resolver/src/did-resolver-base.ts @@ -0,0 +1,66 @@ +import { FetchRequestError } from '@atproto-labs/fetch' +import { Did, DidError, extractDidMethod } from '@atproto/did' +import { ZodError } from 'zod' + +import { DidMethod, DidMethods, ResolveOptions } from './did-method.js' +import { DidResolver, ResolvedDocument } from './did-resolver.js' + +export type { DidMethod, ResolveOptions, ResolvedDocument } + +export class DidResolverBase + implements DidResolver +{ + protected readonly methods: Map> + + constructor(methods: DidMethods) { + this.methods = new Map(Object.entries(methods)) + } + + async resolve( + did: D, + options?: ResolveOptions, + ): Promise> { + options?.signal?.throwIfAborted() + + const method = extractDidMethod(did) + const resolver = this.methods.get(method) + if (!resolver) { + throw new DidError( + did, + `Unsupported DID method`, + 'did-method-invalid', + 400, + ) + } + + try { + const document = await resolver.resolve(did as Did, options) + if (document.id !== did) { + throw new DidError( + did, + `DID document id (${document.id}) does not match DID`, + 'did-document-id-mismatch', + 400, + ) + } + + return document as ResolvedDocument + } catch (err) { + if (err instanceof FetchRequestError) { + throw new DidError(did, err.message, 'did-fetch-error', 400, err) + } + + if (err instanceof ZodError) { + throw new DidError( + did, + err.message, + 'did-document-format-error', + 503, + err, + ) + } + + throw DidError.from(err, did) + } + } +} diff --git a/packages/internal/did-resolver/src/did-resolver-common.ts b/packages/internal/did-resolver/src/did-resolver-common.ts new file mode 100644 index 00000000000..f0be8de1943 --- /dev/null +++ b/packages/internal/did-resolver/src/did-resolver-common.ts @@ -0,0 +1,20 @@ +import { DidResolverBase } from './did-resolver-base.js' +import { DidPlcMethod, DidPlcMethodOptions } from './methods/plc.js' +import { DidWebMethod, DidWebMethodOptions } from './methods/web.js' +import { Simplify } from './util.js' + +export type DidResolverCommonOptions = Simplify< + DidPlcMethodOptions & DidWebMethodOptions +> + +export class DidResolverCommon + extends DidResolverBase<'plc' | 'web'> + implements DidResolverBase<'plc' | 'web'> +{ + constructor(options?: DidResolverCommonOptions) { + super({ + plc: new DidPlcMethod(options), + web: new DidWebMethod(options), + }) + } +} diff --git a/packages/internal/did-resolver/src/did-resolver.ts b/packages/internal/did-resolver/src/did-resolver.ts new file mode 100644 index 00000000000..d389c36578c --- /dev/null +++ b/packages/internal/did-resolver/src/did-resolver.ts @@ -0,0 +1,15 @@ +import { Did, DidDocument } from '@atproto/did' + +import { ResolveOptions } from './did-method.js' + +export type ResolvedDocument = + D extends Did + ? DidDocument + : never + +export interface DidResolver { + resolve( + did: D, + options?: ResolveOptions, + ): Promise> +} diff --git a/packages/internal/did-resolver/src/index.ts b/packages/internal/did-resolver/src/index.ts new file mode 100644 index 00000000000..fb90e6ab373 --- /dev/null +++ b/packages/internal/did-resolver/src/index.ts @@ -0,0 +1,9 @@ +export * from '@atproto/did' + +export * from './did-cache-memory.js' +export * from './did-cache.js' +export * from './did-method.js' +export * from './did-resolver-common.js' +export * from './did-resolver.js' +export * from './methods.js' +export * from './util.js' diff --git a/packages/internal/did-resolver/src/methods.ts b/packages/internal/did-resolver/src/methods.ts new file mode 100644 index 00000000000..9fb8254af2d --- /dev/null +++ b/packages/internal/did-resolver/src/methods.ts @@ -0,0 +1,2 @@ +export * from './methods/plc.js' +export * from './methods/web.js' diff --git a/packages/internal/did-resolver/src/methods/plc.ts b/packages/internal/did-resolver/src/methods/plc.ts new file mode 100644 index 00000000000..c76118bd202 --- /dev/null +++ b/packages/internal/did-resolver/src/methods/plc.ts @@ -0,0 +1,56 @@ +import { + Fetch, + bindFetch, + fetchJsonProcessor, + fetchJsonZodProcessor, + fetchOkProcessor, +} from '@atproto-labs/fetch' +import { pipe } from '@atproto-labs/pipe' +import { Did, checkDidPlc, didDocumentValidator } from '@atproto/did' + +import { DidMethod, ResolveOptions } from '../did-method.js' + +const fetchSuccessHandler = pipe( + fetchOkProcessor(), + fetchJsonProcessor(/^application\/(did\+ld\+)?json$/), + fetchJsonZodProcessor(didDocumentValidator), +) + +export type DidPlcMethodOptions = { + /** + * @default globalThis.fetch + */ + fetch?: Fetch + + /** + * @default 'https://plc.directory/' + */ + plcDirectoryUrl?: string | URL +} + +export class DidPlcMethod implements DidMethod<'plc'> { + protected readonly fetch: Fetch + + public readonly plcDirectoryUrl: URL + + constructor(options?: DidPlcMethodOptions) { + this.plcDirectoryUrl = new URL( + options?.plcDirectoryUrl || 'https://plc.directory/', + ) + this.fetch = bindFetch(options?.fetch) + } + + async resolve(did: Did<'plc'>, options?: ResolveOptions) { + // Although the did should start with `did:plc:` (thanks to typings), we + // should still check if the msid is valid. + checkDidPlc(did) + + const url = new URL(`/${did}`, this.plcDirectoryUrl) + + return this.fetch(url, { + redirect: 'error', + headers: { accept: 'application/did+ld+json,application/json' }, + signal: options?.signal, + }).then(fetchSuccessHandler) + } +} diff --git a/packages/internal/did-resolver/src/methods/web.ts b/packages/internal/did-resolver/src/methods/web.ts new file mode 100644 index 00000000000..2e1798daa38 --- /dev/null +++ b/packages/internal/did-resolver/src/methods/web.ts @@ -0,0 +1,58 @@ +import { + bindFetch, + Fetch, + fetchJsonProcessor, + fetchJsonZodProcessor, + fetchOkProcessor, +} from '@atproto-labs/fetch' +import { pipe } from '@atproto-labs/pipe' +import { Did, didDocumentValidator, didWebToUrl } from '@atproto/did' + +import { DidMethod, ResolveOptions } from '../did-method.js' + +const fetchSuccessHandler = pipe( + fetchOkProcessor(), + fetchJsonProcessor(/^application\/(did\+ld\+)?json$/), + fetchJsonZodProcessor(didDocumentValidator), +) + +export type DidWebMethodOptions = { + fetch?: Fetch +} + +export class DidWebMethod implements DidMethod<'web'> { + protected readonly fetch: Fetch + + constructor({ fetch = globalThis.fetch }: DidWebMethodOptions = {}) { + this.fetch = bindFetch(fetch) + } + + async resolve(did: Did<'web'>, options?: ResolveOptions) { + const didDocumentUrl = buildDidWebDocumentUrl(did) + + return this.fetch(didDocumentUrl, { + redirect: 'error', + headers: { accept: 'application/did+ld+json,application/json' }, + signal: options?.signal, + }).then(fetchSuccessHandler) + } +} + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc8615} + * @see {@link https://w3c-ccg.github.io/did-method-web/#create-register} + */ +export function buildDidWebDocumentUrl(did: Did<'web'>) { + const url = didWebToUrl(did) // Will throw if the DID is invalid + + // Note: DID cannot end with an `:`, so they cannot end with a `/`. This is + // true unless when there is no path at all, in which case the URL constructor + // will set the pathname to `/`. + + // https://w3c-ccg.github.io/did-method-web/#read-resolve + if (url.pathname === '/') { + return new URL(`/.well-known/did.json`, url) + } else { + return new URL(`${url.pathname}/did.json`, url) + } +} diff --git a/packages/internal/did-resolver/src/util.ts b/packages/internal/did-resolver/src/util.ts new file mode 100644 index 00000000000..9b0a81eb26b --- /dev/null +++ b/packages/internal/did-resolver/src/util.ts @@ -0,0 +1 @@ +export type Simplify = { [K in keyof T]: T[K] } & NonNullable diff --git a/packages/internal/did-resolver/tsconfig.build.json b/packages/internal/did-resolver/tsconfig.build.json new file mode 100644 index 00000000000..0e5790a2e3f --- /dev/null +++ b/packages/internal/did-resolver/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src/**/*.ts"] +} diff --git a/packages/internal/did-resolver/tsconfig.json b/packages/internal/did-resolver/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/did-resolver/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/fetch-node/package.json b/packages/internal/fetch-node/package.json new file mode 100644 index 00000000000..59747efee74 --- /dev/null +++ b/packages/internal/fetch-node/package.json @@ -0,0 +1,40 @@ +{ + "name": "@atproto-labs/fetch-node", + "version": "0.0.1", + "license": "MIT", + "description": "SSRF protection for fetch() in Node.js", + "keywords": [ + "atproto", + "fetch", + "node" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/fetch-node" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/fetch": "workspace:*", + "@atproto-labs/pipe": "workspace:*", + "ipaddr.js": "^2.1.0", + "psl": "^1.9.0", + "undici": "^6.14.1" + }, + "devDependencies": { + "@types/psl": "1.1.3", + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/internal/fetch-node/src/index.ts b/packages/internal/fetch-node/src/index.ts new file mode 100644 index 00000000000..b7022c681b9 --- /dev/null +++ b/packages/internal/fetch-node/src/index.ts @@ -0,0 +1,4 @@ +export * from '@atproto-labs/fetch' + +export * from './safe.js' +export * from './ssrf.js' diff --git a/packages/internal/fetch-node/src/safe.ts b/packages/internal/fetch-node/src/safe.ts new file mode 100644 index 00000000000..b64915f2e55 --- /dev/null +++ b/packages/internal/fetch-node/src/safe.ts @@ -0,0 +1,78 @@ +import { + DEFAULT_FORBIDDEN_DOMAIN_NAMES, + Fetch, + fetchMaxSizeProcessor, + forbiddenDomainNameRequestTransform, + protocolCheckRequestTransform, + requireHostHeaderTranform, + timedFetch, + toRequestTransformer, +} from '@atproto-labs/fetch' +import { pipe } from '@atproto-labs/pipe' + +import { ssrfFetchWrap } from './ssrf.js' + +export type SafeFetchWrapOptions = NonNullable< + Parameters[0] +> + +/** + * Wrap a fetch function with safety checks so that it can be safely used + * with user provided input (URL). + */ +export function safeFetchWrap({ + fetch = globalThis.fetch as Fetch, + responseMaxSize = 512 * 1024, // 512kB + allowHttp = false, + allowData = false, + ssrfProtection = true, + timeout = 10e3, + forbiddenDomainNames = DEFAULT_FORBIDDEN_DOMAIN_NAMES as Iterable, +} = {}): Fetch { + return toRequestTransformer( + pipe( + /** + * Prevent using http:, file: or data: protocols. + */ + protocolCheckRequestTransform( + ['https:'] + .concat(allowHttp ? ['http:'] : []) + .concat(allowData ? ['data:'] : []), + ), + + /** + * Only requests that will be issued with a "Host" header are allowed. + */ + requireHostHeaderTranform(), + + /** + * Disallow fetching from domains we know are not atproto/OIDC client + * implementation. Note that other domains can be blocked by providing a + * custom fetch function combined with another + * forbiddenDomainNameRequestTransform. + */ + forbiddenDomainNameRequestTransform(forbiddenDomainNames), + + /** + * Since we will be fetching from the network based on user provided + * input, let's mitigate resource exhaustion attacks by setting a timeout. + */ + timedFetch( + timeout, + + /** + * Since we will be fetching from the network based on user provided + * input, we need to make sure that the request is not vulnerable to SSRF + * attacks. + */ + ssrfProtection ? ssrfFetchWrap({ fetch }) : fetch, + ), + + /** + * Since we will be fetching user owned data, we need to make sure that an + * attacker cannot force us to download a large amounts of data. + */ + fetchMaxSizeProcessor(responseMaxSize), + ), + ) +} diff --git a/packages/internal/fetch-node/src/ssrf.ts b/packages/internal/fetch-node/src/ssrf.ts new file mode 100644 index 00000000000..7879ad9782b --- /dev/null +++ b/packages/internal/fetch-node/src/ssrf.ts @@ -0,0 +1,214 @@ +import dns, { LookupAddress } from 'node:dns' +import { LookupFunction } from 'node:net' + +import { + Fetch, + FetchContext, + FetchRequestError, + toRequestTransformer, +} from '@atproto-labs/fetch' +import ipaddr from 'ipaddr.js' +import { isValid as isValidDomain } from 'psl' +import { Agent } from 'undici' + +const { IPv4, IPv6 } = ipaddr + +const [NODE_VERSION] = process.versions.node.split('.').map(Number) + +export type SsrfFetchWrapOptions = { + allowCustomPort?: boolean + allowUnknownTld?: boolean + fetch?: Fetch +} + +/** + * @see {@link https://owasp.org/Top10/A10_2021-Server-Side_Request_Forgery_%28SSRF%29/} + */ +export function ssrfFetchWrap({ + allowCustomPort = false, + allowUnknownTld = false, + fetch = globalThis.fetch, +}: SsrfFetchWrapOptions): Fetch { + const ssrfAgent = new Agent({ connect: { lookup } }) + + return toRequestTransformer(async function ( + this: C, + request, + ): Promise { + const url = new URL(request.url) + + if (url.protocol === 'data:') { + // No SSRF issue + return fetch.call(this, request) + } + + if (url.protocol === 'http:' || url.protocol === 'https:') { + // @ts-expect-error non-standard option + if (request.dispatcher) { + throw new FetchRequestError( + request, + 500, + 'SSRF protection cannot be used with a custom request dispatcher', + ) + } + + // Check port (OWASP) + if (url.port && !allowCustomPort) { + throw new FetchRequestError( + request, + 400, + 'Request port must be omitted or standard when SSRF is enabled', + ) + } + + // Disable HTTP redirections (OWASP) + if (request.redirect === 'follow') { + throw new FetchRequestError( + request, + 500, + 'Request redirect must be "error" or "manual" when SSRF is enabled', + ) + } + + // If the hostname is an IP address, it must be a unicast address. + const ip = parseIpHostname(url.hostname) + if (ip) { + if (ip.range() !== 'unicast') { + throw new FetchRequestError( + request, + 400, + 'Hostname resolved to non-unicast address', + ) + } + // No additional check required + return fetch.call(this, request) + } + + if (allowUnknownTld !== true && !isValidDomain(url.hostname)) { + throw new FetchRequestError( + request, + 400, + 'Hostname is not a public domain', + ) + } + + // Else hostname is a domain name, use DNS lookup to check if it resolves + // to a unicast address + + if (NODE_VERSION < 21) { + // Note: due to the issue nodejs/undici#2828 (fixed in undici >=6.7.0, + // Node >=21), the "dispatcher" property of the request object will not + // be used by fetch(). As a workaround, we pass the dispatcher as second + // argument to fetch() here, and make sure it is used (which might not be + // the case if a custom fetch() function is used). + + if (fetch === globalThis.fetch) { + // If the global fetch function is used, we can pass the dispatcher + // singleton directly to the fetch function as we know it will be + // used. + + // @ts-expect-error non-standard option + return fetch.call(this, request, { dispatcher: ssrfAgent }) + } + + let didLookup = false + const dispatcher = new Agent({ + connect: { + lookup(...args) { + didLookup = true + lookup(...args) + }, + }, + }) + + try { + // @ts-expect-error non-standard option + return await fetch.call(this, request, { dispatcher }) + } finally { + // Free resources (we cannot await here since the response was not + // consumed yet). + void dispatcher.close().catch((err) => { + // No biggie, but let's still log it + console.warn('Failed to close dispatcher', err) + }) + + if (!didLookup) { + // If you encounter this error, either upgrade to Node.js >=21 or + // make sure that the requestInit object is passed as second + // argument to the global fetch function. + + // eslint-disable-next-line no-unsafe-finally + throw new FetchRequestError( + request, + 500, + 'Unable to enforce SSRF protection', + ) + } + } + } + + // @ts-expect-error non-standard option + return fetch(new Request(request, { dispatcher: ssrfAgent })) + } + + // blob: about: file: all should be rejected + throw new FetchRequestError( + request, + 400, + `Forbidden protocol "${url.protocol}"`, + ) + }) +} + +function parseIpHostname( + hostname: string, +): ipaddr.IPv4 | ipaddr.IPv6 | undefined { + if (IPv4.isIPv4(hostname)) { + return IPv4.parse(hostname) + } + + if (hostname.startsWith('[') && hostname.endsWith(']')) { + return IPv6.parse(hostname.slice(1, -1)) + } + + return undefined +} + +function lookup( + hostname: string, + options: dns.LookupOptions, + callback: Parameters[2], +) { + dns.lookup(hostname, options, (err, address, family) => { + if (err) { + callback(err, address, family) + } else { + const ips = Array.isArray(address) + ? address.map(parseLookupAddress) + : [parseLookupAddress({ address, family })] + + if (ips.some((ip) => ip.range() !== 'unicast')) { + callback( + new Error('Hostname resolved to non-unicast address'), + address, + family, + ) + } else { + callback(null, address, family) + } + } + }) +} + +function parseLookupAddress({ + address, + family, +}: LookupAddress): ipaddr.IPv4 | ipaddr.IPv6 { + const ip = family === 4 ? IPv4.parse(address) : IPv6.parse(address) + + if (ip instanceof IPv6 && ip.isIPv4MappedAddress()) { + return ip.toIPv4Address() + } else { + return ip + } +} diff --git a/packages/internal/fetch-node/tsconfig.build.json b/packages/internal/fetch-node/tsconfig.build.json new file mode 100644 index 00000000000..ea00aba058e --- /dev/null +++ b/packages/internal/fetch-node/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/node.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/internal/fetch-node/tsconfig.json b/packages/internal/fetch-node/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/fetch-node/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/fetch/package.json b/packages/internal/fetch/package.json new file mode 100644 index 00000000000..cf6ae61a05b --- /dev/null +++ b/packages/internal/fetch/package.json @@ -0,0 +1,37 @@ +{ + "name": "@atproto-labs/fetch", + "version": "0.0.1", + "license": "MIT", + "description": "Isomorphic wrapper utilities for fetch API", + "keywords": [ + "atproto", + "fetch" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/fetch" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/pipe": "workspace:*" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "optionalDependencies": { + "zod": "^3.23.8" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/internal/fetch/src/fetch-error.ts b/packages/internal/fetch/src/fetch-error.ts new file mode 100644 index 00000000000..d73e818ea94 --- /dev/null +++ b/packages/internal/fetch/src/fetch-error.ts @@ -0,0 +1,59 @@ +export class FetchError extends Error { + public readonly statusCode: number + + constructor(statusCode?: number, message?: string, options?: ErrorOptions) { + if (statusCode == null || !message) { + const info = extractInfo(extractRootCause(options?.cause)) + statusCode = statusCode ?? info[0] + message = message || info[1] + } + + super(message, options) + + this.statusCode = statusCode + } +} + +function extractRootCause(err: unknown): unknown { + // Unwrap the Network error from undici (i.e. Node's internal fetch() implementation) + // https://github.com/nodejs/undici/blob/3274c975947ce11a08508743df026f73598bfead/lib/web/fetch/index.js#L223-L228 + if ( + err instanceof TypeError && + err.message === 'fetch failed' && + err.cause !== undefined + ) { + return err.cause + } + + return err +} + +function extractInfo(err: unknown): [statusCode: number, message: string] { + if (typeof err === 'string' && err.length > 0) { + return [500, err] + } + + if (!(err instanceof Error)) { + return [500, 'Failed to fetch'] + } + + const code = err['code'] + if (typeof code === 'string') { + switch (true) { + case code === 'ENOTFOUND': + return [400, 'Invalid hostname'] + case code === 'ECONNREFUSED': + return [502, 'Connection refused'] + case code === 'DEPTH_ZERO_SELF_SIGNED_CERT': + return [502, 'Self-signed certificate'] + case code.startsWith('ERR_TLS'): + return [502, 'TLS error'] + case code.startsWith('ECONN'): + return [502, 'Connection error'] + default: + return [500, `${code} error`] + } + } + + return [500, err.message] +} diff --git a/packages/internal/fetch/src/fetch-request.ts b/packages/internal/fetch/src/fetch-request.ts new file mode 100644 index 00000000000..6d84ec8cb2a --- /dev/null +++ b/packages/internal/fetch/src/fetch-request.ts @@ -0,0 +1,118 @@ +import { FetchError } from './fetch-error.js' +import { asRequest } from './fetch.js' +import { isIp } from './util.js' + +export class FetchRequestError extends FetchError { + constructor( + public readonly request: Request, + statusCode?: number, + message?: string, + options?: ErrorOptions, + ) { + super(statusCode, message, options) + } + + static from(request: Request, cause: unknown): FetchRequestError { + if (cause instanceof FetchRequestError) return cause + return new FetchRequestError(request, undefined, undefined, { cause }) + } +} + +const extractUrl = (input: Request | string | URL) => + typeof input === 'string' + ? new URL(input) + : input instanceof URL + ? input + : new URL(input.url) + +export function protocolCheckRequestTransform(protocols: Iterable) { + const allowedProtocols = new Set(protocols) + + return (input: Request | string | URL, init?: RequestInit) => { + const { protocol } = extractUrl(input) + + const request = asRequest(input, init) + + if (!allowedProtocols.has(protocol)) { + throw new FetchRequestError( + request, + 400, + `"${protocol}" protocol is not allowed`, + ) + } + + return request + } +} + +export function requireHostHeaderTranform() { + return (input: Request | string | URL, init?: RequestInit) => { + // Note that fetch() will automatically add the Host header from the URL and + // discard any Host header manually set in the request. + + const { protocol, hostname } = extractUrl(input) + + const request = asRequest(input, init) + + // "Host" header only makes sense in the context of an HTTP request + if (protocol !== 'http:' && protocol !== 'https:') { + throw new FetchRequestError( + request, + 400, + `"${protocol}" requests are not allowed`, + ) + } + + if (!hostname || isIp(hostname)) { + throw new FetchRequestError(request, 400, 'Invalid hostname') + } + + return request + } +} + +export const DEFAULT_FORBIDDEN_DOMAIN_NAMES = [ + 'example.com', + '*.example.com', + 'example.org', + '*.example.org', + 'example.net', + '*.example.net', + 'googleusercontent.com', + '*.googleusercontent.com', +] + +export function forbiddenDomainNameRequestTransform( + denyList: Iterable = DEFAULT_FORBIDDEN_DOMAIN_NAMES, +) { + const denySet = new Set(denyList) + + // Optimization: if no forbidden domain names are provided, we can skip the + // check entirely. + if (denySet.size === 0) { + return async (request) => request + } + + return async (input: Request | string | URL, init?: RequestInit) => { + const { hostname } = extractUrl(input) + + const request = asRequest(input, init) + + // Full domain name check + if (denySet.has(hostname)) { + throw new FetchRequestError(request, 403, 'Forbidden hostname') + } + + // Sub domain name check + let curDot = hostname.indexOf('.') + while (curDot !== -1) { + const subdomain = hostname.slice(curDot + 1) + if (denySet.has(`*.${subdomain}`)) { + throw new FetchRequestError(request, 403, 'Forbidden hostname') + } + curDot = hostname.indexOf('.', curDot + 1) + } + + return request + } +} diff --git a/packages/internal/fetch/src/fetch-response.ts b/packages/internal/fetch/src/fetch-response.ts new file mode 100644 index 00000000000..36f5a06aa0f --- /dev/null +++ b/packages/internal/fetch/src/fetch-response.ts @@ -0,0 +1,278 @@ +import { Transformer, pipe } from '@atproto-labs/pipe' + +// optional dependency for typing purposes +import type { ZodTypeAny, ParseParams, TypeOf } from 'zod' + +import { FetchError } from './fetch-error.js' +import { TransformedResponse } from './transformed-response.js' +import { + Json, + MaxBytesTransformStream, + cancelBody, + ifObject, + ifString, + logCancellationError, +} from './util.js' + +export type ResponseTranformer = Transformer +export type ResponseMessageGetter = Transformer + +export class FetchResponseError extends FetchError { + constructor( + public readonly response: Response, + statusCode: number = response.status, + message: string = response.statusText, + options?: ErrorOptions, + ) { + super(statusCode, message, options) + } + + static async from( + response: Response, + customMessage: string | ResponseMessageGetter = extractResponseMessage, + statusCode = response.status, + options?: ErrorOptions, + ) { + const message = + typeof customMessage === 'string' + ? customMessage + : typeof customMessage === 'function' + ? await customMessage(response) + : undefined + + return new FetchResponseError(response, statusCode, message, options) + } +} + +const extractResponseMessage: ResponseMessageGetter = async (response) => { + const mimeType = extractMime(response) + if (!mimeType) return undefined + + try { + if (mimeType === 'text/plain') { + return await response.text() + } else if (/^application\/(?:[^+]+\+)?json$/i.test(mimeType)) { + const json: unknown = await response.json() + + if (typeof json === 'string') return json + + const errorDescription = ifString(ifObject(json)?.['error_description']) + if (errorDescription) return errorDescription + + const error = ifString(ifObject(json)?.['error']) + if (error) return error + + const message = ifString(ifObject(json)?.['message']) + if (message) return message + } + } catch { + // noop + } + + return undefined +} + +export async function peekJson( + response: Response, + maxSize = Infinity, +): Promise { + const type = extractMime(response) + if (type !== 'application/json') return undefined + checkLength(response, maxSize) + + // 1) Clone the request so we can consume the body + const clonedResponse = response.clone() + + // 2) Make sure the request's body is not too large + const limitedResponse = + response.body && maxSize < Infinity + ? new TransformedResponse( + clonedResponse, + new MaxBytesTransformStream(maxSize), + ) + : // Note: some runtimes (e.g. react-native) don't expose a body property + clonedResponse + + // 3) Parse the JSON + return limitedResponse.json() +} + +export function checkLength(response: Response, maxBytes: number) { + // Note: negation accounts for invalid value types (NaN, non numbers) + if (!(maxBytes >= 0)) { + throw new TypeError('maxBytes must be a non-negative number') + } + const length = extractLength(response) + if (length != null && length > maxBytes) { + throw new FetchResponseError(response, 502, 'Response too large') + } + return length +} + +export function extractLength(response: Response) { + const contentLength = response.headers.get('Content-Length') + if (contentLength == null) return undefined + if (!/^\d+$/.test(contentLength)) { + throw new FetchResponseError(response, 502, 'Invalid Content-Length') + } + const length = Number(contentLength) + if (!Number.isSafeInteger(length)) { + throw new FetchResponseError(response, 502, 'Content-Length too large') + } + return length +} + +export function extractMime(response: Response) { + const contentType = response.headers.get('Content-Type') + if (contentType == null) return undefined + + return contentType.split(';', 1)[0]!.trim() +} + +/** + * If the transformer results in an error, ensure that the response body is + * consumed as, in some environments (Node 👀), the response will not + * automatically be GC'd. + * + * @see {@link https://undici.nodejs.org/#/?id=garbage-collection} + * @param [onCancellationError] - Callback to handle any async body cancelling + * error. Defaults to logging the error. Do not use `null` if the request is + * cloned. + */ +export function cancelBodyOnError( + transformer: Transformer, + onCancellationError: null | ((err: unknown) => void) = logCancellationError, +): (response: Response) => Promise { + return async (response) => { + try { + return await transformer(response) + } catch (err) { + await cancelBody(response, onCancellationError ?? undefined) + throw err + } + } +} + +export function fetchOkProcessor( + customMessage?: string | ResponseMessageGetter, +): ResponseTranformer { + return cancelBodyOnError((response) => { + return fetchOkTransformer(response, customMessage) + }) +} + +export async function fetchOkTransformer( + response: Response, + customMessage?: string | ResponseMessageGetter, +) { + if (response.ok) return response + throw await FetchResponseError.from(response, customMessage) +} + +export function fetchMaxSizeProcessor(maxBytes: number): ResponseTranformer { + if (maxBytes === Infinity) return (response) => response + if (!Number.isFinite(maxBytes) || maxBytes < 0) { + throw new TypeError('maxBytes must be a 0, Infinity or a positive number') + } + return cancelBodyOnError((response) => { + return fetchResponseMaxSizeChecker(response, maxBytes) + }) +} + +export function fetchResponseMaxSizeChecker( + response: Response, + maxBytes: number, +): Response { + if (maxBytes === Infinity) return response + checkLength(response, maxBytes) + + // Some engines (react-native 👀) don't expose a body property. In that case, + // we will only rely on the Content-Length header. + if (!response.body) return response + + const transform = new MaxBytesTransformStream(maxBytes) + return new TransformedResponse(response, transform) +} + +export type MimeTypeCheckFn = (mimeType: string) => boolean +export type MimeTypeCheck = string | RegExp | MimeTypeCheckFn + +export function fetchTypeProcessor( + expectedMime: MimeTypeCheck, + contentTypeRequired = true, +): ResponseTranformer { + const isExpected: MimeTypeCheckFn = + typeof expectedMime === 'string' + ? (mimeType) => mimeType === expectedMime + : expectedMime instanceof RegExp + ? (mimeType) => expectedMime.test(mimeType) + : expectedMime + + return cancelBodyOnError((response) => { + return fetchResponseTypeChecker(response, isExpected, contentTypeRequired) + }) +} + +export async function fetchResponseTypeChecker( + response: Response, + isExpectedMime: MimeTypeCheckFn, + contentTypeRequired = true, +): Promise { + const mimeType = extractMime(response) + if (mimeType) { + if (!isExpectedMime(mimeType)) { + throw await FetchResponseError.from( + response, + `Unexpected response Content-Type (${mimeType})`, + 502, + ) + } + } else if (contentTypeRequired) { + throw await FetchResponseError.from( + response, + 'Missing response Content-Type header', + 502, + ) + } + + return response +} + +export type ParsedJsonResponse = { + response: Response + json: T +} + +export async function fetchResponseJsonTranformer( + response: Response, +): Promise> { + try { + const json = (await response.json()) as T + return { response, json } + } catch (cause) { + throw new FetchResponseError( + response, + 502, + 'Unable to parse response as JSON', + { cause }, + ) + } +} + +export function fetchJsonProcessor( + expectedMime: MimeTypeCheck = /^application\/(?:[^+]+\+)?json$/, + contentTypeRequired = true, +): Transformer> { + return pipe( + fetchTypeProcessor(expectedMime, contentTypeRequired), + cancelBodyOnError(fetchResponseJsonTranformer), + ) +} + +export function fetchJsonZodProcessor( + schema: S, + params?: Partial, +): Transformer> { + return async (jsonResponse: ParsedJsonResponse): Promise> => + schema.parseAsync(jsonResponse.json, params) +} diff --git a/packages/internal/fetch/src/fetch-wrap.ts b/packages/internal/fetch/src/fetch-wrap.ts new file mode 100644 index 00000000000..71291003b81 --- /dev/null +++ b/packages/internal/fetch/src/fetch-wrap.ts @@ -0,0 +1,122 @@ +import { FetchRequestError } from './fetch-request.js' +import { Fetch, FetchContext, toRequestTransformer } from './fetch.js' +import { TransformedResponse } from './transformed-response.js' +import { padLines, stringifyMessage } from './util.js' + +export function loggedFetch( + fetch: Fetch = globalThis.fetch, +) { + return toRequestTransformer(async function ( + this: C, + request, + ): Promise { + const requestMessage = await stringifyMessage(request) + console.info( + `> ${request.method} ${request.url}\n${padLines(requestMessage, ' ')}`, + ) + + try { + const response = await fetch.call(this, request) + + const responseMessage = await stringifyMessage(response.clone()) + console.info( + `< HTTP/1.1 ${response.status} ${response.statusText}\n${padLines(responseMessage, ' ')}`, + ) + + return response + } catch (error) { + console.error(`< Error:`, error) + + throw error + } + }) +} + +export const timedFetch = ( + timeout = 60e3, + fetch: Fetch = globalThis.fetch, +): Fetch => { + if (timeout === Infinity) return fetch + if (!Number.isFinite(timeout) || timeout <= 0) { + throw new TypeError('Timeout must be positive') + } + return toRequestTransformer(async function ( + this: C, + request, + ): Promise { + const controller = new AbortController() + const signal = controller.signal + + const abort = () => { + controller.abort() + } + const cleanup = () => { + clearTimeout(timer) + request.signal?.removeEventListener('abort', abort) + } + + const timer = setTimeout(abort, timeout) + if (typeof timer === 'object') timer.unref?.() // only on node + request.signal?.addEventListener('abort', abort) + + signal.addEventListener('abort', cleanup) + + const response = await fetch.call(this, request, { signal }) + + if (!response.body) { + cleanup() + return response + } else { + // Cleanup the timer & event listeners when the body stream is closed + const transform = new TransformStream({ flush: cleanup }) + return new TransformedResponse(response, transform) + } + }) +} + +/** + * Wraps a fetch function to bind it to a specific context, and wrap any thrown + * errors into a FetchRequestError. + * + * @example + * + * ```ts + * class MyClient { + * constructor(private fetch = globalThis.fetch) {} + * + * async get(url: string) { + * // This will generate an error, because the context used is not a + * // FetchContext (it's a MyClient instance). + * return this.fetch(url) + * } + * } + * ``` + * + * @example + * + * ```ts + * class MyClient { + * private fetch: Fetch + * + * constructor(fetch = globalThis.fetch) { + * this.fetch = bindFetch(fetch) + * } + * + * async get(url: string) { + * return this.fetch(url) // no more error + * } + * } + * ``` + */ +export function bindFetch( + fetch: Fetch = globalThis.fetch, + context: C = globalThis as C, +) { + return toRequestTransformer(async (request) => { + try { + return await fetch.call(context, request) + } catch (err) { + throw FetchRequestError.from(request, err) + } + }) +} diff --git a/packages/internal/fetch/src/fetch.ts b/packages/internal/fetch/src/fetch.ts new file mode 100644 index 00000000000..a99889f80b1 --- /dev/null +++ b/packages/internal/fetch/src/fetch.ts @@ -0,0 +1,39 @@ +import { ThisParameterOverride } from './util.js' + +export type FetchContext = void | null | typeof globalThis + +export type FetchBound = ( + input: string | URL | Request, + init?: RequestInit, +) => Promise + +// NOT using "typeof globalThis.fetch" here because "globalThis.fetch" does not +// have a "this" parameter, while runtimes do ensure that "fetch" is called with +// the correct "this" parameter (either null, undefined, or window). + +export type Fetch = ThisParameterOverride + +export type SimpleFetchBound = (input: Request) => Promise +export type SimpleFetch = ThisParameterOverride< + C, + SimpleFetchBound +> + +export function toRequestTransformer( + requestTransformer: (this: C, input: Request) => O, +): ThisParameterOverride< + C, + (input: string | URL | Request, init?: RequestInit) => O +> { + return function (this: C, input, init) { + return requestTransformer.call(this, asRequest(input, init)) + } +} + +export function asRequest( + input: string | URL | Request, + init?: RequestInit, +): Request { + if (!init && input instanceof Request) return input + return new Request(input, init) +} diff --git a/packages/internal/fetch/src/index.ts b/packages/internal/fetch/src/index.ts new file mode 100644 index 00000000000..d8fa808f3ec --- /dev/null +++ b/packages/internal/fetch/src/index.ts @@ -0,0 +1,6 @@ +export * from './fetch-error.js' +export * from './fetch-request.js' +export * from './fetch-response.js' +export * from './fetch-wrap.js' +export * from './fetch.js' +export * from './util.js' diff --git a/packages/internal/fetch/src/transformed-response.ts b/packages/internal/fetch/src/transformed-response.ts new file mode 100644 index 00000000000..f0d11389bec --- /dev/null +++ b/packages/internal/fetch/src/transformed-response.ts @@ -0,0 +1,36 @@ +export class TransformedResponse extends Response { + #response: Response + + constructor(response: Response, transform: TransformStream) { + if (!response.body) { + throw new TypeError('Response body is not available') + } + if (response.bodyUsed) { + throw new TypeError('Response body is already used') + } + + super(response.body.pipeThrough(transform), { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }) + + this.#response = response + } + + /** + * Some props can't be set through ResponseInit, so we need to proxy them + */ + get url() { + return this.#response.url + } + get redirected() { + return this.#response.redirected + } + get type() { + return this.#response.type + } + get statusText() { + return this.#response.statusText + } +} diff --git a/packages/internal/fetch/src/util.ts b/packages/internal/fetch/src/util.ts new file mode 100644 index 00000000000..ffc518a2b2a --- /dev/null +++ b/packages/internal/fetch/src/util.ts @@ -0,0 +1,169 @@ +// @TODO: Move some of these to a shared package ? + +export type JsonScalar = string | number | boolean | null +export type Json = JsonScalar | Json[] | { [key: string]: undefined | Json } +export type JsonObject = { [key: string]: Json } +export type JsonArray = Json[] + +export type ThisParameterOverride< + C, + Fn extends (...a: any) => any, +> = Fn extends (...args: infer P) => infer R + ? ((this: C, ...args: P) => R) & { + bind(context: C): (...args: P) => R + } + : never + +export function isIp(hostname: string) { + // IPv4 + if (hostname.match(/^\d+\.\d+\.\d+\.\d+$/)) return true + + // IPv6 + if (hostname.startsWith('[') && hostname.endsWith(']')) return true + + return false +} + +const plainObjectProto = Object.prototype +export const ifObject = (v: V) => { + if (typeof v === 'object' && v != null && !Array.isArray(v)) { + const proto = Object.getPrototypeOf(v) + if (proto === null || proto === plainObjectProto) { + // eslint-disable-next-line @typescript-eslint/ban-types + return v as V extends JsonScalar | JsonArray | Function | symbol + ? never + : V extends Json + ? V + : // Plain object are (mostly) safe to access using a string index + Record + } + } + + return undefined +} + +export const ifString = (v: V) => (typeof v === 'string' ? v : undefined) + +export class MaxBytesTransformStream extends TransformStream< + Uint8Array, + Uint8Array +> { + constructor(maxBytes: number) { + // Note: negation accounts for invalid value types (NaN, non numbers) + if (!(maxBytes >= 0)) { + throw new TypeError('maxBytes must be a non-negative number') + } + + let bytesRead = 0 + + super({ + transform: ( + chunk: Uint8Array, + ctrl: TransformStreamDefaultController, + ) => { + if ((bytesRead += chunk.length) <= maxBytes) { + ctrl.enqueue(chunk) + } else { + ctrl.error(new Error('Response too large')) + } + }, + }) + } +} + +const LINE_BREAK = /\r?\n/g +export function padLines(input: string, pad: string) { + if (!input) return input + return pad + input.replace(LINE_BREAK, `$&${pad}`) +} + +/** + * @param [onCancellationError] - Callback that will trigger to asynchronously + * handle any error that occurs while cancelling the response body. Providing + * this will speed up the process and avoid potential deadlocks. Defaults to + * awaiting the cancellation operation. use `"log"` to log the error. + * @see {@link https://undici.nodejs.org/#/?id=garbage-collection} + * @note awaiting this function's result, when no `onCancellationError` is + * provided, might result in a dead lock. Indeed, if the response was cloned(), + * the response.body.cancel() method will not resolve until the other response's + * body is consumed/cancelled. + * + * @example + * ```ts + * // Make sure response was not cloned, or that every cloned response was + * // consumed/cancelled before awaiting this function's result. + * await cancelBody(response) + * ``` + * @example + * ```ts + * await cancelBody(response, (err) => { + * // No biggie, let's just log the error + * console.warn('Failed to cancel response body', err) + * }) + * ``` + * @example + * ```ts + * // Will generate an "unhandledRejection" if an error occurs while cancelling + * // the response body. This will likely crash the process. + * await cancelBody(response, (err) => { throw err }) + * ``` + */ +export async function cancelBody( + body: Body, + onCancellationError?: 'log' | ((err: unknown) => void), +): Promise { + if ( + body.body && + !body.bodyUsed && + !body.body.locked && + // Support for alternative fetch implementations + typeof body.body.cancel === 'function' + ) { + if (typeof onCancellationError === 'function') { + void body.body.cancel().catch(onCancellationError) + } else if (onCancellationError === 'log') { + void body.body.cancel().catch(logCancellationError) + } else { + await body.body.cancel() + } + } +} + +export function logCancellationError(err: unknown): void { + console.warn('Failed to cancel response body', err) +} + +export async function stringifyMessage(input: Body & { headers: Headers }) { + try { + const headers = stringifyHeaders(input.headers) + const payload = await stringifyBody(input) + return headers && payload ? `${headers}\n${payload}` : headers || payload + } finally { + void cancelBody(input, 'log') + } +} + +function stringifyHeaders(headers: Headers) { + return Array.from(headers) + .map(([name, value]) => `${name}: ${value}`) + .join('\n') +} + +async function stringifyBody(body: Body) { + try { + const blob = await body.blob() + if (blob.type?.startsWith('text/')) { + const text = await blob.text() + return JSON.stringify(text) + } + + if (/application\/(?:\w+\+)?json/.test(blob.type)) { + const text = await blob.text() + return text.includes('\n') ? JSON.stringify(JSON.parse(text)) : text + } + + return `[Body size: ${blob.size}, type: ${JSON.stringify(blob.type)} ]` + } catch { + return '[Body could not be read]' + } +} diff --git a/packages/internal/fetch/tsconfig.build.json b/packages/internal/fetch/tsconfig.build.json new file mode 100644 index 00000000000..9f3c5252d4c --- /dev/null +++ b/packages/internal/fetch/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/internal/fetch/tsconfig.json b/packages/internal/fetch/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/fetch/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/handle-resolver-node/package.json b/packages/internal/handle-resolver-node/package.json new file mode 100644 index 00000000000..ba77018f5ac --- /dev/null +++ b/packages/internal/handle-resolver-node/package.json @@ -0,0 +1,39 @@ +{ + "name": "@atproto-labs/handle-resolver-node", + "version": "0.0.1", + "license": "MIT", + "description": "Node specific ATProto handle to DID resolver", + "keywords": [ + "atproto", + "oauth", + "handle", + "identity", + "node" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/handle-resolver-node" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/fetch-node": "workspace:*", + "@atproto-labs/handle-resolver": "workspace:*", + "@atproto/did": "workspace:*" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/internal/handle-resolver-node/src/atproto-handle-resolver-node.ts b/packages/internal/handle-resolver-node/src/atproto-handle-resolver-node.ts new file mode 100644 index 00000000000..626dca43886 --- /dev/null +++ b/packages/internal/handle-resolver-node/src/atproto-handle-resolver-node.ts @@ -0,0 +1,51 @@ +import { Fetch, safeFetchWrap } from '@atproto-labs/fetch-node' +import { + AtprotoHandleResolver, + HandleResolver, +} from '@atproto-labs/handle-resolver' + +import { + nodeResolveTxtDefault, + nodeResolveTxtFactory, +} from './node-resolve-txt-factory.js' + +export type AtprotoHandleResolverNodeOptions = { + /** + * List of backup nameservers to use in case the primary ones fail. Will + * default to no fallback nameservers. + */ + fallbackNameservers?: string[] + + /** + * Fetch function to use for HTTP requests. Allows customizing the request + * behavior, e.g. adding headers, setting a timeout, mocking, etc. The + * provided fetch function will be wrapped with a safeFetchWrap function that + * adds SSRF protection. + * + * @default `globalThis.fetch` + */ + fetch?: Fetch +} + +export class AtprotoHandleResolverNode + extends AtprotoHandleResolver + implements HandleResolver +{ + constructor({ + fetch = globalThis.fetch, + fallbackNameservers, + }: AtprotoHandleResolverNodeOptions = {}) { + super({ + fetch: safeFetchWrap({ + fetch, + timeout: 3000, // 3 seconds + ssrfProtection: true, + responseMaxSize: 10 * 1048, // DID are max 2048 characters, 10kb for safety + }), + resolveTxt: nodeResolveTxtDefault, + resolveTxtFallback: fallbackNameservers?.length + ? nodeResolveTxtFactory(fallbackNameservers) + : undefined, + }) + } +} diff --git a/packages/internal/handle-resolver-node/src/index.ts b/packages/internal/handle-resolver-node/src/index.ts new file mode 100644 index 00000000000..3dfc49a73ff --- /dev/null +++ b/packages/internal/handle-resolver-node/src/index.ts @@ -0,0 +1,4 @@ +// Main export +export * from './atproto-handle-resolver-node.js' +export * from './node-resolve-txt-factory.js' +export { AtprotoHandleResolverNode as default } from './atproto-handle-resolver-node.js' diff --git a/packages/internal/handle-resolver-node/src/node-resolve-txt-factory.ts b/packages/internal/handle-resolver-node/src/node-resolve-txt-factory.ts new file mode 100644 index 00000000000..2429ea139f0 --- /dev/null +++ b/packages/internal/handle-resolver-node/src/node-resolve-txt-factory.ts @@ -0,0 +1,100 @@ +import { Resolver, lookup, resolveTxt } from 'node:dns/promises' +import { isIP } from 'node:net' + +import { ResolveTxt } from '@atproto-labs/handle-resolver' + +export const nodeResolveTxtDefault: ResolveTxt = (hostname) => + resolveTxt(hostname).then(groupChunks, handleError) + +export function nodeResolveTxtFactory(nameservers: string[]): ResolveTxt { + // Optimization + if (!nameservers.length) return async () => null + + // Build the resolver asynchronously (will be awaited on every use) + const resolverPromise: Promise = Promise.all( + nameservers.map((nameserver) => { + const [domain, port = null] = nameserver.split(':', 2) + + if (port !== null && !/^\d+$/.test(port)) { + throw new TypeError(`Invalid name server "${nameserver}"`) + } + + return isIP(domain) === 4 || isBracedIPv6(domain) + ? [nameserver] // No need to lookup + : lookup(domain, { all: true }).then( + (r) => r.map((a) => appendPort(a.address, port)), + // Let's just ignore failed nameservers resolution + (_err) => [], + ) + }), + ).then((results) => { + const backupIps = results.flat(1) + // No resolver if no valid IP + if (!backupIps.length) return null + + const resolver = new Resolver() + resolver.setServers(backupIps) + return resolver + }) + + // Avoid uncaught promise rejection + void resolverPromise.catch(() => { + // Should never happen though... + }) + + return async (hostname) => { + const resolver = await resolverPromise + return resolver + ? resolver.resolveTxt(hostname).then(groupChunks, handleError) + : null + } +} + +function isBracedIPv6(address: string): boolean { + return ( + address.startsWith('[') && + address.endsWith(']') && + isIP(address.slice(1, -1)) === 6 + ) +} + +function groupChunks(results: string[][]): string[] { + return results.map((chunks) => chunks.join('')) +} + +function handleError(err: unknown) { + // Invalid argument type (e.g. hostname is a number) + if (err instanceof TypeError) throw err + + // If the hostname does not resolve, return null + if (err instanceof Error) { + if (err['code'] === 'ENOTFOUND') return null + + // Hostname is not a valid domain name + if (err['code'] === 'EBADNAME') throw err + + // DNS server unreachable + // if (err['code'] === 'ETIMEOUT') throw err + } + + // Historically, errors were not thrown here. A "null" value indicates to the + // AtprotoHandleResolver that it should try the fallback resolver. + + // @TODO We might want to re-visit this to only apply when an unexpected error + // occurs (by throwing here). For now, let's keep the same behavior as before. + + // throw err + + return null +} + +function appendPort(address: string, port: string | null): string { + switch (isIP(address)) { + case 4: + return port ? `${address}:${port}` : address + case 6: + return port ? `[${address}]:${port}` : `[${address}]` + default: + throw new TypeError(`Invalid IP address "${address}"`) + } +} diff --git a/packages/internal/handle-resolver-node/tsconfig.build.json b/packages/internal/handle-resolver-node/tsconfig.build.json new file mode 100644 index 00000000000..b60d50d0927 --- /dev/null +++ b/packages/internal/handle-resolver-node/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/node.json"], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/internal/handle-resolver-node/tsconfig.json b/packages/internal/handle-resolver-node/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/handle-resolver-node/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/handle-resolver/README.md b/packages/internal/handle-resolver/README.md new file mode 100644 index 00000000000..5847a119d5b --- /dev/null +++ b/packages/internal/handle-resolver/README.md @@ -0,0 +1,148 @@ +# Universal Handle Resolver implementation for ATPROTO + +This package provides a handle resolver implementation for ATPROTO. It is used +to resolve handles to their corresponding DID. + +This package is meant to be used in any JavaScript environment that support the +`fetch()` function. Because APTORO handle resolution requires DNS resolution, +you will need to provide your own DNS resolution function when using this +package. + +There are two main classes in this package: + +- `AtprotoHandleResolver` This implements the official ATPROTO handle resolution + algorithm (and requires a DNS resolver). +- `AppViewHandleResolver` This uses HTTP requests to the Bluesky AppView + (bsky.app) to provide handle resolution. + +## Usage + +### From a front-end app + +Since the ATPROTO handle resolution algorithm requires DNS resolution, and the +browser does not provide a built-in DNS resolver, this package offers two +options: + +- Delegate handle resolution to an AppView (`AppViewHandleResolver`). This is + the recommended approach for front-end apps. +- Use a DNS-over-HTTPS (DoH) server (`DohHandleResolver`). Prefer this method + if you don't own an AppView and already have a DoH server that you trust. + +Using an AppView: + +> [!CAUTION] +> Use the Bluesky owned AppView (`https://api.bsky.app/`), or PDS +> (`https://bsky.social/`), at your own risk. Using these servers in a +> third-party application might expose your users' data (IP address) to Bluesky. +> Bluesky might log the data sent to it when your app is resolving handles. +> Bluesky might also change the API, or terms or use, at any time without +> notice. Make sure you are compliant with the Bluesky terms of use as well as +> any laws and regulations that apply to your use case. + +```ts +import { AppViewHandleResolver } from '@atproto-labs/handle-resolver' + +const resolver = new AppViewHandleResolver({ + service: 'https://my-app-view.com/', +}) +const did = await resolver.resolve('my-handle.bsky.social') +``` + +Using DNS-over-HTTPS (DoH) for DNS resolution: + +> [!CAUTION] +> Using a DoH server that you don't own might expose your users' data to +> the DoH server provider. The DoH server provider might log the data sent to it +> by your app, allowing them to track which handles are being resolved by your +> users. In the browser, it is recommended to use a DoH server that you own and +> control. Or to implement your own AppView and use the `AppViewHandleResolver` +> class. + +> [!NOTE] +> Using the `DohHandleResolver` requires a DNS-over-HTTPS server that +> supports the DNS-over-HTTPS protocol with "application/dns-json" responses. + +```ts +import { DohHandleResolver } from '@atproto-labs/handle-resolver' + +// Also works with 'https://cloudflare-dns.com/dns-query' +const resolver = new DohHandleResolver('https://dns.google/resolve', { + // Optional: Custom fetch function that will be used both for DNS resolution + // and well-known resolution. + fetch: globalThis.fetch.bind(globalThis), +}) + +const did = await resolver.resolve('my-handle.bsky.social') +``` + +### From a Node.js app + +> [!NOTE] +> On a Node.js backend, you will probably want to use the +> "@atproto-labs/handle-resolver-node" package. The example below applies to +> Node.js code running on a user's machine (e.g. through Electron). + +```ts +import { AtprotoHandleResolver } from '@atproto-labs/handle-resolver' +import { resolveTxt } from 'node:dns/promises' + +const resolver = new AtprotoHandleResolver({ + // Optional: Custom fetch function (used for well-known resolution) + fetch: globalThis.fetch.bind(globalThis), + + resolveTxt: async (domain: string) => + resolveTxt(domain).then((chunks) => chunks.join('')), +}) +``` + +### Caching + +Using a default, in-memory cache, in which items expire after 10 minutes: + +```ts +import { + AppViewHandleResolver, + CachedHandleResolver, + HandleResolver, + HandleCache, +} from '@atproto-labs/handle-resolver' + +// See previous examples for creating a resolver +declare const sourceResolver: HandleResolver + +const resolver = new CachedHandleResolver(sourceResolver) +const did = await resolver.resolve('my-handle.bsky.social') +const did = await resolver.resolve('my-handle.bsky.social') // Result from cache +const did = await resolver.resolve('my-handle.bsky.social') // Result from cache +``` + +Using a custom cache: + +```ts +import { + AppViewHandleResolver, + CachedHandleResolver, + HandleResolver, + HandleCache, +} from '@atproto-labs/handle-resolver' + +// See previous examples for creating a resolver +declare const sourceResolver: HandleResolver + +const cache: HandleCache = { + set(handle, did): Promise { + /* TODO */ + }, + get(handle): Promise { + /* TODO */ + }, + del(handle): Promise { + /* TODO */ + }, +} + +const resolver = new CachedHandleResolver(sourceResolver, cache) +const did = await resolver.resolve('my-handle.bsky.social') +const did = await resolver.resolve('my-handle.bsky.social') // Result from cache +const did = await resolver.resolve('my-handle.bsky.social') // Result from cache +``` diff --git a/packages/internal/handle-resolver/package.json b/packages/internal/handle-resolver/package.json new file mode 100644 index 00000000000..9c9258053a6 --- /dev/null +++ b/packages/internal/handle-resolver/package.json @@ -0,0 +1,42 @@ +{ + "name": "@atproto-labs/handle-resolver", + "version": "0.0.1", + "license": "MIT", + "description": "Isomorphic ATProto handle to DID resolver", + "keywords": [ + "atproto", + "oauth", + "handle", + "identity", + "browser", + "node", + "isomorphic" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/handle-resolver" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/simple-store": "workspace:*", + "@atproto-labs/simple-store-memory": "workspace:*", + "@atproto/did": "workspace:*", + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/internal/handle-resolver/src/app-view-handle-resolver.ts b/packages/internal/handle-resolver/src/app-view-handle-resolver.ts new file mode 100644 index 00000000000..444ba10f625 --- /dev/null +++ b/packages/internal/handle-resolver/src/app-view-handle-resolver.ts @@ -0,0 +1,95 @@ +import z from 'zod' + +import { + HandleResolver, + ResolveOptions, + ResolvedHandle, + isResolvedHandle, +} from './types.js' + +export const xrpcErrorSchema = z.object({ + error: z.string(), + message: z.string().optional(), +}) + +export type AppViewHandleResolverOptions = { + /** + * Fetch function to use for HTTP requests. Allows customizing the request + * behavior, e.g. adding headers, setting a timeout, mocking, etc. + * + * @default globalThis.fetch + */ + fetch?: typeof globalThis.fetch +} + +export class AppViewHandleResolver implements HandleResolver { + static from( + service: URL | string | HandleResolver, + options?: AppViewHandleResolverOptions, + ): HandleResolver { + if (typeof service === 'string' || service instanceof URL) { + return new AppViewHandleResolver(service, options) + } + return service + } + + /** + * URL of the atproto lexicon server. This is the base URL where the + * `com.atproto.identity.resolveHandle` XRPC method is located. + */ + protected readonly serviceUrl: URL + protected readonly fetch: typeof globalThis.fetch + + constructor(service: URL | string, options?: AppViewHandleResolverOptions) { + this.serviceUrl = new URL(service) + this.fetch = options?.fetch ?? globalThis.fetch + } + + public async resolve( + handle: string, + options?: ResolveOptions, + ): Promise { + const url = new URL( + '/xrpc/com.atproto.identity.resolveHandle', + this.serviceUrl, + ) + url.searchParams.set('handle', handle) + + const headers = new Headers() + if (options?.noCache) headers.set('cache-control', 'no-cache') + + const response = await this.fetch.call(null, url, { + headers, + signal: options?.signal, + redirect: 'error', + }) + const payload = await response.json() + + // The response should either be + // - 400 Bad Request with { error: 'InvalidRequest', message: 'Unable to resolve handle' } + // - 200 OK with { did: NonNullable } + // Any other response is considered unexpected behavior an should throw an error. + + if (response.status === 400) { + const data = xrpcErrorSchema.parse(payload) + if ( + data.error === 'InvalidRequest' && + data.message === 'Unable to resolve handle' + ) { + return null + } + } + + if (!response.ok) { + throw new TypeError('Invalid response from resolveHandle method') + } + + const value: unknown = payload?.did + + if (!isResolvedHandle(value)) { + throw new TypeError('Invalid DID returned from resolveHandle method') + } + + return value + } +} diff --git a/packages/internal/handle-resolver/src/atproto-doh-handle-resolver.ts b/packages/internal/handle-resolver/src/atproto-doh-handle-resolver.ts new file mode 100644 index 00000000000..87b6ed639a2 --- /dev/null +++ b/packages/internal/handle-resolver/src/atproto-doh-handle-resolver.ts @@ -0,0 +1,123 @@ +import { + AtprotoHandleResolver, + AtprotoHandleResolverOptions, +} from './atproto-handle-resolver.js' +import { HandleResolver } from './types.js' +import { ResolveTxt } from './internal-resolvers/dns-handle-resolver.js' + +export type AtprotoDohHandleResolverOptions = Omit< + AtprotoHandleResolverOptions, + 'resolveTxt' | 'resolveTxtFallback' +> & { + dohEndpoint: string | URL +} + +export class AtprotoDohHandleResolver + extends AtprotoHandleResolver + implements HandleResolver +{ + constructor(options: AtprotoDohHandleResolverOptions) { + super({ + ...options, + resolveTxt: dohResolveTxtFactory(options), + resolveTxtFallback: undefined, + }) + } +} + +/** + * Resolver for DNS-over-HTTPS (DoH) handles. Only works with servers supporting + * Google Flavoured "application/dns-json" queries. + * + * @see {@link https://developers.google.com/speed/public-dns/docs/doh/json} + * @see {@link https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json/} + * @todo Add support for DoH using application/dns-message (?) + */ +function dohResolveTxtFactory({ + dohEndpoint, + fetch = globalThis.fetch, +}: AtprotoDohHandleResolverOptions): ResolveTxt { + return async (hostname) => { + const url = new URL(dohEndpoint) + url.searchParams.set('type', 'TXT') + url.searchParams.set('name', hostname) + + const response = await fetch(url, { + method: 'GET', + headers: { accept: 'application/dns-json' }, + redirect: 'follow', + }) + try { + const contentType = response.headers.get('content-type')?.trim() + if (!response.ok) { + const message = contentType?.startsWith('text/plain') + ? await response.text() + : `Failed to resolve ${hostname}` + throw new TypeError(message) + } else if (contentType !== 'application/dns-json') { + throw new TypeError('Unexpected response from DoH server') + } + + const result = asResult(await response.json()) + return result.Answer?.filter(isAnswerTxt).map(extractTxtData) ?? null + } finally { + // Make sure to always cancel the response body as some engines (Node 👀) + // do not do this automatically. + // https://undici.nodejs.org/#/?id=garbage-collection + if (response.bodyUsed === false) { + // Handle rejection asynchronously + void response.body?.cancel().catch(onCancelError) + } + } + } +} + +function onCancelError(err: unknown) { + if (!(err instanceof DOMException) || err.name !== 'AbortError') { + console.error('An error occurred while cancelling the response body:', err) + } +} + +type Result = { Status: number; Answer?: Answer[] } +function isResult(result: unknown): result is Result { + if (typeof result !== 'object' || result === null) return false + if (!('Status' in result) || typeof result.Status !== 'number') return false + if ('Answer' in result && !isArrayOf(result.Answer, isAnswer)) return false + return true +} +function asResult(result: unknown): Result { + if (isResult(result)) return result + throw new TypeError('Invalid DoH response') +} + +function isArrayOf( + value: unknown, + predicate: (v: unknown) => v is T, +): value is T[] { + return Array.isArray(value) && value.every(predicate) +} + +type Answer = { name: string; type: number; data: string; TTL: number } +function isAnswer(answer: unknown): answer is Answer { + return ( + typeof answer === 'object' && + answer !== null && + 'name' in answer && + typeof answer.name === 'string' && + 'type' in answer && + typeof answer.type === 'number' && + 'data' in answer && + typeof answer.data === 'string' && + 'TTL' in answer && + typeof answer.TTL === 'number' + ) +} + +type AnswerTxt = Answer & { type: 16 } +function isAnswerTxt(answer: Answer): answer is AnswerTxt { + return answer.type === 16 +} + +function extractTxtData(answer: AnswerTxt): string { + return answer.data.replace(/^"|"$/g, '').replace(/\\"/g, '"') +} diff --git a/packages/internal/handle-resolver/src/atproto-handle-resolver.ts b/packages/internal/handle-resolver/src/atproto-handle-resolver.ts new file mode 100644 index 00000000000..640739d2b22 --- /dev/null +++ b/packages/internal/handle-resolver/src/atproto-handle-resolver.ts @@ -0,0 +1,75 @@ +import { + DnsHandleResolver, + ResolveTxt, +} from './internal-resolvers/dns-handle-resolver.js' +import { + WellKnownHandleResolver, + WellKnownHandleResolverOptions, +} from './internal-resolvers/well-known-handler-resolver.js' +import { HandleResolver, ResolveOptions, ResolvedHandle } from './types.js' + +export type { ResolveTxt } +export type AtprotoHandleResolverOptions = WellKnownHandleResolverOptions & { + resolveTxt: ResolveTxt + resolveTxtFallback?: ResolveTxt +} + +const noop = () => {} + +/** + * Implementation of the official ATPROTO handle resolution strategy. + * This implementation relies on two primitives: + * - HTTP Well-Known URI resolution (requires a `fetch()` implementation) + * - DNS TXT record resolution (requires a `resolveTxt()` function) + */ +export class AtprotoHandleResolver implements HandleResolver { + private readonly httpResolver: HandleResolver + private readonly dnsResolver: HandleResolver + private readonly dnsResolverFallback?: HandleResolver + + constructor(options: AtprotoHandleResolverOptions) { + this.httpResolver = new WellKnownHandleResolver(options) + this.dnsResolver = new DnsHandleResolver(options.resolveTxt) + this.dnsResolverFallback = options.resolveTxtFallback + ? new DnsHandleResolver(options.resolveTxtFallback) + : undefined + } + + async resolve( + handle: string, + options?: ResolveOptions, + ): Promise { + options?.signal?.throwIfAborted() + + const abortController = new AbortController() + const { signal } = abortController + options?.signal?.addEventListener('abort', () => abortController.abort(), { + signal, + }) + + const wrappedOptions = { ...options, signal } + + try { + const dnsPromise = this.dnsResolver.resolve(handle, wrappedOptions) + const httpPromise = this.httpResolver.resolve(handle, wrappedOptions) + + // Prevent uncaught promise rejection + httpPromise.catch(noop) + + const dnsRes = await dnsPromise + if (dnsRes) return dnsRes + + signal.throwIfAborted() + + const res = await httpPromise + if (res) return res + + signal.throwIfAborted() + + return this.dnsResolverFallback?.resolve(handle, wrappedOptions) ?? null + } finally { + // Cancel pending requests, and remove "abort" listener on incoming signal + abortController.abort() + } + } +} diff --git a/packages/internal/handle-resolver/src/cached-handle-resolver.ts b/packages/internal/handle-resolver/src/cached-handle-resolver.ts new file mode 100644 index 00000000000..d0ce2b238a4 --- /dev/null +++ b/packages/internal/handle-resolver/src/cached-handle-resolver.ts @@ -0,0 +1,32 @@ +import { CachedGetter, SimpleStore } from '@atproto-labs/simple-store' +import { SimpleStoreMemory } from '@atproto-labs/simple-store-memory' +import { ResolveOptions, HandleResolver, ResolvedHandle } from './types.js' + +export type HandleCache = SimpleStore + +export class CachedHandleResolver implements HandleResolver { + private getter: CachedGetter + + constructor( + /** + * The resolver that will be used to resolve handles. + */ + resolver: HandleResolver, + cache: HandleCache = new SimpleStoreMemory({ + max: 1000, + ttl: 10 * 60e3, + }), + ) { + this.getter = new CachedGetter( + (handle, options) => resolver.resolve(handle, options), + cache, + ) + } + + async resolve( + handle: string, + options?: ResolveOptions, + ): Promise { + return this.getter.get(handle, options) + } +} diff --git a/packages/internal/handle-resolver/src/index.ts b/packages/internal/handle-resolver/src/index.ts new file mode 100644 index 00000000000..f9253ab69df --- /dev/null +++ b/packages/internal/handle-resolver/src/index.ts @@ -0,0 +1,9 @@ +export * from './types.js' + +// Main Handle Resolvers strategies +export * from './app-view-handle-resolver.js' +export * from './atproto-doh-handle-resolver.js' +export * from './atproto-handle-resolver.js' + +// Handle Resolver Caching utility +export * from './cached-handle-resolver.js' diff --git a/packages/internal/handle-resolver/src/internal-resolvers/dns-handle-resolver.ts b/packages/internal/handle-resolver/src/internal-resolvers/dns-handle-resolver.ts new file mode 100644 index 00000000000..ada6ba4ecf5 --- /dev/null +++ b/packages/internal/handle-resolver/src/internal-resolvers/dns-handle-resolver.ts @@ -0,0 +1,38 @@ +import { HandleResolver, ResolvedHandle, isResolvedHandle } from '../types' + +const SUBDOMAIN = '_atproto' +const PREFIX = 'did=' + +/** + * DNS TXT record resolver. Return `null` if the hostname successfully does not + * resolve to a valid DID. Throw an error if an unexpected error occurs. + */ +export type ResolveTxt = (hostname: string) => Promise + +export class DnsHandleResolver implements HandleResolver { + constructor(protected resolveTxt: ResolveTxt) {} + + async resolve(handle: string): Promise { + const results = await this.resolveTxt.call(null, `${SUBDOMAIN}.${handle}`) + + if (!results) return null + + for (let i = 0; i < results.length; i++) { + // If the line does not start with "did=", skip it + if (!results[i].startsWith(PREFIX)) continue + + // Ensure no other entry starting with "did=" follows + for (let j = i + 1; j < results.length; j++) { + if (results[j].startsWith(PREFIX)) return null + } + + // Note: No trimming (to be consistent with spec) + const did = results[i].slice(PREFIX.length) + + // Invalid DBS record + return isResolvedHandle(did) ? did : null + } + + return null + } +} diff --git a/packages/internal/handle-resolver/src/internal-resolvers/well-known-handler-resolver.ts b/packages/internal/handle-resolver/src/internal-resolvers/well-known-handler-resolver.ts new file mode 100644 index 00000000000..24eaf675ea6 --- /dev/null +++ b/packages/internal/handle-resolver/src/internal-resolvers/well-known-handler-resolver.ts @@ -0,0 +1,56 @@ +import { + ResolveOptions, + HandleResolver, + ResolvedHandle, + isResolvedHandle, +} from '../types.js' + +export type WellKnownHandleResolverOptions = { + /** + * Fetch function to use for HTTP requests. Allows customizing the request + * behavior, e.g. adding headers, setting a timeout, mocking, etc. The + * provided fetch function will be wrapped with a safeFetchWrap function that + * adds SSRF protection. + * + * @default `globalThis.fetch` + */ + fetch?: typeof globalThis.fetch +} + +export class WellKnownHandleResolver implements HandleResolver { + protected readonly fetch: typeof globalThis.fetch + + constructor(options?: WellKnownHandleResolverOptions) { + this.fetch = options?.fetch ?? globalThis.fetch + } + + public async resolve( + handle: string, + options?: ResolveOptions, + ): Promise { + const url = new URL('/.well-known/atproto-did', `https://${handle}`) + + const headers = new Headers() + if (options?.noCache) headers.set('cache-control', 'no-cache') + + try { + const response = await this.fetch.call(null, url, { + headers, + signal: options?.signal, + redirect: 'error', + }) + const text = await response.text() + const firstLine = text.split('\n')[0]!.trim() + + if (isResolvedHandle(firstLine)) return firstLine + + return null + } catch (err) { + // The the request failed, assume the handle does not resolve to a DID, + // unless the failure was due to the signal being aborted. + options?.signal?.throwIfAborted() + + return null + } + } +} diff --git a/packages/internal/handle-resolver/src/types.ts b/packages/internal/handle-resolver/src/types.ts new file mode 100644 index 00000000000..8f49f3ceae1 --- /dev/null +++ b/packages/internal/handle-resolver/src/types.ts @@ -0,0 +1,33 @@ +import { Did, isAtprotoDidWeb, isDidPlc } from '@atproto/did' + +export type ResolveOptions = { + signal?: AbortSignal + noCache?: boolean +} + +/** + * @see {@link https://atproto.com/specs/did#blessed-did-methods} + */ +export type ResolvedHandle = null | Did<'plc' | 'web'> + +export { type Did } + +/** + * @see {@link https://atproto.com/specs/did#blessed-did-methods} + */ +export function isResolvedHandle( + value: T, +): value is T & ResolvedHandle { + return value === null || isDidPlc(value) || isAtprotoDidWeb(value) +} + +export interface HandleResolver { + /** + * @returns the DID that corresponds to the given handle, or `null` if no DID + * is found. `null` should only be returned if no unexpected behavior occurred + * during the resolution process. + * @throws Error if the resolution method fails due to an unexpected error, or + * if the resolution is aborted ({@link ResolveOptions}). + */ + resolve(handle: string, options?: ResolveOptions): Promise +} diff --git a/packages/internal/handle-resolver/tsconfig.build.json b/packages/internal/handle-resolver/tsconfig.build.json new file mode 100644 index 00000000000..2ef4f334355 --- /dev/null +++ b/packages/internal/handle-resolver/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig/isomorphic.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/internal/handle-resolver/tsconfig.json b/packages/internal/handle-resolver/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/handle-resolver/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/identity-resolver/package.json b/packages/internal/identity-resolver/package.json new file mode 100644 index 00000000000..f291cf99ae2 --- /dev/null +++ b/packages/internal/identity-resolver/package.json @@ -0,0 +1,38 @@ +{ + "name": "@atproto-labs/identity-resolver", + "version": "0.0.1", + "license": "MIT", + "description": "A library resolving ATPROTO identities", + "keywords": [ + "atproto", + "identity", + "isomorphic", + "resolver" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/identity-resolver" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/did-resolver": "workspace:*", + "@atproto-labs/handle-resolver": "workspace:*", + "@atproto/syntax": "workspace:*" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/internal/identity-resolver/src/identity-resolver.ts b/packages/internal/identity-resolver/src/identity-resolver.ts new file mode 100644 index 00000000000..3aa6fb90f0d --- /dev/null +++ b/packages/internal/identity-resolver/src/identity-resolver.ts @@ -0,0 +1,78 @@ +import { + Did, + DidDocument, + ResolveOptions as DidResolveOptions, + DidResolver, + DidService, +} from '@atproto-labs/did-resolver' +import { + ResolveOptions as HandleResolveOptions, + HandleResolver, + ResolvedHandle, + isResolvedHandle, +} from '@atproto-labs/handle-resolver' +import { normalizeAndEnsureValidHandle } from '@atproto/syntax' + +export type ResolvedIdentity = { + did: NonNullable + pds: URL +} + +export type ResolveOptions = DidResolveOptions & HandleResolveOptions + +export class IdentityResolver { + constructor( + readonly didResolver: DidResolver<'plc' | 'web'>, + readonly handleResolver: HandleResolver, + ) {} + + public async resolve( + input: string, + options?: ResolveOptions, + ): Promise { + const did = isResolvedHandle(input) + ? input // Already a did + : await this.handleResolver.resolve( + normalizeAndEnsureValidHandle(input), + options, + ) + + options?.signal?.throwIfAborted() + + if (!did) { + throw new TypeError(`Handle "${input}" does not resolve to a DID`) + } + + const document = await this.didResolver.resolve(did, options) + + const service = document.service?.find( + isAtprotoPersonalDataServerService<'plc' | 'web'>, + document, + ) + + if (!service) { + throw new TypeError( + `No valid "AtprotoPersonalDataServer" service found in "${did}" DID document`, + ) + } + + const pds = new URL(service.serviceEndpoint) + + return { did, pds } + } +} + +function isAtprotoPersonalDataServerService( + this: DidDocument, + s: DidService, +): s is { + id: '#atproto_pds' | `${Did}#atproto_pds` + type: 'AtprotoPersonalDataServer' + serviceEndpoint: string +} { + return ( + typeof s.serviceEndpoint === 'string' && + s.type === 'AtprotoPersonalDataServer' && + (s.id === '#atproto_pds' || s.id === `${this.id}#atproto_pds`) + ) +} diff --git a/packages/internal/identity-resolver/src/index.ts b/packages/internal/identity-resolver/src/index.ts new file mode 100644 index 00000000000..69c3c9e910d --- /dev/null +++ b/packages/internal/identity-resolver/src/index.ts @@ -0,0 +1 @@ +export * from './identity-resolver.js' diff --git a/packages/internal/identity-resolver/tsconfig.build.json b/packages/internal/identity-resolver/tsconfig.build.json new file mode 100644 index 00000000000..9f3c5252d4c --- /dev/null +++ b/packages/internal/identity-resolver/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/internal/identity-resolver/tsconfig.json b/packages/internal/identity-resolver/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/identity-resolver/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/pipe/package.json b/packages/internal/pipe/package.json new file mode 100644 index 00000000000..734dcd090b8 --- /dev/null +++ b/packages/internal/pipe/package.json @@ -0,0 +1,32 @@ +{ + "name": "@atproto-labs/pipe", + "version": "0.0.1", + "license": "MIT", + "description": "Library for combining multiple functions into a single function.", + "keywords": [ + "atproto", + "transformer" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/pipe" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": {}, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/internal/pipe/src/index.ts b/packages/internal/pipe/src/index.ts new file mode 100644 index 00000000000..fee2d89a918 --- /dev/null +++ b/packages/internal/pipe/src/index.ts @@ -0,0 +1,2 @@ +export { pipe, pipeTwo } from './pipe.js' +export { type Transformer } from './transformer.js' diff --git a/packages/internal/pipe/src/pipe.ts b/packages/internal/pipe/src/pipe.ts new file mode 100644 index 00000000000..2ded877cc20 --- /dev/null +++ b/packages/internal/pipe/src/pipe.ts @@ -0,0 +1,63 @@ +import { Transformer } from './transformer.js' + +type PipelineInput[]> = T extends [ + Transformer, + ...any[], +] + ? I + : T extends Transformer[] + ? I + : never + +type PipelineOutput[]> = T extends [ + ...any[], + Transformer, +] + ? O + : T extends Transformer[] + ? O + : never + +type Pipeline< + F extends readonly Transformer[], + Acc extends readonly Transformer[] = [], +> = F extends [Transformer] + ? [...Acc, Transformer] + : F extends [Transformer, ...infer Tail] + ? Tail extends [Transformer, ...any[]] + ? Pipeline]> + : Acc + : Acc + +/** + * This utility function allows to properly type a pipeline of transformers. + * + * @example + * ```ts + * // Will be typed as "(input: string) => Promise" + * const parse = pipe( + * async (input: string) => JSON.parse(input), + * async (input: unknown) => { + * if (typeof input === 'number') return input + * throw new TypeError('Invalid input') + * }, + * (input: number) => input * 2, + * ) + * ``` + */ +export function pipe(): never +export function pipe[]>( + ...pipeline: Pipeline extends T ? T : Pipeline +): (input: PipelineInput) => Promise> +export function pipe[]>( + ...pipeline: Pipeline extends T ? T : Pipeline +): (input: PipelineInput) => Promise> { + return pipeline.reduce(pipeTwo) +} + +export function pipeTwo( + first: Transformer, + second: Transformer, +): (input: I) => Promise { + return async (input) => second(await first(input)) +} diff --git a/packages/internal/pipe/src/transformer.ts b/packages/internal/pipe/src/transformer.ts new file mode 100644 index 00000000000..9cd7f965f00 --- /dev/null +++ b/packages/internal/pipe/src/transformer.ts @@ -0,0 +1 @@ +export type Transformer = (input: I) => O | PromiseLike diff --git a/packages/internal/pipe/tsconfig.build.json b/packages/internal/pipe/tsconfig.build.json new file mode 100644 index 00000000000..9f3c5252d4c --- /dev/null +++ b/packages/internal/pipe/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/internal/pipe/tsconfig.json b/packages/internal/pipe/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/pipe/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/rollup-plugin-bundle-manifest/README.md b/packages/internal/rollup-plugin-bundle-manifest/README.md new file mode 100644 index 00000000000..87d0f82ba8e --- /dev/null +++ b/packages/internal/rollup-plugin-bundle-manifest/README.md @@ -0,0 +1,99 @@ +# @atproto-labs/rollup-plugin-bundle-manifest + +This Rollup plugin allows to generate a (JSON) manifest containing the output +files of a Rollup build. The manifest will look as follows: + +```json +{ + "main.js": { + "type": "chunk", + "mime": "application/javascript", + "dynamicImports": [], + "isDynamicEntry": false, + "isEntry": true, + "isImplicitEntry": false, + "name": "main", + "sha256": "", + "data": "" + }, + "main.js.map": { + "type": "asset", + "mime": "application/json", + "sha256": "", + "data": "" + }, + "main.css": { + "type": "asset", + "mime": "text/css", + "sha256": "", + "data": "" + } + // ... more entries as needed +} +``` + +This manifest will typically be useful for a backend service that serves the +frontend assets, as it can be used to determine the correct `Content-Type` and +and file integrity (via the SHA-256 hash), without having to read the files +themselves. + +## Usage + +```js +// rollup.config.js + +import bundleManifest from '@atproto-labs/rollup-plugin-bundle-manifest' + +export default { + input: 'src/index.js', + output: { + dir: 'dist', + format: 'es', + }, + plugins: [ + bundleManifest({ + name: 'bundle-manifest.json', + + // Optional: should the asset data be embedded (as base64 string) in the manifest? + data: false, + }), + ], +} +``` + +## Options + +- `name` (string): The name of the manifest file. Defaults to `bundle-manifest.json`. +- `data` (boolean): Whether to embed the asset data in the manifest. Defaults to `false`. + +## Example + +```js +const assetManifest = require('./dist/bundle-manifest.json') + +const app = express() + +app.use((req, res, next) => { + const asset = assetManifest[req.path.slice(1)] + if (!asset) return next() + + res.setHeader('Content-Type', asset.mime) + res.setHeader('Content-Length', asset.data.length) + + res.end(Buffer.from(asset.data, 'base64')) +}) + +app.use((req, res, next) => { + res.setHeader( + 'Content-Security-Policy', + buildCSP(assetManifest), // Not provided here + ) + + // Serve the index.html file + res.sendFile('index.html') +}) +``` + +## License + +MIT diff --git a/packages/internal/rollup-plugin-bundle-manifest/package.json b/packages/internal/rollup-plugin-bundle-manifest/package.json new file mode 100644 index 00000000000..f128853355a --- /dev/null +++ b/packages/internal/rollup-plugin-bundle-manifest/package.json @@ -0,0 +1,39 @@ +{ + "name": "@atproto-labs/rollup-plugin-bundle-manifest", + "version": "0.0.1", + "license": "MIT", + "description": "Library for generating a manifest of bundled files from a Rollup build", + "keywords": [ + "atproto", + "rollup", + "manifest" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/rollup-plugin-bundle-manifest" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "mime": "^3.0.0" + }, + "peerDependencies": { + "rollup": "^4.0.0" + }, + "devDependencies": { + "rollup": "^4.10.0", + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/internal/rollup-plugin-bundle-manifest/src/index.ts b/packages/internal/rollup-plugin-bundle-manifest/src/index.ts new file mode 100644 index 00000000000..83c4b0266cd --- /dev/null +++ b/packages/internal/rollup-plugin-bundle-manifest/src/index.ts @@ -0,0 +1,76 @@ +import { createHash } from 'node:crypto' +import { extname } from 'node:path' + +import mime from 'mime' +import { Plugin } from 'rollup' + +type AssetItem = { + type: 'asset' + mime?: string + sha256: string + data?: string +} + +type ChunkItem = { + type: 'chunk' + mime: string + sha256: string + dynamicImports: string[] + isDynamicEntry: boolean + isEntry: boolean + isImplicitEntry: boolean + name: string + data?: string +} + +export type ManifestItem = AssetItem | ChunkItem + +export type Manifest = Record + +export default function bundleManifest({ + name = 'bundle-manifest.json', + data = false, +}: { + name?: string + data?: boolean +} = {}): Plugin { + return { + name: 'bundle-manifest', + generateBundle(outputOptions, bundle) { + const manifest: Manifest = {} + + for (const [fileName, chunk] of Object.entries(bundle)) { + if (chunk.type === 'asset') { + manifest[fileName] = { + type: chunk.type, + data: data + ? Buffer.from(chunk.source).toString('base64') + : undefined, + mime: mime.getType(extname(fileName)) || undefined, + sha256: createHash('sha256').update(chunk.source).digest('base64'), + } + } + + if (chunk.type === 'chunk') { + manifest[fileName] = { + type: chunk.type, + data: data ? Buffer.from(chunk.code).toString('base64') : undefined, + mime: 'application/javascript', + sha256: createHash('sha256').update(chunk.code).digest('base64'), + dynamicImports: chunk.dynamicImports, + isDynamicEntry: chunk.isDynamicEntry, + isEntry: chunk.isEntry, + isImplicitEntry: chunk.isImplicitEntry, + name: chunk.name, + } + } + } + + this.emitFile({ + type: 'asset', + fileName: name, + source: JSON.stringify(manifest, null, 2), + }) + }, + } +} diff --git a/packages/internal/rollup-plugin-bundle-manifest/tsconfig.build.json b/packages/internal/rollup-plugin-bundle-manifest/tsconfig.build.json new file mode 100644 index 00000000000..ea00aba058e --- /dev/null +++ b/packages/internal/rollup-plugin-bundle-manifest/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/node.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/internal/rollup-plugin-bundle-manifest/tsconfig.json b/packages/internal/rollup-plugin-bundle-manifest/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/rollup-plugin-bundle-manifest/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/simple-store-memory/package.json b/packages/internal/simple-store-memory/package.json new file mode 100644 index 00000000000..66fe801c025 --- /dev/null +++ b/packages/internal/simple-store-memory/package.json @@ -0,0 +1,36 @@ +{ + "name": "@atproto-labs/simple-store-memory", + "version": "0.0.1", + "license": "MIT", + "description": "Memory based simple-store implementation", + "keywords": [ + "cache", + "isomorphic", + "memory" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/simple-store-memory" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/simple-store": "workspace:*", + "lru-cache": "^10.2.0" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/internal/simple-store-memory/src/index.ts b/packages/internal/simple-store-memory/src/index.ts new file mode 100644 index 00000000000..b2150c969a8 --- /dev/null +++ b/packages/internal/simple-store-memory/src/index.ts @@ -0,0 +1,99 @@ +import { SimpleStore, Key, Value } from '@atproto-labs/simple-store' +import { LRUCache } from 'lru-cache' + +import { roughSizeOfObject } from './util.js' + +export type SimpleStoreMemoryOptions = { + /** + * The maximum number of entries in the cache. + */ + max?: number + + /** + * The time-to-live of a cache entry, in milliseconds. + */ + ttl?: number + + /** + * Whether to automatically prune expired entries. + */ + ttlAutopurge?: boolean + + /** + * The maximum total size of the cache, in units defined by the sizeCalculation + * function. + * + * @default No limit + */ + maxSize?: number + + /** + * The maximum size of a single cache entry, in units defined by the + * sizeCalculation function. + * + * @default No limit + */ + maxEntrySize?: number + + /** + * A function that returns the size of a value. The size is used to determine + * when the cache should be pruned, based on `maxSize`. + * + * @default The (rough) size in bytes used in memory. + */ + sizeCalculation?: (value: V, key: K) => number +} & ( // Memory is not infinite, so at least one pruning option is required. + | { max: number } + | { maxSize: number } + | { ttl: number; ttlAutopurge: boolean } +) + +// LRUCache does not allow storing "null", so we use a symbol to represent it. +const nullSymbol = Symbol('nullItem') +type AsLruValue = V extends null + ? typeof nullSymbol + : Exclude +const toLruValue = (value: V) => + (value === null ? nullSymbol : value) as AsLruValue +const fromLruValue = (value: AsLruValue) => + (value === nullSymbol ? null : value) as V + +export class SimpleStoreMemory + implements SimpleStore +{ + #cache: LRUCache> + + constructor({ sizeCalculation, ...options }: SimpleStoreMemoryOptions) { + this.#cache = new LRUCache>({ + ...options, + allowStale: false, + updateAgeOnGet: false, + updateAgeOnHas: false, + sizeCalculation: sizeCalculation + ? (value, key) => sizeCalculation(fromLruValue(value), key) + : options.maxEntrySize != null || options.maxSize != null + ? // maxEntrySize and maxSize require a size calculation function. + roughSizeOfObject + : undefined, + }) + } + + get(key: K): V | undefined { + const value = this.#cache.get(key) + if (value === undefined) return undefined + + return fromLruValue(value) + } + + set(key: K, value: V): void { + this.#cache.set(key, toLruValue(value)) + } + + del(key: K): void { + this.#cache.delete(key) + } + + clear(): void { + this.#cache.clear() + } +} diff --git a/packages/internal/simple-store-memory/src/util.ts b/packages/internal/simple-store-memory/src/util.ts new file mode 100644 index 00000000000..0fb05fee470 --- /dev/null +++ b/packages/internal/simple-store-memory/src/util.ts @@ -0,0 +1,77 @@ +const knownSizes = new WeakMap() + +/** + * @see {@link https://stackoverflow.com/a/11900218/356537} + */ +export function roughSizeOfObject(value: unknown): number { + const objectList = new Set() + const stack = [value] // This would be more efficient using a circular buffer + let bytes = 0 + + while (stack.length) { + const value = stack.pop() + + // > All objects on the heap start with a shape descriptor, which takes one + // > pointer size (usually 4 bytes these days, thanks to "pointer + // > compression" on 64-bit platforms). + + switch (typeof value) { + // Types are ordered by frequency + case 'string': + // https://stackoverflow.com/a/68791382/356537 + bytes += 12 + 4 * Math.ceil(value.length / 4) + break + case 'number': + bytes += 12 // Shape descriptor + double + break + case 'boolean': + bytes += 4 // Shape descriptor + break + case 'object': + bytes += 4 // Shape descriptor + + if (value === null) { + break + } + + if (knownSizes.has(value)) { + bytes += knownSizes.get(value)! + break + } + + if (objectList.has(value)) continue + objectList.add(value) + + if (Array.isArray(value)) { + bytes += 4 + stack.push(...value) + } else { + bytes += 8 + const keys = Object.getOwnPropertyNames(value) + for (let i = 0; i < keys.length; i++) { + bytes += 4 + const key = keys[i] + const val = value[key] + if (val !== undefined) stack.push(val) + stack.push(key) + } + } + break + case 'function': + bytes += 8 // Shape descriptor + pointer (assuming functions are shared) + break + case 'symbol': + bytes += 8 // Shape descriptor + pointer + break + case 'bigint': + bytes += 16 // Shape descriptor + BigInt + break + } + } + + if (typeof value === 'object' && value !== null) { + knownSizes.set(value, bytes) + } + + return bytes +} diff --git a/packages/internal/simple-store-memory/tsconfig.build.json b/packages/internal/simple-store-memory/tsconfig.build.json new file mode 100644 index 00000000000..2ef4f334355 --- /dev/null +++ b/packages/internal/simple-store-memory/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig/isomorphic.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/internal/simple-store-memory/tsconfig.json b/packages/internal/simple-store-memory/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/simple-store-memory/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/internal/simple-store/package.json b/packages/internal/simple-store/package.json new file mode 100644 index 00000000000..c62ab5b46c3 --- /dev/null +++ b/packages/internal/simple-store/package.json @@ -0,0 +1,32 @@ +{ + "name": "@atproto-labs/simple-store", + "version": "0.0.1", + "license": "MIT", + "description": "Simple store interfaces & utilities", + "keywords": [ + "cache", + "isomorphic" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/internal/simple-store" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": {}, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/internal/simple-store/src/cached-getter.ts b/packages/internal/simple-store/src/cached-getter.ts new file mode 100644 index 00000000000..077c0b92e89 --- /dev/null +++ b/packages/internal/simple-store/src/cached-getter.ts @@ -0,0 +1,160 @@ +import { Awaitable, SimpleStore, Key, Value } from './simple-store.js' + +export type GetCachedOptions = { + signal?: AbortSignal + + /** + * Do not use the cache to get the value. Always get a new value from the + * getter function. + * + * @default false + */ + noCache?: boolean + + /** + * When getting a value from the cache, allow the value to be returned even if + * it is stale. + * + * Has no effect if the `isStale` option was not provided to the CachedGetter. + * + * @default true // If the CachedGetter has an isStale option + * @default false // If no isStale option was provided to the CachedGetter + */ + allowStale?: boolean +} + +export type Getter = ( + key: K, + options: undefined | GetCachedOptions, + storedValue: undefined | V, +) => Awaitable + +export type CachedGetterOptions = { + isStale?: (key: K, value: V) => boolean | PromiseLike + onStoreError?: (err: unknown, key: K, value: V) => void | PromiseLike + deleteOnError?: ( + err: unknown, + key: K, + value: V, + ) => boolean | PromiseLike +} + +type PendingItem = Promise<{ value: V; isFresh: boolean }> + +const returnTrue = () => true +const returnFalse = () => false + +/** + * Wrapper utility that uses a store to speed up the retrieval of values from an + * (expensive) getter function. + */ +export class CachedGetter { + private pending = new Map>() + + constructor( + readonly getter: Getter, + readonly store: SimpleStore, + readonly options?: Readonly>, + ) {} + + async get(key: K, options?: GetCachedOptions): Promise { + options?.signal?.throwIfAborted() + + const isStale = this.options?.isStale + + const allowStored: (value: V) => Awaitable = options?.noCache + ? returnFalse // Never allow stored values to be returned + : options?.allowStale || isStale == null + ? returnTrue // Always allow stored values to be returned + : async (value: V) => !(await isStale(key, value)) + + // As long as concurrent requests are made for the same key, only one + // request will be made to the cache & getter function at a time. This works + // because there is no async operation between the while() loop and the + // pending.set() call. Because of the "single threaded" nature of + // JavaScript, the pending item will be set before the next iteration of the + // while loop. + let previousExecutionFlow: undefined | PendingItem + while ((previousExecutionFlow = this.pending.get(key))) { + try { + const { isFresh, value } = await previousExecutionFlow + + if (isFresh) return value + if (await allowStored(value)) return value + } catch { + // Ignore errors from previous execution flows (they will have been + // propagated by that flow). + } + + options?.signal?.throwIfAborted() + } + + const currentExecutionFlow: PendingItem = Promise.resolve() + .then(async () => { + const storedValue = await this.getStored(key, options) + if (storedValue !== undefined && (await allowStored(storedValue))) { + // Use the stored value as return value for the current execution + // flow. Notify other concurrent execution flows (that should be + // "stuck" in the loop before until this promise resolves) that we got + // a value, but that it came from the store (isFresh = false). + return { isFresh: false, value: storedValue } + } + + return Promise.resolve() + .then(async () => (0, this.getter)(key, options, storedValue)) + .catch(async (err) => { + if (storedValue !== undefined) { + if (await this.options?.deleteOnError?.(err, key, storedValue)) { + await this.delStored(key) + } + } + throw err + }) + .then(async (value) => { + // The value should be stored even is the signal was aborted. + await this.setStored(key, value) + return { isFresh: true, value } + }) + }) + .finally(() => { + this.pending.delete(key) + }) + + if (this.pending.has(key)) { + // This should never happen. Indeed, there must not be any 'await' + // statement between this and the loop iteration check meaning that + // this.pending.get returned undefined. It is there to catch bugs that + // would occur in future changes to the code. + throw new Error('Concurrent request for the same key') + } + + this.pending.set(key, currentExecutionFlow) + + const { value } = await currentExecutionFlow + return value + } + + bind(key: K): (options?: GetCachedOptions) => Promise { + return async (options) => this.get(key, options) + } + + async getStored(key: K, options?: GetCachedOptions): Promise { + try { + return await this.store.get(key, options) + } catch (err) { + return undefined + } + } + + async setStored(key: K, value: V): Promise { + try { + await this.store.set(key, value) + } catch (err) { + await this.options?.onStoreError?.(err, key, value) + } + } + + async delStored(key: K): Promise { + await this.store.del(key) + } +} diff --git a/packages/internal/simple-store/src/index.ts b/packages/internal/simple-store/src/index.ts new file mode 100644 index 00000000000..c217e9d27fd --- /dev/null +++ b/packages/internal/simple-store/src/index.ts @@ -0,0 +1,2 @@ +export * from './cached-getter.js' +export * from './simple-store.js' diff --git a/packages/internal/simple-store/src/simple-store.ts b/packages/internal/simple-store/src/simple-store.ts new file mode 100644 index 00000000000..1ae6c6fa3bc --- /dev/null +++ b/packages/internal/simple-store/src/simple-store.ts @@ -0,0 +1,16 @@ +export type Awaitable = V | PromiseLike + +export type Key = string | number +export type Value = NonNullable | null + +export type GetOptions = { signal?: AbortSignal } + +export interface SimpleStore { + /** + * @return undefined if the key is not in the store (which is why Value cannot contain "undefined"). + */ + get: (key: K, options?: GetOptions) => Awaitable + set: (key: K, value: V) => Awaitable + del: (key: K) => Awaitable + clear?: () => Awaitable +} diff --git a/packages/internal/simple-store/tsconfig.build.json b/packages/internal/simple-store/tsconfig.build.json new file mode 100644 index 00000000000..2ef4f334355 --- /dev/null +++ b/packages/internal/simple-store/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig/isomorphic.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/internal/simple-store/tsconfig.json b/packages/internal/simple-store/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/internal/simple-store/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/lex-cli/package.json b/packages/lex-cli/package.json index 6eb3dad6ee9..f14cb0f8324 100644 --- a/packages/lex-cli/package.json +++ b/packages/lex-cli/package.json @@ -29,6 +29,6 @@ "prettier": "^3.2.5", "ts-morph": "^16.0.0", "yesno": "^0.4.0", - "zod": "^3.21.4" + "zod": "^3.23.8" } } diff --git a/packages/lexicon/package.json b/packages/lexicon/package.json index 3a7a035b985..abea5d22607 100644 --- a/packages/lexicon/package.json +++ b/packages/lexicon/package.json @@ -24,7 +24,7 @@ "@atproto/syntax": "workspace:^", "iso-datestring-validator": "^2.2.2", "multiformats": "^9.9.0", - "zod": "^3.21.4" + "zod": "^3.23.8" }, "devDependencies": { "jest": "^28.1.2" diff --git a/packages/lexicon/src/lexicons.ts b/packages/lexicon/src/lexicons.ts index dfd84be9bcc..382c8e898c7 100644 --- a/packages/lexicon/src/lexicons.ts +++ b/packages/lexicon/src/lexicons.ts @@ -22,18 +22,33 @@ import * as ComplexValidators from './validators/complex' /** * A collection of compiled lexicons. */ -export class Lexicons { +export class Lexicons implements Iterable { docs: Map = new Map() defs: Map = new Map() - constructor(docs?: LexiconDoc[]) { - if (docs?.length) { + constructor(docs?: Iterable) { + if (docs) { for (const doc of docs) { this.add(doc) } } } + /** + * @example clone a lexicon: + * ```ts + * const clone = new Lexicons(originalLexicon) + * ``` + * + * @example get docs array: + * ```ts + * const docs = Array.from(lexicons) + * ``` + */ + [Symbol.iterator](): Iterator { + return this.docs.values() + } + /** * Add a lexicon doc. */ diff --git a/packages/oauth/jwk-jose/package.json b/packages/oauth/jwk-jose/package.json new file mode 100644 index 00000000000..3bdd4b8372d --- /dev/null +++ b/packages/oauth/jwk-jose/package.json @@ -0,0 +1,36 @@ +{ + "name": "@atproto/jwk-jose", + "version": "0.0.1", + "license": "MIT", + "description": "`jose` based implementation of @atproto/jwk Key's", + "keywords": [ + "atproto", + "jwk", + "jose" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/jwk-jose" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto/jwk": "workspace:*", + "jose": "^5.2.0" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/oauth/jwk-jose/src/index.ts b/packages/oauth/jwk-jose/src/index.ts new file mode 100644 index 00000000000..84d51ccddd8 --- /dev/null +++ b/packages/oauth/jwk-jose/src/index.ts @@ -0,0 +1 @@ +export * from './jose-key.js' diff --git a/packages/oauth/jwk-jose/src/jose-key.ts b/packages/oauth/jwk-jose/src/jose-key.ts new file mode 100644 index 00000000000..75253b5234a --- /dev/null +++ b/packages/oauth/jwk-jose/src/jose-key.ts @@ -0,0 +1,194 @@ +import { JwtVerifyError } from '@atproto/jwk' +import { + SignJWT, + exportJWK, + generateKeyPair, + importJWK, + importPKCS8, + jwtVerify, + type GenerateKeyPairOptions, + type GenerateKeyPairResult, + type JWK, + type JWTVerifyOptions, + type KeyLike, +} from 'jose' +import { JOSEError } from 'jose/errors' + +import { + Jwk, + JwkError, + JwtCreateError, + JwtHeader, + JwtPayload, + Key, + SignedJwt, + VerifyOptions, + VerifyPayload, + VerifyResult, + jwkValidator, +} from '@atproto/jwk' +import { either } from './util' + +export type Importable = string | KeyLike | Jwk + +export type { GenerateKeyPairOptions, GenerateKeyPairResult } + +export class JoseKey extends Key { + #keyObj?: KeyLike | Uint8Array + + protected async getKey() { + try { + return (this.#keyObj ||= await importJWK(this.jwk as JWK)) + } catch (cause) { + throw new JwkError('Failed to import JWK', undefined, { cause }) + } + } + + async createJwt(header: JwtHeader, payload: JwtPayload) { + if (header.kid && header.kid !== this.kid) { + throw new JwtCreateError( + `Invalid "kid" (${header.kid}) used to sign with key "${this.kid}"`, + ) + } + + if (!header.alg || !this.algorithms.includes(header.alg)) { + throw new JwtCreateError( + `Invalid "alg" (${header.alg}) used to sign with key "${this.kid}"`, + ) + } + + const keyObj = await this.getKey() + return new SignJWT(payload) + .setProtectedHeader({ ...header, kid: this.kid }) + .sign(keyObj) as Promise + } + + async verifyJwt< + P extends VerifyPayload = JwtPayload, + C extends string = string, + >(token: SignedJwt, options?: VerifyOptions): Promise> { + try { + const keyObj = await this.getKey() + const result = await jwtVerify(token, keyObj, { + ...options, + algorithms: this.algorithms, + } as JWTVerifyOptions) + + return result as VerifyResult + } catch (error) { + if (error instanceof JOSEError) { + throw new JwtVerifyError(error.message, error.code, { cause: error }) + } else { + throw JwtVerifyError.from(error) + } + } + } + + static async generateKeyPair( + allowedAlgos: readonly string[] = ['ES256'], + options?: GenerateKeyPairOptions, + ) { + if (!allowedAlgos.length) { + throw new JwkError('No algorithms provided for key generation') + } + + const errors: unknown[] = [] + for (const alg of allowedAlgos) { + try { + return await generateKeyPair(alg, options) + } catch (err) { + errors.push(err) + } + } + + throw new JwkError('Failed to generate key pair', undefined, { + cause: new AggregateError(errors, 'None of the algorithms worked'), + }) + } + + static async generate( + allowedAlgos: string[] = ['ES256'], + kid?: string, + options?: Omit, + ) { + const kp = await this.generateKeyPair(allowedAlgos, { + ...options, + extractable: true, + }) + return this.fromImportable(kp.privateKey, kid) + } + + static async fromImportable( + input: Importable, + kid?: string, + ): Promise { + if (typeof input === 'string') { + // PKCS8 + if (input.startsWith('-----')) { + // The "alg" is only needed in WebCrypto (NodeJS will be fine) + return this.fromPKCS8(input, '', kid) + } + + // Jwk (string) + if (input.startsWith('{')) { + return this.fromJWK(input, kid) + } + + throw new JwkError('Invalid input') + } + + if (typeof input === 'object') { + // Jwk + if ('kty' in input || 'alg' in input) { + return this.fromJWK(input, kid) + } + + // KeyLike + if (!kid) throw new JwkError('Missing "kid" for KeyLike key') + return this.fromKeyLike(input, kid) + } + + throw new JwkError('Invalid input') + } + + /** + * @see {@link exportJWK} + */ + static async fromKeyLike( + keyLike: KeyLike | Uint8Array, + kid?: string, + alg?: string, + ): Promise { + const jwk = await exportJWK(keyLike) + if (alg) { + if (!jwk.alg) jwk.alg = alg + else if (jwk.alg !== alg) throw new JwkError('Invalid "alg" in JWK') + } + return this.fromJWK(jwk, kid) + } + + /** + * @see {@link importPKCS8} + */ + static async fromPKCS8( + pem: string, + alg: string, + kid?: string, + ): Promise { + const keyLike = await importPKCS8(pem, alg, { extractable: true }) + return this.fromKeyLike(keyLike, kid) + } + + static async fromJWK( + input: string | Record, + inputKid?: string, + ): Promise { + const jwk = typeof input === 'string' ? JSON.parse(input) : input + if (!jwk || typeof jwk !== 'object') throw new JwkError('Invalid JWK') + + const kid = either(jwk.kid, inputKid) + const use = jwk.use || 'sig' + + return new JoseKey(jwkValidator.parse({ ...jwk, kid, use })) + } +} diff --git a/packages/oauth/jwk-jose/src/util.ts b/packages/oauth/jwk-jose/src/util.ts new file mode 100644 index 00000000000..f75cdb66718 --- /dev/null +++ b/packages/oauth/jwk-jose/src/util.ts @@ -0,0 +1,9 @@ +export function either( + a?: T, + b?: T, +): T | undefined { + if (a != null && b != null && a !== b) { + throw new TypeError(`Expected "${b}", got "${a}"`) + } + return a ?? b ?? undefined +} diff --git a/packages/oauth/jwk-jose/tsconfig.build.json b/packages/oauth/jwk-jose/tsconfig.build.json new file mode 100644 index 00000000000..10eec9ceea2 --- /dev/null +++ b/packages/oauth/jwk-jose/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/nodenext.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/oauth/jwk-jose/tsconfig.json b/packages/oauth/jwk-jose/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/oauth/jwk-jose/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/oauth/jwk-webcrypto/package.json b/packages/oauth/jwk-webcrypto/package.json new file mode 100644 index 00000000000..4ab2b7454b5 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/package.json @@ -0,0 +1,36 @@ +{ + "name": "@atproto/jwk-webcrypto", + "version": "0.0.1", + "license": "MIT", + "description": "Webcrypto based implementation of @atproto/jwk Key's", + "keywords": [ + "atproto", + "jwk", + "webcrypto" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/jwk-webcrypto" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto/jwk": "workspace:*", + "@atproto/jwk-jose": "workspace:*" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/oauth/jwk-webcrypto/src/index.ts b/packages/oauth/jwk-webcrypto/src/index.ts new file mode 100644 index 00000000000..fd2837bd208 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/src/index.ts @@ -0,0 +1 @@ +export * from './webcrypto-key.js' diff --git a/packages/oauth/jwk-webcrypto/src/util.ts b/packages/oauth/jwk-webcrypto/src/util.ts new file mode 100644 index 00000000000..b0d8eb51a14 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/src/util.ts @@ -0,0 +1,122 @@ +export type JWSAlgorithm = + // HMAC + | 'HS256' + | 'HS384' + | 'HS512' + // RSA + | 'PS256' + | 'PS384' + | 'PS512' + | 'RS256' + | 'RS384' + | 'RS512' + // EC + | 'ES256' + | 'ES256K' + | 'ES384' + | 'ES512' + // OKP + | 'EdDSA' + +export type SubtleAlgorithm = RsaHashedKeyGenParams | EcKeyGenParams + +export function toSubtleAlgorithm( + alg: string, + crv?: string, + options?: { modulusLength?: number }, +): SubtleAlgorithm { + switch (alg) { + case 'PS256': + case 'PS384': + case 'PS512': + return { + name: 'RSA-PSS', + hash: `SHA-${alg.slice(-3) as '256' | '384' | '512'}`, + modulusLength: options?.modulusLength ?? 2048, + publicExponent: new Uint8Array([0x01, 0x00, 0x01]), + } + case 'RS256': + case 'RS384': + case 'RS512': + return { + name: 'RSASSA-PKCS1-v1_5', + hash: `SHA-${alg.slice(-3) as '256' | '384' | '512'}`, + modulusLength: options?.modulusLength ?? 2048, + publicExponent: new Uint8Array([0x01, 0x00, 0x01]), + } + case 'ES256': + case 'ES384': + return { + name: 'ECDSA', + namedCurve: `P-${alg.slice(-3) as '256' | '384'}`, + } + case 'ES512': + return { + name: 'ECDSA', + namedCurve: 'P-521', + } + default: + // https://github.com/w3c/webcrypto/issues/82#issuecomment-849856773 + + throw new TypeError(`Unsupported alg "${alg}"`) + } +} + +export function fromSubtleAlgorithm(algorithm: KeyAlgorithm): JWSAlgorithm { + switch (algorithm.name) { + case 'RSA-PSS': + case 'RSASSA-PKCS1-v1_5': { + const hash = (algorithm).hash.name + switch (hash) { + case 'SHA-256': + case 'SHA-384': + case 'SHA-512': { + const prefix = algorithm.name === 'RSA-PSS' ? 'PS' : 'RS' + return `${prefix}${hash.slice(-3) as '256' | '384' | '512'}` + } + default: + throw new TypeError('unsupported RsaHashedKeyAlgorithm hash') + } + } + case 'ECDSA': { + const namedCurve = (algorithm).namedCurve + switch (namedCurve) { + case 'P-256': + case 'P-384': + case 'P-512': + return `ES${namedCurve.slice(-3) as '256' | '384' | '512'}` + case 'P-521': + return 'ES512' + default: + throw new TypeError('unsupported EcKeyAlgorithm namedCurve') + } + } + case 'Ed448': + case 'Ed25519': + return 'EdDSA' + default: + // https://github.com/w3c/webcrypto/issues/82#issuecomment-849856773 + + throw new TypeError(`Unexpected algorithm "${algorithm.name}"`) + } +} + +export function isCryptoKeyPair( + v: unknown, + extractable?: boolean, +): v is CryptoKeyPair { + return ( + typeof v === 'object' && + v !== null && + 'privateKey' in v && + v.privateKey instanceof CryptoKey && + v.privateKey.type === 'private' && + (extractable == null || v.privateKey.extractable === extractable) && + v.privateKey.usages.includes('sign') && + 'publicKey' in v && + v.publicKey instanceof CryptoKey && + v.publicKey.type === 'public' && + v.publicKey.extractable === true && + v.publicKey.usages.includes('verify') + ) +} diff --git a/packages/oauth/jwk-webcrypto/src/webcrypto-key.ts b/packages/oauth/jwk-webcrypto/src/webcrypto-key.ts new file mode 100644 index 00000000000..32ba8757ac0 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/src/webcrypto-key.ts @@ -0,0 +1,69 @@ +import { Jwk, jwkSchema } from '@atproto/jwk' +import { GenerateKeyPairOptions, JoseKey } from '@atproto/jwk-jose' + +import { fromSubtleAlgorithm, isCryptoKeyPair } from './util.js' + +export class WebcryptoKey extends JoseKey { + // We need to override the static method generate from JoseKey because + // the browser needs both the private and public keys + static override async generate( + allowedAlgos: string[] = ['ES256'], + kid: string = crypto.randomUUID(), + options?: GenerateKeyPairOptions, + ) { + const keyPair = await this.generateKeyPair(allowedAlgos, options) + + // Type safety only: in the browser, 'jose' always generates a CryptoKeyPair + if (!isCryptoKeyPair(keyPair)) { + throw new TypeError('Invalid CryptoKeyPair') + } + + return this.fromKeypair(keyPair, kid) + } + + static async fromKeypair(cryptoKeyPair: CryptoKeyPair, kid?: string) { + // https://datatracker.ietf.org/doc/html/rfc7517 + // > The "use" and "key_ops" JWK members SHOULD NOT be used together; [...] + // > Applications should specify which of these members they use. + + const { key_ops: _, ...jwk } = await crypto.subtle.exportKey( + 'jwk', + cryptoKeyPair.privateKey.extractable + ? cryptoKeyPair.privateKey + : cryptoKeyPair.publicKey, + ) + + const use = jwk.use ?? 'sig' + const alg = + jwk.alg ?? fromSubtleAlgorithm(cryptoKeyPair.privateKey.algorithm) + + if (use !== 'sig') { + throw new TypeError('Unsupported JWK use') + } + + return new WebcryptoKey( + jwkSchema.parse({ ...jwk, use, kid, alg }), + cryptoKeyPair, + ) + } + + constructor( + jwk: Jwk, + readonly cryptoKeyPair: CryptoKeyPair, + ) { + super(jwk) + } + + get isPrivate() { + return true + } + + get privateJwk(): Jwk | undefined { + if (super.isPrivate) return this.jwk + throw new Error('Private Webcrypto Key not exportable') + } + + protected override async getKey() { + return this.cryptoKeyPair.privateKey + } +} diff --git a/packages/oauth/jwk-webcrypto/tsconfig.build.json b/packages/oauth/jwk-webcrypto/tsconfig.build.json new file mode 100644 index 00000000000..d37ddd8d394 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/oauth/jwk-webcrypto/tsconfig.json b/packages/oauth/jwk-webcrypto/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/oauth/jwk-webcrypto/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/oauth/jwk/package.json b/packages/oauth/jwk/package.json new file mode 100644 index 00000000000..491888ceeca --- /dev/null +++ b/packages/oauth/jwk/package.json @@ -0,0 +1,38 @@ +{ + "name": "@atproto/jwk", + "version": "0.0.1", + "license": "MIT", + "description": "A library for working with JSON Web Keys (JWKs) in TypeScript. This is meant to be extended by environment-specific libraries like @atproto/jwk-jose.", + "keywords": [ + "atproto", + "jwk", + "jwks", + "jwt", + "json web key" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/jwk" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "multiformats": "^9.9.0", + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.json" + } +} diff --git a/packages/oauth/jwk/src/alg.ts b/packages/oauth/jwk/src/alg.ts new file mode 100644 index 00000000000..8af66ebcdf1 --- /dev/null +++ b/packages/oauth/jwk/src/alg.ts @@ -0,0 +1,98 @@ +import { JwkError } from './errors.js' +import { Jwk } from './jwk.js' + +declare const process: undefined | { versions?: { node?: string } } +const IS_NODE_RUNTIME = + typeof process !== 'undefined' && typeof process?.versions?.node === 'string' + +export function* jwkAlgorithms(jwk: Jwk): Generator { + // Ed25519, Ed448, and secp256k1 always have "alg" + // OKP always has "use" + if (jwk.alg) { + yield jwk.alg + return + } + + switch (jwk.kty) { + case 'EC': { + if (jwk.use === 'enc' || jwk.use === undefined) { + yield 'ECDH-ES' + yield 'ECDH-ES+A128KW' + yield 'ECDH-ES+A192KW' + yield 'ECDH-ES+A256KW' + } + + if (jwk.use === 'sig' || jwk.use === undefined) { + const crv = 'crv' in jwk ? jwk.crv : undefined + switch (crv) { + case 'P-256': + case 'P-384': + yield `ES${crv.slice(-3)}` + break + case 'P-521': + yield 'ES512' + break + case 'secp256k1': + if (IS_NODE_RUNTIME) yield 'ES256K' + break + default: + throw new JwkError(`Unsupported crv "${crv}"`) + } + } + + return + } + + case 'OKP': { + if (!jwk.use) throw new JwkError('Missing "use" Parameter value') + yield 'ECDH-ES' + yield 'ECDH-ES+A128KW' + yield 'ECDH-ES+A192KW' + yield 'ECDH-ES+A256KW' + return + } + + case 'RSA': { + if (jwk.use === 'enc' || jwk.use === undefined) { + yield 'RSA-OAEP' + yield 'RSA-OAEP-256' + yield 'RSA-OAEP-384' + yield 'RSA-OAEP-512' + if (IS_NODE_RUNTIME) yield 'RSA1_5' + } + + if (jwk.use === 'sig' || jwk.use === undefined) { + yield 'PS256' + yield 'PS384' + yield 'PS512' + yield 'RS256' + yield 'RS384' + yield 'RS512' + } + + return + } + + case 'oct': { + if (jwk.use === 'enc' || jwk.use === undefined) { + yield 'A128GCMKW' + yield 'A192GCMKW' + yield 'A256GCMKW' + yield 'A128KW' + yield 'A192KW' + yield 'A256KW' + } + + if (jwk.use === 'sig' || jwk.use === undefined) { + yield 'HS256' + yield 'HS384' + yield 'HS512' + } + + return + } + + default: + throw new JwkError(`Unsupported kty "${jwk.kty}"`) + } +} diff --git a/packages/oauth/jwk/src/errors.ts b/packages/oauth/jwk/src/errors.ts new file mode 100644 index 00000000000..afaf7172e50 --- /dev/null +++ b/packages/oauth/jwk/src/errors.ts @@ -0,0 +1,56 @@ +export type ErrorOptions = { cause?: unknown } + +export const ERR_JWKS_NO_MATCHING_KEY = 'ERR_JWKS_NO_MATCHING_KEY' +export const ERR_JWK_INVALID = 'ERR_JWK_INVALID' +export const ERR_JWK_NOT_FOUND = 'ERR_JWK_NOT_FOUND' +export const ERR_JWT_INVALID = 'ERR_JWT_INVALID' +export const ERR_JWT_CREATE = 'ERR_JWT_CREATE' +export const ERR_JWT_VERIFY = 'ERR_JWT_VERIFY' + +export class JwkError extends TypeError { + constructor( + message = 'JWK error', + public readonly code = ERR_JWK_INVALID, + options?: ErrorOptions, + ) { + super(message, options) + } +} + +export class JwtCreateError extends Error { + constructor( + message = 'Unable to create JWT', + public readonly code = ERR_JWT_CREATE, + options?: ErrorOptions, + ) { + super(message, options) + } + + static from(cause: unknown, code?: string, message?: string): JwtCreateError { + if (cause instanceof JwtCreateError) return cause + if (cause instanceof JwkError) { + return new JwtCreateError(message, cause.code, { cause }) + } + + return new JwtCreateError(message, code, { cause }) + } +} + +export class JwtVerifyError extends Error { + constructor( + message = 'Invalid JWT', + public readonly code = ERR_JWT_VERIFY, + options?: ErrorOptions, + ) { + super(message, options) + } + + static from(cause: unknown, code?: string, message?: string): JwtVerifyError { + if (cause instanceof JwtVerifyError) return cause + if (cause instanceof JwkError) { + return new JwtVerifyError(message, cause.code, { cause }) + } + + return new JwtVerifyError(message, code, { cause }) + } +} diff --git a/packages/oauth/jwk/src/index.ts b/packages/oauth/jwk/src/index.ts new file mode 100644 index 00000000000..e7296842cef --- /dev/null +++ b/packages/oauth/jwk/src/index.ts @@ -0,0 +1,10 @@ +export * from './alg.js' +export * from './errors.js' +export * from './jwk.js' +export * from './jwks.js' +export * from './jwt-decode.js' +export * from './jwt-verify.js' +export * from './jwt.js' +export * from './key.js' +export * from './keyset.js' +export * from './util.js' diff --git a/packages/oauth/jwk/src/jwk.ts b/packages/oauth/jwk/src/jwk.ts new file mode 100644 index 00000000000..de94f5afd4b --- /dev/null +++ b/packages/oauth/jwk/src/jwk.ts @@ -0,0 +1,141 @@ +import { z } from 'zod' + +export const keyUsageSchema = z.enum([ + 'sign', + 'verify', + 'encrypt', + 'decrypt', + 'wrapKey', + 'unwrapKey', + 'deriveKey', + 'deriveBits', +]) + +export type KeyUsage = z.infer + +/** + * The "use" and "key_ops" JWK members SHOULD NOT be used together; + * however, if both are used, the information they convey MUST be + * consistent. Applications should specify which of these members they + * use, if either is to be used by the application. + * + * @todo Actually check that "use" and "key_ops" are consistent when both are present. + * @see {@link https://datatracker.ietf.org/doc/html/rfc7517#section-4.3} + */ +export const jwkBaseSchema = z.object({ + kty: z.string().min(1), + alg: z.string().min(1).optional(), + kid: z.string().min(1).optional(), + ext: z.boolean().optional(), + use: z.enum(['sig', 'enc']).optional(), + key_ops: z.array(keyUsageSchema).optional(), + + x5c: z.array(z.string()).optional(), // X.509 Certificate Chain + x5t: z.string().min(1).optional(), // X.509 Certificate SHA-1 Thumbprint + 'x5t#S256': z.string().min(1).optional(), // X.509 Certificate SHA-256 Thumbprint + x5u: z.string().url().optional(), // X.509 URL +}) + +/** + * @todo: properly implement this + */ +export const jwkRsaKeySchema = jwkBaseSchema.extend({ + kty: z.literal('RSA'), + alg: z + .enum(['RS256', 'RS384', 'RS512', 'PS256', 'PS384', 'PS512']) + .optional(), + + n: z.string().min(1), // Modulus + e: z.string().min(1), // Exponent + + d: z.string().min(1).optional(), // Private Exponent + p: z.string().min(1).optional(), // First Prime Factor + q: z.string().min(1).optional(), // Second Prime Factor + dp: z.string().min(1).optional(), // First Factor CRT Exponent + dq: z.string().min(1).optional(), // Second Factor CRT Exponent + qi: z.string().min(1).optional(), // First CRT Coefficient + oth: z + .array( + z.object({ + r: z.string().optional(), + d: z.string().optional(), + t: z.string().optional(), + }), + ) + .nonempty() + + .optional(), // Other Primes Info +}) + +export const jwkEcKeySchema = jwkBaseSchema.extend({ + kty: z.literal('EC'), + alg: z.enum(['ES256', 'ES384', 'ES512']).optional(), + crv: z.enum(['P-256', 'P-384', 'P-521']), + + x: z.string().min(1), + y: z.string().min(1), + + d: z.string().min(1).optional(), // ECC Private Key +}) + +export const jwkEcSecp256k1KeySchema = jwkBaseSchema.extend({ + kty: z.literal('EC'), + alg: z.enum(['ES256K']).optional(), + crv: z.enum(['secp256k1']), + + x: z.string().min(1), + y: z.string().min(1), + + d: z.string().min(1).optional(), // ECC Private Key +}) + +export const jwkOkpKeySchema = jwkBaseSchema.extend({ + kty: z.literal('OKP'), + alg: z.enum(['EdDSA']).optional(), + crv: z.enum(['Ed25519', 'Ed448']), + + x: z.string().min(1), + d: z.string().min(1).optional(), // ECC Private Key +}) + +export const jwkSymKeySchema = jwkBaseSchema.extend({ + kty: z.literal('oct'), // Octet Sequence (used to represent symmetric keys) + alg: z.enum(['HS256', 'HS384', 'HS512']).optional(), + + k: z.string(), // Key Value (base64url encoded) +}) + +export const jwkUnknownKeySchema = jwkBaseSchema.extend({ + kty: z + .string() + .refine((v) => v !== 'RSA' && v !== 'EC' && v !== 'OKP' && v !== 'oct'), +}) + +export const jwkSchema = z.union([ + jwkUnknownKeySchema, + jwkRsaKeySchema, + jwkEcKeySchema, + jwkEcSecp256k1KeySchema, + jwkOkpKeySchema, + jwkSymKeySchema, +]) + +export type Jwk = z.infer + +export const jwkValidator = jwkSchema + .refine((k) => k.use != null || k.key_ops != null, 'use or key_ops required') + .refine( + (k) => + !k.use || + !k.key_ops || + k.key_ops.every((o) => + k.use === 'sig' + ? o === 'sign' || o === 'verify' + : o === 'encrypt' || o === 'decrypt', + ), + 'use and key_ops must be consistent', + ) + +export const jwkPubSchema = jwkValidator + .refine((k) => k.kid != null, 'kid is required') + .refine((k) => !('k' in k) && !('d' in k), 'private key not allowed') diff --git a/packages/oauth/jwk/src/jwks.ts b/packages/oauth/jwk/src/jwks.ts new file mode 100644 index 00000000000..25d5af452cd --- /dev/null +++ b/packages/oauth/jwk/src/jwks.ts @@ -0,0 +1,15 @@ +import { z } from 'zod' + +import { jwkPubSchema, jwkSchema } from './jwk.js' + +export const jwksSchema = z.object({ + keys: z.array(jwkSchema), +}) + +export type Jwks = z.infer + +export const jwksPubSchema = z.object({ + keys: z.array(jwkPubSchema), +}) + +export type JwksPub = z.infer diff --git a/packages/oauth/jwk/src/jwt-decode.ts b/packages/oauth/jwk/src/jwt-decode.ts new file mode 100644 index 00000000000..acf418f657f --- /dev/null +++ b/packages/oauth/jwk/src/jwt-decode.ts @@ -0,0 +1,27 @@ +import { ERR_JWT_INVALID, JwtVerifyError } from './errors.js' +import { + JwtHeader, + JwtPayload, + jwtHeaderSchema, + jwtPayloadSchema, +} from './jwt.js' +import { parseB64uJson } from './util.js' + +export function unsafeDecodeJwt(jwt: string): { + header: JwtHeader + payload: JwtPayload +} { + const { 0: headerEnc, 1: payloadEnc, length } = jwt.split('.') + if (length > 3 || length < 2) { + throw new JwtVerifyError(undefined, ERR_JWT_INVALID) + } + + const header = jwtHeaderSchema.parse(parseB64uJson(headerEnc!)) + if (length === 2 && header?.alg !== 'none') { + throw new JwtVerifyError(undefined, ERR_JWT_INVALID) + } + + const payload = jwtPayloadSchema.parse(parseB64uJson(payloadEnc!)) + + return { header, payload } +} diff --git a/packages/oauth/jwk/src/jwt-verify.ts b/packages/oauth/jwk/src/jwt-verify.ts new file mode 100644 index 00000000000..8c80c510d81 --- /dev/null +++ b/packages/oauth/jwk/src/jwt-verify.ts @@ -0,0 +1,22 @@ +import { JwtHeader, JwtPayload } from './jwt.js' +import { RequiredKey } from './util.js' + +export type VerifyOptions = { + audience?: string | readonly string[] + /** in seconds */ + clockTolerance?: number + issuer?: string | readonly string[] + /** in seconds */ + maxTokenAge?: number + subject?: string + typ?: string + currentDate?: Date + requiredClaims?: readonly C[] +} + +export type VerifyPayload = Record + +export type VerifyResult

= { + payload: RequiredKey

+ protectedHeader: JwtHeader +} diff --git a/packages/oauth/jwk/src/jwt.ts b/packages/oauth/jwk/src/jwt.ts new file mode 100644 index 00000000000..76bfdd34084 --- /dev/null +++ b/packages/oauth/jwk/src/jwt.ts @@ -0,0 +1,173 @@ +import { z } from 'zod' + +import { jwkPubSchema } from './jwk.js' +import { jwtCharsRefinement, segmentedStringRefinementFactory } from './util.js' + +export const signedJwtSchema = z + .string() + .superRefine(jwtCharsRefinement) + .superRefine(segmentedStringRefinementFactory(3)) + +export type SignedJwt = z.infer +export const isSignedJwt = (data: unknown): data is SignedJwt => + signedJwtSchema.safeParse(data).success + +export const unsignedJwtSchema = z + .string() + .superRefine(jwtCharsRefinement) + .superRefine(segmentedStringRefinementFactory(2)) + +export type UnsignedJwt = z.infer +export const isUnsignedJwt = (data: unknown): data is UnsignedJwt => + unsignedJwtSchema.safeParse(data).success + +/** + * @see {@link https://www.rfc-editor.org/rfc/rfc7515.html#section-4} + */ +export const jwtHeaderSchema = z.object({ + /** "alg" (Algorithm) Header Parameter */ + alg: z.string(), + /** "jku" (JWK Set URL) Header Parameter */ + jku: z.string().url().optional(), + /** "jwk" (JSON Web Key) Header Parameter */ + jwk: z + .object({ + kty: z.string(), + crv: z.string().optional(), + x: z.string().optional(), + y: z.string().optional(), + e: z.string().optional(), + n: z.string().optional(), + }) + .optional(), + /** "kid" (Key ID) Header Parameter */ + kid: z.string().optional(), + /** "x5u" (X.509 URL) Header Parameter */ + x5u: z.string().optional(), + /** "x5c" (X.509 Certificate Chain) Header Parameter */ + x5c: z.array(z.string()).optional(), + /** "x5t" (X.509 Certificate SHA-1 Thumbprint) Header Parameter */ + x5t: z.string().optional(), + /** "x5t#S256" (X.509 Certificate SHA-256 Thumbprint) Header Parameter */ + 'x5t#S256': z.string().optional(), + /** "typ" (Type) Header Parameter */ + typ: z.string().optional(), + /** "cty" (Content Type) Header Parameter */ + cty: z.string().optional(), + /** "crit" (Critical) Header Parameter */ + crit: z.array(z.string()).optional(), +}) + +export type JwtHeader = z.infer + +// https://www.iana.org/assignments/jwt/jwt.xhtml +export const jwtPayloadSchema = z.object({ + iss: z.string().optional(), + aud: z.union([z.string(), z.array(z.string()).nonempty()]).optional(), + sub: z.string().optional(), + exp: z.number().int().optional(), + nbf: z.number().int().optional(), + iat: z.number().int().optional(), + jti: z.string().optional(), + htm: z.string().optional(), + htu: z.string().optional(), + ath: z.string().optional(), + acr: z.string().optional(), + azp: z.string().optional(), + amr: z.array(z.string()).optional(), + // https://datatracker.ietf.org/doc/html/rfc7800 + cnf: z + .object({ + kid: z.string().optional(), // Key ID + jwk: jwkPubSchema.optional(), // JWK + jwe: z.string().optional(), // Encrypted key + jku: z.string().url().optional(), // JWK Set URI ("kid" should also be provided) + + // https://datatracker.ietf.org/doc/html/rfc9449#section-6.1 + jkt: z.string().optional(), + + // https://datatracker.ietf.org/doc/html/rfc8705 + 'x5t#S256': z.string().optional(), // X.509 Certificate SHA-256 Thumbprint + + // https://datatracker.ietf.org/doc/html/rfc9203 + osc: z.string().optional(), // OSCORE_Input_Material carrying the parameters for using OSCORE per-message security with implicit key confirmation + }) + .optional(), + + client_id: z.string().optional(), + + scope: z.string().optional(), + nonce: z.string().optional(), + + at_hash: z.string().optional(), + c_hash: z.string().optional(), + s_hash: z.string().optional(), + auth_time: z.number().int().optional(), + + // https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims + + // OpenID: "profile" scope + name: z.string().optional(), + family_name: z.string().optional(), + given_name: z.string().optional(), + middle_name: z.string().optional(), + nickname: z.string().optional(), + preferred_username: z.string().optional(), + gender: z.string().optional(), // OpenID only defines "male" and "female" without forbidding other values + picture: z.string().url().optional(), + profile: z.string().url().optional(), + website: z.string().url().optional(), + birthdate: z + .string() + .regex(/\d{4}-\d{2}-\d{2}/) // YYYY-MM-DD + .optional(), + zoneinfo: z + .string() + .regex(/^[A-Za-z0-9_/]+$/) + .optional(), + locale: z + .string() + .regex(/^[a-z]{2}(-[A-Z]{2})?$/) + .optional(), + updated_at: z.number().int().optional(), + + // OpenID: "email" scope + email: z.string().optional(), + email_verified: z.boolean().optional(), + + // OpenID: "phone" scope + phone_number: z.string().optional(), + phone_number_verified: z.boolean().optional(), + + // OpenID: "address" scope + // https://openid.net/specs/openid-connect-core-1_0.html#AddressClaim + address: z + .object({ + formatted: z.string().optional(), + street_address: z.string().optional(), + locality: z.string().optional(), + region: z.string().optional(), + postal_code: z.string().optional(), + country: z.string().optional(), + }) + .optional(), + + // https://datatracker.ietf.org/doc/html/rfc9396#section-14.2 + authorization_details: z + .array( + z + .object({ + type: z.string(), + // https://datatracker.ietf.org/doc/html/rfc9396#section-2.2 + locations: z.array(z.string()).optional(), + actions: z.array(z.string()).optional(), + datatypes: z.array(z.string()).optional(), + identifier: z.string().optional(), + privileges: z.array(z.string()).optional(), + }) + .passthrough(), + ) + .optional(), +}) + +export type JwtPayload = z.infer diff --git a/packages/oauth/jwk/src/key.ts b/packages/oauth/jwk/src/key.ts new file mode 100644 index 00000000000..519661615ed --- /dev/null +++ b/packages/oauth/jwk/src/key.ts @@ -0,0 +1,93 @@ +import { jwkAlgorithms } from './alg.js' +import { JwkError } from './errors.js' +import { Jwk, jwkSchema } from './jwk.js' +import { VerifyOptions, VerifyPayload, VerifyResult } from './jwt-verify.js' +import { JwtHeader, JwtPayload, SignedJwt } from './jwt.js' +import { cachedGetter } from './util.js' + +export abstract class Key { + constructor(protected readonly jwk: Readonly) { + // A key should always be used either for signing or encryption. + if (!jwk.use) throw new JwkError('Missing "use" Parameter value') + } + + get isPrivate(): boolean { + const { jwk } = this + if ('d' in jwk && jwk.d !== undefined) return true + if ('k' in jwk && jwk.k !== undefined) return true + return false + } + + get isSymetric(): boolean { + const { jwk } = this + if ('k' in jwk && jwk.k !== undefined) return true + return false + } + + get privateJwk(): Jwk | undefined { + return this.isPrivate ? this.jwk : undefined + } + + @cachedGetter + get publicJwk(): Jwk | undefined { + if (this.isSymetric) return undefined + if (this.isPrivate) { + const { d: _, ...jwk } = this.jwk as any + return jwk + } + return this.jwk + } + + @cachedGetter + get bareJwk(): Jwk | undefined { + if (this.isSymetric) return undefined + const { kty, crv, e, n, x, y } = this.jwk as any + return jwkSchema.parse({ crv, e, kty, n, x, y }) + } + + get use() { + return this.jwk.use! + } + + /** + * The (forced) algorithm to use. If not provided, the key will be usable with + * any of the algorithms in {@link algorithms}. + * + * @see {@link https://datatracker.ietf.org/doc/html/rfc7518#section-3.1 | "alg" (Algorithm) Header Parameter Values for JWS} + */ + get alg() { + return this.jwk.alg + } + + get kid() { + return this.jwk.kid + } + + get crv() { + return (this.jwk as { crv: undefined } | Extract).crv + } + + /** + * All the algorithms that this key can be used with. If `alg` is provided, + * this set will only contain that algorithm. + */ + @cachedGetter + get algorithms(): readonly string[] { + return Array.from(jwkAlgorithms(this.jwk)) + } + + /** + * Create a signed JWT + */ + abstract createJwt(header: JwtHeader, payload: JwtPayload): Promise + + /** + * Verify the signature, headers and payload of a JWT + * + * @throws {JwtVerifyError} if the JWT is invalid + */ + abstract verifyJwt< + P extends VerifyPayload = JwtPayload, + C extends string = string, + >(token: SignedJwt, options?: VerifyOptions): Promise> +} diff --git a/packages/oauth/jwk/src/keyset.ts b/packages/oauth/jwk/src/keyset.ts new file mode 100644 index 00000000000..dd2534c73b4 --- /dev/null +++ b/packages/oauth/jwk/src/keyset.ts @@ -0,0 +1,240 @@ +import { + ERR_JWKS_NO_MATCHING_KEY, + ERR_JWK_NOT_FOUND, + ERR_JWT_INVALID, + JwkError, + JwtCreateError, + JwtVerifyError, +} from './errors.js' +import { Jwk } from './jwk.js' +import { Jwks } from './jwks.js' +import { unsafeDecodeJwt } from './jwt-decode.js' +import { VerifyOptions, VerifyResult } from './jwt-verify.js' +import { JwtHeader, JwtPayload, SignedJwt } from './jwt.js' +import { Key } from './key.js' +import { + Override, + cachedGetter, + isDefined, + matchesAny, + preferredOrderCmp, +} from './util.js' + +export type JwtSignHeader = Override> + +export type JwtPayloadGetter

= ( + header: JwtHeader, + key: Key, +) => P | PromiseLike

+ +export type KeySearch = { + use?: 'sig' | 'enc' + kid?: string | string[] + alg?: string | string[] +} + +const extractPrivateJwk = (key: Key): Jwk | undefined => key.privateJwk +const extractPublicJwk = (key: Key): Jwk | undefined => key.publicJwk + +export class Keyset implements Iterable { + private readonly keys: readonly K[] + + constructor( + iterable: Iterable, + /** + * The preferred algorithms to use when signing a JWT using this keyset. + * + * @see {@link https://datatracker.ietf.org/doc/html/rfc7518#section-3.1} + */ + public readonly preferredSigningAlgorithms: readonly string[] = iterable instanceof + Keyset + ? [...iterable.preferredSigningAlgorithms] + : [ + // Prefer elliptic curve algorithms + 'EdDSA', + 'ES256K', + 'ES256', + // https://datatracker.ietf.org/doc/html/rfc7518#section-3.5 + 'PS256', + 'PS384', + 'PS512', + 'HS256', + 'HS384', + 'HS512', + ], + ) { + const keys: K[] = [] + + const kids = new Set() + for (const key of iterable) { + if (!key) continue + + keys.push(key) + + if (key.kid) { + if (kids.has(key.kid)) throw new JwkError(`Duplicate key: ${key.kid}`) + else kids.add(key.kid) + } + } + + this.keys = Object.freeze(keys) + } + + @cachedGetter + get signAlgorithms(): readonly string[] { + const algorithms = new Set() + for (const key of this) { + if (key.use !== 'sig') continue + for (const alg of key.algorithms) { + algorithms.add(alg) + } + } + return Object.freeze( + [...algorithms].sort(preferredOrderCmp(this.preferredSigningAlgorithms)), + ) + } + + @cachedGetter + get publicJwks(): Jwks { + return { + keys: Array.from(this, extractPublicJwk).filter(isDefined), + } + } + + @cachedGetter + get privateJwks(): Jwks { + return { + keys: Array.from(this, extractPrivateJwk).filter(isDefined), + } + } + + has(kid: string): boolean { + return this.keys.some((key) => key.kid === kid) + } + + get(search: KeySearch): K { + for (const key of this.list(search)) { + return key + } + + throw new JwkError( + `Key not found ${search.kid || search.alg || ''}`, + ERR_JWK_NOT_FOUND, + ) + } + + *list(search: KeySearch): Generator { + // Optimization: Empty string or empty array will not match any key + if (search.kid?.length === 0) return + if (search.alg?.length === 0) return + + for (const key of this) { + if (search.use && key.use !== search.use) continue + + if (Array.isArray(search.kid)) { + if (!key.kid || !search.kid.includes(key.kid)) continue + } else if (search.kid) { + if (key.kid !== search.kid) continue + } + + if (Array.isArray(search.alg)) { + if (!search.alg.some((a) => key.algorithms.includes(a))) continue + } else if (typeof search.alg === 'string') { + if (!key.algorithms.includes(search.alg)) continue + } + + yield key + } + } + + findKey({ kid, alg, use }: KeySearch): [key: Key, alg: string] { + const matchingKeys: Key[] = [] + + for (const key of this.list({ kid, alg, use })) { + // Not a signing key + if (!key.isPrivate) continue + + // Skip negotiation if a specific "alg" was provided + if (typeof alg === 'string') return [key, alg] + + matchingKeys.push(key) + } + + const isAllowedAlg = matchesAny(alg) + const candidates = matchingKeys.map( + (key) => [key, key.algorithms.filter(isAllowedAlg)] as const, + ) + + // Return the first candidates that matches the preferred algorithms + for (const prefAlg of this.preferredSigningAlgorithms) { + for (const [matchingKey, matchingAlgs] of candidates) { + if (matchingAlgs.includes(prefAlg)) return [matchingKey, prefAlg] + } + } + + // Return any candidate + for (const [matchingKey, matchingAlgs] of candidates) { + for (const alg of matchingAlgs) { + return [matchingKey, alg] + } + } + + throw new JwkError( + `No singing key found for ${kid || alg || use || ''}`, + ERR_JWK_NOT_FOUND, + ) + } + + [Symbol.iterator](): IterableIterator { + return this.keys.values() + } + + async createJwt( + { alg: sAlg, kid: sKid, ...header }: JwtSignHeader, + payload: JwtPayload | JwtPayloadGetter, + ): Promise { + try { + const [key, alg] = this.findKey({ alg: sAlg, kid: sKid, use: 'sig' }) + const protectedHeader = { ...header, alg, kid: key.kid } + + if (typeof payload === 'function') { + payload = await payload(protectedHeader, key) + } + + return await key.createJwt(protectedHeader, payload) + } catch (err) { + throw JwtCreateError.from(err) + } + } + + async verifyJwt< + P extends Record = JwtPayload, + C extends string = string, + >( + token: SignedJwt, + options?: VerifyOptions, + ): Promise & { key: K }> { + const { header } = unsafeDecodeJwt(token) + const { kid, alg } = header + + const errors: unknown[] = [] + + for (const key of this.list({ kid, alg })) { + try { + const result = await key.verifyJwt(token, options) + return { ...result, key } + } catch (err) { + errors.push(err) + } + } + + switch (errors.length) { + case 0: + throw new JwtVerifyError('No key matched', ERR_JWKS_NO_MATCHING_KEY) + case 1: + throw JwtVerifyError.from(errors[0], ERR_JWT_INVALID) + default: + throw JwtVerifyError.from(errors, ERR_JWT_INVALID) + } + } +} diff --git a/packages/oauth/jwk/src/util.ts b/packages/oauth/jwk/src/util.ts new file mode 100644 index 00000000000..805417eb94b --- /dev/null +++ b/packages/oauth/jwk/src/util.ts @@ -0,0 +1,181 @@ +import { base64url } from 'multiformats/bases/base64' +import { RefinementCtx, ZodIssueCode } from 'zod' + +// eslint-disable-next-line @typescript-eslint/ban-types +export type Simplify = { [K in keyof T]: T[K] } & {} +export type Override = Simplify> + +export type RequiredKey = Simplify< + string extends K + ? T + : { + [L in K]: Exclude + } & Omit +> + +export const isDefined = (i: T | undefined): i is T => i !== undefined + +export const preferredOrderCmp = + (order: readonly T[]) => + (a: T, b: T) => { + const aIdx = order.indexOf(a) + const bIdx = order.indexOf(b) + if (aIdx === bIdx) return 0 + if (aIdx === -1) return 1 + if (bIdx === -1) return -1 + return aIdx - bIdx + } + +export function matchesAny( + value: null | undefined | T | readonly T[], +): (v: unknown) => v is T { + return value == null + ? (v): v is T => true + : Array.isArray(value) + ? (v): v is T => value.includes(v) + : (v): v is T => v === value +} + +/** + * Decorator to cache the result of a getter on a class instance. + */ +export const cachedGetter = ( + target: (this: T) => V, + _context: ClassGetterDecoratorContext, +) => { + return function (this: T) { + const value = target.call(this) + Object.defineProperty(this, target.name, { + get: () => value, + enumerable: true, + configurable: true, + }) + return value + } +} + +const decoder = new TextDecoder() +export function parseB64uJson(input: string): unknown { + const inputBytes = base64url.baseDecode(input) + const json = decoder.decode(inputBytes) + return JSON.parse(json) +} + +/** + * @example + * ```ts + * // jwtSchema will only allow base64url chars & "." (dot) + * const jwtSchema = z.string().superRefine(jwtCharsRefinement) + * ``` + */ +export const jwtCharsRefinement = (data: string, ctx: RefinementCtx): void => { + // Note: this is a hot path, let's avoid using a RegExp + let char + + for (let i = 0; i < data.length; i++) { + char = data.charCodeAt(i) + + if ( + // Base64 URL encoding (most frequent) + (65 <= char && char <= 90) || // A-Z + (97 <= char && char <= 122) || // a-z + (48 <= char && char <= 57) || // 0-9 + char === 45 || // - + char === 95 || // _ + // Boundary (least frequent, check last) + char === 46 // . + ) { + // continue + } else { + // Invalid char might be a surrogate pair + const invalidChar = String.fromCodePoint(data.codePointAt(i)!) + return ctx.addIssue({ + code: ZodIssueCode.custom, + message: `Invalid character "${invalidChar}" in JWT at position ${i}`, + }) + } + } +} + +/** + * @example + * ```ts + * type SegmentedString3 = SegmentedString<3> // `${string}.${string}.${string}` + * type SegmentedString4 = SegmentedString<4> // `${string}.${string}.${string}.${string}` + * ``` + * + * @note + * This utility only provides one way type safety (A SegmentedString<4> can be + * assigned to SegmentedString<3> but not vice versa). The purpose of this + * utility is to improve DX by avoiding as many potential errors as build time. + * DO NOT rely on this to enforce security or data integrity. + */ +type SegmentedString< + C extends number, + Acc extends string[] = [string], +> = Acc['length'] extends C + ? `${Acc[0]}` + : `${Acc[0]}.${SegmentedString}` + +/** + * @example + * ```ts + * const jwtSchema = z.string().superRefine(segmentedStringRefinementFactory(3)) + * type Jwt = z.infer // `${string}.${string}.${string}` + * ``` + */ +export const segmentedStringRefinementFactory = ( + count: C, + minPartLength = 2, +) => { + if (!Number.isFinite(count) || count < 1 || (count | 0) !== count) { + throw new TypeError(`Count must be a natural number (got ${count})`) + } + + const minTotalLength = count * minPartLength + (count - 1) + const errorPrefix = `Invalid JWT format` + + return (data: string, ctx: RefinementCtx): data is SegmentedString => { + if (data.length < minTotalLength) { + ctx.addIssue({ + code: ZodIssueCode.custom, + message: `${errorPrefix}: too short`, + }) + return false + } + let currentStart = 0 + for (let i = 0; i < count - 1; i++) { + const nextDot = data.indexOf('.', currentStart) + if (nextDot === -1) { + ctx.addIssue({ + code: ZodIssueCode.custom, + message: `${errorPrefix}: expected ${count} segments, got ${i + 1}`, + }) + return false + } + if (nextDot - currentStart < minPartLength) { + ctx.addIssue({ + code: ZodIssueCode.custom, + message: `${errorPrefix}: segment ${i + 1} is too short`, + }) + return false + } + currentStart = nextDot + 1 + } + if (data.indexOf('.', currentStart) !== -1) { + ctx.addIssue({ + code: ZodIssueCode.custom, + message: `${errorPrefix}: too many segments`, + }) + return false + } + if (data.length - currentStart < minPartLength) { + ctx.addIssue({ + code: ZodIssueCode.custom, + message: `${errorPrefix}: last segment is too short`, + }) + return false + } + return true + } +} diff --git a/packages/oauth/jwk/tsconfig.build.json b/packages/oauth/jwk/tsconfig.build.json new file mode 100644 index 00000000000..9f3c5252d4c --- /dev/null +++ b/packages/oauth/jwk/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "outDir": "dist", + "rootDir": "src" + }, + "include": ["src"] +} diff --git a/packages/oauth/jwk/tsconfig.json b/packages/oauth/jwk/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/oauth/jwk/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/oauth/oauth-client-browser/example/.gitignore b/packages/oauth/oauth-client-browser/example/.gitignore new file mode 100644 index 00000000000..1521c8b7652 --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/.gitignore @@ -0,0 +1 @@ +dist diff --git a/packages/oauth/oauth-client-browser/example/.postcssrc.yml b/packages/oauth/oauth-client-browser/example/.postcssrc.yml new file mode 100644 index 00000000000..0114fbc9e78 --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/.postcssrc.yml @@ -0,0 +1,3 @@ +plugins: + tailwindcss: {} + autoprefixer: {} diff --git a/packages/oauth/oauth-client-browser/example/package.json b/packages/oauth/oauth-client-browser/example/package.json new file mode 100644 index 00000000000..3dbc1ca591c --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/packages/oauth/oauth-client-browser/example/rollup.config.js b/packages/oauth/oauth-client-browser/example/rollup.config.js new file mode 100644 index 00000000000..bea6c33c379 --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/rollup.config.js @@ -0,0 +1,91 @@ +/* eslint-env node */ + +import { defineConfig } from 'rollup' + +import commonjs from '@rollup/plugin-commonjs' +import html, { makeHtmlAttributes } from '@rollup/plugin-html' +import json from '@rollup/plugin-json' +import nodeResolve from '@rollup/plugin-node-resolve' +import replace from '@rollup/plugin-replace' +import typescript from '@rollup/plugin-typescript' +import postcss from 'rollup-plugin-postcss' +import serve from 'rollup-plugin-serve' + +export default defineConfig((commandLineArguments) => { + const NODE_ENV = + process.env['NODE_ENV'] ?? + (commandLineArguments.watch ? 'development' : 'production') + + return { + input: 'src/main.tsx', + output: { dir: 'dist', sourcemap: true }, + plugins: [ + nodeResolve({ preferBuiltins: false, browser: true }), + commonjs(), + postcss({ config: true, extract: true, minimize: false }), + json(), + typescript({ + tsconfig: './tsconfig.build.json', + outputToFilesystem: true, + }), + replace({ + preventAssignment: true, + values: { 'process.env.NODE_ENV': JSON.stringify(NODE_ENV) }, + }), + html({ + title: 'OAuth Client Example', + template: (templateOptions) => { + // https://github.com/rollup/plugins/pull/1718 + if (!templateOptions) throw new Error('No template options provided') + const { attributes, files, meta, publicPath, title } = templateOptions + + return ` + + + + ${meta + .map((attrs) => ``) + .join('\n')} + + ${title} + ${files.css + .map( + (asset) => + ``, + ) + .join('\n')} + + +

+ ${files.js + .map( + (asset) => + ``, + ) + .join('\n')} + + + ` + }, + }), + commandLineArguments.watch && + serve({ + contentBase: 'dist', + port: 8080, + headers: { 'Cache-Control': 'no-store' }, + }), + ], + onwarn(warning, warn) { + // 'use client' directives are fine + if (warning.code === 'MODULE_LEVEL_DIRECTIVE') return + warn(warning) + }, + } +}) diff --git a/packages/oauth/oauth-client-browser/example/src/app.tsx b/packages/oauth/oauth-client-browser/example/src/app.tsx new file mode 100644 index 00000000000..af9db9d1b7b --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/src/app.tsx @@ -0,0 +1,65 @@ +import { BrowserOAuthClient } from '@atproto/oauth-client-browser' +import { useCallback, useState } from 'react' + +import LoginForm from './login-form' +import { useOAuth } from './oauth' + +const client = new BrowserOAuthClient({ + plcDirectoryUrl: 'http://localhost:2582', // dev-env + handleResolver: 'http://localhost:2584', // dev-env +}) + +function App() { + const { agent, signedIn, signOut, loading, signIn } = useOAuth(client) + const [profile, setProfile] = useState<{ + value: { displayName?: string } + } | null>(null) + + const loadProfile = useCallback(async () => { + if (!agent) return + + const info = await agent.getInfo() + console.log('info', info) + + // A call that requires to be authenticated + console.log( + await agent + .request( + '/xrpc/com.atproto.server.getServiceAuth?' + + new URLSearchParams({ aud: info.sub }).toString(), + ) + .then((r) => r.json()), + ) + + // This call does not require authentication + const profile = await agent + .request( + '/xrpc/com.atproto.repo.getRecord?' + + new URLSearchParams({ + repo: info.sub, + collection: 'app.bsky.actor.profile', + rkey: 'self', + }).toString(), + ) + .then((r) => r.json()) + console.log(profile) + + setProfile(profile.data) + }, [agent]) + + return signedIn ? ( +
+

Logged in!

+ + +
{profile ? JSON.stringify(profile, undefined, 2) : null}
+
+ + +
+ ) : ( + + ) +} + +export default App diff --git a/packages/oauth/oauth-client-browser/example/src/index.css b/packages/oauth/oauth-client-browser/example/src/index.css new file mode 100644 index 00000000000..b5c61c95671 --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/src/index.css @@ -0,0 +1,3 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/packages/oauth/oauth-client-browser/example/src/login-form.tsx b/packages/oauth/oauth-client-browser/example/src/login-form.tsx new file mode 100644 index 00000000000..ac0d11b21fe --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/src/login-form.tsx @@ -0,0 +1,103 @@ +import { FormEvent, useEffect, useState } from 'react' + +/** + * @returns Nice tailwind css form asking to enter either a handle or the host + * to use to login. + */ +export default function LoginForm({ + onLogin, + loading, + error = null, + ...props +}: { + loading?: boolean + error?: null | string + onLogin: (input: string, options?: { display?: 'popup' | 'page' }) => void +} & React.HTMLAttributes) { + const [value, setValue] = useState('') + const [display, setDisplay] = useState<'popup' | 'page'>('popup') + const [localError, setLocalError] = useState(error) + + useEffect(() => { + setLocalError(null) + }, [value]) + + useEffect(() => { + setLocalError(error) + }, [error]) + + const onSubmit = (e: FormEvent) => { + e.preventDefault() + if (loading) return + + if (value.startsWith('did:')) { + if (value.length > 5) onLogin(value, { display }) + else setLocalError('DID must be at least 6 characters') + return + } + + if (value.startsWith('https://') || value.startsWith('http://')) { + try { + const url = new URL(value) + if (value !== url.origin) throw new Error('PDS URL must be a origin') + onLogin(value, { display }) + } catch (err) { + setLocalError((err as any)?.message || String(err)) + } + return + } + + if (value.includes('.') && value.length > 3) { + const handle = value.startsWith('@') ? value.slice(1) : value + if (handle.length > 3) onLogin(handle, { display }) + else setLocalError('Handle must be at least 4 characters') + return + } + + setLocalError('Please provide a valid handle, DID or PDS URL') + } + + return ( +
+
+
+ +
+ + {/*
*/} + +
+ setValue(e.target.value)} + /> + +
+
+ + {localError ? ( +
{localError}
+ ) : null} +
+ ) +} diff --git a/packages/oauth/oauth-client-browser/example/src/main.tsx b/packages/oauth/oauth-client-browser/example/src/main.tsx new file mode 100644 index 00000000000..d8db51390fa --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/src/main.tsx @@ -0,0 +1,12 @@ +import './index.css' + +import React from 'react' +import ReactDOM from 'react-dom/client' + +import App from './app' + +ReactDOM.createRoot(document.getElementById('root')!).render( + + + , +) diff --git a/packages/oauth/oauth-client-browser/example/src/oauth.ts b/packages/oauth/oauth-client-browser/example/src/oauth.ts new file mode 100644 index 00000000000..b636c237aba --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/src/oauth.ts @@ -0,0 +1,119 @@ +import { OAuthAgent, AuthorizeOptions } from '@atproto/oauth-client' +import { + BrowserOAuthClient, + LoginContinuedInParentWindowError, +} from '@atproto/oauth-client-browser' +import { useCallback, useEffect, useRef, useState } from 'react' + +const CURRENT_AUTHENTICATED_SUB = 'CURRENT_AUTHENTICATED_SUB' + +export function useOAuth(client: BrowserOAuthClient) { + const [agent, setAgent] = useState(null) + const [loading, setLoading] = useState(true) + + useEffect(() => { + // Ignore init step + if (loading) return + + if (agent) { + localStorage.setItem(CURRENT_AUTHENTICATED_SUB, agent.sub) + } else { + localStorage.removeItem(CURRENT_AUTHENTICATED_SUB) + } + }, [loading, agent]) + + const clientRef = useRef() + useEffect(() => { + // In strict mode, we don't want to reinitialize the client if it's the same + if (clientRef.current === client) return + clientRef.current = client + + setLoading(true) + setAgent(null) + + const subToLoad = + localStorage.getItem(CURRENT_AUTHENTICATED_SUB) || undefined + + client + .init(subToLoad) + .then(async (r) => { + if (clientRef.current !== client) return + + setAgent(r?.agent || null) + }) + .catch((err) => { + console.error('Failed to init:', err) + + if (clientRef.current !== client) return + if (err instanceof LoginContinuedInParentWindowError) return + + localStorage.removeItem(CURRENT_AUTHENTICATED_SUB) + setAgent(null) + }) + .finally(() => { + if (clientRef.current !== client) return + + setLoading(false) + }) + }, [client]) + + useEffect(() => { + if (!agent) return + + const clear = ({ detail }: { detail: { sub: string } }) => { + if (detail.sub === agent.sub) { + setAgent(null) + setLoading(true) + } + } + + client.addEventListener('deleted', clear) + + return () => { + client.removeEventListener('deleted', clear) + } + }, [client, agent]) + + const signOut = useCallback(async () => { + if (!agent) return + + setAgent(null) + setLoading(true) + + try { + await agent.signOut() + } catch (err) { + console.error('Failed to clear credentials', err) + throw err + } finally { + setLoading(false) + } + }, [agent]) + + const signIn = useCallback( + async (input: string, options?: AuthorizeOptions) => { + if (agent) return + + setLoading(true) + + try { + const agent = await client.signIn(input, options) + setAgent(agent) + } catch (err) { + console.error('Failed to login', err) + throw err + } finally { + setLoading(false) + } + }, + [agent, client], + ) + + return { + agent, + loading, + signedIn: agent != null, + signIn, + signOut, + } +} diff --git a/packages/oauth/oauth-client-browser/example/tailwind.config.js b/packages/oauth/oauth-client-browser/example/tailwind.config.js new file mode 100644 index 00000000000..7141e4528c6 --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/tailwind.config.js @@ -0,0 +1,8 @@ +/** @type {import('tailwindcss').Config} */ +export default { + content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'], + theme: { + extend: {}, + }, + plugins: [], +} diff --git a/packages/oauth/oauth-client-browser/example/tsconfig.build.json b/packages/oauth/oauth-client-browser/example/tsconfig.build.json new file mode 100644 index 00000000000..5d2476f34aa --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/tsconfig.build.json @@ -0,0 +1,11 @@ +{ + "extends": [ + "../../../../tsconfig/browser.json", + "../../../../tsconfig/bundler.json" + ], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src/**/*.ts", "./src/**/*.tsx"] +} diff --git a/packages/oauth/oauth-client-browser/example/tsconfig.json b/packages/oauth/oauth-client-browser/example/tsconfig.json new file mode 100644 index 00000000000..ad9365d269b --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/tsconfig.json @@ -0,0 +1,7 @@ +{ + "include": [], + "references": [ + { "path": "./tsconfig.build.json" }, + { "path": "./tsconfig.tools.json" } + ] +} diff --git a/packages/oauth/oauth-client-browser/example/tsconfig.tools.json b/packages/oauth/oauth-client-browser/example/tsconfig.tools.json new file mode 100644 index 00000000000..1588d29506e --- /dev/null +++ b/packages/oauth/oauth-client-browser/example/tsconfig.tools.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../../tsconfig/node.json", + "compilerOptions": { + "rootDir": ".", + "noEmit": true + }, + "include": ["./*.js", "./*.ts"] +} diff --git a/packages/oauth/oauth-client-browser/package.json b/packages/oauth/oauth-client-browser/package.json new file mode 100644 index 00000000000..a4fb66717b3 --- /dev/null +++ b/packages/oauth/oauth-client-browser/package.json @@ -0,0 +1,72 @@ +{ + "name": "@atproto/oauth-client-browser", + "version": "0.0.1", + "license": "MIT", + "description": "ATPROTO OAuth client for the browser (relies on WebCrypto & Indexed DB)", + "keywords": [ + "atproto", + "oauth", + "client", + "browser", + "webcrypto", + "indexed", + "db" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/oauth-client-browser" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "files": [ + "dist" + ], + "dependencies": { + "@atproto-labs/did-resolver": "workspace:*", + "@atproto-labs/handle-resolver": "workspace:*", + "@atproto-labs/simple-store": "workspace:*", + "@atproto/did": "workspace:*", + "@atproto/jwk": "workspace:*", + "@atproto/jwk-webcrypto": "workspace:*", + "@atproto/oauth-client": "workspace:*", + "@atproto/oauth-types": "workspace:*" + }, + "devDependencies": { + "@atproto/api": "workspace:*", + "@atproto/oauth-client": "workspace:*", + "@atproto/oauth-client-browser": "workspace:*", + "@atproto/oauth-types": "workspace:*", + "@atproto/xrpc": "workspace:*", + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-html": "^1.0.3", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@rollup/plugin-terser": "^0.4.4", + "@rollup/plugin-typescript": "^11.1.6", + "@types/react": "^18.2.50", + "@types/react-dom": "^18.2.18", + "autoprefixer": "^10.4.17", + "postcss": "^8.4.33", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "rollup": "^4.13.0", + "rollup-plugin-postcss": "^4.0.2", + "rollup-plugin-serve": "^1.1.1", + "tailwindcss": "^3.4.1", + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json", + "dev": "cd ./example && rollup --config rollup.config.js --watch" + } +} diff --git a/packages/oauth/oauth-client-browser/src/browser-oauth-client.ts b/packages/oauth/oauth-client-browser/src/browser-oauth-client.ts new file mode 100644 index 00000000000..240531e5b1e --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/browser-oauth-client.ts @@ -0,0 +1,514 @@ +import { HandleResolver } from '@atproto-labs/handle-resolver' +import { + AuthorizeOptions, + OAuthAgent, + OAuthCallbackError, + OAuthClient, + Session, + TokenSet, +} from '@atproto/oauth-client' +import { + OAuthClientId, + OAuthClientMetadataInput, + OAuthResponseMode, + atprotoLoopbackClientMetadata, + isOAuthClientIdDiscoverable, + isOAuthClientIdLoopback, + oauthClientMetadataSchema, +} from '@atproto/oauth-types' + +import { + BrowserOAuthDatabase, + DatabaseStore, +} from './browser-oauth-database.js' +import { BrowserRuntimeImplementation } from './browser-runtime-implementation.js' +import { LoginContinuedInParentWindowError } from './errors.js' +import { buildLoopbackClientId } from './util.js' + +export type BrowserOAuthClientOptions = { + clientMetadata?: OAuthClientMetadataInput + handleResolver?: HandleResolver | string | URL + responseMode?: OAuthResponseMode + plcDirectoryUrl?: string | URL + + crypto?: typeof globalThis.crypto + fetch?: typeof globalThis.fetch +} + +type EventDetails = { + updated: TokenSet + deleted: { sub: string } +} + +type CustomEventListener = ( + event: CustomEvent, +) => void + +const initEvent = ( + type: T, + detail: EventDetails[T], +) => new CustomEvent(type, { detail, cancelable: false, bubbles: false }) + +const NAMESPACE = `@@atproto/oauth-client-browser` + +//- Popup channel + +const POPUP_CHANNEL_NAME = `${NAMESPACE}(popup-channel)` +const POPUP_STATE_PREFIX = `${NAMESPACE}(popup-state):` + +type PopupChannelResultData = { + key: string + result: PromiseRejectedResult | PromiseFulfilledResult +} + +type PopupChannelAckData = { + key: string + ack: true +} + +type PopupChannelData = PopupChannelResultData | PopupChannelAckData + +//- Deleted channel + +const deletedChannel = new BroadcastChannel(`${NAMESPACE}(deleted-channel)`) + +type WrappedSessionStore = Disposable & DatabaseStore +const wrapSessionStore = ( + dbStore: DatabaseStore, + eventTarget: EventTarget, +) => { + const store: WrappedSessionStore = { + getKeys: async () => { + return dbStore.getKeys() + }, + get: async (sub) => { + return dbStore.get(sub) + }, + set: async (sub, session) => { + await dbStore.set(sub, session) + + eventTarget.dispatchEvent(initEvent('updated', session.tokenSet)) + }, + del: async (sub) => { + await dbStore.del(sub) + deletedChannel.postMessage(sub) + + eventTarget.dispatchEvent(initEvent('deleted', { sub })) + }, + clear: async () => { + await dbStore.clear?.() + }, + [Symbol.dispose]: () => { + deletedChannel.removeEventListener('message', onMessage) + }, + } + + const onMessage = (event: MessageEvent) => { + // Listen for "deleted" events from other windows. The content will already + // have been deleted from the store so we only need to notify the listeners. + if (event.source !== window) { + const sub = event.data + eventTarget.dispatchEvent(initEvent('deleted', { sub })) + } + } + + deletedChannel.addEventListener('message', onMessage) + + return store +} + +export type BrowserOAuthClientLoadOptions = Omit< + BrowserOAuthClientOptions, + 'clientMetadata' +> & { + clientId: OAuthClientId + signal?: AbortSignal +} + +export class BrowserOAuthClient extends OAuthClient { + static async load({ clientId, ...options }: BrowserOAuthClientLoadOptions) { + if (isOAuthClientIdLoopback(clientId)) { + return new BrowserOAuthClient({ + clientMetadata: atprotoLoopbackClientMetadata(clientId), + ...options, + }) + } else if (isOAuthClientIdDiscoverable(clientId)) { + const fetch = options?.fetch ?? globalThis.fetch + const request = new Request(clientId, { + redirect: 'error', + signal: options.signal, + }) + const response = await fetch(request) + + if (response.status !== 200) { + throw new TypeError( + `Failed to fetch client metadata: ${response.status}`, + ) + } + + const mime = response.headers.get('content-type')?.split(';')[0].trim() + if (mime !== 'application/json') { + throw new TypeError(`Invalid content type: ${mime}`) + } + + const json: unknown = await response.json() + + options.signal?.throwIfAborted() + + return new BrowserOAuthClient({ + clientMetadata: oauthClientMetadataSchema.parse(json), + ...options, + }) + } else { + throw new TypeError(`Invalid client id: ${clientId}`) + } + } + + readonly sessionStore: WrappedSessionStore + + private readonly eventTarget: EventTarget + private readonly database: BrowserOAuthDatabase + + constructor({ + clientMetadata, + handleResolver = 'https://bsky.social', + // "fragment" is safer as it is not sent to the server + responseMode = 'fragment', + plcDirectoryUrl = 'https://plc.directory', + crypto = globalThis.crypto, + fetch = globalThis.fetch, + }: BrowserOAuthClientOptions = {}) { + const database = new BrowserOAuthDatabase() + + const eventTarget = new EventTarget() + const sessionStore = wrapSessionStore( + database.getSessionStore(), + eventTarget, + ) + + super({ + clientMetadata: + clientMetadata == null + ? atprotoLoopbackClientMetadata( + buildLoopbackClientId(window.location), + ) + : clientMetadata, + responseMode, + fetch, + runtimeImplementation: new BrowserRuntimeImplementation(crypto), + plcDirectoryUrl, + handleResolver, + sessionStore, + stateStore: database.getStateStore(), + + didCache: database.getDidCache(), + handleCache: database.getHandleCache(), + dpopNonceCache: database.getDpopNonceCache(), + authorizationServerMetadataCache: + database.getAuthorizationServerMetadataCache(), + protectedResourceMetadataCache: + database.getProtectedResourceMetadataCache(), + }) + + this.sessionStore = sessionStore + + this.eventTarget = eventTarget + this.database = database + + fixLocation(this.clientMetadata) + } + + addEventListener( + type: T, + callback: CustomEventListener | null, + options?: AddEventListenerOptions | boolean, + ) { + this.eventTarget.addEventListener(type, callback as EventListener, options) + } + + removeEventListener( + type: string, + callback: CustomEventListener | null, + options?: EventListenerOptions | boolean, + ) { + this.eventTarget.removeEventListener( + type, + callback as EventListener, + options, + ) + } + + async restoreAll() { + const subs = await this.sessionStore.getKeys() + return Object.fromEntries( + await Promise.all( + subs.map(async (sub) => [sub, await this.restore(sub, false)] as const), + ), + ) + } + + async init(sub?: string, refresh?: boolean) { + const signInResult = await this.signInCallback() + if (signInResult) { + return signInResult + } else if (sub) { + const agent = await this.restore(sub, refresh) + return { agent } + } + } + + async signIn( + input: string, + options?: AuthorizeOptions & { signal?: AbortSignal }, + ) { + if (options?.display === 'popup') { + return this.signInPopup(input, options) + } else { + return this.signInRedirect(input, options) + } + } + + async signInRedirect(input: string, options?: AuthorizeOptions) { + const url = await this.authorize(input, options) + + window.location.href = url.href + + // back-forward cache + return new Promise((resolve, reject) => { + setTimeout(() => reject(new Error('User navigated back')), 5e3) + }) + } + + async signInPopup( + input: string, + options?: Omit & { signal?: AbortSignal }, + ): Promise { + // Open new window asap to prevent popup busting by browsers + const popupFeatures = 'width=600,height=600,menubar=no,toolbar=no' + let popup: Window | null = window.open( + 'about:blank', + '_blank', + popupFeatures, + ) + + const stateKey = `${Math.random().toString(36).slice(2)}` + + const url = await this.authorize(input, { + ...options, + state: `${POPUP_STATE_PREFIX}${stateKey}`, + display: options?.display ?? 'popup', + }) + + options?.signal?.throwIfAborted() + + if (popup) { + popup.window.location.href = url.href + } else { + popup = window.open(url.href, '_blank', popupFeatures) + } + + popup?.focus() + + return new Promise((resolve, reject) => { + const popupChannel = new BroadcastChannel(POPUP_CHANNEL_NAME) + + const cleanup = () => { + clearTimeout(timeout) + popupChannel.removeEventListener('message', onMessage) + popupChannel.close() + options?.signal?.removeEventListener('abort', cancel) + popup?.close() + } + + const cancel = () => { + // @TODO: Store fact that the request was cancelled, allowing any + // callback (e.g. in the popup) to revoke the session or credentials. + + reject(new Error(options?.signal?.aborted ? 'Aborted' : 'Timeout')) + cleanup() + } + + options?.signal?.addEventListener('abort', cancel) + + const timeout = setTimeout(cancel, 5 * 60e3) + + const onMessage = async ({ data }: MessageEvent) => { + if (data.key !== stateKey) return + if (!('result' in data)) return + + // Send acknowledgment to popup window + popupChannel.postMessage({ key: stateKey, ack: true }) + + cleanup() + + const { result } = data + if (result.status === 'fulfilled') { + const sub = result.value + try { + options?.signal?.throwIfAborted() + resolve(await this.restore(sub)) + } catch (err) { + reject(err) + void this.revoke(sub) + } + } else { + const { message, params } = result.reason + reject(new OAuthCallbackError(new URLSearchParams(params), message)) + } + } + + popupChannel.addEventListener('message', onMessage) + }) + } + + private readCallbackParams(): URLSearchParams | null { + const params = + this.responseMode === 'fragment' + ? new URLSearchParams(location.hash.slice(1)) + : new URLSearchParams(location.search) + + // Only if the current URL contains a valid oauth response params + if (!params.has('state') || !(params.has('code') || params.has('error'))) { + return null + } + + const matchesLocation = (url: URL) => + location.origin === url.origin && location.pathname === url.pathname + const redirectUrls = this.clientMetadata.redirect_uris.map( + (uri) => new URL(uri), + ) + + // Only if the current URL is one of the redirect_uris + if (!redirectUrls.some(matchesLocation)) return null + + return params + } + + async signInCallback() { + const params = this.readCallbackParams() + + // Not a (valid) OAuth redirect + if (!params) return null + + // Replace the current history entry without the params (this will prevent + // the following code to run again if the user refreshes the page) + history.replaceState(null, '', location.pathname) + + const sendResult = (message: PopupChannelResultData) => { + const popupChannel = new BroadcastChannel(POPUP_CHANNEL_NAME) + + return new Promise((resolve) => { + const cleanup = (result: boolean) => { + clearTimeout(timer) + popupChannel.removeEventListener('message', onMessage) + popupChannel.close() + resolve(result) + } + + const onTimeout = () => { + cleanup(false) + } + + const onMessage = ({ data }: MessageEvent) => { + if ('ack' in data && message.key === data.key) cleanup(true) + } + + popupChannel.addEventListener('message', onMessage) + popupChannel.postMessage(message) + // Receiving of "ack" should be very fast, giving it 500 ms anyway + const timer = setTimeout(onTimeout, 500) + }) + } + + return this.callback(params) + .then(async (result) => { + if (result.state?.startsWith(POPUP_STATE_PREFIX)) { + const receivedByParent = await sendResult({ + key: result.state.slice(POPUP_STATE_PREFIX.length), + result: { + status: 'fulfilled', + value: result.agent.sub, + }, + }) + + // Revoke the credentials if the parent window was closed + if (!receivedByParent) await result.agent.signOut() + + throw new LoginContinuedInParentWindowError() // signInPopup + } + + return result + }) + .catch(async (err) => { + if ( + err instanceof OAuthCallbackError && + err.state?.startsWith(POPUP_STATE_PREFIX) + ) { + await sendResult({ + key: err.state.slice(POPUP_STATE_PREFIX.length), + result: { + status: 'rejected', + reason: { + message: err.message, + params: Array.from(err.params.entries()), + }, + }, + }) + + throw new LoginContinuedInParentWindowError() // signInPopup + } + + // Most probable cause at this point is that the "state" parameter is + // invalid. + throw err + }) + .catch((err) => { + if (err instanceof LoginContinuedInParentWindowError) { + // parent will also try to close the popup + window.close() + } + + throw err + }) + } + + async [Symbol.asyncDispose]() { + // TODO This should be implemented using a DisposableStack + await this.sessionStore[Symbol.dispose]() + await this.database[Symbol.asyncDispose]() + } +} + +/** + * Since "localhost" is often used either in IP mode or in hostname mode, + * and because the redirect uris must use the IP mode, we need to make sure + * that the current location url is not using "localhost". + * + * This is required for the IndexedDB to work properly. Indeed, the IndexedDB + * is shared by origin, so we must ensure to be on the same origin as the + * redirect uris. + */ +function fixLocation(clientMetadata: OAuthClientMetadataInput) { + if (clientMetadata.client_id !== 'http://localhost/') return + if (window.location.hostname !== 'localhost') return + + const locationUrl = new URL(window.location.href) + + for (const uri of clientMetadata.redirect_uris) { + const url = new URL(uri) + if ( + url.port === locationUrl.port && + url.protocol === locationUrl.protocol && + (url.hostname === '127.0.0.1' || url.hostname === '[::1]') + ) { + window.location.hostname = url.hostname + + // Prevent APP from loading on the wrong hostname + throw new Error('Redirecting to loopback IP...') + } + } + + throw new Error( + `Please use the loopback IP address instead of ${locationUrl}`, + ) +} diff --git a/packages/oauth/oauth-client-browser/src/browser-oauth-database.ts b/packages/oauth/oauth-client-browser/src/browser-oauth-database.ts new file mode 100644 index 00000000000..feff7d87718 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/browser-oauth-database.ts @@ -0,0 +1,267 @@ +import { ResolvedHandle } from '@atproto-labs/handle-resolver' +import { SimpleStore, Value } from '@atproto-labs/simple-store' +import { DidDocument } from '@atproto/did' +import { Key } from '@atproto/jwk' +import { WebcryptoKey } from '@atproto/jwk-webcrypto' +import { InternalStateData, Session, TokenSet } from '@atproto/oauth-client' +import { + OAuthAuthorizationServerMetadata, + OAuthProtectedResourceMetadata, +} from '@atproto/oauth-types' + +import { DB, DBObjectStore } from './indexed-db/index.js' +import { TupleUnion } from './util.js' + +type Item = { + value: V + expiresAt?: string // ISO Date +} + +type EncodedKey = { + keyId: string + keyPair: CryptoKeyPair +} + +function encodeKey(key: Key): EncodedKey { + if (!(key instanceof WebcryptoKey) || !key.kid) { + throw new Error('Invalid key object') + } + return { + keyId: key.kid, + keyPair: key.cryptoKeyPair, + } +} + +async function decodeKey(encoded: EncodedKey): Promise { + return WebcryptoKey.fromKeypair(encoded.keyPair, encoded.keyId) +} + +export type Schema = { + state: Item<{ + dpopKey: EncodedKey + + iss: string + nonce: string + verifier?: string + appState?: string + }> + session: Item<{ + dpopKey: EncodedKey + + tokenSet: TokenSet + }> + + didCache: Item + dpopNonceCache: Item + handleCache: Item + authorizationServerMetadataCache: Item + protectedResourceMetadataCache: Item +} + +export type DatabaseStore = SimpleStore & { + getKeys: () => Promise +} + +const STORES: TupleUnion = [ + 'state', + 'session', + + 'didCache', + 'dpopNonceCache', + 'handleCache', + 'authorizationServerMetadataCache', + 'protectedResourceMetadataCache', +] + +export type BrowserOAuthDatabaseOptions = { + name?: string + durability?: 'strict' | 'relaxed' + cleanupInterval?: number +} + +export class BrowserOAuthDatabase { + #dbPromise: Promise> + #cleanupInterval?: ReturnType + + constructor(options?: BrowserOAuthDatabaseOptions) { + this.#dbPromise = DB.open( + options?.name ?? '@atproto-oauth-client', + [ + (db) => { + for (const name of STORES) { + const store = db.createObjectStore(name, { autoIncrement: true }) + store.createIndex('expiresAt', 'expiresAt', { unique: false }) + } + }, + ], + { durability: options?.durability ?? 'strict' }, + ) + + this.#cleanupInterval = setInterval(() => { + void this.cleanup() + }, options?.cleanupInterval ?? 30e3) + } + + protected async run( + storeName: N, + mode: 'readonly' | 'readwrite', + fn: (s: DBObjectStore) => R | Promise, + ): Promise { + const db = await this.#dbPromise + return await db.transaction([storeName], mode, (tx) => + fn(tx.objectStore(storeName)), + ) + } + + protected createStore( + name: N, + { + encode, + decode, + expiresAt, + }: { + encode: (value: V) => Schema[N]['value'] | PromiseLike + decode: (encoded: Schema[N]['value']) => V | PromiseLike + expiresAt: (value: V) => null | Date + }, + ): DatabaseStore { + return { + get: async (key) => { + // Find item in store + const item = await this.run(name, 'readonly', (store) => store.get(key)) + + // Not found + if (item === undefined) return undefined + + // Too old (delete) + if (item.expiresAt != null && new Date(item.expiresAt) < new Date()) { + await this.run(name, 'readwrite', (store) => store.delete(key)) + return undefined + } + + // Item found and valid. Decode + return decode(item.value) + }, + + getKeys: async () => { + const keys = await this.run(name, 'readonly', (store) => + store.getAllKeys(), + ) + return keys.filter((key): key is string => typeof key === 'string') + }, + + set: async (key, value) => { + // Create encoded item record + const item = { + value: await encode(value), + expiresAt: expiresAt(value)?.toISOString(), + } as Schema[N] + + // Store item record + await this.run(name, 'readwrite', (store) => store.put(item, key)) + }, + + del: async (key) => { + // Delete + await this.run(name, 'readwrite', (store) => store.delete(key)) + }, + } + } + + getSessionStore(): DatabaseStore { + return this.createStore('session', { + expiresAt: ({ tokenSet }) => + tokenSet.refresh_token || tokenSet.expires_at == null + ? null + : new Date(tokenSet.expires_at), + encode: ({ dpopKey, ...session }) => ({ + ...session, + dpopKey: encodeKey(dpopKey), + }), + decode: async ({ dpopKey, ...encoded }) => ({ + ...encoded, + dpopKey: await decodeKey(dpopKey), + }), + }) + } + + getStateStore(): DatabaseStore { + return this.createStore('state', { + expiresAt: (_value) => new Date(Date.now() + 10 * 60e3), + encode: ({ dpopKey, ...session }) => ({ + ...session, + dpopKey: encodeKey(dpopKey), + }), + decode: async ({ dpopKey, ...encoded }) => ({ + ...encoded, + dpopKey: await decodeKey(dpopKey), + }), + }) + } + + getDpopNonceCache(): undefined | DatabaseStore { + return this.createStore('dpopNonceCache', { + expiresAt: (_value) => new Date(Date.now() + 600e3), + encode: (value) => value, + decode: (encoded) => encoded, + }) + } + + getDidCache(): undefined | DatabaseStore { + return this.createStore('didCache', { + expiresAt: (_value) => new Date(Date.now() + 60e3), + encode: (value) => value, + decode: (encoded) => encoded, + }) + } + + getHandleCache(): undefined | DatabaseStore { + return this.createStore('handleCache', { + expiresAt: (_value) => new Date(Date.now() + 60e3), + encode: (value) => value, + decode: (encoded) => encoded, + }) + } + + getAuthorizationServerMetadataCache(): + | undefined + | DatabaseStore { + return this.createStore('authorizationServerMetadataCache', { + expiresAt: (_value) => new Date(Date.now() + 60e3), + encode: (value) => value, + decode: (encoded) => encoded, + }) + } + + getProtectedResourceMetadataCache(): + | undefined + | DatabaseStore { + return this.createStore('protectedResourceMetadataCache', { + expiresAt: (_value) => new Date(Date.now() + 60e3), + encode: (value) => value, + decode: (encoded) => encoded, + }) + } + + async cleanup() { + const db = await this.#dbPromise + + for (const name of STORES) { + await db.transaction([name], 'readwrite', (tx) => + tx + .objectStore(name) + .index('expiresAt') + .deleteAll(IDBKeyRange.upperBound(Date.now())), + ) + } + } + + async [Symbol.asyncDispose]() { + clearInterval(this.#cleanupInterval) + const dbPromise = this.#dbPromise + this.#dbPromise = Promise.reject(new Error('Database has been disposed')) + + const db = await dbPromise + await (db[Symbol.asyncDispose] || db[Symbol.dispose]).call(db) + } +} diff --git a/packages/oauth/oauth-client-browser/src/browser-runtime-implementation.ts b/packages/oauth/oauth-client-browser/src/browser-runtime-implementation.ts new file mode 100644 index 00000000000..d230ccd22c7 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/browser-runtime-implementation.ts @@ -0,0 +1,65 @@ +import { WebcryptoKey } from '@atproto/jwk-webcrypto' +import { + DigestAlgorithm, + Key, + RuntimeImplementation, +} from '@atproto/oauth-client' + +export class BrowserRuntimeImplementation implements RuntimeImplementation { + // https://developer.mozilla.org/en-US/docs/Web/API/LockManager/request + requestLock = navigator.locks?.request + ? (name: string, fn: () => T | PromiseLike): Promise => + navigator.locks.request(name, { mode: 'exclusive' }, async () => fn()) + : undefined + + constructor(private crypto = globalThis.crypto) { + if (!crypto?.subtle) { + throw new Error( + 'Crypto with CryptoSubtle is required. If running in a browser, make sure the current page is loaded over HTTPS.', + ) + } + + if (!this.requestLock) { + // There is no real need to polyfill this on older browsers. The + // oauth-client library will try and recover from concurrency issues when + // refreshing tokens. + console.warn( + 'Locks API not available. You should consider using a more recent browser.', + ) + } + } + + async createKey(algs: string[]): Promise { + return WebcryptoKey.generate(algs) + } + + getRandomValues(byteLength: number): Uint8Array { + const bytes = new Uint8Array(byteLength) + this.crypto.getRandomValues(bytes) + return bytes + } + + async digest( + bytes: Uint8Array, + algorithm: DigestAlgorithm, + ): Promise { + const buffer = await this.crypto.subtle.digest( + digestAlgorithmToSubtle(algorithm), + bytes, + ) + return new Uint8Array(buffer) + } +} + +function digestAlgorithmToSubtle({ + name, +}: DigestAlgorithm): AlgorithmIdentifier { + switch (name) { + case 'sha256': + case 'sha384': + case 'sha512': + return `SHA-${name.slice(-3)}` + default: + throw new TypeError(`Unknown hash algorithm ${name}`) + } +} diff --git a/packages/oauth/oauth-client-browser/src/disposable-polyfill/README.md b/packages/oauth/oauth-client-browser/src/disposable-polyfill/README.md new file mode 100644 index 00000000000..f7de611a52a --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/disposable-polyfill/README.md @@ -0,0 +1,9 @@ +# Ppolyfill for Symbol.dispose and Symbol.asyncDispose + +While typescript does transpile `using` to `try`/`finally` blocks, it does not +provide a polyfill for the `Symbol.dispose` and `Symbol.asyncDispose` symbols. +This package provides a polyfill for these symbols. + +This _could_ be used as a standalone library, but the Bluesky dev team does not +want to maintain it as such. As it is currently only used by the +`@atproto/oauth-client-browser` package, it is included here. diff --git a/packages/oauth/oauth-client-browser/src/disposable-polyfill/index.ts b/packages/oauth/oauth-client-browser/src/disposable-polyfill/index.ts new file mode 100644 index 00000000000..ddb9073b163 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/disposable-polyfill/index.ts @@ -0,0 +1,10 @@ +// Code compiled with tsc supports "using" and "await using" syntax. This +// features is supported by downleveling the code to ES2017. The downleveling +// relies on `Symbol.dispose` and `Symbol.asyncDispose` symbols. These symbols +// might not be available in all environments. This package provides a polyfill +// for these symbols. + +// @ts-expect-error +Symbol.dispose ??= Symbol('@@dispose') +// @ts-expect-error +Symbol.asyncDispose ??= Symbol('@@asyncDispose') diff --git a/packages/oauth/oauth-client-browser/src/errors.ts b/packages/oauth/oauth-client-browser/src/errors.ts new file mode 100644 index 00000000000..37d084d577d --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/errors.ts @@ -0,0 +1,10 @@ +/** + * Special error class destined to be thrown when the login process was + * performed in a popup and should be continued in the parent/initiating window. + */ +export class LoginContinuedInParentWindowError extends Error { + code = 'LOGIN_CONTINUED_IN_PARENT_WINDOW' + constructor() { + super('Login complete, please close the popup window.') + } +} diff --git a/packages/oauth/oauth-client-browser/src/index.ts b/packages/oauth/oauth-client-browser/src/index.ts new file mode 100644 index 00000000000..7e2ad2aa2b3 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/index.ts @@ -0,0 +1,16 @@ +import './disposable-polyfill/index.js' + +export * from '@atproto-labs/did-resolver' +export { + AppViewHandleResolver, + AtprotoHandleResolver, +} from '@atproto-labs/handle-resolver' + +export * from '@atproto/did' +export * from '@atproto/jwk-webcrypto' +export * from '@atproto/oauth-client' +export * from '@atproto/oauth-types' + +export * from './browser-oauth-client.js' +export * from './errors.js' +export { buildLoopbackClientId } from './util.js' diff --git a/packages/oauth/oauth-client-browser/src/indexed-db-store.ts b/packages/oauth/oauth-client-browser/src/indexed-db-store.ts new file mode 100644 index 00000000000..53980b811e8 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db-store.ts @@ -0,0 +1,79 @@ +import { SimpleStore, Key, Value } from '@atproto-labs/simple-store' +import { DB, DBObjectStore } from './indexed-db/index.js' + +const storeName = 'store' +type Item = { + value: V + createdAt: Date +} + +export class IndexedDBStore< + K extends Extract, + V extends Value, +> implements SimpleStore +{ + constructor( + private dbName: string, + protected maxAge = 600e3, + ) {} + + protected async run( + mode: 'readonly' | 'readwrite', + fn: (s: DBObjectStore>) => R | Promise, + ): Promise { + const db = await DB.open<{ store: Item }>( + this.dbName, + [ + (db) => { + const store = db.createObjectStore(storeName) + store.createIndex('createdAt', 'createdAt', { unique: false }) + }, + ], + { durability: 'strict' }, + ) + try { + return await db.transaction([storeName], mode, (tx) => + fn(tx.objectStore(storeName)), + ) + } finally { + await db[Symbol.dispose]() + } + } + + async get(key: K): Promise { + const item = await this.run('readonly', (store) => store.get(key)) + + if (!item) return undefined + + const age = Date.now() - item.createdAt.getTime() + if (age > this.maxAge) { + await this.del(key) + return undefined + } + + return item?.value + } + + async set(key: K, value: V): Promise { + await this.run('readwrite', (store) => { + store.put({ value, createdAt: new Date() }, key) + }) + } + + async del(key: K): Promise { + await this.run('readwrite', (store) => { + store.delete(key) + }) + } + + async deleteOutdated() { + const upperBound = new Date(Date.now() - this.maxAge) + const query = IDBKeyRange.upperBound(upperBound) + + await this.run('readwrite', async (store) => { + const index = store.index('createdAt') + const keys = await index.getAllKeys(query) + for (const key of keys) store.delete(key) + }) + } +} diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/README.md b/packages/oauth/oauth-client-browser/src/indexed-db/README.md new file mode 100644 index 00000000000..595be23f4ba --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/README.md @@ -0,0 +1,8 @@ +# IndexedDB utilities + +This is a small wrapper around the IndexedDB API that provides a simple way to +store and retrieve data from an IndexedDB database. + +This _could_ be used as a standalone library, but the Bluesky dev team does not +want to maintain it as such. As it is currently only used by the +`@atproto/oauth-client-browser` package, it is included here. diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/db-index.ts b/packages/oauth/oauth-client-browser/src/indexed-db/db-index.ts new file mode 100644 index 00000000000..dc041f024b5 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/db-index.ts @@ -0,0 +1,44 @@ +import { ObjectStoreSchema } from './schema.js' +import { promisify } from './util.js' + +export class DBIndex { + constructor(private idbIndex: IDBIndex) {} + + count(query?: IDBValidKey | IDBKeyRange) { + return promisify(this.idbIndex.count(query)) + } + + get(query: IDBValidKey | IDBKeyRange) { + return promisify(this.idbIndex.get(query)) + } + + getKey(query: IDBValidKey | IDBKeyRange) { + return promisify(this.idbIndex.getKey(query)) + } + + getAll(query?: IDBValidKey | IDBKeyRange | null, count?: number) { + return promisify(this.idbIndex.getAll(query, count)) + } + + getAllKeys(query?: IDBValidKey | IDBKeyRange | null, count?: number) { + return promisify(this.idbIndex.getAllKeys(query, count)) + } + + deleteAll(query?: IDBValidKey | IDBKeyRange | null): Promise { + return new Promise((resolve, reject) => { + const result = this.idbIndex.openCursor(query) + result.onsuccess = function (event) { + const cursor = (event as any).target.result as IDBCursorWithValue + if (cursor) { + cursor.delete() + cursor.continue() + } else { + resolve() + } + } + result.onerror = function (event) { + reject((event.target as any)?.error || new Error('Unexpected error')) + } + }) + } +} diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/db-object-store.ts b/packages/oauth/oauth-client-browser/src/indexed-db/db-object-store.ts new file mode 100644 index 00000000000..9b15fcad6a9 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/db-object-store.ts @@ -0,0 +1,47 @@ +import { DBIndex } from './db-index.js' +import { ObjectStoreSchema } from './schema.js' +import { promisify } from './util.js' + +export class DBObjectStore { + constructor(private idbObjStore: IDBObjectStore) {} + + get name() { + return this.idbObjStore.name + } + + index(name: string) { + return new DBIndex(this.idbObjStore.index(name)) + } + + get(key: IDBValidKey | IDBKeyRange) { + return promisify(this.idbObjStore.get(key)) + } + + getKey(query: IDBValidKey | IDBKeyRange) { + return promisify(this.idbObjStore.getKey(query)) + } + + getAll(query?: IDBValidKey | IDBKeyRange | null, count?: number) { + return promisify(this.idbObjStore.getAll(query, count)) + } + + getAllKeys(query?: IDBValidKey | IDBKeyRange | null, count?: number) { + return promisify(this.idbObjStore.getAllKeys(query, count)) + } + + add(value: Schema, key?: IDBValidKey) { + return promisify(this.idbObjStore.add(value, key)) + } + + put(value: Schema, key?: IDBValidKey) { + return promisify(this.idbObjStore.put(value, key)) + } + + delete(key: IDBValidKey | IDBKeyRange) { + return promisify(this.idbObjStore.delete(key)) + } + + clear() { + return promisify(this.idbObjStore.clear()) + } +} diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/db-transaction.ts b/packages/oauth/oauth-client-browser/src/indexed-db/db-transaction.ts new file mode 100644 index 00000000000..79905e796e8 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/db-transaction.ts @@ -0,0 +1,52 @@ +import { DBObjectStore } from './db-object-store.js' +import { DatabaseSchema } from './schema.js' + +export class DBTransaction + implements Disposable +{ + #tx: IDBTransaction | null + + constructor(tx: IDBTransaction) { + this.#tx = tx + + const onAbort = () => { + cleanup() + } + const onComplete = () => { + cleanup() + } + const cleanup = () => { + this.#tx = null + tx.removeEventListener('abort', onAbort) + tx.removeEventListener('complete', onComplete) + } + tx.addEventListener('abort', onAbort) + tx.addEventListener('complete', onComplete) + } + + protected get tx(): IDBTransaction { + if (!this.#tx) throw new Error('Transaction already ended') + return this.#tx + } + + async abort() { + const { tx } = this + this.#tx = null + tx.abort() + } + + async commit() { + const { tx } = this + this.#tx = null + tx.commit?.() + } + + objectStore(name: T) { + const store = this.tx.objectStore(name) + return new DBObjectStore(store) + } + + [Symbol.dispose](): void { + if (this.#tx) this.commit() + } +} diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/db.ts b/packages/oauth/oauth-client-browser/src/indexed-db/db.ts new file mode 100644 index 00000000000..03ffe61d8c2 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/db.ts @@ -0,0 +1,114 @@ +import { DatabaseSchema } from './schema.js' +import { DBTransaction } from './db-transaction.js' + +export class DB implements Disposable { + static async open( + dbName: string, + migrations: ReadonlyArray<(db: IDBDatabase) => void>, + txOptions?: IDBTransactionOptions, + ) { + const db = await new Promise((resolve, reject) => { + const request = indexedDB.open(dbName, migrations.length) + + request.onerror = () => reject(request.error) + request.onsuccess = () => resolve(request.result) + request.onupgradeneeded = ({ oldVersion, newVersion }) => { + const db = request.result + try { + for ( + let version = oldVersion; + version < (newVersion ?? migrations.length); + ++version + ) { + const migration = migrations[version] + if (migration) migration(db) + else throw new Error(`Missing migration for version ${version}`) + } + } catch (err) { + db.close() + reject(err) + } + } + }) + + return new DB(db, txOptions) + } + + #db: null | IDBDatabase + + constructor( + db: IDBDatabase, + protected readonly txOptions?: IDBTransactionOptions, + ) { + this.#db = db + + const cleanup = () => { + this.#db = null + db.removeEventListener('versionchange', cleanup) + db.removeEventListener('close', cleanup) + db.close() // Can we call close on a "closed" database? + } + + db.addEventListener('versionchange', cleanup) + db.addEventListener('close', cleanup) + } + + protected get db(): IDBDatabase { + if (!this.#db) throw new Error('Database closed') + return this.#db + } + + get name() { + return this.db.name + } + + get objectStoreNames() { + return this.db.objectStoreNames + } + + get version() { + return this.db.version + } + + async transaction( + storeNames: T, + mode: IDBTransactionMode, + run: (tx: DBTransaction>) => R | PromiseLike, + ): Promise { + // eslint-disable-next-line no-async-promise-executor + return new Promise(async (resolve, reject) => { + try { + const tx = this.db.transaction(storeNames, mode, this.txOptions) + let result: { done: false } | { done: true; value: R } = { done: false } + + tx.oncomplete = () => { + if (result.done) resolve(result.value) + else reject(new Error('Transaction completed without result')) + } + tx.onerror = () => reject(tx.error) + tx.onabort = () => reject(tx.error || new Error('Transaction aborted')) + + try { + const value = await run(new DBTransaction(tx)) + result = { done: true, value } + tx.commit() + } catch (err) { + tx.abort() + throw err + } + } catch (err) { + reject(err) + } + }) + } + + close() { + const { db } = this + this.#db = null + db.close() + } + + [Symbol.dispose]() { + if (this.#db) return this.close() + } +} diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/index.ts b/packages/oauth/oauth-client-browser/src/indexed-db/index.ts new file mode 100644 index 00000000000..c14ed746899 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/index.ts @@ -0,0 +1,6 @@ +import '../disposable-polyfill/index.js' + +export * from './db.js' +export * from './db-index.js' +export * from './db-object-store.js' +export * from './db-transaction.js' diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/schema.ts b/packages/oauth/oauth-client-browser/src/indexed-db/schema.ts new file mode 100644 index 00000000000..f8736b2a19d --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/schema.ts @@ -0,0 +1,2 @@ +export type ObjectStoreSchema = NonNullable +export type DatabaseSchema = Record diff --git a/packages/oauth/oauth-client-browser/src/indexed-db/util.ts b/packages/oauth/oauth-client-browser/src/indexed-db/util.ts new file mode 100644 index 00000000000..6e52b5919c4 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/indexed-db/util.ts @@ -0,0 +1,20 @@ +export function promisify(request: IDBRequest) { + const promise = new Promise((resolve, reject) => { + const cleanup = () => { + request.removeEventListener('success', success) + request.removeEventListener('error', error) + } + const success = () => { + resolve(request.result) + cleanup() + } + const error = () => { + reject(request.error) + cleanup() + } + request.addEventListener('success', success) + request.addEventListener('error', error) + }) + + return promise +} diff --git a/packages/oauth/oauth-client-browser/src/util.ts b/packages/oauth/oauth-client-browser/src/util.ts new file mode 100644 index 00000000000..aabc9f931c2 --- /dev/null +++ b/packages/oauth/oauth-client-browser/src/util.ts @@ -0,0 +1,33 @@ +import { isLoopbackHost } from '@atproto/oauth-types' + +export type Simplify = { [K in keyof T]: T[K] } & NonNullable +export type TupleUnion = { + [S in U]: Exclude extends never + ? [...R, S] + : TupleUnion, [...R, S]> +}[U] + +/** + * @example + * ```ts + * const clientId = buildLoopbackClientId(window.location) + * ``` + */ +export function buildLoopbackClientId( + location: { + hostname: string + pathname: string + port: string + }, + localhost = '127.0.0.1', +): string { + if (!isLoopbackHost(location.hostname)) { + throw new TypeError(`Expected a loopback host, got ${location.hostname}`) + } + + const redirectUri = `http://${location.hostname === 'localhost' ? localhost : location.hostname}${location.port && !location.port.startsWith(':') ? `:${location.port}` : location.port}${location.pathname}` + + return `http://localhost${ + location.pathname === '/' ? '' : location.pathname + }?redirect_uri=${encodeURIComponent(redirectUri)}` +} diff --git a/packages/oauth/oauth-client-browser/tsconfig.build.json b/packages/oauth/oauth-client-browser/tsconfig.build.json new file mode 100644 index 00000000000..2ef4f334355 --- /dev/null +++ b/packages/oauth/oauth-client-browser/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig/isomorphic.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/oauth/oauth-client-browser/tsconfig.json b/packages/oauth/oauth-client-browser/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/oauth/oauth-client-browser/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/oauth/oauth-client/README.md b/packages/oauth/oauth-client/README.md new file mode 100644 index 00000000000..365ad454f5b --- /dev/null +++ b/packages/oauth/oauth-client/README.md @@ -0,0 +1,124 @@ +# @atproto/oauth-client: atproto flavoured OAuth client + +Core library for implementing ATPROTO OAuth clients. + +For a browser specific implementation, see `@atproto/oauth-client-browser`. +For a node specific implementation, see `@atproto/oauth-client-node`. + +```ts +import { OAuthClient } from '@atproto/oauth-client' +import { JoseKey } from '@atproto/jwk-jose' // NodeJS/Browser only + +const client = new OAuthClient({ + handleResolver: 'https://bsky.social', // On node, you should use a DNS based resolver + responseMode: 'query', // or "fragment" or "form_post" (for backend clients only) + clientMetadata: { + // These must be the same metadata as the one exposed on the + // "/.well-known/oauth-client-metadata" endpoint (except when using a + // loopback client) + }, + + runtimeImplementation: { + // A runtime specific implementation of the crypto operations needed by the + // OAuth client. + + createKey(algs: string[]): Promise { + // algs is an ordered array of preferred algorithms (e.g. ['RS256', 'ES256']) + + // Note, in browser environments, it is better to use non extractable keys + // to prevent leaking the private key. This can be done using the + // WebcryptoKey class from the "@atproto/jwk-webcrypto" package. The + // inconvenient of these keys (which is also what makes them stronger) is + // that the only way to persist them across browser reloads is to save + // them in the indexed DB. + return JoseKey.generate(algs) + }, + getRandomValues(length: number): Uint8Array | PromiseLike { + // length is the number of bytes to generate + + const bytes = new Uint8Array(byteLength) + crypto.getRandomValues(bytes) + return bytes + }, + digest( + bytes: Uint8Array, + algorithm: { name: 'sha256' | 'sha384' | 'sha512' }, + ): Uint8Array | PromiseLike { + // sha256 is required. Unsupported algorithms should throw an error. + + const buffer = await this.crypto.subtle.digest( + algorithm.name.startsWith('sha') + ? `SHA-${algorithm.name.slice(-3)}` + : 'invalid', + bytes, + ) + return new Uint8Array(buffer) + }, + }, + + stateStore: { + // A store for saving state data while the user is being redirected to the + // authorization server. + + set(key: string, internalState: InternalStateData): Promise { + throw new Error('Not implemented') + }, + get(key: string): Promise { + throw new Error('Not implemented') + }, + del(key: string): Promise { + throw new Error('Not implemented') + }, + }, + + sessionStore: { + // A store for saving session data. + + set(sub: string, session: Session): Promise { + throw new Error('Not implemented') + }, + get(sub: string): Promise { + throw new Error('Not implemented') + }, + del(sub: string): Promise { + throw new Error('Not implemented') + }, + }, + + keyset: [ + // For backend clients only, a list of private keys to use for signing + // credentials. These keys MUST correspond to the public keys exposed on the + // "jwks_uri" of the client metadata. + await JoseKey.fromImportable(process.env.PRIVATE_KEY_1), + await JoseKey.fromImportable(process.env.PRIVATE_KEY_2), + await JoseKey.fromImportable(process.env.PRIVATE_KEY_3), + ], +}) +``` + +```ts +const url = await client.authorize('foo.bsky.team', { + state: '434321', + prompt: 'consent', + scope: 'email', + ui_locales: 'fr', +}) + +// Make user visit "url". Then, once it was redirected to the callback URI, call: + +const params = new URLSearchParams('code=...&state=...') +const result = await client.callback(params) + +// Verify the state (e.g. to link to an internal user) +result.state === '434321' + +// The authenticated user's identifier +result.agent.sub + +// Make an authenticated request to the server. New credentials will be +// automatically fetched if needed (causing sessionStore.set() to be called). +await result.agent.request('/xrpc/foo.bar') + +// revoke credentials on the server (causing sessionStore.del() to be called) +await result.agent.signOut() +``` diff --git a/packages/oauth/oauth-client/package.json b/packages/oauth/oauth-client/package.json new file mode 100644 index 00000000000..a7d21d4bdbb --- /dev/null +++ b/packages/oauth/oauth-client/package.json @@ -0,0 +1,46 @@ +{ + "name": "@atproto/oauth-client", + "version": "0.0.1", + "license": "MIT", + "description": "OAuth client for ATPROTO PDS. This package serves as common base for environment-specific implementations (NodeJS, Browser, React-Native).", + "keywords": [ + "atproto", + "oauth", + "client", + "isomorphic" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/oauth-client" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/did-resolver": "workspace:*", + "@atproto-labs/fetch": "workspace:*", + "@atproto-labs/handle-resolver": "workspace:*", + "@atproto-labs/identity-resolver": "workspace:*", + "@atproto-labs/simple-store": "workspace:*", + "@atproto-labs/simple-store-memory": "workspace:*", + "@atproto/did": "workspace:*", + "@atproto/jwk": "workspace:*", + "@atproto/oauth-types": "workspace:*", + "multiformats": "^9.9.0", + "zod": "^3.23.8" + }, + "devDependencies": { + "typescript": "^5.3.3" + }, + "scripts": { + "build": "tsc --build tsconfig.build.json" + } +} diff --git a/packages/oauth/oauth-client/src/constants.ts b/packages/oauth/oauth-client/src/constants.ts new file mode 100644 index 00000000000..2ee644cde0f --- /dev/null +++ b/packages/oauth/oauth-client/src/constants.ts @@ -0,0 +1,4 @@ +/** + * Per ATProto spec (OpenID uses RS256) + */ +export const FALLBACK_ALG = 'ES256' diff --git a/packages/oauth/oauth-client/src/fetch-dpop.ts b/packages/oauth/oauth-client/src/fetch-dpop.ts new file mode 100644 index 00000000000..f61afbe8e25 --- /dev/null +++ b/packages/oauth/oauth-client/src/fetch-dpop.ts @@ -0,0 +1,235 @@ +import { Fetch, FetchContext, cancelBody, peekJson } from '@atproto-labs/fetch' +import { SimpleStore } from '@atproto-labs/simple-store' +import { Key } from '@atproto/jwk' +import { base64url } from 'multiformats/bases/base64' + +// "undefined" in non https environments or environments without crypto +const subtle = globalThis.crypto?.subtle as SubtleCrypto | undefined + +const ReadableStream = globalThis.ReadableStream as + | typeof globalThis.ReadableStream + | undefined + +export type DpopFetchWrapperOptions = { + key: Key + iss: string + nonces: SimpleStore + supportedAlgs?: string[] + sha256?: (input: string) => Promise + + /** + * Is the intended server an authorization server (true) or a resource server + * (false)? Setting this may allow to avoid parsing the response body to + * determine the dpop-nonce. + * + * @default undefined + */ + isAuthServer?: boolean + fetch?: Fetch +} + +export function dpopFetchWrapper({ + key, + iss, + supportedAlgs, + nonces, + sha256 = typeof subtle !== 'undefined' ? subtleSha256 : undefined, + isAuthServer, + fetch = globalThis.fetch, +}: DpopFetchWrapperOptions): Fetch { + if (!sha256) { + throw new TypeError( + `crypto.subtle is not available in this environment. Please provide a sha256 function.`, + ) + } + + const alg = negotiateAlg(key, supportedAlgs) + + return async function (this: C, input, init) { + if (!key.algorithms.includes(alg)) { + throw new TypeError(`Key does not support the algorithm ${alg}`) + } + + const request: Request = + init == null && input instanceof Request + ? input + : new Request(input, init) + + const authorizationHeader = request.headers.get('Authorization') + const ath = authorizationHeader?.startsWith('DPoP ') + ? await sha256(authorizationHeader.slice(5)) + : undefined + + const { method, url } = request + const { origin } = new URL(url) + + let initNonce: string | undefined + try { + initNonce = await nonces.get(origin) + } catch { + // Ignore get errors, we will just not send a nonce + } + + const initProof = await buildProof( + key, + alg, + iss, + method, + url, + initNonce, + ath, + ) + request.headers.set('DPoP', initProof) + + const initResponse = await fetch.call(this, request) + + // Make sure the response body is consumed. Either by the caller (when the + // response is returned), of if an error is thrown (catch block). + + const nextNonce = initResponse.headers.get('DPoP-Nonce') + if (!nextNonce || nextNonce === initNonce) { + // No nonce was returned or it is the same as the one we sent. No need to + // update the nonce store, or retry the request. + return initResponse + } + + // Store the fresh nonce for future requests + try { + await nonces.set(origin, nextNonce) + } catch { + // Ignore set errors + } + + const shouldRetry = await isUseDpopNonceError(initResponse, isAuthServer) + if (!shouldRetry) { + // Not a "use_dpop_nonce" error, so there is no need to retry + return initResponse + } + + // If the input stream was already consumed, we cannot retry the request. A + // solution would be to clone() the request but that would bufferize the + // entire stream in memory which can lead to memory starvation. Instead, we + // will return the original response and let the calling code handle retries. + + if (input === request) { + // The input request body was consumed. We cannot retry the request. + return initResponse + } + + if (ReadableStream && init?.body instanceof ReadableStream) { + // The init body was consumed. We cannot retry the request. + return initResponse + } + + // We will now retry the request with the fresh nonce. + + // The initial response body must be consumed (see cancelBody's doc). + await cancelBody(initResponse, 'log') + + const nextProof = await buildProof( + key, + alg, + iss, + method, + url, + nextNonce, + ath, + ) + const nextRequest = new Request(input, init) + nextRequest.headers.set('DPoP', nextProof) + + return fetch.call(this, nextRequest) + } +} + +async function buildProof( + key: Key, + alg: string, + iss: string, + htm: string, + htu: string, + nonce?: string, + ath?: string, +) { + if (!key.bareJwk) { + throw new Error('Only asymmetric keys can be used as DPoP proofs') + } + + const now = Math.floor(Date.now() / 1e3) + + return key.createJwt( + { + alg, + typ: 'dpop+jwt', + jwk: key.bareJwk, + }, + { + iss, + iat: now, + exp: now + 10, + // Any collision will cause the request to be rejected by the server. no biggie. + jti: Math.random().toString(36).slice(2), + htm, + htu, + nonce, + ath, + }, + ) +} + +async function isUseDpopNonceError( + response: Response, + isAuthServer?: boolean, +): Promise { + // https://datatracker.ietf.org/doc/html/rfc6750#section-3 + // https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no + if (isAuthServer === undefined || isAuthServer === false) { + if (response.status === 401) { + const wwwAuth = response.headers.get('WWW-Authenticate') + if (wwwAuth?.startsWith('DPoP')) { + return wwwAuth.includes('error="use_dpop_nonce"') + } + } + } + + // https://datatracker.ietf.org/doc/html/rfc9449#name-authorization-server-provid + if (isAuthServer === undefined || isAuthServer === true) { + if (response.status === 400) { + try { + const json = await peekJson(response, 10 * 1024) + return typeof json === 'object' && json?.['error'] === 'use_dpop_nonce' + } catch { + // Response too big (to be "use_dpop_nonce" error) or invalid JSON + return false + } + } + } + + return false +} + +function negotiateAlg(key: Key, supportedAlgs: string[] | undefined): string { + if (supportedAlgs) { + // Use order of supportedAlgs as preference + const alg = supportedAlgs.find((a) => key.algorithms.includes(a)) + if (alg) return alg + } else { + const [alg] = key.algorithms + if (alg) return alg + } + + throw new Error('Key does not match any alg supported by the server') +} + +async function subtleSha256(input: string): Promise { + if (subtle == null) { + throw new Error( + `crypto.subtle is not available in this environment. Please provide a sha256 function.`, + ) + } + + const bytes = new TextEncoder().encode(input) + const digest = await subtle.digest('SHA-256', bytes) + const digestBytes = new Uint8Array(digest) + return base64url.baseEncode(digestBytes) +} diff --git a/packages/oauth/oauth-client/src/index.ts b/packages/oauth/oauth-client/src/index.ts new file mode 100644 index 00000000000..1e14908c48b --- /dev/null +++ b/packages/oauth/oauth-client/src/index.ts @@ -0,0 +1,18 @@ +export { + FetchError, + FetchRequestError, + FetchResponseError, +} from '@atproto-labs/fetch' +export * from './oauth-agent.js' +export * from './oauth-authorization-server-metadata-resolver.js' +export * from './oauth-callback-error.js' +export * from './oauth-client.js' +export * from './oauth-protected-resource-metadata-resolver.js' +export * from './oauth-resolver-error.js' +export * from './oauth-response-error.js' +export * from './oauth-server-agent.js' +export * from './oauth-server-factory.js' +export * from './refresh-error.js' +export * from './runtime-implementation.js' +export * from './session-getter.js' +export * from './types.js' diff --git a/packages/oauth/oauth-client/src/lock.ts b/packages/oauth/oauth-client/src/lock.ts new file mode 100644 index 00000000000..2a89b1a967f --- /dev/null +++ b/packages/oauth/oauth-client/src/lock.ts @@ -0,0 +1,34 @@ +const locks = new Map>() + +function acquireLocalLock(name: unknown): Promise<() => void> { + return new Promise((resolveAcquire) => { + const prev = locks.get(name) ?? Promise.resolve() + const next = prev.then(() => { + return new Promise((resolveRelease) => { + const release = () => { + // Only delete the lock if it is still the current one + if (locks.get(name) === next) locks.delete(name) + + resolveRelease() + } + + resolveAcquire(release) + }) + }) + + locks.set(name, next) + }) +} + +export function requestLocalLock( + name: string, + fn: () => T | PromiseLike, +): Promise { + return acquireLocalLock(name).then(async (release) => { + try { + return await fn() + } finally { + release() + } + }) +} diff --git a/packages/oauth/oauth-client/src/oauth-agent.ts b/packages/oauth/oauth-client/src/oauth-agent.ts new file mode 100644 index 00000000000..1b1df7e0178 --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-agent.ts @@ -0,0 +1,150 @@ +import { Fetch, bindFetch } from '@atproto-labs/fetch' +import { JwtPayload, unsafeDecodeJwt } from '@atproto/jwk' +import { OAuthAuthorizationServerMetadata } from '@atproto/oauth-types' + +import { dpopFetchWrapper } from './fetch-dpop.js' +import { OAuthServerAgent, TokenSet } from './oauth-server-agent.js' +import { SessionGetter } from './session-getter.js' + +const ReadableStream = globalThis.ReadableStream as + | typeof globalThis.ReadableStream + | undefined + +export class OAuthAgent { + protected dpopFetch: Fetch + + constructor( + public readonly server: OAuthServerAgent, + public readonly sub: string, + private readonly sessionGetter: SessionGetter, + fetch: Fetch = globalThis.fetch, + ) { + this.dpopFetch = dpopFetchWrapper({ + fetch: bindFetch(fetch), + iss: server.clientMetadata.client_id, + key: server.dpopKey, + supportedAlgs: server.serverMetadata.dpop_signing_alg_values_supported, + sha256: async (v) => server.runtime.sha256(v), + nonces: server.dpopNonces, + isAuthServer: false, + }) + } + + get serverMetadata(): Readonly { + return this.server.serverMetadata + } + + public async refreshIfNeeded(): Promise { + await this.getTokenSet(undefined) + } + + /** + * @param refresh See {@link SessionGetter.getSession} + */ + protected async getTokenSet(refresh?: boolean): Promise { + const { tokenSet } = await this.sessionGetter.getSession(this.sub, refresh) + return tokenSet + } + + async getInfo(): Promise<{ + userinfo?: JwtPayload + expired?: boolean + scope?: string + iss: string + aud: string + sub: string + }> { + const tokenSet = await this.getTokenSet() + + return { + userinfo: tokenSet.id_token + ? unsafeDecodeJwt(tokenSet.id_token).payload + : undefined, + expired: + tokenSet.expires_at == null + ? undefined + : new Date(tokenSet.expires_at).getTime() < Date.now() - 5e3, + scope: tokenSet.scope, + iss: tokenSet.iss, + aud: tokenSet.aud, + sub: tokenSet.sub, + } + } + + async signOut(): Promise { + try { + const { tokenSet } = await this.sessionGetter.getSession(this.sub, false) + await this.server.revoke(tokenSet.access_token) + } finally { + await this.sessionGetter.delStored(this.sub) + } + } + + async request(pathname: string, init?: RequestInit): Promise { + // This will try and refresh the token if it is known to be expired + const tokenSet = await this.getTokenSet(undefined) + + const initialUrl = new URL(pathname, tokenSet.aud) + const initialAuth = `${tokenSet.token_type} ${tokenSet.access_token}` + + const headers = new Headers(init?.headers) + headers.set('Authorization', initialAuth) + + const initialResponse = await this.dpopFetch(initialUrl, { + ...init, + headers, + }) + + // If the token is not expired, we don't need to refresh it + if (!isTokenExpiredResponse(initialResponse)) { + return initialResponse + } + + let tokenSetFresh: TokenSet + try { + // "true" here will cause the token to be refreshed + tokenSetFresh = await this.getTokenSet(true) + } catch (err) { + return initialResponse + } + + // The stream was already consumed. We cannot retry the request. A solution + // would be to tee() the input stream but that would bufferize the entire + // stream in memory which can lead to memory starvation. Instead, we will + // return the original response and let the calling code handle retries. + if (ReadableStream && init?.body instanceof ReadableStream) { + return initialResponse + } + + const finalAuth = `${tokenSetFresh.token_type} ${tokenSetFresh.access_token}` + const finalUrl = new URL(pathname, tokenSetFresh.aud) + + headers.set('Authorization', finalAuth) + + const finalResponse = await this.dpopFetch(finalUrl, { ...init, headers }) + + // There is no need to keep the session in the store if the token is expired + // and there is no way to refresh it. + if (isTokenExpiredResponse(finalResponse)) { + // TODO: Is there a "softer" way to handle this, e.g. by marking the + // session as "expired" and allow the user to trigger a new login? + await this.sessionGetter.delStored(this.sub) + } + + return finalResponse + } +} + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc6750#section-3} + * @see {@link https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no} + */ +function isTokenExpiredResponse(response: Response) { + if (response.status !== 401) return false + const wwwAuth = response.headers.get('WWW-Authenticate') + return ( + wwwAuth != null && + (wwwAuth.startsWith('Bearer ') || wwwAuth.startsWith('DPoP ')) && + wwwAuth.includes('error="invalid_token"') + ) +} diff --git a/packages/oauth/oauth-client/src/oauth-authorization-server-metadata-resolver.ts b/packages/oauth/oauth-client/src/oauth-authorization-server-metadata-resolver.ts new file mode 100644 index 00000000000..44e7ce60be4 --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-authorization-server-metadata-resolver.ts @@ -0,0 +1,98 @@ +import { + bindFetch, + cancelBody, + Fetch, + FetchResponseError, +} from '@atproto-labs/fetch' +import { + CachedGetter, + GetCachedOptions, + SimpleStore, +} from '@atproto-labs/simple-store' +import { + OAuthAuthorizationServerMetadata, + oauthAuthorizationServerMetadataValidator, + oauthIssuerIdentifierSchema, +} from '@atproto/oauth-types' +import { contentMime } from './util' + +export type { GetCachedOptions, OAuthAuthorizationServerMetadata } + +export type AuthorizationServerMetadataCache = SimpleStore< + string, + OAuthAuthorizationServerMetadata +> + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc8414} + */ +export class OAuthAuthorizationServerMetadataResolver extends CachedGetter< + string, + OAuthAuthorizationServerMetadata +> { + private readonly fetch: Fetch + + constructor(cache: AuthorizationServerMetadataCache, fetch?: Fetch) { + super(async (issuer, options) => this.fetchMetadata(issuer, options), cache) + + this.fetch = bindFetch(fetch) + } + + async get( + issuer: string, + options?: GetCachedOptions, + ): Promise { + return super.get(oauthIssuerIdentifierSchema.parse(issuer), options) + } + + private async fetchMetadata( + issuer: string, + options?: GetCachedOptions, + ): Promise { + const headers = new Headers([['accept', 'application/json']]) + if (options?.noCache) headers.set('cache-control', 'no-cache') + + const url = new URL(`/.well-known/oauth-authorization-server`, issuer) + const request = new Request(url, { + signal: options?.signal, + headers, + redirect: 'manual', // response must be 200 OK + }) + + const response = await this.fetch(request) + + // https://datatracker.ietf.org/doc/html/rfc8414#section-3.2 + if (response.status !== 200) { + await cancelBody(response, 'log') + throw await FetchResponseError.from( + response, + `Unexpected status code ${response.status} for "${url}"`, + undefined, + { cause: request }, + ) + } + + if (contentMime(response.headers) !== 'application/json') { + await cancelBody(response, 'log') + throw await FetchResponseError.from( + response, + `Unexpected content type for "${url}"`, + undefined, + { cause: request }, + ) + } + + const metadata = oauthAuthorizationServerMetadataValidator.parse( + await response.json(), + ) + + // Validate the issuer (MIX-UP attacks) + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#name-mix-up-attacks + // https://datatracker.ietf.org/doc/html/rfc8414#section-2 + if (metadata.issuer !== issuer) { + throw new TypeError(`Invalid issuer ${metadata.issuer}`) + } + + return metadata + } +} diff --git a/packages/oauth/oauth-client/src/oauth-callback-error.ts b/packages/oauth/oauth-client/src/oauth-callback-error.ts new file mode 100644 index 00000000000..9c9c26d19da --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-callback-error.ts @@ -0,0 +1,16 @@ +export class OAuthCallbackError extends Error { + static from(err: unknown, params: URLSearchParams, state?: string) { + if (err instanceof OAuthCallbackError) return err + const message = err instanceof Error ? err.message : undefined + return new OAuthCallbackError(params, message, state, err) + } + + constructor( + public readonly params: URLSearchParams, + message = params.get('error_description') || 'OAuth callback error', + public readonly state?: string, + cause?: unknown, + ) { + super(message, { cause }) + } +} diff --git a/packages/oauth/oauth-client/src/oauth-client.ts b/packages/oauth/oauth-client/src/oauth-client.ts new file mode 100644 index 00000000000..d45cd397531 --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-client.ts @@ -0,0 +1,440 @@ +import { + DidCache, + DidResolverCached, + DidResolverCommon, +} from '@atproto-labs/did-resolver' +import { Fetch } from '@atproto-labs/fetch' +import { + AppViewHandleResolver, + CachedHandleResolver, + HandleCache, + HandleResolver, +} from '@atproto-labs/handle-resolver' +import { IdentityResolver } from '@atproto-labs/identity-resolver' +import { SimpleStore } from '@atproto-labs/simple-store' +import { SimpleStoreMemory } from '@atproto-labs/simple-store-memory' +import { Key, Keyset } from '@atproto/jwk' +import { + OAuthClientMetadata, + OAuthClientMetadataInput, + OAuthResponseMode, +} from '@atproto/oauth-types' + +import { FALLBACK_ALG } from './constants.js' +import { OAuthAgent } from './oauth-agent.js' +import { + AuthorizationServerMetadataCache, + OAuthAuthorizationServerMetadataResolver, +} from './oauth-authorization-server-metadata-resolver.js' +import { OAuthCallbackError } from './oauth-callback-error.js' +import { + OAuthProtectedResourceMetadataResolver, + ProtectedResourceMetadataCache, +} from './oauth-protected-resource-metadata-resolver.js' +import { OAuthResolver } from './oauth-resolver.js' +import { DpopNonceCache, OAuthServerAgent } from './oauth-server-agent.js' +import { OAuthServerFactory } from './oauth-server-factory.js' +import { RuntimeImplementation } from './runtime-implementation.js' +import { Runtime } from './runtime.js' +import { SessionGetter, SessionStore } from './session-getter.js' +import { AuthorizeOptions, ClientMetadata } from './types.js' +import { validateClientMetadata } from './validate-client-metadata.js' + +export type InternalStateData = { + iss: string + nonce: string + dpopKey: Key + verifier?: string + + /** + * @note This could be parametrized to be of any type. This wasn't done for + * the sake of simplicity but could be added in a later development. + */ + appState?: string +} + +export type StateStore = SimpleStore + +// Export all types needed to construct OAuthClientOptions +export type { + AuthorizationServerMetadataCache, + DpopNonceCache, + Fetch, + Keyset, + OAuthClientMetadata, + OAuthClientMetadataInput, + OAuthResponseMode, + ProtectedResourceMetadataCache, + RuntimeImplementation, + SessionStore, +} + +export type OAuthClientOptions = { + // Config + responseMode: OAuthResponseMode + clientMetadata: Readonly + keyset?: Keyset | Iterable + + // Stores + stateStore: StateStore + sessionStore: SessionStore + didCache?: DidCache + handleCache?: HandleCache + authorizationServerMetadataCache?: AuthorizationServerMetadataCache + protectedResourceMetadataCache?: ProtectedResourceMetadataCache + dpopNonceCache?: DpopNonceCache + + // Services + handleResolver: HandleResolver | URL | string + plcDirectoryUrl?: URL | string + runtimeImplementation: RuntimeImplementation + fetch?: Fetch +} + +export class OAuthClient { + // Config + readonly clientMetadata: ClientMetadata + readonly responseMode: OAuthResponseMode + readonly keyset?: Keyset + + // Services + readonly runtime: Runtime + readonly fetch: Fetch + readonly oauthResolver: OAuthResolver + readonly serverFactory: OAuthServerFactory + + // Stores + readonly sessionGetter: SessionGetter + readonly stateStore: StateStore + + constructor({ + fetch = globalThis.fetch, + + stateStore, + sessionStore, + + didCache = undefined, + dpopNonceCache = new SimpleStoreMemory({ ttl: 60e3, max: 100 }), + handleCache = undefined, + authorizationServerMetadataCache = new SimpleStoreMemory({ + ttl: 60e3, + max: 100, + }), + protectedResourceMetadataCache = new SimpleStoreMemory({ + ttl: 60e3, + max: 100, + }), + + responseMode, + clientMetadata, + handleResolver, + plcDirectoryUrl, + runtimeImplementation, + keyset, + }: OAuthClientOptions) { + this.keyset = keyset + ? keyset instanceof Keyset + ? keyset + : new Keyset(keyset) + : undefined + this.clientMetadata = validateClientMetadata(clientMetadata, this.keyset) + this.responseMode = responseMode + + this.runtime = new Runtime(runtimeImplementation) + this.fetch = fetch + this.oauthResolver = new OAuthResolver( + new IdentityResolver( + new DidResolverCached( + new DidResolverCommon({ fetch, plcDirectoryUrl }), + didCache, + ), + new CachedHandleResolver( + AppViewHandleResolver.from(handleResolver, { fetch }), + handleCache, + ), + ), + new OAuthProtectedResourceMetadataResolver( + protectedResourceMetadataCache, + fetch, + ), + new OAuthAuthorizationServerMetadataResolver( + authorizationServerMetadataCache, + fetch, + ), + ) + this.serverFactory = new OAuthServerFactory( + this.clientMetadata, + this.runtime, + this.oauthResolver, + this.fetch, + this.keyset, + dpopNonceCache, + ) + + this.sessionGetter = new SessionGetter( + sessionStore, + this.serverFactory, + this.runtime, + ) + this.stateStore = stateStore + } + + // Exposed as public API for convenience + get identityResolver() { + return this.oauthResolver.identityResolver + } + + // Exposed as public API for convenience + get didResolver() { + return this.identityResolver.didResolver + } + + // Exposed as public API for convenience + get handleResolver() { + return this.identityResolver.handleResolver + } + + async authorize( + input: string, + options?: AuthorizeOptions & { signal?: AbortSignal }, + ): Promise { + const redirectUri = + options?.redirect_uri ?? this.clientMetadata.redirect_uris[0] + if (!this.clientMetadata.redirect_uris.includes(redirectUri)) { + // The server will enforce this, but let's catch it early + throw new TypeError('Invalid redirect_uri') + } + + const signal = options?.signal + const { identity, metadata } = /^https?:\/\//.test(input) + ? // Allow using an entryway url directly as login input (e.g. when the + // user forgot their handle, or when the handle does not resolve to a + // DID) + { + identity: undefined, + metadata: await this.oauthResolver.resolveMetadata(input, { signal }), + } + : await this.oauthResolver.resolve(input, { signal }) + + const nonce = await this.runtime.generateNonce() + const pkce = await this.runtime.generatePKCE() + const dpopKey = await this.runtime.generateKey( + metadata.dpop_signing_alg_values_supported || [FALLBACK_ALG], + ) + + const state = await this.runtime.generateNonce() + + await this.stateStore.set(state, { + iss: metadata.issuer, + dpopKey, + nonce, + verifier: pkce?.verifier, + appState: options?.state, + }) + + const parameters = { + client_id: this.clientMetadata.client_id, + redirect_uri: redirectUri, + code_challenge: pkce?.challenge, + code_challenge_method: pkce?.method, + nonce, + state, + login_hint: identity?.did || undefined, + response_mode: this.responseMode, + response_type: + // Negotiate by using the order in the client metadata + this.clientMetadata.response_types?.find((t) => + metadata['response_types_supported']?.includes(t), + ) ?? 'code', + + display: options?.display, + id_token_hint: options?.id_token_hint, + max_age: options?.max_age, // this.clientMetadata.default_max_age + prompt: options?.prompt, + scope: options?.scope + ?.split(' ') + .filter((s) => metadata.scopes_supported?.includes(s)) + .join(' '), + ui_locales: options?.ui_locales, + } + + if (metadata.pushed_authorization_request_endpoint) { + const server = await this.serverFactory.fromMetadata(metadata, dpopKey) + const parResponse = await server.request( + 'pushed_authorization_request', + parameters, + ) + + const authorizationUrl = new URL(metadata.authorization_endpoint) + authorizationUrl.searchParams.set( + 'client_id', + this.clientMetadata.client_id, + ) + authorizationUrl.searchParams.set('request_uri', parResponse.request_uri) + return authorizationUrl + } else if (metadata.require_pushed_authorization_requests) { + throw new Error( + 'Server requires pushed authorization requests (PAR) but no PAR endpoint is available', + ) + } else { + const authorizationUrl = new URL(metadata.authorization_endpoint) + for (const [key, value] of Object.entries(parameters)) { + if (value) authorizationUrl.searchParams.set(key, String(value)) + } + + // Length of the URL that will be sent to the server + const urlLength = + authorizationUrl.pathname.length + authorizationUrl.search.length + if (urlLength < 2048) { + return authorizationUrl + } else if (!metadata.pushed_authorization_request_endpoint) { + throw new Error('Login URL too long') + } + } + + throw new Error( + 'Server does not support pushed authorization requests (PAR)', + ) + } + + async callback(params: URLSearchParams): Promise<{ + agent: OAuthAgent + state: string | null + }> { + const responseJwt = params.get('response') + if (responseJwt != null) { + // https://openid.net/specs/oauth-v2-jarm.html + throw new OAuthCallbackError(params, 'JARM not supported') + } + + const issuerParam = params.get('iss') + const stateParam = params.get('state') + const errorParam = params.get('error') + const codeParam = params.get('code') + + if (!stateParam) { + throw new OAuthCallbackError(params, 'Missing "state" parameter') + } + const stateData = await this.stateStore.get(stateParam) + if (stateData) { + // Prevent any kind of replay + await this.stateStore.del(stateParam) + } else { + throw new OAuthCallbackError( + params, + `Unknown authorization session "${stateParam}"`, + ) + } + + try { + if (errorParam != null) { + throw new OAuthCallbackError(params, undefined, stateData.appState) + } + + if (!codeParam) { + throw new OAuthCallbackError( + params, + 'Missing "code" query param', + stateData.appState, + ) + } + + const server = await this.serverFactory.fromIssuer( + stateData.iss, + stateData.dpopKey, + ) + + if (issuerParam != null) { + if (!server.serverMetadata.issuer) { + throw new OAuthCallbackError( + params, + 'Issuer not found in metadata', + stateData.appState, + ) + } + if (server.serverMetadata.issuer !== issuerParam) { + throw new OAuthCallbackError( + params, + 'Issuer mismatch', + stateData.appState, + ) + } + } else if ( + server.serverMetadata.authorization_response_iss_parameter_supported + ) { + throw new OAuthCallbackError( + params, + 'iss missing from the response', + stateData.appState, + ) + } + + const tokenSet = await server.exchangeCode(codeParam, stateData.verifier) + try { + if (tokenSet.id_token) { + await this.runtime.validateIdTokenClaims( + tokenSet.id_token, + stateParam, + stateData.nonce, + codeParam, + tokenSet.access_token, + ) + } + + const { sub } = tokenSet + + await this.sessionGetter.setStored(sub, { + dpopKey: stateData.dpopKey, + tokenSet, + }) + + const agent = this.createAgent(server, sub) + + return { agent, state: stateData.appState ?? null } + } catch (err) { + await server.revoke(tokenSet.access_token) + + throw err + } + } catch (err) { + // Make sure, whatever the underlying error, that the appState is + // available in the calling code + throw OAuthCallbackError.from(err, params, stateData.appState) + } + } + + /** + * Build an agent from a stored session. This will refresh the token only if + * needed (about to expire) by default. + * + * @param refresh See {@link SessionGetter.getSession} + */ + async restore(sub: string, refresh?: boolean): Promise { + const { dpopKey, tokenSet } = await this.sessionGetter.getSession( + sub, + refresh, + ) + + const server = await this.serverFactory.fromIssuer(tokenSet.iss, dpopKey, { + noCache: refresh === true, + allowStale: refresh === false, + }) + + return this.createAgent(server, sub) + } + + async revoke(sub: string) { + const { dpopKey, tokenSet } = await this.sessionGetter.get(sub, { + allowStale: true, + }) + + const server = await this.serverFactory.fromIssuer(tokenSet.iss, dpopKey) + + await server.revoke(tokenSet.access_token) + await this.sessionGetter.delStored(sub) + } + + createAgent(server: OAuthServerAgent, sub: string): OAuthAgent { + return new OAuthAgent(server, sub, this.sessionGetter, this.fetch) + } +} diff --git a/packages/oauth/oauth-client/src/oauth-protected-resource-metadata-resolver.ts b/packages/oauth/oauth-client/src/oauth-protected-resource-metadata-resolver.ts new file mode 100644 index 00000000000..52c9e8828a2 --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-protected-resource-metadata-resolver.ts @@ -0,0 +1,102 @@ +import { + Fetch, + FetchResponseError, + bindFetch, + cancelBody, +} from '@atproto-labs/fetch' +import { + CachedGetter, + GetCachedOptions, + SimpleStore, +} from '@atproto-labs/simple-store' +import { + OAuthProtectedResourceMetadata, + oauthProtectedResourceMetadataSchema, +} from '@atproto/oauth-types' +import { contentMime } from './util' + +export type { GetCachedOptions, OAuthProtectedResourceMetadata } + +export type ProtectedResourceMetadataCache = SimpleStore< + string, + OAuthProtectedResourceMetadata +> + +/** + * @see {@link https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-05} + */ +export class OAuthProtectedResourceMetadataResolver extends CachedGetter< + string, + OAuthProtectedResourceMetadata +> { + private readonly fetch: Fetch + + constructor( + cache: ProtectedResourceMetadataCache, + fetch: Fetch = globalThis.fetch, + ) { + super(async (origin, options) => this.fetchMetadata(origin, options), cache) + + this.fetch = bindFetch(fetch) + } + + async get( + resource: string | URL, + options?: GetCachedOptions, + ): Promise { + const { protocol, origin } = new URL(resource) + if (protocol !== 'https:' && protocol !== 'http:') { + throw new TypeError(`Invalid resource server ${protocol}`) + } + return super.get(origin, options) + } + + private async fetchMetadata( + origin: string, + options?: GetCachedOptions, + ): Promise { + const headers = new Headers([['accept', 'application/json']]) + if (options?.noCache) headers.set('cache-control', 'no-cache') + + const url = new URL(`/.well-known/oauth-protected-resource`, origin) + const request = new Request(url, { + signal: options?.signal, + headers, + redirect: 'error', // response must be 200 OK + }) + + const response = await this.fetch(request) + + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-05#section-3.2 + if (response.status !== 200) { + await cancelBody(response, 'log') + throw await FetchResponseError.from( + response, + `Unexpected status code ${response.status} for "${url}"`, + undefined, + { cause: request }, + ) + } + + if (contentMime(response.headers) !== 'application/json') { + await cancelBody(response, 'log') + throw await FetchResponseError.from( + response, + `Unexpected content type for "${url}"`, + undefined, + { cause: request }, + ) + } + + const metadata = oauthProtectedResourceMetadataSchema.parse( + await response.json(), + ) + + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-05#section-3.3 + if (metadata.resource !== origin) { + throw new TypeError(`Invalid issuer ${metadata.resource}`) + } + + return metadata + } +} diff --git a/packages/oauth/oauth-client/src/oauth-resolver-error.ts b/packages/oauth/oauth-client/src/oauth-resolver-error.ts new file mode 100644 index 00000000000..6541b4b1c4e --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-resolver-error.ts @@ -0,0 +1,12 @@ +export class OAuthResolverError extends Error { + constructor(message: string, options?: { cause?: unknown }) { + super(message, options) + } + + static from(cause: unknown, message?: string): OAuthResolverError { + if (cause instanceof OAuthResolverError) return cause + return new OAuthResolverError(message ?? `Unable to resolve identity`, { + cause, + }) + } +} diff --git a/packages/oauth/oauth-client/src/oauth-resolver.ts b/packages/oauth/oauth-client/src/oauth-resolver.ts new file mode 100644 index 00000000000..43acfacbb0d --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-resolver.ts @@ -0,0 +1,111 @@ +import { + ResolveOptions as IdentityResolveOptions, + IdentityResolver, + ResolvedIdentity, +} from '@atproto-labs/identity-resolver' +import { OAuthAuthorizationServerMetadata } from '@atproto/oauth-types' + +import { OAuthResolverError } from './oauth-resolver-error.js' +import { + GetCachedOptions, + OAuthAuthorizationServerMetadataResolver, +} from './oauth-authorization-server-metadata-resolver.js' +import { OAuthProtectedResourceMetadataResolver } from './oauth-protected-resource-metadata-resolver.js' + +export type { GetCachedOptions } +export type ResolveOptions = GetCachedOptions & IdentityResolveOptions + +export class OAuthResolver { + constructor( + readonly identityResolver: IdentityResolver, + readonly protectedResourceMetadataResolver: OAuthProtectedResourceMetadataResolver, + readonly authorizationServerMetadataResolver: OAuthAuthorizationServerMetadataResolver, + ) {} + + public async resolveIdentity( + input: string, + options?: IdentityResolveOptions, + ): Promise { + try { + return await this.identityResolver.resolve(input, options) + } catch (cause) { + throw OAuthResolverError.from( + cause, + `Failed to resolve identity: ${input}`, + ) + } + } + + public async resolveMetadata( + issuer: string, + options?: GetCachedOptions, + ): Promise { + try { + return await this.authorizationServerMetadataResolver.get(issuer, options) + } catch (cause) { + throw OAuthResolverError.from( + cause, + `Failed to resolve OAuth server metadata for issuer: ${issuer}`, + ) + } + } + + public async resolvePdsMetadata( + pds: string | URL, + options?: GetCachedOptions, + ) { + try { + const rsMetadata = await this.protectedResourceMetadataResolver.get( + pds, + options, + ) + + const issuer = rsMetadata.authorization_servers?.[0] + if (!issuer) { + throw new OAuthResolverError( + `No authorization servers found for PDS: ${pds}`, + ) + } + + options?.signal?.throwIfAborted() + + const asMetadata = await this.resolveMetadata(issuer, options) + + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-05#section-4 + if (asMetadata.protected_resources) { + if (!asMetadata.protected_resources.includes(rsMetadata.resource)) { + throw new OAuthResolverError( + `PDS "${pds}" not protected by issuer "${issuer}"`, + ) + } + } + + return asMetadata + } catch (cause) { + options?.signal?.throwIfAborted() + + throw OAuthResolverError.from( + cause, + `Failed to resolve OAuth server metadata for resource: ${pds}`, + ) + } + } + + public async resolve( + input: string, + options?: ResolveOptions, + ): Promise<{ + identity: ResolvedIdentity + metadata: OAuthAuthorizationServerMetadata + }> { + options?.signal?.throwIfAborted() + + const identity = await this.resolveIdentity(input, options) + + options?.signal?.throwIfAborted() + + const metadata = await this.resolvePdsMetadata(identity.pds, options) + + return { identity, metadata } + } +} diff --git a/packages/oauth/oauth-client/src/oauth-response-error.ts b/packages/oauth/oauth-client/src/oauth-response-error.ts new file mode 100644 index 00000000000..b49c2f5c1d2 --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-response-error.ts @@ -0,0 +1,31 @@ +import { Json, ifString, ifObject } from '@atproto-labs/fetch' + +export class OAuthResponseError extends Error { + readonly error?: string + readonly errorDescription?: string + + constructor( + public readonly response: Response, + public readonly payload: Json, + ) { + const error = ifString(ifObject(payload)?.['error']) + const errorDescription = ifString(ifObject(payload)?.['error_description']) + + const messageError = error ? `"${error}"` : 'unknown' + const messageDesc = errorDescription ? `: ${errorDescription}` : '' + const message = `OAuth ${messageError} error${messageDesc}` + + super(message) + + this.error = error + this.errorDescription = errorDescription + } + + get status() { + return this.response.status + } + + get headers() { + return this.response.headers + } +} diff --git a/packages/oauth/oauth-client/src/oauth-server-agent.ts b/packages/oauth/oauth-client/src/oauth-server-agent.ts new file mode 100644 index 00000000000..35403e0723e --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-server-agent.ts @@ -0,0 +1,275 @@ +import { Fetch, Json, fetchJsonProcessor, bindFetch } from '@atproto-labs/fetch' +import { SimpleStore } from '@atproto-labs/simple-store' +import { Key, Keyset, SignedJwt } from '@atproto/jwk' +import { + CLIENT_ASSERTION_TYPE_JWT_BEARER, + OAuthAuthorizationServerMetadata, + OAuthClientIdentification, + OAuthEndpointName, + OAuthParResponse, + OAuthTokenResponse, + OAuthTokenType, + oauthParResponseSchema, + oauthTokenResponseSchema, +} from '@atproto/oauth-types' + +import { FALLBACK_ALG } from './constants.js' +import { dpopFetchWrapper } from './fetch-dpop.js' +import { OAuthResolver } from './oauth-resolver.js' +import { OAuthResponseError } from './oauth-response-error.js' +import { RefreshError } from './refresh-error.js' +import { Runtime } from './runtime.js' +import { ClientMetadata } from './types.js' +import { withSignal } from './util.js' + +export type TokenSet = { + iss: string + sub: string + aud: string + scope?: string + + id_token?: SignedJwt + refresh_token?: string + access_token: string + token_type: OAuthTokenType + /** ISO Date */ + expires_at?: string +} + +export type DpopNonceCache = SimpleStore + +export class OAuthServerAgent { + protected dpopFetch: Fetch + + constructor( + readonly dpopKey: Key, + readonly serverMetadata: OAuthAuthorizationServerMetadata, + readonly clientMetadata: ClientMetadata, + readonly dpopNonces: DpopNonceCache, + readonly oauthResolver: OAuthResolver, + readonly runtime: Runtime, + readonly keyset?: Keyset, + fetch?: Fetch, + ) { + this.dpopFetch = dpopFetchWrapper({ + fetch: bindFetch(fetch), + iss: clientMetadata.client_id, + key: dpopKey, + supportedAlgs: serverMetadata.dpop_signing_alg_values_supported, + sha256: async (v) => runtime.sha256(v), + nonces: dpopNonces, + isAuthServer: true, + }) + } + + async revoke(token: string) { + try { + await this.request('revocation', { token }) + } catch { + // Don't care + } + } + + async exchangeCode(code: string, verifier?: string): Promise { + const tokenResponse = await this.request('token', { + grant_type: 'authorization_code', + redirect_uri: this.clientMetadata.redirect_uris[0]!, + code, + code_verifier: verifier, + }) + + try { + return this.processTokenResponse(tokenResponse) + } catch (err) { + await this.revoke(tokenResponse.access_token) + + throw err + } + } + + async refresh(tokenSet: TokenSet): Promise { + if (!tokenSet.refresh_token) { + throw new RefreshError(tokenSet.sub, 'No refresh token available') + } + + const tokenResponse = await this.request('token', { + grant_type: 'refresh_token', + refresh_token: tokenSet.refresh_token, + }) + + try { + if (tokenSet.sub !== tokenResponse.sub) { + throw new RefreshError( + tokenSet.sub, + `Unexpected "sub" in token response (${tokenResponse.sub})`, + ) + } + if (tokenSet.iss !== this.serverMetadata.issuer) { + throw new RefreshError(tokenSet.sub, 'Issuer mismatch') + } + + return this.processTokenResponse(tokenResponse) + } catch (err) { + await this.revoke(tokenResponse.access_token) + + throw err + } + } + + /** + * VERY IMPORTANT ! Always call this to process token responses. + * + * Whenever an OAuth token response is received, we **MUST** verify that the + * "sub" is a DID, whose issuer authority is indeed the server we just + * obtained credentials from. This check is a critical step to actually be + * able to use the "sub" (DID) as being the actual user's identifier. + */ + private async processTokenResponse( + tokenResponse: OAuthTokenResponse, + ): Promise { + const { sub } = tokenResponse + // ATPROTO requires that the "sub" is always present in the token response. + if (!sub) throw new TypeError(`Missing "sub" in token response`) + + // @TODO (?) make timeout configurable + const resolved = await withSignal({ timeout: 10e3 }, (signal) => + this.oauthResolver.resolve(sub, { signal }), + ) + + if (resolved.metadata.issuer !== this.serverMetadata.issuer) { + // Best case scenario; the user switched PDS. Worst case scenario; a bad + // actor is trying to impersonate a user. In any case, we must not allow + // this token to be used. + throw new TypeError('Issuer mismatch') + } + + return { + sub, + aud: resolved.identity.pds.href, + iss: resolved.metadata.issuer, + + scope: tokenResponse.scope, + id_token: tokenResponse.id_token, + refresh_token: tokenResponse.refresh_token, + access_token: tokenResponse.access_token, + token_type: tokenResponse.token_type ?? 'Bearer', + expires_at: + typeof tokenResponse.expires_in === 'number' + ? new Date(Date.now() + tokenResponse.expires_in * 1000).toISOString() + : undefined, + } + } + + async request( + endpoint: 'token', + payload: Record, + ): Promise + async request( + endpoint: 'pushed_authorization_request', + payload: Record, + ): Promise + async request( + endpoint: OAuthEndpointName, + payload: Record, + ): Promise + + async request(endpoint: OAuthEndpointName, payload: Record) { + const url = this.serverMetadata[`${endpoint}_endpoint`] + if (!url) throw new Error(`No ${endpoint} endpoint available`) + + const auth = await this.buildClientAuth(endpoint) + + const { response, json } = await this.dpopFetch(url, { + method: 'POST', + headers: { ...auth.headers, 'Content-Type': 'application/json' }, + body: JSON.stringify({ ...payload, ...auth.payload }), + }).then(fetchJsonProcessor()) + + if (response.ok) { + switch (endpoint) { + case 'token': + return oauthTokenResponseSchema.parse(json) + case 'pushed_authorization_request': + return oauthParResponseSchema.parse(json) + default: + return json + } + } else { + throw new OAuthResponseError(response, json) + } + } + + async buildClientAuth(endpoint: OAuthEndpointName): Promise<{ + headers?: Record + payload: OAuthClientIdentification + }> { + const methodSupported = + this.serverMetadata[`${endpoint}_endpoint_auth_methods_supported`] || + this.serverMetadata[`token_endpoint_auth_methods_supported`] + + const method = + this.clientMetadata[`${endpoint}_endpoint_auth_method`] || + this.clientMetadata[`token_endpoint_auth_method`] + + if ( + method === 'private_key_jwt' || + (this.keyset && + !method && + (methodSupported?.includes('private_key_jwt') ?? false)) + ) { + if (!this.keyset) throw new Error('No keyset available') + + try { + const alg = + this.serverMetadata[ + `${endpoint}_endpoint_auth_signing_alg_values_supported` + ] ?? + this.serverMetadata[ + `token_endpoint_auth_signing_alg_values_supported` + ] ?? + FALLBACK_ALG + + // If jwks is defined, make sure to only sign using a key that exists in + // the jwks. If jwks_uri is defined, we can't be sure that the key we're + // looking for is in there so we will just assume it is. + const kid = this.clientMetadata.jwks?.keys + .map(({ kid }) => kid) + .filter((v): v is string => typeof v === 'string') + + return { + payload: { + client_id: this.clientMetadata.client_id, + client_assertion_type: CLIENT_ASSERTION_TYPE_JWT_BEARER, + client_assertion: await this.keyset.createJwt( + { alg, kid }, + { + iss: this.clientMetadata.client_id, + sub: this.clientMetadata.client_id, + aud: this.serverMetadata.issuer, + jti: await this.runtime.generateNonce(), + iat: Math.floor(Date.now() / 1000), + }, + ), + }, + } + } catch (err) { + if (method === 'private_key_jwt') throw err + + // Else try next method + } + } + + if ( + method === 'none' || + (!method && (methodSupported?.includes('none') ?? true)) + ) { + return { + payload: { + client_id: this.clientMetadata.client_id, + }, + } + } + + throw new Error(`Unsupported ${endpoint} authentication method`) + } +} diff --git a/packages/oauth/oauth-client/src/oauth-server-factory.ts b/packages/oauth/oauth-client/src/oauth-server-factory.ts new file mode 100644 index 00000000000..def97d3b22e --- /dev/null +++ b/packages/oauth/oauth-client/src/oauth-server-factory.ts @@ -0,0 +1,41 @@ +import { Fetch } from '@atproto-labs/fetch' +import { Key, Keyset } from '@atproto/jwk' +import { OAuthAuthorizationServerMetadata } from '@atproto/oauth-types' + +import { GetCachedOptions } from './oauth-authorization-server-metadata-resolver.js' +import { OAuthResolver } from './oauth-resolver.js' +import { DpopNonceCache, OAuthServerAgent } from './oauth-server-agent.js' +import { Runtime } from './runtime.js' +import { ClientMetadata } from './types.js' + +export class OAuthServerFactory { + constructor( + readonly clientMetadata: ClientMetadata, + readonly runtime: Runtime, + readonly resolver: OAuthResolver, + readonly fetch: Fetch, + readonly keyset: Keyset | undefined, + readonly dpopNonceCache: DpopNonceCache, + ) {} + + async fromIssuer(issuer: string, dpopKey: Key, options?: GetCachedOptions) { + const serverMetadata = await this.resolver.resolveMetadata(issuer, options) + return this.fromMetadata(serverMetadata, dpopKey) + } + + async fromMetadata( + serverMetadata: OAuthAuthorizationServerMetadata, + dpopKey: Key, + ) { + return new OAuthServerAgent( + dpopKey, + serverMetadata, + this.clientMetadata, + this.dpopNonceCache, + this.resolver, + this.runtime, + this.keyset, + this.fetch, + ) + } +} diff --git a/packages/oauth/oauth-client/src/refresh-error.ts b/packages/oauth/oauth-client/src/refresh-error.ts new file mode 100644 index 00000000000..4bfc6283aea --- /dev/null +++ b/packages/oauth/oauth-client/src/refresh-error.ts @@ -0,0 +1,9 @@ +export class RefreshError extends Error { + constructor( + public readonly sub: string, + message: string, + options?: { cause?: unknown }, + ) { + super(message, options) + } +} diff --git a/packages/oauth/oauth-client/src/runtime-implementation.ts b/packages/oauth/oauth-client/src/runtime-implementation.ts new file mode 100644 index 00000000000..0cb12312dd8 --- /dev/null +++ b/packages/oauth/oauth-client/src/runtime-implementation.ts @@ -0,0 +1,17 @@ +import { Key } from '@atproto/jwk' + +export type DigestAlgorithm = { + name: 'sha256' | 'sha384' | 'sha512' +} + +export type { Key } + +export interface RuntimeImplementation { + createKey(algs: string[]): Key | PromiseLike + getRandomValues: (length: number) => Uint8Array | PromiseLike + digest: ( + bytes: Uint8Array, + algorithm: DigestAlgorithm, + ) => Uint8Array | PromiseLike + requestLock?: (name: string, fn: () => T | PromiseLike) => Promise +} diff --git a/packages/oauth/oauth-client/src/runtime.ts b/packages/oauth/oauth-client/src/runtime.ts new file mode 100644 index 00000000000..82dd8b16b55 --- /dev/null +++ b/packages/oauth/oauth-client/src/runtime.ts @@ -0,0 +1,211 @@ +import { JwtHeader, JwtPayload, Key, unsafeDecodeJwt } from '@atproto/jwk' +import { base64url } from 'multiformats/bases/base64' + +import { requestLocalLock } from './lock.js' +import { + DigestAlgorithm, + RuntimeImplementation, +} from './runtime-implementation.js' + +export class Runtime { + constructor(protected implementation: RuntimeImplementation) {} + + public async generateKey(algs: string[]): Promise { + const algsSorted = Array.from(algs).sort(compareAlgos) + return this.implementation.createKey(algsSorted) + } + + public async sha256(text: string): Promise { + const bytes = new TextEncoder().encode(text) + const digest = await this.implementation.digest(bytes, { name: 'sha256' }) + return base64url.baseEncode(digest) + } + + public async generateNonce(length = 16): Promise { + const bytes = await this.implementation.getRandomValues(length) + return base64url.baseEncode(bytes) + } + + get hasLock() { + return !!this.implementation.requestLock + } + + public async withLock( + name: string, + fn: () => T | PromiseLike, + ): Promise { + if (this.implementation.requestLock) { + return this.implementation.requestLock(name, fn) + } else { + // Falling back to a local lock + return requestLocalLock(name, fn) + } + } + + public async validateIdTokenClaims( + token: string, + state: string, + nonce: string, + code?: string, + accessToken?: string, + ): Promise<{ + header: JwtHeader + payload: JwtPayload + }> { + // It's fine to use unsafeDecodeJwt here because the token was received from + // the server's token endpoint. The following checks are to ensure that the + // oauth flow was indeed initiated by the client. + const { header, payload } = unsafeDecodeJwt(token) + if (!payload.nonce || payload.nonce !== nonce) { + throw new TypeError('Nonce mismatch') + } + if (payload.c_hash) { + await this.validateHashClaim(payload.c_hash, code, header) + } + if (payload.s_hash) { + await this.validateHashClaim(payload.s_hash, state, header) + } + if (payload.at_hash) { + await this.validateHashClaim(payload.at_hash, accessToken, header) + } + return { header, payload } + } + + private async validateHashClaim( + claim: unknown, + source: unknown, + header: { alg: string; crv?: string }, + ): Promise { + if (typeof claim !== 'string' || !claim) { + throw new TypeError(`string "_hash" claim expected`) + } + if (typeof source !== 'string' || !source) { + throw new TypeError(`string value expected`) + } + const expected = await this.generateHashClaim(source, header) + if (expected !== claim) { + throw new TypeError(`"_hash" does not match`) + } + } + + protected async generateHashClaim( + source: string, + header: { alg: string; crv?: string }, + ) { + const algo = getHashAlgo(header) + const bytes = new TextEncoder().encode(source) + const digest = await this.implementation.digest(bytes, algo) + if (digest.length % 2 !== 0) throw new TypeError('Invalid digest length') + const digestHalf = digest.slice(0, digest.length / 2) + return base64url.baseEncode(digestHalf) + } + + public async generatePKCE(byteLength?: number) { + const verifier = await this.generateVerifier(byteLength) + return { + verifier, + challenge: await this.sha256(verifier), + method: 'S256', + } + } + + public async calculateJwkThumbprint(jwk) { + const components = extractJktComponents(jwk) + const data = JSON.stringify(components) + return this.sha256(data) + } + + /** + * @see {@link https://datatracker.ietf.org/doc/html/rfc7636#section-4.1} + * @note It is RECOMMENDED that the output of a suitable random number generator + * be used to create a 32-octet sequence. The octet sequence is then + * base64url-encoded to produce a 43-octet URL safe string to use as the code + * verifier. + */ + protected async generateVerifier(byteLength = 32) { + if (byteLength < 32 || byteLength > 96) { + throw new TypeError('Invalid code_verifier length') + } + const bytes = await this.implementation.getRandomValues(byteLength) + return base64url.baseEncode(bytes) + } +} + +function getHashAlgo(header: { alg: string; crv?: string }): DigestAlgorithm { + switch (header.alg) { + case 'HS256': + case 'RS256': + case 'PS256': + case 'ES256': + case 'ES256K': + return { name: 'sha256' } + case 'HS384': + case 'RS384': + case 'PS384': + case 'ES384': + return { name: 'sha384' } + case 'HS512': + case 'RS512': + case 'PS512': + case 'ES512': + return { name: 'sha512' } + case 'EdDSA': + switch (header.crv) { + case 'Ed25519': + return { name: 'sha512' } + default: + throw new TypeError('unrecognized or invalid EdDSA curve provided') + } + default: + throw new TypeError('unrecognized or invalid JWS algorithm provided') + } +} + +function extractJktComponents(jwk) { + const get = (field) => { + const value = jwk[field] + if (typeof value !== 'string' || !value) { + throw new TypeError(`"${field}" Parameter missing or invalid`) + } + return value + } + + switch (jwk.kty) { + case 'EC': + return { crv: get('crv'), kty: get('kty'), x: get('x'), y: get('y') } + case 'OKP': + return { crv: get('crv'), kty: get('kty'), x: get('x') } + case 'RSA': + return { e: get('e'), kty: get('kty'), n: get('n') } + case 'oct': + return { k: get('k'), kty: get('kty') } + default: + throw new TypeError('"kty" (Key Type) Parameter missing or unsupported') + } +} + +/** + * 256K > ES (256 > 384 > 512) > PS (256 > 384 > 512) > RS (256 > 384 > 512) > other (in original order) + */ +function compareAlgos(a: string, b: string): number { + if (a === 'ES256K') return -1 + if (b === 'ES256K') return 1 + + for (const prefix of ['ES', 'PS', 'RS']) { + if (a.startsWith(prefix)) { + if (b.startsWith(prefix)) { + const aLen = parseInt(a.slice(2, 5)) + const bLen = parseInt(b.slice(2, 5)) + + // Prefer shorter key lengths + return aLen - bLen + } + return -1 + } else if (b.startsWith(prefix)) { + return 1 + } + } + + // Don't know how to compare, keep original order + return 0 +} diff --git a/packages/oauth/oauth-client/src/session-getter.ts b/packages/oauth/oauth-client/src/session-getter.ts new file mode 100644 index 00000000000..20294e2537a --- /dev/null +++ b/packages/oauth/oauth-client/src/session-getter.ts @@ -0,0 +1,182 @@ +import { + CachedGetter, + GetCachedOptions, + SimpleStore, +} from '@atproto-labs/simple-store' +import { Key } from '@atproto/jwk' +import { OAuthResponseError } from './oauth-response-error.js' +import { TokenSet } from './oauth-server-agent.js' +import { OAuthServerFactory } from './oauth-server-factory.js' +import { RefreshError } from './refresh-error.js' +import { Runtime } from './runtime.js' +import { withSignal } from './util.js' + +export type Session = { + dpopKey: Key + tokenSet: TokenSet +} + +export type SessionStore = SimpleStore + +/** + * There are several advantages to wrapping the sessionStore in a (single) + * CachedGetter, the main of which is that the cached getter will ensure that at + * most one fresh call is ever being made. Another advantage, is that it + * contains the logic for reading from the cache which, if the cache is based on + * localStorage/indexedDB, will sync across multiple tabs (for a given sub). + */ +export class SessionGetter extends CachedGetter { + constructor( + sessionStore: SessionStore, + serverFactory: OAuthServerFactory, + private readonly runtime: Runtime, + ) { + super( + async (sub, options, storedSession) => { + // There needs to be a previous session to be able to refresh. If + // storedSession is undefined, it means that the store does not contain + // a session for the given sub. Since this might have been caused by the + // value being cleared in another process (e.g. another tab), we will + // give a chance to the process running this code to detect that the + // session was revoked. This should allow processes not implementing a + // subscribe/notify between instances to still be "notified" that the + // session was revoked. + if (storedSession === undefined) { + // Because the session is not in the store, the sessionStore.del + // function will not be called, even if the "deleteOnError" callback + // returns true when the error is an "OAuthRefreshError". Let's + // call it here manually. + await sessionStore.del(sub) + throw new RefreshError(sub, 'The session was revoked') + } + + if (sub !== storedSession.tokenSet.sub) { + // Fool-proofing (e.g. against invalid session storage) + throw new RefreshError(sub, 'Stored session sub mismatch') + } + + // Since refresh tokens can only be used once, we might run into + // concurrency issues if multiple tabs/instances are trying to refresh + // the same token. The chances of this happening when multiple instances + // are started simultaneously is reduced by randomizing the expiry time + // (see isStale() bellow). Even so, There still exist chances that + // multiple tabs will try to refresh the token at the same time. The + // best solution would be to use a mutex/lock to ensure that only one + // instance is refreshing the token at a time. A simpler workaround is + // to check if the value stored in the session store is the same as the + // one in memory. If it isn't, then another instance has already + // refreshed the token. + + const { tokenSet, dpopKey } = storedSession + const server = await serverFactory.fromIssuer(tokenSet.iss, dpopKey) + + // We must not use the "signal" to cancel the refresh or its storage in + // case of successful refresh. If we obtain a new refresh token, we must + // ensure that is gets stored in the session store (by returning the new + // session object). Failing to do so would result in the new credentials + // being lost. + options?.signal?.throwIfAborted() + + const newTokenSet = await server + .refresh(tokenSet) + .catch(async (cause) => { + if ( + cause instanceof OAuthResponseError && + cause.status === 400 && + cause.error === 'invalid_grant' + ) { + // In case there is no lock implementation in the runtime, we will + // wait for a short time to give the other concurrent instances a + // chance to finish their refreshing of the token. If a concurrent + // refresh did occur, we will pretend that this one succeeded. + if (!runtime.hasLock) { + await new Promise((r) => setTimeout(r, 1000)) + + const stored = await this.getStored(sub) + if (stored === undefined) { + // Using a distinct error message mainly for debugging + // purposes + const msg = 'The session was revoked by another process' + throw new RefreshError(sub, msg, { cause }) + } else if ( + stored.tokenSet.access_token !== tokenSet.access_token || + stored.tokenSet.refresh_token !== tokenSet.refresh_token + ) { + // A concurrent refresh occurred. Pretend this one succeeded. + return stored.tokenSet + } else { + // There were no concurrent refresh. The token is (likely) + // simply no longer valid. + } + } + + // Throwing an RefreshError to trigger deletion through the + // deleteOnError callback. + const msg = cause.errorDescription ?? 'The session was revoked' + throw new RefreshError(sub, msg, { cause }) + } + + throw cause + }) + + if (sub !== newTokenSet.sub) { + // The server returned another sub. Was the tokenSet manipulated? + throw new RefreshError(sub, 'Token set sub mismatch') + } + + return { ...storedSession, tokenSet: newTokenSet } + }, + sessionStore, + { + isStale: (sub, { tokenSet }) => { + return ( + tokenSet.expires_at != null && + new Date(tokenSet.expires_at).getTime() < + // Add some lee way to ensure the token is not expired when it + // reaches the server. + Date.now() + 60e3 + ) + }, + onStoreError: async (err, sub, { tokenSet, dpopKey }) => { + // If the token data cannot be stored, let's revoke it + const server = await serverFactory.fromIssuer(tokenSet.iss, dpopKey) + await server.revoke(tokenSet.refresh_token ?? tokenSet.access_token) + throw err + }, + deleteOnError: async (err) => { + return err instanceof RefreshError + }, + }, + ) + } + + /** + * @param refresh When `true`, the credentials will be refreshed even if they + * are not expired. When `false`, the credentials will not be refreshed even + * if they are expired. When `undefined`, the credentials will be refreshed + * if, and only if, they are (about to be) expired. Defaults to `undefined`. + */ + async getSession(sub: string, refresh?: boolean) { + const session = await this.get(sub, { + noCache: refresh === true, + allowStale: refresh === false, + }) + + if (sub !== session.tokenSet.sub) { + // Fool-proofing (e.g. against invalid session storage) + throw new Error('Token set does not match the expected sub') + } + + return session + } + + async get(sub: string, options?: GetCachedOptions): Promise { + return this.runtime.withLock(`@atproto-oauth-client-${sub}`, async () => { + // Make sure, even if there is no signal in the options, that the request + // will be cancelled after at most 30 seconds. + return withSignal({ signal: options?.signal, timeout: 30e3 }, (signal) => + super.get(sub, { ...options, signal }), + ) + }) + } +} diff --git a/packages/oauth/oauth-client/src/types.ts b/packages/oauth/oauth-client/src/types.ts new file mode 100644 index 00000000000..630e0ad383c --- /dev/null +++ b/packages/oauth/oauth-client/src/types.ts @@ -0,0 +1,26 @@ +import { + oauthClientIdSchema, + oauthClientMetadataSchema, +} from '@atproto/oauth-types' +import z from 'zod' + +// Note: These types are not prefixed with `OAuth` because they are not specific +// to OAuth. They are specific to this packages. OAuth specific types are in +// `@atproto/oauth-types`. + +export type AuthorizeOptions = { + display?: 'page' | 'popup' | 'touch' | 'wap' + redirect_uri?: string + id_token_hint?: string + max_age?: number + prompt?: 'login' | 'none' | 'consent' | 'select_account' + scope?: string + state?: string + ui_locales?: string +} + +export const clientMetadataSchema = oauthClientMetadataSchema.extend({ + client_id: oauthClientIdSchema.url(), +}) + +export type ClientMetadata = z.infer diff --git a/packages/oauth/oauth-client/src/util.ts b/packages/oauth/oauth-client/src/util.ts new file mode 100644 index 00000000000..7948fbac9f5 --- /dev/null +++ b/packages/oauth/oauth-client/src/util.ts @@ -0,0 +1,51 @@ +/** + * @todo (?) move to common package + */ +export const withSignal = async ( + options: + | undefined + | { + signal?: AbortSignal + timeout: number + }, + fn: (signal: AbortSignal) => T | PromiseLike, +): Promise => { + options?.signal?.throwIfAborted() + + const abortController = new AbortController() + const { signal } = abortController + + options?.signal?.addEventListener( + 'abort', + (reason) => abortController.abort(reason), + { once: true, signal }, + ) + + if (options?.timeout != null) { + const timeoutId = setTimeout( + (err) => abortController.abort(err), + options.timeout, + new Error('Timeout'), + ) + + timeoutId.unref?.() // NodeJS only + + signal.addEventListener('abort', () => clearTimeout(timeoutId), { + once: true, + signal, + }) + } + + try { + return await fn(signal) + } finally { + // - Remove listener on incoming signal + // - Cancel timeout + // - Cancel pending (async) tasks + abortController.abort() + } +} + +export function contentMime(headers: Headers): string | undefined { + return headers.get('content-type')?.split(';')[0]!.trim() +} diff --git a/packages/oauth/oauth-client/src/validate-client-metadata.ts b/packages/oauth/oauth-client/src/validate-client-metadata.ts new file mode 100644 index 00000000000..a0dc5d7c629 --- /dev/null +++ b/packages/oauth/oauth-client/src/validate-client-metadata.ts @@ -0,0 +1,61 @@ +import { Keyset } from '@atproto/jwk' +import { + OAUTH_AUTHENTICATED_ENDPOINT_NAMES, + OAuthClientMetadataInput, +} from '@atproto/oauth-types' + +import { ClientMetadata, clientMetadataSchema } from './types.js' + +// Improve bundle size by using concatenation +const _ENDPOINT_AUTH_METHOD = '_endpoint_auth_method' +const _ENDPOINT_AUTH_SIGNING_ALG = '_endpoint_auth_signing_alg' + +const TOKEN_ENDPOINT_AUTH_METHOD = `token${_ENDPOINT_AUTH_METHOD}` + +export function validateClientMetadata( + input: OAuthClientMetadataInput, + keyset?: Keyset, +): ClientMetadata { + const metadata = clientMetadataSchema.parse(input) + + // ATPROTO uses client metadata discovery + try { + new URL(metadata.client_id) + } catch (cause) { + throw new TypeError(`client_id must be a valid URL`, { cause }) + } + + if (!metadata[TOKEN_ENDPOINT_AUTH_METHOD]) { + throw new TypeError(`${TOKEN_ENDPOINT_AUTH_METHOD} must be provided`) + } + + for (const endpointName of OAUTH_AUTHENTICATED_ENDPOINT_NAMES) { + const method = metadata[`${endpointName}${_ENDPOINT_AUTH_METHOD}`] + switch (method) { + case undefined: + case 'none': + if (metadata[`${endpointName}${_ENDPOINT_AUTH_SIGNING_ALG}`]) { + throw new TypeError( + `${endpointName}${_ENDPOINT_AUTH_SIGNING_ALG} must not be provided`, + ) + } + break + case 'client_secret_jwt': + if (!keyset) { + throw new TypeError(`Keyset is required for ${method} method`) + } + if (!metadata[`${endpointName}${_ENDPOINT_AUTH_SIGNING_ALG}`]) { + throw new TypeError( + `${endpointName}${_ENDPOINT_AUTH_SIGNING_ALG} must be provided`, + ) + } + break + default: + throw new TypeError( + `Invalid "${endpointName}${_ENDPOINT_AUTH_METHOD}" value: ${method}`, + ) + } + } + + return metadata +} diff --git a/packages/oauth/oauth-client/tsconfig.build.json b/packages/oauth/oauth-client/tsconfig.build.json new file mode 100644 index 00000000000..d37ddd8d394 --- /dev/null +++ b/packages/oauth/oauth-client/tsconfig.build.json @@ -0,0 +1,8 @@ +{ + "extends": ["../../../tsconfig/isomorphic.json"], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist" + }, + "include": ["./src"] +} diff --git a/packages/oauth/oauth-client/tsconfig.json b/packages/oauth/oauth-client/tsconfig.json new file mode 100644 index 00000000000..e84b8178b47 --- /dev/null +++ b/packages/oauth/oauth-client/tsconfig.json @@ -0,0 +1,4 @@ +{ + "include": [], + "references": [{ "path": "./tsconfig.build.json" }] +} diff --git a/packages/oauth/oauth-provider/.postcssrc.yml b/packages/oauth/oauth-provider/.postcssrc.yml new file mode 100644 index 00000000000..0114fbc9e78 --- /dev/null +++ b/packages/oauth/oauth-provider/.postcssrc.yml @@ -0,0 +1,3 @@ +plugins: + tailwindcss: {} + autoprefixer: {} diff --git a/packages/oauth/oauth-provider/package.json b/packages/oauth/oauth-provider/package.json new file mode 100644 index 00000000000..4cca3762672 --- /dev/null +++ b/packages/oauth/oauth-provider/package.json @@ -0,0 +1,83 @@ +{ + "name": "@atproto/oauth-provider", + "version": "0.0.0", + "license": "MIT", + "description": "Generic OAuth2 and OpenID Connect provider for Node.js. Currently only supports features needed for Atproto.", + "keywords": [ + "atproto", + "oauth", + "oauth2", + "open id connect", + "oidc", + "provider", + "oidc provider" + ], + "homepage": "https://atproto.com", + "repository": { + "type": "git", + "url": "https://github.com/bluesky-social/atproto", + "directory": "packages/oauth/oauth-provider" + }, + "type": "commonjs", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + }, + "dependencies": { + "@atproto-labs/fetch": "workspace:*", + "@atproto-labs/fetch-node": "workspace:*", + "@atproto-labs/pipe": "workspace:*", + "@atproto-labs/simple-store": "workspace:*", + "@atproto-labs/simple-store-memory": "workspace:*", + "@atproto/jwk": "workspace:*", + "@atproto/jwk-jose": "workspace:*", + "@atproto/oauth-types": "workspace:*", + "@hapi/accept": "^6.0.3", + "@hapi/bourne": "^3.0.0", + "cookie": "^0.6.0", + "http-errors": "^2.0.0", + "ioredis": "^5.3.2", + "jose": "^5.2.0", + "keygrip": "^1.1.0", + "oidc-token-hash": "^5.0.3", + "psl": "^1.9.0", + "zod": "^3.23.8" + }, + "devDependencies": { + "@atproto-labs/rollup-plugin-bundle-manifest": "workspace:*", + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@rollup/plugin-terser": "^0.4.4", + "@rollup/plugin-typescript": "^11.1.6", + "@types/cookie": "^0.6.0", + "@types/keygrip": "^1.0.6", + "@types/psl": "1.1.3", + "@types/react": "^18.2.50", + "@types/react-dom": "^18.2.18", + "@types/send": "^0.17.4", + "@web/rollup-plugin-import-meta-assets": "^2.2.1", + "autoprefixer": "^10.4.17", + "postcss": "^8.4.33", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "rollup": "^4.13.0", + "rollup-plugin-postcss": "^4.0.2", + "tailwindcss": "^3.4.1", + "typescript": "^5.3.3" + }, + "optionalDependencies": { + "ioredis": "^5.3.2", + "keygrip": "^1.1.0" + }, + "scripts": { + "build:frontend": "rollup --config rollup.config.js", + "build:backend": "tsc --build --force tsconfig.backend.json", + "build": "pnpm --parallel --stream '/^build:.+$/'", + "dev": "rollup --config rollup.config.js --watch" + } +} diff --git a/packages/oauth/oauth-provider/rollup.config.js b/packages/oauth/oauth-provider/rollup.config.js new file mode 100644 index 00000000000..54a754a6d6f --- /dev/null +++ b/packages/oauth/oauth-provider/rollup.config.js @@ -0,0 +1,55 @@ +/* eslint-env node */ + +const { defineConfig } = require('rollup') + +const { + default: manifest, +} = require('@atproto-labs/rollup-plugin-bundle-manifest') +const { default: commonjs } = require('@rollup/plugin-commonjs') +const { default: nodeResolve } = require('@rollup/plugin-node-resolve') +const { default: replace } = require('@rollup/plugin-replace') +const { default: terser } = require('@rollup/plugin-terser') +const { default: typescript } = require('@rollup/plugin-typescript') +const postcss = ((m) => m.default || m)(require('rollup-plugin-postcss')) + +module.exports = defineConfig((commandLineArguments) => { + const NODE_ENV = + process.env['NODE_ENV'] ?? + (commandLineArguments.watch ? 'development' : 'production') + + const minify = NODE_ENV !== 'development' + + return { + input: 'src/assets/app/main.tsx', + output: { + manualChunks: undefined, + sourcemap: true, + file: 'dist/assets/app/main.js', + format: 'iife', + }, + plugins: [ + nodeResolve({ preferBuiltins: false, browser: true }), + commonjs(), + postcss({ config: true, extract: true, minimize: minify }), + typescript({ + tsconfig: './tsconfig.frontend.json', + outputToFilesystem: true, + }), + replace({ + preventAssignment: true, + values: { 'process.env.NODE_ENV': JSON.stringify(NODE_ENV) }, + }), + // Change `data` to `true` to include assets data in the manifest, + // allowing for easier bundling of the backend code (eg. using esbuild) as + // bundlers know how to bundle JSON files but not how to bundle assets + // referenced at runtime. + manifest({ data: false }), + minify && terser({}), + ], + onwarn(warning, warn) { + // 'use client' directives are fine + if (warning.code === 'MODULE_LEVEL_DIRECTIVE') return + warn(warning) + }, + } +}) diff --git a/packages/oauth/oauth-provider/src/access-token/access-token-type.ts b/packages/oauth/oauth-provider/src/access-token/access-token-type.ts new file mode 100644 index 00000000000..d75a9711319 --- /dev/null +++ b/packages/oauth/oauth-provider/src/access-token/access-token-type.ts @@ -0,0 +1,5 @@ +export enum AccessTokenType { + auto = 'auto', + jwt = 'jwt', + id = 'id', +} diff --git a/packages/oauth/oauth-provider/src/account/account-manager.ts b/packages/oauth/oauth-provider/src/account/account-manager.ts new file mode 100644 index 00000000000..b8f963ee814 --- /dev/null +++ b/packages/oauth/oauth-provider/src/account/account-manager.ts @@ -0,0 +1,55 @@ +import { isOAuthClientIdLoopback } from '@atproto/oauth-types' +import { Client } from '../client/client.js' +import { DeviceId } from '../device/device-id.js' +import { constantTime } from '../lib/util/time.js' +import { InvalidRequestError } from '../oauth-errors.js' +import { Sub } from '../oidc/sub.js' +import { ClientAuth } from '../token/token-store.js' +import { + Account, + AccountInfo, + AccountStore, + LoginCredentials, +} from './account-store.js' + +const TIMING_ATTACK_MITIGATION_DELAY = 400 + +export class AccountManager { + constructor(protected readonly store: AccountStore) {} + + public async signIn( + credentials: LoginCredentials, + deviceId: DeviceId, + ): Promise { + return constantTime(TIMING_ATTACK_MITIGATION_DELAY, async () => { + const result = await this.store.authenticateAccount(credentials, deviceId) + if (result) return result + + throw new InvalidRequestError('Invalid credentials') + }) + } + + public async get(deviceId: DeviceId, sub: Sub): Promise { + const result = await this.store.getDeviceAccount(deviceId, sub) + if (result) return result + + throw new InvalidRequestError(`Account not found`) + } + + public async addAuthorizedClient( + deviceId: DeviceId, + account: Account, + client: Client, + _clientAuth: ClientAuth, + ): Promise { + // "Loopback" clients are not distinguishable from one another. + if (isOAuthClientIdLoopback(client.id)) return + + await this.store.addAuthorizedClient(deviceId, account.sub, client.id) + } + + public async list(deviceId: DeviceId): Promise { + const results = await this.store.listDeviceAccounts(deviceId) + return results.filter((result) => result.info.remembered) + } +} diff --git a/packages/oauth/oauth-provider/src/account/account-store.ts b/packages/oauth/oauth-provider/src/account/account-store.ts new file mode 100644 index 00000000000..37d06ef898f --- /dev/null +++ b/packages/oauth/oauth-provider/src/account/account-store.ts @@ -0,0 +1,74 @@ +import { ClientId } from '../client/client-id.js' +import { DeviceId } from '../device/device-id.js' +import { Awaitable } from '../lib/util/type.js' +import { Sub } from '../oidc/sub.js' +import { Account } from './account.js' + +export type LoginCredentials = { + username: string + password: string + + /** + * If false, the account must not be returned from + * {@link AccountStore.listDeviceAccounts}. Note that this only makes sense when + * used with a device ID. + */ + remember?: boolean +} + +export type DeviceAccountInfo = { + remembered: boolean + authenticatedAt: Date + authorizedClients: readonly ClientId[] +} + +// Export all types needed to implement the AccountStore interface +export type { Account, DeviceId, Sub } + +export type AccountInfo = { + account: Account + info: DeviceAccountInfo +} + +export interface AccountStore { + authenticateAccount( + credentials: LoginCredentials, + deviceId: DeviceId, + ): Awaitable + + addAuthorizedClient( + deviceId: DeviceId, + sub: Sub, + clientId: ClientId, + ): Awaitable + + getDeviceAccount(deviceId: DeviceId, sub: Sub): Awaitable + removeDeviceAccount(deviceId: DeviceId, sub: Sub): Awaitable + + /** + * @note Only the accounts that where logged in with `remember: true` need to + * be returned. The others will be ignored. + */ + listDeviceAccounts(deviceId: DeviceId): Awaitable +} + +export function isAccountStore( + implementation: Record & Partial, +): implementation is Record & AccountStore { + return ( + typeof implementation.authenticateAccount === 'function' && + typeof implementation.getDeviceAccount === 'function' && + typeof implementation.addAuthorizedClient === 'function' && + typeof implementation.listDeviceAccounts === 'function' && + typeof implementation.removeDeviceAccount === 'function' + ) +} + +export function asAccountStore( + implementation?: Record & Partial, +): AccountStore { + if (!implementation || !isAccountStore(implementation)) { + throw new Error('Invalid AccountStore implementation') + } + return implementation +} diff --git a/packages/oauth/oauth-provider/src/account/account.ts b/packages/oauth/oauth-provider/src/account/account.ts new file mode 100644 index 00000000000..0728232375f --- /dev/null +++ b/packages/oauth/oauth-provider/src/account/account.ts @@ -0,0 +1,10 @@ +import { OIDCStandardPayload } from '../oidc/claims.js' +import { Sub } from '../oidc/sub.js' +import { Simplify } from '../lib/util/type.js' + +export type Account = Simplify< + { + sub: Sub // Account id + aud: string | [string, ...string[]] // Resource server URL + } & OIDCStandardPayload +> diff --git a/packages/oauth/oauth-provider/src/assets/app/app.tsx b/packages/oauth/oauth-provider/src/assets/app/app.tsx new file mode 100644 index 00000000000..e925ffae5dd --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/app.tsx @@ -0,0 +1,28 @@ +import type { + AuthorizeData, + CustomizationData, + ErrorData, +} from './backend-data' +import { AuthorizeView } from './views/authorize-view' +import { ErrorView } from './views/error-view' + +export type AppProps = { + authorizeData?: AuthorizeData + customizationData?: CustomizationData + errorData?: ErrorData +} + +export function App({ authorizeData, customizationData, errorData }: AppProps) { + if (authorizeData && !errorData) { + return ( + + ) + } else { + return ( + + ) + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/backend-data.ts b/packages/oauth/oauth-provider/src/assets/app/backend-data.ts new file mode 100644 index 00000000000..5bf4c4dac0d --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/backend-data.ts @@ -0,0 +1,65 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' + +// TODO: Find a way to share these types with the backend code + +export type Account = { + sub: string + aud: string + + email?: string + name?: string + preferred_username?: string + picture?: string +} + +export type Session = { + account: Account + + selected: boolean + loginRequired: boolean + consentRequired: boolean +} + +export type LinkDefinition = { + title: string + href: string + rel?: string +} + +export type CustomizationData = { + name?: string + logo?: string + links?: LinkDefinition[] +} + +export type ErrorData = { + error: string + error_description: string +} + +export type AuthorizeData = { + clientId: string + clientMetadata: OAuthClientMetadata + clientTrusted: boolean + requestUri: string + csrfCookie: string + sessions: Session[] + newSessionsRequireConsent: boolean + loginHint?: string +} + +// see "declareBackendData()" in the backend +const readBackendData = (key: string): T | undefined => { + const value = window[key] as T | undefined + delete window[key] // Prevent accidental usage / potential leaks to dependencies + return value +} + +// These values are injected by the backend when it builds the +// page HTML. + +export const customizationData = readBackendData( + '__customizationData', +) +export const errorData = readBackendData('__errorData') +export const authorizeData = readBackendData('__authorizeData') diff --git a/packages/oauth/oauth-provider/src/assets/app/components/accept-form.tsx b/packages/oauth/oauth-provider/src/assets/app/components/accept-form.tsx new file mode 100644 index 00000000000..d82377e097d --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/accept-form.tsx @@ -0,0 +1,112 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' +import { type HTMLAttributes } from 'react' + +import { Account } from '../backend-data' +import { clsx } from '../lib/clsx' +import { AccountIdentifier } from './account-identifier' +import { ClientIdentifier } from './client-identifier' +import { ClientName } from './client-name' + +export type AcceptFormProps = { + account: Account + clientId: string + clientMetadata: OAuthClientMetadata + clientTrusted: boolean + onAccept: () => void + acceptLabel?: string + + onReject: () => void + rejectLabel?: string + + onBack?: () => void + backLabel?: string +} + +export function AcceptForm({ + account, + clientId, + clientMetadata, + clientTrusted, + onAccept, + acceptLabel = 'Accept', + onReject, + rejectLabel = 'Deny access', + onBack, + backLabel = 'Back', + + ...attrs +}: AcceptFormProps & HTMLAttributes) { + return ( +
+ {clientTrusted && clientMetadata.logo_uri && ( +
+ {clientMetadata.client_name} +
+ )} + + + +

+ {' '} + is asking for permission to access your{' '} + account. +

+ +

+ By clicking {acceptLabel}, you allow this application to access + your information in accordance to its{' '} + + terms of service + + . +

+ +
+ +
+ + + {onBack && ( + + )} + +
+ + +
+
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/account-identifier.tsx b/packages/oauth/oauth-provider/src/assets/app/components/account-identifier.tsx new file mode 100644 index 00000000000..96b860e3261 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/account-identifier.tsx @@ -0,0 +1,18 @@ +import { HTMLAttributes } from 'react' + +import { Account } from '../backend-data' + +export type AccountIdentifierProps = { + account: Account +} + +export function AccountIdentifier({ + account, + ...attrs +}: AccountIdentifierProps & HTMLAttributes) { + return ( + + {account.preferred_username || account.email || account.sub} + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/account-picker.tsx b/packages/oauth/oauth-provider/src/assets/app/components/account-picker.tsx new file mode 100644 index 00000000000..76dcf008003 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/account-picker.tsx @@ -0,0 +1,108 @@ +import type { HTMLAttributes, ReactNode } from 'react' +import { Account } from '../backend-data' +import { clsx } from '../lib/clsx' + +export type AccountPickerProps = { + accounts: readonly Account[] + + onAccount: (account: Account) => void + accountAria?: (account: Account) => string + + onOther?: () => void + otherLabel?: ReactNode + otherAria?: string + + onBack?: () => void + backLabel?: ReactNode + backAria?: string +} + +export function AccountPicker({ + accounts, + + onAccount, + accountAria = (a) => `Sign in as ${a.name}`, + + onOther = undefined, + otherLabel = 'Other account', + otherAria = 'Login to account that is not listed', + + onBack, + backAria, + backLabel = backAria, + + className, + ...attrs +}: AccountPickerProps & HTMLAttributes) { + return ( +
+

Sign in as...

+ + {accounts.map((account) => { + const [name, identifier] = [ + account.name, + account.preferred_username, + account.email, + account.sub, + ].filter(Boolean) as [string, string?] + + return ( + + ) + })} + {onOther && ( + + )} + +
+ + {onBack && ( +
+ +
+ )} +
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/client-identifier.tsx b/packages/oauth/oauth-provider/src/assets/app/components/client-identifier.tsx new file mode 100644 index 00000000000..1de6f5f7ac6 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/client-identifier.tsx @@ -0,0 +1,32 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' +import { HTMLAttributes } from 'react' + +import { UrlViewer } from './url-viewer' + +export type ClientIdentifierProps = { + clientId: string + clientMetadata: OAuthClientMetadata + as?: keyof JSX.IntrinsicElements +} + +export function ClientIdentifier({ + clientId, + clientMetadata, + as: As = 'span', + ...attrs +}: ClientIdentifierProps & HTMLAttributes) { + if (clientMetadata.client_uri) { + return ( + + ) + } + + // Fallback to the client ID + return {clientId} +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/client-name.tsx b/packages/oauth/oauth-provider/src/assets/app/components/client-name.tsx new file mode 100644 index 00000000000..abafc87d23c --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/client-name.tsx @@ -0,0 +1,30 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' +import { HTMLAttributes } from 'react' + +import { ClientIdentifier } from './client-identifier' + +export type ClientNameProps = { + clientId: string + clientMetadata: OAuthClientMetadata + as?: keyof JSX.IntrinsicElements +} + +export function ClientName({ + clientId, + clientMetadata, + as: As = 'span', + ...attrs +}: ClientNameProps & HTMLAttributes) { + if (clientMetadata.client_name) { + return {clientMetadata.client_name} + } + + return ( + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/error-card.tsx b/packages/oauth/oauth-provider/src/assets/app/components/error-card.tsx new file mode 100644 index 00000000000..3a41112717b --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/error-card.tsx @@ -0,0 +1,41 @@ +import { HtmlHTMLAttributes } from 'react' +import { clsx } from '../lib/clsx' + +export type ErrorCardProps = { + message?: null | string + role?: 'alert' | 'status' +} + +export function ErrorCard({ + message, + + role = 'alert', + className, + ...attrs +}: Partial & + Omit, keyof ErrorCardProps | 'children'>) { + return ( +
+ + + + +
+

+ {typeof message === 'string' ? message : 'An unknown error occurred'} +

+
+
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/help-card.tsx b/packages/oauth/oauth-provider/src/assets/app/components/help-card.tsx new file mode 100644 index 00000000000..23c0434fd06 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/help-card.tsx @@ -0,0 +1,42 @@ +import { HTMLAttributes } from 'react' +import { LinkDefinition } from '../backend-data' +import { clsx } from '../lib/clsx' + +export type HelpCardProps = { + links?: readonly LinkDefinition[] +} + +export function HelpCard({ + links, + + className, + ...attrs +}: HelpCardProps & + Omit< + HTMLAttributes, + keyof HelpCardProps | 'children' + >) { + const helpLink = links?.find((l) => l.rel === 'help') + + if (!helpLink) return null + + return ( +

+ Having trouble?{' '} + + Contact {helpLink.title} + +

+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/layout-title-page.tsx b/packages/oauth/oauth-provider/src/assets/app/components/layout-title-page.tsx new file mode 100644 index 00000000000..c4985c22bed --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/layout-title-page.tsx @@ -0,0 +1,43 @@ +import { HTMLAttributes, ReactNode } from 'react' +import { clsx } from '../lib/clsx' + +export type LayoutTitlePageProps = { + title?: ReactNode + subtitle?: ReactNode +} + +export function LayoutTitlePage({ + children, + title, + subtitle, + ...attrs +}: LayoutTitlePageProps & + Omit, keyof LayoutTitlePageProps>) { + return ( +
+
+ {title && ( +

+ {title} +

+ )} + + {subtitle && ( +

+ {subtitle} +

+ )} +
+ +
+ {children} +
+
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/layout-welcome.tsx b/packages/oauth/oauth-provider/src/assets/app/components/layout-welcome.tsx new file mode 100644 index 00000000000..e3c02bba5b2 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/layout-welcome.tsx @@ -0,0 +1,58 @@ +import { PropsWithChildren } from 'react' + +export type LayoutWelcomeProps = { + name?: string + logo?: string + links?: Array<{ + title: string + href: string + rel?: string + }> + logoAlt?: string +} + +export function LayoutWelcome({ + name, + logo, + logoAlt = name || 'Logo', + links, + children, +}: PropsWithChildren) { + return ( +
+
+ {logo && ( + {logoAlt} + )} + + {name && ( +

+ {name} +

+ )} + + {children} +
+ + {links != null && links.length > 0 && ( + + )} +
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/sign-in-form.tsx b/packages/oauth/oauth-provider/src/assets/app/components/sign-in-form.tsx new file mode 100644 index 00000000000..d15789fb3d9 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/sign-in-form.tsx @@ -0,0 +1,290 @@ +import { + FormHTMLAttributes, + ReactNode, + SyntheticEvent, + useCallback, + useState, +} from 'react' + +import { clsx } from '../lib/clsx' +import { ErrorCard } from './error-card' + +export type SignInFormOutput = { + username: string + password: string + remember?: boolean +} + +export type SignInFormProps = { + title?: ReactNode + + onSubmit: (credentials: SignInFormOutput) => void | PromiseLike + submitLabel?: ReactNode + submitAria?: string + + onCancel?: () => void + cancelLabel?: ReactNode + cancelAria?: string + + usernameDefault?: string + usernameReadonly?: boolean + usernameLabel?: string + usernamePlaceholder?: string + usernameAria?: string + usernamePattern?: string + usernameTitle?: string + + passwordLabel?: string + passwordPlaceholder?: string + passwordWarning?: ReactNode + passwordAria?: string + passwordPattern?: string + passwordTitle?: string + + rememberVisible?: boolean + rememberDefault?: boolean + rememberLabel?: string + rememberAria?: string +} + +export function SignInForm({ + title = 'Sign in', + + onSubmit, + submitAria = 'Next', + submitLabel = submitAria, + + onCancel = undefined, + cancelAria = 'Cancel', + cancelLabel = cancelAria, + + usernameDefault = '', + usernameReadonly = false, + usernameLabel = 'Email address or handle', + usernameAria = usernameLabel, + usernamePlaceholder = usernameLabel, + usernamePattern, + usernameTitle = 'Username must not be empty', + + passwordLabel = 'Password', + passwordAria = passwordLabel, + passwordPlaceholder = passwordLabel, + passwordPattern, + passwordTitle = 'Password must not be empty', + passwordWarning = ( + <> +

Warning

+

+ Please verify the domain name of the website before entering your + password. Never enter your password on a domain you do not trust. +

+ + ), + + rememberVisible = true, + rememberDefault = false, + rememberLabel = 'Remember this account on this device', + rememberAria = rememberLabel, + + className, + ...attrs +}: SignInFormProps & + Omit< + FormHTMLAttributes, + keyof SignInFormProps | 'children' + >) { + const [focused, setFocused] = useState(false) + const [loading, setLoading] = useState(false) + const [errorMessage, setErrorMessage] = useState(null) + + const doSubmit = useCallback( + async ( + event: SyntheticEvent< + HTMLFormElement & { + username: HTMLInputElement + password: HTMLInputElement + remember?: HTMLInputElement + }, + SubmitEvent + >, + ) => { + event.preventDefault() + + const credentials = { + username: event.currentTarget.username.value, + password: event.currentTarget.password.value, + remember: event.currentTarget.remember?.checked, + } + + setLoading(true) + setErrorMessage(null) + try { + await onSubmit(credentials) + } catch (err) { + setErrorMessage(parseErrorMessage(err)) + } finally { + setLoading(false) + } + }, + [onSubmit, setErrorMessage, setLoading], + ) + + return ( +
+

{title}

+
+
+ @ + setErrorMessage(null)} + className="relative m-0 block w-[1px] min-w-0 flex-auto px-3 py-[0.25rem] leading-[1.6] bg-transparent bg-clip-padding text-base text-inherit outline-none dark:placeholder:text-neutral-100 disabled:text-gray-500" + placeholder={usernamePlaceholder} + aria-label={usernameAria} + autoCapitalize="none" + autoCorrect="off" + autoComplete="username" + spellCheck="false" + dir="auto" + enterKeyHint="next" + required + defaultValue={usernameDefault} + readOnly={usernameReadonly} + disabled={usernameReadonly} + pattern={usernamePattern} + title={usernameTitle} + /> +
+ +
+ +
+ * + setErrorMessage(null)} + onFocus={() => setFocused(true)} + onBlur={() => setTimeout(setFocused, 100, false)} + className="relative m-0 block w-[1px] min-w-0 flex-auto px-3 py-[0.25rem] leading-[1.6] bg-transparent bg-clip-padding text-base text-inherit outline-none dark:placeholder:text-neutral-100" + placeholder={passwordPlaceholder} + aria-label={passwordAria} + autoCapitalize="none" + autoCorrect="off" + autoComplete="current-password" + dir="auto" + enterKeyHint="done" + spellCheck="false" + required + pattern={passwordPattern} + title={passwordTitle} + /> +
+ + {passwordWarning && ( + <> +
+
+
+
+ + + +
+
{passwordWarning}
+
+
+ + )} + + {rememberVisible && ( + <> +
+ +
+ + setErrorMessage(null)} + /> + + + +
+ + )} +
+ + {errorMessage && } + +
+ +
+ + + {onCancel && ( + + )} + +
+
+ + ) +} + +function parseErrorMessage(err: unknown): string { + console.error('Sign-in failed:', err) + switch ((err as any)?.message) { + case 'Invalid credentials': + return 'Invalid username or password' + default: + return 'An unknown error occurred' + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/sign-up-account-form.tsx b/packages/oauth/oauth-provider/src/assets/app/components/sign-up-account-form.tsx new file mode 100644 index 00000000000..de7bbf410c3 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/sign-up-account-form.tsx @@ -0,0 +1,210 @@ +import { + FormHTMLAttributes, + ReactNode, + SyntheticEvent, + useCallback, + useState, +} from 'react' + +import { clsx } from '../lib/clsx' +import { ErrorCard } from './error-card' + +export type SignUpAccountFormOutput = { + username: string + password: string +} + +export type SignUpAccountFormProps = { + onSubmit: (credentials: SignUpAccountFormOutput) => void | PromiseLike + submitLabel?: ReactNode + submitAria?: string + + onCancel?: () => void + cancelLabel?: ReactNode + cancelAria?: string + + username?: string + usernamePlaceholder?: string + usernameLabel?: string + usernameAria?: string + usernamePattern?: string + usernameTitle?: string + + passwordPlaceholder?: string + passwordLabel?: string + passwordAria?: string + passwordPattern?: string + passwordTitle?: string +} + +export function SignUpAccountForm({ + onSubmit, + submitAria = 'Next', + submitLabel = submitAria, + + onCancel = undefined, + cancelAria = 'Cancel', + cancelLabel = cancelAria, + + username: defaultUsername = '', + usernameLabel = 'Username', + usernameAria = usernameLabel, + usernamePlaceholder = usernameLabel, + usernamePattern, + usernameTitle, + + passwordLabel = 'Password', + passwordAria = passwordLabel, + passwordPlaceholder = passwordLabel, + passwordPattern, + passwordTitle, + + className, + children, + ...attrs +}: SignUpAccountFormProps & + Omit, keyof SignUpAccountFormProps>) { + const [loading, setLoading] = useState(false) + const [errorMessage, setErrorMessage] = useState(null) + + const doSubmit = useCallback( + async ( + event: SyntheticEvent< + HTMLFormElement & { + username: HTMLInputElement + password: HTMLInputElement + }, + SubmitEvent + >, + ) => { + event.preventDefault() + + const credentials = { + username: event.currentTarget.username.value, + password: event.currentTarget.password.value, + } + + setLoading(true) + setErrorMessage(null) + try { + await onSubmit(credentials) + } catch (err) { + setErrorMessage(parseErrorMessage(err)) + } finally { + setLoading(false) + } + }, + [onSubmit, setErrorMessage, setLoading], + ) + + return ( +
+
+ + +
+ @ + setErrorMessage(null)} + className="relative m-1 block w-[1px] min-w-0 flex-auto leading-[1.6] bg-transparent bg-clip-padding text-base text-inherit outline-none dark:placeholder:text-neutral-100 disabled:text-gray-500" + placeholder={usernamePlaceholder} + aria-label={usernameAria} + autoCapitalize="none" + autoCorrect="off" + autoComplete="username" + spellCheck="false" + dir="auto" + enterKeyHint="next" + required + defaultValue={defaultUsername} + pattern={usernamePattern} + title={usernameTitle} + /> +
+ + + +
+ + * + + setErrorMessage(null)} + className="relative m-1 block w-[1px] min-w-0 flex-auto leading-[1.6] bg-transparent bg-clip-padding text-base text-inherit outline-none dark:placeholder:text-neutral-100" + placeholder={passwordPlaceholder} + aria-label={passwordAria} + autoCapitalize="none" + autoCorrect="off" + autoComplete="new-password" + dir="auto" + enterKeyHint="done" + spellCheck="false" + required + pattern={passwordPattern} + title={passwordTitle} + /> +
+
+ + {children &&
{children}
} + + {errorMessage && } + +
+ +
+ + + {onCancel && ( + + )} + +
+
+ + ) +} + +function parseErrorMessage(err: unknown): string { + switch ((err as any)?.message) { + case 'Invalid credentials': + return 'Invalid username or password' + default: + console.error(err) + return 'An unknown error occurred' + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/sign-up-disclaimer.tsx b/packages/oauth/oauth-provider/src/assets/app/components/sign-up-disclaimer.tsx new file mode 100644 index 00000000000..0acc39f6fa8 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/sign-up-disclaimer.tsx @@ -0,0 +1,44 @@ +import { HTMLAttributes } from 'react' +import { LinkDefinition } from '../backend-data' +import { clsx } from '../lib/clsx' + +export type SignUpDisclaimerProps = { + links?: readonly LinkDefinition[] +} + +export function SignUpDisclaimer({ + links, + + className, + ...attrs +}: SignUpDisclaimerProps & + Omit< + HTMLAttributes, + keyof SignUpDisclaimerProps | 'children' + >) { + const relevantLinks = links?.filter( + (l) => l.rel === 'privacy-policy' || l.rel === 'terms-of-service', + ) + + return ( +

+ By creating an account you agree to the{' '} + {relevantLinks && relevantLinks.length + ? relevantLinks.map((l, i, a) => ( + + {i > 0 && (i < a.length - 1 ? ', ' : ' and ')} + + {l.title} + + + )) + : 'Terms of Service and Privacy Policy'} + . +

+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/components/url-viewer.tsx b/packages/oauth/oauth-provider/src/assets/app/components/url-viewer.tsx new file mode 100644 index 00000000000..4879a5d8e23 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/components/url-viewer.tsx @@ -0,0 +1,70 @@ +import { HTMLAttributes, useMemo } from 'react' + +export type UrlPartRenderingOptions = { + faded?: boolean + bold?: boolean +} + +export type UrlRendererProps = { + url: string | URL + proto?: boolean | UrlPartRenderingOptions + host?: boolean | UrlPartRenderingOptions + path?: boolean | UrlPartRenderingOptions + query?: boolean | UrlPartRenderingOptions + hash?: boolean | UrlPartRenderingOptions + as?: keyof JSX.IntrinsicElements +} + +export function UrlViewer({ + url, + proto = false, + host = true, + path = false, + query = false, + hash = false, + as: As = 'span', + ...attrs +}: UrlRendererProps & HTMLAttributes) { + const urlObj = useMemo(() => new URL(url), [url]) + + return ( + + {proto && ( + + )} + {host && ( + + )} + {path && ( + + )} + {query && ( + + )} + {hash && ( + + )} + + ) +} + +function UrlPartViewer({ + value, + faded = true, + bold = false, +}: { value: string } & UrlPartRenderingOptions) { + const Comp = bold ? 'b' : 'span' + return {value} +} diff --git a/packages/oauth/oauth-provider/src/assets/app/cookies.ts b/packages/oauth/oauth-provider/src/assets/app/cookies.ts new file mode 100644 index 00000000000..c13a7369de6 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/cookies.ts @@ -0,0 +1,11 @@ +export const parseCookieString = ( + cookie: string, +): Record => + Object.fromEntries( + cookie + .split(';') + .filter(Boolean) + .map((str) => str.split('=', 2).map((s) => decodeURIComponent(s.trim()))), + ) + +export const cookies = parseCookieString(document.cookie) diff --git a/packages/oauth/oauth-provider/src/assets/app/hooks/use-api.ts b/packages/oauth/oauth-provider/src/assets/app/hooks/use-api.ts new file mode 100644 index 00000000000..60d0bf44526 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/hooks/use-api.ts @@ -0,0 +1,104 @@ +import { useCallback, useMemo, useState } from 'react' + +import { Account, AuthorizeData, Session } from '../backend-data' +import { Api } from '../lib/api' +import { upsert } from '../lib/util' +import { useCsrfToken } from './use-csrf-token' + +export type SignInCredentials = { + username: string + password: string + remember?: boolean +} + +export type SignUpData = { + username: string + password: string + extra?: Record +} + +export function useApi( + { + clientId, + requestUri, + csrfCookie, + sessions: initialSessions, + newSessionsRequireConsent, + }: AuthorizeData, + { + onRedirected, + }: { + onRedirected?: () => void + } = {}, +) { + const csrfToken = useCsrfToken(csrfCookie) ?? '' // Invalid value + const [sessions, setSessions] = useState(initialSessions) + + const setSession = useCallback( + (sub: string | null) => { + setSessions((sessions) => + sub === (sessions.find((s) => s.selected)?.account.sub || null) + ? sessions + : sessions.map((s) => ({ ...s, selected: s.account.sub === sub })), + ) + }, + [setSessions], + ) + + const api = useMemo( + () => new Api(requestUri, clientId, csrfToken, newSessionsRequireConsent), + [requestUri, clientId, csrfToken, newSessionsRequireConsent], + ) + + const performRedirect = useCallback( + (url: URL) => { + window.location.href = String(url) + if (onRedirected) setTimeout(onRedirected) + }, + [onRedirected], + ) + + const doSignIn = useCallback( + async (credentials: SignInCredentials): Promise => { + const session = await api.signIn(credentials) + const { sub } = session.account + + setSessions((sessions) => { + return upsert(sessions, session, (s) => s.account.sub === sub).map( + // Make sure to de-select any other selected session + (s) => (s === session || !s.selected ? s : { ...s, selected: false }), + ) + }) + }, + [api, performRedirect, clientId, setSessions], + ) + + const doSignUp = useCallback( + (data: SignUpData) => { + // + throw new Error('Not implemented') + }, + [api], + ) + + const doAccept = useCallback( + async (account: Account) => { + performRedirect(await api.accept(account)) + }, + [api, performRedirect], + ) + + const doReject = useCallback(async () => { + performRedirect(await api.reject()) + }, [api, performRedirect]) + + return { + sessions, + setSession, + + doSignIn, + doSignUp, + doAccept, + doReject, + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/hooks/use-bound-dispatch.ts b/packages/oauth/oauth-provider/src/assets/app/hooks/use-bound-dispatch.ts new file mode 100644 index 00000000000..8945ea41989 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/hooks/use-bound-dispatch.ts @@ -0,0 +1,5 @@ +import { Dispatch, useCallback } from 'react' + +export function useBoundDispatch(dispatch: Dispatch, value: A) { + return useCallback(() => dispatch(value), [dispatch, value]) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/hooks/use-csrf-token.ts b/packages/oauth/oauth-provider/src/assets/app/hooks/use-csrf-token.ts new file mode 100644 index 00000000000..ee272abcb0c --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/hooks/use-csrf-token.ts @@ -0,0 +1,5 @@ +import { cookies } from '../cookies' + +export function useCsrfToken(cookieName: string) { + return cookies[cookieName] +} diff --git a/packages/oauth/oauth-provider/src/assets/app/lib/api.ts b/packages/oauth/oauth-provider/src/assets/app/lib/api.ts new file mode 100644 index 00000000000..4933a217201 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/lib/api.ts @@ -0,0 +1,64 @@ +import { fetchJsonProcessor, fetchOkProcessor } from '@atproto-labs/fetch' + +import { Account, Session } from '../backend-data' + +export class Api { + constructor( + private requestUri: string, + private clientId: string, + private csrfToken: string, + private newSessionsRequireConsent: boolean, + ) {} + + async signIn(credentials: { + username: string + password: string + remember?: boolean + }): Promise { + const { json } = await fetch('/oauth/authorize/sign-in', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + mode: 'same-origin', + body: JSON.stringify({ + csrf_token: this.csrfToken, + request_uri: this.requestUri, + client_id: this.clientId, + credentials, + }), + }) + .then(fetchOkProcessor()) + .then( + fetchJsonProcessor<{ + account: Account + consentRequired: boolean + }>(), + ) + + return { + account: json.account, + + selected: true, + loginRequired: false, + consentRequired: this.newSessionsRequireConsent || json.consentRequired, + } + } + + async accept(account: Account): Promise { + const url = new URL('/oauth/authorize/accept', window.origin) + url.searchParams.set('request_uri', this.requestUri) + url.searchParams.set('account_sub', account.sub) + url.searchParams.set('client_id', this.clientId) + url.searchParams.set('csrf_token', this.csrfToken) + + return url + } + + async reject(): Promise { + const url = new URL('/oauth/authorize/reject', window.origin) + url.searchParams.set('request_uri', this.requestUri) + url.searchParams.set('client_id', this.clientId) + url.searchParams.set('csrf_token', this.csrfToken) + + return url + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/lib/clsx.ts b/packages/oauth/oauth-provider/src/assets/app/lib/clsx.ts new file mode 100644 index 00000000000..d60a7e5d722 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/lib/clsx.ts @@ -0,0 +1,4 @@ +export function clsx(a: string | undefined, b?: string) { + if (a && b) return `${a} ${b}` + return a || b +} diff --git a/packages/oauth/oauth-provider/src/assets/app/lib/util.ts b/packages/oauth/oauth-provider/src/assets/app/lib/util.ts new file mode 100644 index 00000000000..b27faabb0d1 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/lib/util.ts @@ -0,0 +1,10 @@ +export function upsert( + arr: readonly T[], + item: T, + predicate: (value: T, index: number, obj: readonly T[]) => boolean, +): T[] { + const idx = arr.findIndex(predicate) + return idx === -1 + ? [...arr, item] + : [...arr.slice(0, idx), item, ...arr.slice(idx + 1)] +} diff --git a/packages/oauth/oauth-provider/src/assets/app/main.css b/packages/oauth/oauth-provider/src/assets/app/main.css new file mode 100644 index 00000000000..74dcc88675c --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/main.css @@ -0,0 +1,11 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +/* Matches colors defined in tailwind.config.js */ +@layer base { + :root { + --color-primary: 255 115 179; + --color-error: 235 65 49; + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/main.tsx b/packages/oauth/oauth-provider/src/assets/app/main.tsx new file mode 100644 index 00000000000..f144dcf05f5 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/main.tsx @@ -0,0 +1,28 @@ +// This must be loaded before any dependency to ensure that global variables +// cannot be accessed by JS from node_modules. +import * as backendData from './backend-data' + +import { createRoot } from 'react-dom/client' + +import { App } from './app' +import './main.css' + +const { authorizeData } = backendData +if (authorizeData) { + // When the user is logging in, make sure the page URL contains the + // "request_uri" in case the user refreshes the page. + const url = new URL(window.location.href) + if ( + url.pathname === '/oauth/authorize' && + !url.searchParams.has('request_uri') + ) { + url.search = '' + url.searchParams.set('client_id', authorizeData.clientId) + url.searchParams.set('request_uri', authorizeData.requestUri) + window.history.replaceState(history.state, '', url.pathname + url.search) + } +} + +const container = document.getElementById('root')! +const root = createRoot(container) +root.render() diff --git a/packages/oauth/oauth-provider/src/assets/app/views/accept-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/accept-view.tsx new file mode 100644 index 00000000000..0bcb1b8d57f --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/accept-view.tsx @@ -0,0 +1,51 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' + +import { Session } from '../backend-data' +import { AcceptForm } from '../components/accept-form' +import { LayoutTitlePage } from '../components/layout-title-page' + +export type AcceptViewProps = { + clientId: string + clientMetadata: OAuthClientMetadata + clientTrusted: boolean + session: Session + + onAccept: () => void + onReject: () => void + onBack?: () => void +} + +export function AcceptView({ + clientId, + clientMetadata, + clientTrusted, + session, + onAccept, + onReject, + onBack, +}: AcceptViewProps) { + const { account } = session + return ( + + Grant access to your{' '} + {account.preferred_username || account.email || account.sub}{' '} + account. + + } + > + + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/views/authorize-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/authorize-view.tsx new file mode 100644 index 00000000000..d27e50b0edf --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/authorize-view.tsx @@ -0,0 +1,101 @@ +import { useEffect, useState } from 'react' + +import type { AuthorizeData, CustomizationData } from '../backend-data' +import { LayoutTitlePage } from '../components/layout-title-page' +import { useApi } from '../hooks/use-api' +import { useBoundDispatch } from '../hooks/use-bound-dispatch' +import { AcceptView } from './accept-view' +import { SignInView } from './sign-in-view' +import { SignUpView } from './sign-up-view' +import { WelcomeView } from './welcome-view' + +export type AuthorizeViewProps = { + authorizeData: AuthorizeData + customizationData?: CustomizationData +} + +export function AuthorizeView({ + authorizeData, + customizationData, +}: AuthorizeViewProps) { + const forceSignIn = authorizeData?.loginHint != null + + const [view, setView] = useState< + 'welcome' | 'sign-in' | 'sign-up' | 'accept' | 'done' + >(forceSignIn ? 'sign-in' : 'welcome') + + const showDone = useBoundDispatch(setView, 'done') + const showSignIn = useBoundDispatch(setView, 'sign-in') + // const showSignUp = useBoundDispatch(setView, 'sign-up') + const showAccept = useBoundDispatch(setView, 'accept') + const showWelcome = useBoundDispatch(setView, 'welcome') + + const { sessions, setSession, doAccept, doReject, doSignIn, doSignUp } = + useApi(authorizeData, { onRedirected: showDone }) + + const session = sessions.find((s) => s.selected && !s.loginRequired) + useEffect(() => { + if (session) { + if (session.consentRequired) showAccept() + else doAccept(session.account) + } + }, [session, doAccept, showAccept]) + + if (view === 'welcome') { + return ( + + ) + } + + if (view === 'sign-up') { + return ( + + ) + } + + if (view === 'sign-in') { + return ( + + ) + } + + if (view === 'accept' && session) { + return ( + doAccept(session.account)} + onReject={doReject} + onBack={() => { + setSession(null) + setView(sessions.length ? 'sign-in' : 'welcome') + }} + /> + ) + } + + return ( + + You are being redirected... + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/views/error-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/error-view.tsx new file mode 100644 index 00000000000..90a664d3abd --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/error-view.tsx @@ -0,0 +1,27 @@ +import { CustomizationData, ErrorData } from '../backend-data' +import { ErrorCard } from '../components/error-card' +import { LayoutWelcome } from '../components/layout-welcome' + +export type ErrorViewProps = { + customizationData?: CustomizationData + errorData?: ErrorData +} + +export function ErrorView({ errorData, customizationData }: ErrorViewProps) { + return ( + + + + ) +} + +function getUserFriendlyMessage(errorData?: ErrorData) { + const desc = errorData?.error_description + switch (desc) { + case 'Unknown request_uri': // Request was removed from database + case 'This request has expired': + return 'This sign-in session has expired' + default: + return desc || 'An unknown error occurred' + } +} diff --git a/packages/oauth/oauth-provider/src/assets/app/views/sign-in-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/sign-in-view.tsx new file mode 100644 index 00000000000..5af2e25e4fb --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/sign-in-view.tsx @@ -0,0 +1,121 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' + +import { Session } from '../backend-data' +import { AccountPicker } from '../components/account-picker' +import { LayoutTitlePage } from '../components/layout-title-page' +import { SignInForm, SignInFormOutput } from '../components/sign-in-form' + +export type SignInViewProps = { + sessions: readonly Session[] + setSession: (sub: string | null) => void + loginHint?: string + + onSignIn: (credentials: SignInFormOutput) => void | PromiseLike + onBack?: () => void +} + +export function SignInView({ + loginHint, + sessions, + setSession, + + onSignIn, + onBack, +}: SignInViewProps) { + const session = useMemo(() => sessions.find((s) => s.selected), [sessions]) + const clearSession = useCallback(() => setSession(null), [setSession]) + const accounts = useMemo(() => sessions.map((s) => s.account), [sessions]) + const [showSignInForm, setShowSignInForm] = useState(sessions.length === 0) + + useEffect(() => { + // Make sure the "back" action shows the account picker instead of the + // sign-in form (since the account was added to the list of current + // sessions). + if (session) setShowSignInForm(false) + }, [session]) + + if (session) { + // All set (parent view will handle the redirect) + if (!session.loginRequired) return null + + return ( + + + + ) + } + + if (loginHint) { + return ( + + + + ) + } + + if (sessions.length === 0) { + return ( + + + + ) + } + + if (showSignInForm) { + return ( + + setShowSignInForm(false)} + cancelAria="Back" // to account picker + /> + + ) + } + + return ( + + setSession(a.sub)} + onOther={() => setShowSignInForm(true)} + onBack={onBack} + backAria="Back" // to previous view + /> + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/views/sign-up-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/sign-up-view.tsx new file mode 100644 index 00000000000..a85780de411 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/sign-up-view.tsx @@ -0,0 +1,93 @@ +import { ReactNode, useCallback, useState } from 'react' + +import { LinkDefinition } from '../backend-data' +import { HelpCard } from '../components/help-card' +import { LayoutTitlePage } from '../components/layout-title-page' +import { + SignUpAccountForm, + SignUpAccountFormOutput, +} from '../components/sign-up-account-form' +import { SignUpDisclaimer } from '../components/sign-up-disclaimer' + +export type SignUpViewProps = { + stepName?: (step: number, total: number) => ReactNode + stepTitle?: (step: number, total: number) => ReactNode + + links?: LinkDefinition[] + onSignUp: (data: { + username: string + password: string + extra?: Record + }) => void | PromiseLike + onBack?: () => void +} + +export function SignUpView({ + stepName = (step, total) => `Step ${step} of ${total}`, + stepTitle = (step, total) => { + switch (step) { + case 1: + return 'Your account' + default: + return null + } + }, + + links, + + onSignUp, + onBack, +}: SignUpViewProps) { + const [_credentials, setCredentials] = + useState(null) + const [step, setStep] = useState<1 | 2>(1) + + const stepCount = 2 + + const doSubmitAccount = useCallback( + (credentials: SignUpAccountFormOutput) => { + setCredentials(credentials) + setStep(2) + }, + [onSignUp, setCredentials, setStep], + ) + + return ( + +
+

+ {stepName(step, stepCount)} +

+

+ {stepTitle(step, stepCount)} +

+ + {step === 1 && ( + + + + )} + + {step === 2 && ( + + )} + + +
+
+ ) +} diff --git a/packages/oauth/oauth-provider/src/assets/app/views/welcome-view.tsx b/packages/oauth/oauth-provider/src/assets/app/views/welcome-view.tsx new file mode 100644 index 00000000000..576d3bea8d6 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/app/views/welcome-view.tsx @@ -0,0 +1,61 @@ +import { LayoutWelcome, LayoutWelcomeProps } from '../components/layout-welcome' +import { clsx } from '../lib/clsx' + +export type WelcomeViewParams = LayoutWelcomeProps & { + onSignIn?: () => void + signInLabel?: string + + onSignUp?: () => void + signUpLabel?: string + + onCancel?: () => void + cancelLabel?: string +} + +export function WelcomeView({ + onSignUp, + signUpLabel = 'Create a new account', + onSignIn, + signInLabel = 'Sign in', + onCancel, + cancelLabel = 'Cancel', + + ...props +}: WelcomeViewParams) { + return ( + + {onSignUp && ( + + )} + + {onSignIn && ( + + )} + + {onCancel && ( + + )} + + ) +} diff --git a/packages/oauth/oauth-provider/src/assets/asset.ts b/packages/oauth/oauth-provider/src/assets/asset.ts new file mode 100644 index 00000000000..136e8794572 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/asset.ts @@ -0,0 +1,8 @@ +import type { Readable } from 'node:stream' + +export type Asset = { + url: string + type?: string + sha256: string + createStream: () => Readable +} diff --git a/packages/oauth/oauth-provider/src/assets/assets-middleware.ts b/packages/oauth/oauth-provider/src/assets/assets-middleware.ts new file mode 100644 index 00000000000..2471e4c9075 --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/assets-middleware.ts @@ -0,0 +1,32 @@ +import { writeStream } from '../lib/http/index.js' + +import { ASSETS_URL_PREFIX, getAsset } from './index.js' + +export function authorizeAssetsMiddleware() { + return async function assetsMiddleware(req, res, next): Promise { + if (req.method !== 'GET' && req.method !== 'HEAD') return next() + if (!req.url?.startsWith(ASSETS_URL_PREFIX)) return next() + + const [pathname, query] = req.url.split('?', 2) as [ + string, + string | undefined, + ] + const filename = pathname.slice(ASSETS_URL_PREFIX.length) + if (!filename) return next() + + const asset = await getAsset(filename).catch(() => null) + if (!asset) return next() + + if (req.headers['if-none-match'] === asset.sha256) { + return void res.writeHead(304).end() + } + + res.setHeader('ETag', asset.sha256) + + if (query === asset.sha256) { + res.setHeader('Cache-Control', 'public, max-age=31536000, immutable') + } + + await writeStream(res, asset.createStream(), asset.type) + } +} diff --git a/packages/oauth/oauth-provider/src/assets/index.ts b/packages/oauth/oauth-provider/src/assets/index.ts new file mode 100644 index 00000000000..c7cf4d6098d --- /dev/null +++ b/packages/oauth/oauth-provider/src/assets/index.ts @@ -0,0 +1,74 @@ +import type { ManifestItem } from '@atproto-labs/rollup-plugin-bundle-manifest' + +// If this library is used as a regular dependency (e.g. from node_modules), the +// assets will simply be referenced from the node_modules directory. However, if +// this library is bundled (e.g. via rollup), the assets need to be copied to +// the output directory. Most bundlers support this (webpack, rollup, etc.) by +// re-writing new URL('./path', import.meta.url) calls to point to the correct +// output directory. +// +// https://github.com/evanw/esbuild/issues/795 +// https://www.npmjs.com/package/@web/rollup-plugin-import-meta-assets + +// Note that the bundle-manifest, being a JSON file, can be imported directly +// without any special handling. This is because all bundlers support JSON +// imports out of the box. + +import { createReadStream } from 'node:fs' +import { join, posix } from 'node:path' +import { Readable } from 'node:stream' + +// @ts-expect-error: This file is generated at build time +import appBundleManifestJson from './app/bundle-manifest.json' +import { Asset } from './asset' + +const appBundleManifest: Map = new Map( + Object.entries(appBundleManifestJson), +) + +export const ASSETS_URL_PREFIX = '/@atproto/oauth-provider/~assets/' + +export async function getAsset(inputFilename: string): Promise { + const filename = posix.normalize(inputFilename) + + if ( + filename.startsWith('/') || // Prevent absolute paths + filename.startsWith('../') || // Prevent directory traversal attacks + /[<>:"|?*\\]/.test(filename) // Windows disallowed characters + ) { + throw new AssetNotFoundError(filename) + } + + const manifest = appBundleManifest.get(filename) + if (!manifest) throw new AssetNotFoundError(filename) + + // When this package is used as a regular "node_modules" dependency, and gets + // bundled by the consumer, the assets should be copied to the bundle's output + // directory. In case the bundler does not support copying assets from the + // "dist/assets/app" folder, this package's build system can be modified to + // embed the asset data directly into the bundle-manifest.json (see the `data` + // option of "@atproto-labs/rollup-plugin-bundle-manifest" in rollup.config.js). + + const { data } = manifest + + return { + url: posix.join(ASSETS_URL_PREFIX, filename), + type: manifest.mime, + sha256: manifest.sha256, + createStream: data + ? () => Readable.from(Buffer.from(data, 'base64')) + : () => + // ESM version: + // createReadStream(new URL(`./app/${filename}`, import.meta.url)) + // CJS version: + createReadStream(join(__dirname, './app', filename)), + } +} + +class AssetNotFoundError extends Error { + public readonly code = 'ENOENT' + public readonly statusCode = 404 + constructor(filename: string) { + super(`Asset not found: ${filename}`) + } +} diff --git a/packages/oauth/oauth-provider/src/client/client-auth.ts b/packages/oauth/oauth-provider/src/client/client-auth.ts new file mode 100644 index 00000000000..32b1b6661b4 --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-auth.ts @@ -0,0 +1,45 @@ +import { CLIENT_ASSERTION_TYPE_JWT_BEARER } from '@atproto/oauth-types' +import { KeyLike, calculateJwkThumbprint, exportJWK } from 'jose' +import { JOSEError } from 'jose/errors' + +import { InvalidClientError } from '../errors/invalid-client-error.js' + +export type ClientAuth = + | { method: 'none' } + | { + method: typeof CLIENT_ASSERTION_TYPE_JWT_BEARER + alg: string + kid: string + jkt: string + } + +export function compareClientAuth(a: ClientAuth, b: ClientAuth): boolean { + if (a.method === 'none') { + if (b.method !== a.method) return false + + return true + } + + if (a.method === CLIENT_ASSERTION_TYPE_JWT_BEARER) { + if (b.method !== a.method) return false + + return true + } + + // Fool-proof + throw new TypeError('Invalid ClientAuth method') +} + +export async function authJwkThumbprint( + key: Uint8Array | KeyLike, +): Promise { + try { + return await calculateJwkThumbprint(await exportJWK(key), 'sha512') + } catch (err) { + const message = + err instanceof JOSEError + ? err.message + : 'Failed to compute JWK thumbprint' + throw new InvalidClientError(message, err) + } +} diff --git a/packages/oauth/oauth-provider/src/client/client-data.ts b/packages/oauth/oauth-provider/src/client/client-data.ts new file mode 100644 index 00000000000..c6ad4c44029 --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-data.ts @@ -0,0 +1,9 @@ +import { Jwks } from '@atproto/jwk' +import { OAuthClientMetadata } from '@atproto/oauth-types' + +export type { OAuthClientMetadata } + +export type ClientData = { + metadata: OAuthClientMetadata + jwks?: Jwks +} diff --git a/packages/oauth/oauth-provider/src/client/client-id.ts b/packages/oauth/oauth-provider/src/client/client-id.ts new file mode 100644 index 00000000000..68bc8bcaaab --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-id.ts @@ -0,0 +1,4 @@ +import { OAuthClientId, oauthClientIdSchema } from '@atproto/oauth-types' + +export type ClientId = OAuthClientId +export const clientIdSchema = oauthClientIdSchema diff --git a/packages/oauth/oauth-provider/src/client/client-info.ts b/packages/oauth/oauth-provider/src/client/client-info.ts new file mode 100644 index 00000000000..2cd0756bb96 --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-info.ts @@ -0,0 +1,13 @@ +export type ClientInfo = { + /** + * Defaults to `false` + */ + isFirstParty: boolean + + /** + * Defaults to `true` if the client is isFirstParty, or if the client was + * loaded from the store. (i.e. false in case of "loopback" & "discoverable" + * clients) + */ + isTrusted: boolean +} diff --git a/packages/oauth/oauth-provider/src/client/client-manager.ts b/packages/oauth/oauth-provider/src/client/client-manager.ts new file mode 100644 index 00000000000..f896b3a747e --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-manager.ts @@ -0,0 +1,818 @@ +import { + bindFetch, + Fetch, + fetchJsonProcessor, + fetchJsonZodProcessor, + fetchOkProcessor, +} from '@atproto-labs/fetch' +import { pipe } from '@atproto-labs/pipe' +import { + CachedGetter, + GetCachedOptions, + SimpleStore, +} from '@atproto-labs/simple-store' +import { Jwks, jwksSchema, Keyset } from '@atproto/jwk' +import { + isLoopbackHost, + isLoopbackUrl, + isOAuthClientIdDiscoverable, + isOAuthClientIdLoopback, + OAUTH_AUTHENTICATED_ENDPOINT_NAMES, + OAuthClientIdDiscoverable, + OAuthClientIdLoopback, + OAuthClientMetadata, + OAuthClientMetadataInput, + oauthClientMetadataSchema, +} from '@atproto/oauth-types' + +import { ALLOW_LOOPBACK_CLIENT_REFRESH_TOKEN } from '../constants.js' +import { InvalidClientMetadataError } from '../errors/invalid-client-metadata-error.js' +import { InvalidRedirectUriError } from '../errors/invalid-redirect-uri-error.js' +import { OAuthError } from '../errors/oauth-error.js' +import { parseDomain, parseUrlDomain } from '../lib/util/hostname.js' +import { Awaitable } from '../lib/util/type.js' +import { OAuthHooks } from '../oauth-hooks.js' +import { ClientId } from './client-id.js' +import { ClientStore } from './client-store.js' +import { parseDiscoverableClientId, parseRedirectUri } from './client-utils.js' +import { Client } from './client.js' + +const fetchMetadataHandler = pipe( + fetchOkProcessor(), + fetchJsonProcessor('application/json', false), + fetchJsonZodProcessor(oauthClientMetadataSchema), +) + +const fetchJwksHandler = pipe( + fetchOkProcessor(), + fetchJsonProcessor('application/json', false), + fetchJsonZodProcessor(jwksSchema), +) + +export type LoopbackMetadataGetter = ( + url: string, +) => Awaitable + +export class ClientManager { + protected readonly jwks: CachedGetter + protected readonly metadata: CachedGetter + + constructor( + protected readonly keyset: Keyset, + protected readonly hooks: OAuthHooks, + protected readonly store: ClientStore | null, + protected readonly loopbackMetadata: LoopbackMetadataGetter | null = null, + safeFetch: Fetch, + clientJwksCache: SimpleStore, + clientMetadataCache: SimpleStore, + ) { + const fetch = bindFetch(safeFetch) + + this.jwks = new CachedGetter(async (uri, options) => { + const jwks = await fetch(buildJsonGetRequest(uri, options)).then( + fetchJwksHandler, + ) + + return jwks + }, clientJwksCache) + + this.metadata = new CachedGetter(async (uri, options) => { + const metadata = await fetch(buildJsonGetRequest(uri, options)).then( + fetchMetadataHandler, + ) + + // Validate within the getter to avoid caching invalid metadata + return this.validateClientMetadata(uri, metadata) + }, clientMetadataCache) + } + + /** + * + * @see {@link https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 OIDC Client Registration} + */ + public async getClient(clientId: string) { + try { + const metadata = await this.getClientMetadata(clientId) + + const jwks = metadata.jwks_uri + ? await this.jwks.get(metadata.jwks_uri) + : undefined + + const partialInfo = await this.hooks.onClientInfo?.(clientId, { + metadata, + jwks, + }) + + const isFirstParty = partialInfo?.isFirstParty ?? false + const isTrusted = + partialInfo?.isTrusted ?? + (isFirstParty || + // If the client was loaded from the store, we consider it trusted: + (!isOAuthClientIdLoopback(clientId) && + !isOAuthClientIdDiscoverable(clientId))) + + return new Client(clientId, metadata, jwks, { isFirstParty, isTrusted }) + } catch (err) { + if (err instanceof OAuthError) throw err + if (err?.['code'] === 'DEPTH_ZERO_SELF_SIGNED_CERT') { + throw new InvalidClientMetadataError('Self-signed certificate', err) + } + throw InvalidClientMetadataError.from(err) + } + } + + protected async getClientMetadata( + clientId: ClientId, + ): Promise { + if (isOAuthClientIdLoopback(clientId)) { + return this.getLoopbackClientMetadata(clientId) + } else if (isOAuthClientIdDiscoverable(clientId)) { + return this.getDiscoverableClientMetadata(clientId) + } else if (this.store) { + return this.getStoredClientMetadata(clientId) + } + + throw new InvalidClientMetadataError(`Invalid client ID "${clientId}"`) + } + + protected async getLoopbackClientMetadata( + clientId: OAuthClientIdLoopback, + ): Promise { + const { loopbackMetadata } = this + if (!loopbackMetadata) { + throw new InvalidClientMetadataError('Loopback clients are not allowed') + } + + const result = oauthClientMetadataSchema.safeParse( + await loopbackMetadata(clientId), + ) + + if (!result.success) { + throw InvalidClientMetadataError.from(result.error) + } + + return this.validateClientMetadata(clientId, result.data) + } + + protected async getDiscoverableClientMetadata( + clientId: OAuthClientIdDiscoverable, + ): Promise { + const metadataUrl = parseDiscoverableClientId(clientId) + + const metadata = await this.metadata.get(metadataUrl.href) + + // Note: we do *not* re-validate the metadata here, as the metadata is + // validated within the getter. This is to avoid double validation. + // + // return this.validateClientMetadata(metadataUrl.href, metadata) + return metadata + } + + protected async getStoredClientMetadata( + clientId: ClientId, + ): Promise { + if (this.store) { + const metadata = await this.store.findClient(clientId) + return this.validateClientMetadata(clientId, metadata) + } + + throw new InvalidClientMetadataError(`Invalid client ID "${clientId}"`) + } + + /** + * This method will ensure that the client metadata is valid w.r.t. the OAuth + * and OIDC specifications. It will also ensure that the metadata is + * compatible with the implementation of this library, and ATPROTO's + * requirements. + */ + protected validateClientMetadata( + clientId: ClientId, + metadata: OAuthClientMetadata, + ): OAuthClientMetadata { + if (metadata.jwks && metadata.jwks_uri) { + throw new InvalidClientMetadataError( + 'jwks_uri and jwks are mutually exclusive', + ) + } + + const clientUriUrl = metadata.client_uri + ? new URL(metadata.client_uri) + : null + const clientUriParsed = clientUriUrl ? parseUrlDomain(clientUriUrl) : null + + if (clientUriUrl && !clientUriParsed) { + throw new InvalidClientMetadataError('client_uri must be a valid URL') + } + + const scopes = metadata.scope?.split(' ') + if ( + metadata.grant_types.includes('refresh_token') !== + (scopes?.includes('offline_access') ?? false) + ) { + throw new InvalidClientMetadataError( + 'Grant type "refresh_token" requires scope "offline_access" (and vice versa)', + ) + } + + for (const grantType of metadata.grant_types) { + switch (grantType) { + case 'authorization_code': + case 'refresh_token': + case 'implicit': // Required by OIDC (for id_token) + continue + case 'password': + throw new InvalidClientMetadataError( + `Grant type "${grantType}" is not allowed`, + ) + default: + throw new InvalidClientMetadataError( + `Grant type "${grantType}" is not supported`, + ) + } + } + + if (metadata.client_id && metadata.client_id !== clientId) { + throw new InvalidClientMetadataError('client_id does not match') + } + + if (metadata.subject_type && metadata.subject_type !== 'public') { + throw new InvalidClientMetadataError( + 'Only "public" subject_type is supported', + ) + } + + if ( + metadata.userinfo_signed_response_alg && + !this.keyset.signAlgorithms.includes( + metadata.userinfo_signed_response_alg, + ) + ) { + throw new InvalidClientMetadataError( + `Unsupported "userinfo_signed_response_alg" ${metadata.userinfo_signed_response_alg}`, + ) + } + + if ( + metadata.id_token_signed_response_alg && + !this.keyset.signAlgorithms.includes( + metadata.id_token_signed_response_alg, + ) + ) { + throw new InvalidClientMetadataError( + `Unsupported "id_token_signed_response_alg" ${metadata.id_token_signed_response_alg}`, + ) + } + + if (metadata.userinfo_encrypted_response_alg) { + // We only support signature for now. + throw new InvalidClientMetadataError( + 'Encrypted userinfo response is not supported', + ) + } + + if (!metadata[`token_endpoint_auth_method`]) { + throw new InvalidClientMetadataError( + 'Missing token_endpoint_auth_method client metadata', + ) + } + + for (const endpoint of OAUTH_AUTHENTICATED_ENDPOINT_NAMES) { + const method = + metadata[`${endpoint}_endpoint_auth_method`] || + metadata[`token_endpoint_auth_method`] + + switch (method) { + case 'none': + if (metadata.token_endpoint_auth_signing_alg) { + throw new InvalidClientMetadataError( + `${endpoint}_endpoint_auth_method "none" must not have ${endpoint}_endpoint_auth_signing_alg`, + ) + } + break + + case 'private_key_jwt': + if (!metadata.jwks && !metadata.jwks_uri) { + throw new InvalidClientMetadataError( + `private_key_jwt auth method requires jwks or jwks_uri`, + ) + } + if (metadata.jwks?.keys.length === 0) { + throw new InvalidClientMetadataError( + `private_key_jwt auth method requires at least one key in jwks`, + ) + } + if (!metadata.token_endpoint_auth_signing_alg) { + throw new InvalidClientMetadataError( + `Missing token_endpoint_auth_signing_alg client metadata`, + ) + } + break + + default: + throw new InvalidClientMetadataError( + `${method} is not a supported "${endpoint}_endpoint_auth_method". Use "private_key_jwt" or "none".`, + ) + } + } + + if (metadata.authorization_encrypted_response_enc) { + throw new InvalidClientMetadataError( + 'Encrypted authorization response is not supported', + ) + } + + if (metadata.tls_client_certificate_bound_access_tokens) { + throw new InvalidClientMetadataError( + 'Mutual-TLS bound access tokens are not supported', + ) + } + + if ( + metadata.authorization_encrypted_response_enc && + !metadata.authorization_encrypted_response_alg + ) { + throw new InvalidClientMetadataError( + 'authorization_encrypted_response_enc requires authorization_encrypted_response_alg', + ) + } + + // ATPROTO spec requires the use of DPoP (OAuth spec defaults to false) + if (metadata.dpop_bound_access_tokens !== true) { + throw new InvalidClientMetadataError( + '"dpop_bound_access_tokens" must be true', + ) + } + + for (const responseType of metadata.response_types) { + const rt = responseType.split(' ') + + // ATPROTO spec requires the use of PKCE + if (rt.includes('token')) { + throw new InvalidClientMetadataError( + '"token" response type is not compatible with PKCE (use "code" instead)', + ) + } + + // Consistency check + if ( + rt.includes('code') && + !metadata.grant_types.includes('authorization_code') + ) { + throw new InvalidClientMetadataError( + `Response type "${responseType}" requires the "authorization_code" grant type`, + ) + } + + // Asking for "code token" or "code id_token" is fine (as long as the + // grant_types includes "authorization_code" and the scope includes + // "openid"). Asking for "token" or "id_token" (without "code") requires + // the "implicit" grant type. + if ( + (rt.includes('token') || rt.includes('id_token')) && + !metadata.grant_types.includes('implicit') + ) { + throw new InvalidClientMetadataError( + `Response type "${responseType}" requires the "implicit" grant type`, + ) + } + } + + if (metadata.application_type === 'native') { + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4 + // + // > Except when using a mechanism like Dynamic Client Registration + // > [RFC7591] to provision per-instance secrets, native apps are + // > classified as public clients, as defined by Section 2.1 of OAuth 2.0 + // > [RFC6749]; they MUST be registered with the authorization server as + // > such. Authorization servers MUST record the client type in the + // > client registration details in order to identify and process requests + // > accordingly. + } + + if (!metadata.redirect_uris?.length) { + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // + // > OPs can require that request_uri values used be pre-registered with + // > the require_request_uri_registration discovery parameter. + + throw new InvalidClientMetadataError( + 'At least one redirect_uri is required', + ) + } + + if ( + metadata.application_type === 'web' && + metadata.grant_types.includes('implicit') + ) { + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // + // > Web Clients [as defined by "application_type"] using the OAuth + // > Implicit Grant Type MUST only register URLs using the https + // > scheme as redirect_uris; they MUST NOT use localhost as the + // > hostname. + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + if (url.protocol !== 'https:') { + throw new InvalidRedirectUriError( + `Web clients must use HTTPS redirect URIs`, + ) + } + + if (url.hostname === 'localhost') { + throw new InvalidRedirectUriError( + `Web clients must not use localhost as the hostname`, + ) + } + } + } + + if (metadata.application_type === 'native') { + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // + // > Native Clients [as defined by "application_type"] MUST only + // > register redirect_uris using custom URI schemes or loopback URLs + // > using the http scheme; loopback URLs use localhost or the IP + // > loopback literals 127.0.0.1 or [::1] as the hostname. + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + if (url.protocol !== 'http:') { + throw new InvalidRedirectUriError( + `Native clients must use HTTP redirect URIs (got ${url})`, + ) + } + + if (!isLoopbackHost(url.hostname) && !isPrivateUseUriScheme(url)) { + throw new InvalidRedirectUriError( + 'Loopback redirect URIs are only allowed for native apps', + ) + } + } + } + + if (metadata.application_type === 'native') { + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // + // > Authorization Servers MAY reject Redirection URI values using + // > the http scheme, other than the loopback case for Native + // > Clients. + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + if (url.protocol === 'http:' && !isLoopbackUrl(url)) { + throw new InvalidRedirectUriError( + `Native clients must not use HTTP redirect URIs (got ${url})`, + ) + } + } + } + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + + if (url.username || url.password) { + // Is this a valid concern? Should we allow credentials in the URI? + throw new InvalidRedirectUriError( + `Redirect URI ${url} must not contain credentials`, + ) + } + + switch (true) { + // FIRST: Loopback redirect URI exception (only for native apps) + + case url.hostname === 'localhost': { + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.3 + // + // > While redirect URIs using localhost (i.e., + // > "http://localhost:{port}/{path}") function similarly to loopback IP + // > redirects described in Section 7.3, the use of localhost is NOT + // > RECOMMENDED. Specifying a redirect URI with the loopback IP literal + // > rather than localhost avoids inadvertently listening on network + // > interfaces other than the loopback interface. It is also less + // > susceptible to client-side firewalls and misconfigured host name + // > resolution on the user's device. + throw new InvalidRedirectUriError( + `Loopback redirect URI ${url} is not allowed (use explicit IPs instead)`, + ) + } + // falls through + case url.hostname === '127.0.0.1': + case url.hostname === '[::1]': { + // https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 + // + // > Loopback redirect URIs use the "http" scheme and are constructed + // > with the loopback IP literal and whatever port the client is + // > listening on. That is, "http://127.0.0.1:{port}/{path}" for IPv4, + // > and "http://[::1]:{port}/{path}" for IPv6. + + if (metadata.application_type !== 'native') { + throw new InvalidRedirectUriError( + `Loopback redirect URIs are only allowed for native apps`, + ) + } + + if (url.port) { + // https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 + // + // > The authorization server MUST allow any port to be specified at + // > the time of the request for loopback IP redirect URIs, to + // > accommodate clients that obtain an available ephemeral port + // > from the operating system at the time of the request. + // + // Note: although validation of the redirect_uri will ignore the + // port we still allow it to be specified, as the spec does not + // forbid it. If a port number is specified, ports will need to + // match when validating authorization requests. See + // "compareRedirectUri()". + } + + if (url.protocol !== 'http:') { + throw new InvalidRedirectUriError( + `Loopback redirect URI ${url} must use HTTP`, + ) + } + + break + } + + // SECOND: Protocol-based URI Redirection + + case url.protocol === 'http:': { + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // + // > request_uri [...] URLs MUST use the https scheme unless the + // > target Request Object is signed in a way that is verifiable by + // > the OP. + // + // TODO: Should we allow this (and check for signed request objects)? + throw new InvalidRedirectUriError( + `Non loopback redirect URI ${url} must use HTTPS`, + ) + } + + case url.protocol === 'https:': { + const redirectUriDomain = parseUrlDomain(url) + if (!redirectUriDomain) { + throw new InvalidRedirectUriError( + `Redirect URI ${url} must be a valid URL`, + ) + } + + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4 + // + // > In addition to the collision-resistant properties, requiring a + // > URI scheme based on a domain name that is under the control of + // > the app can help to prove ownership in the event of a dispute + // > where two apps claim the same private-use URI scheme (where one + // > app is acting maliciously). + // + // Although this only applies to "native" clients (extract being from + // rfc8252), we apply this rule to "web" clients as well. + if (!clientUriParsed) { + throw new InvalidClientMetadataError( + 'client_uri is required for HTTPS redirect URIs', + ) + } else { + if (redirectUriDomain.domain !== clientUriParsed.domain) { + throw new InvalidRedirectUriError( + `Redirect URI ${url} must be under the same domain as client_uri ${metadata.client_uri}`, + ) + } + } + + break + } + + case isPrivateUseUriScheme(url): { + // https://datatracker.ietf.org/doc/html/rfc8252#section-7.1 + // + // > When choosing a URI scheme to associate with the app, apps MUST + // > use a URI scheme based on a domain name under their control, + // > expressed in reverse order, as recommended by Section 3.8 of + // > [RFC7595] for private-use URI schemes. + + if (metadata.application_type !== 'native') { + throw new InvalidRedirectUriError( + `Private-Use URI Scheme redirect URI are only allowed for native apps`, + ) + } + + const redirectUriDomain = parseDomain( + reverseDomain(url.protocol.slice(0, -1)), + ) + + if (!redirectUriDomain) { + throw new InvalidRedirectUriError( + `Private-use URI Scheme redirect URI must be based on a valid domain name`, + ) + } + + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4 + // + // > In addition to the collision-resistant properties, requiring a + // > URI scheme based on a domain name that is under the control of + // > the app can help to prove ownership in the event of a dispute + // > where two apps claim the same private-use URI scheme (where one + // > app is acting maliciously). + if (!clientUriParsed) { + throw new InvalidClientMetadataError( + 'client_uri is required for native apps using private-use URI Scheme redirect URIs', + ) + } else { + if (redirectUriDomain.domain !== clientUriParsed.domain) { + throw new InvalidRedirectUriError( + `Private-Use URI Scheme redirect URI ${url} must be under the same domain as client_uri ${metadata.client_uri}`, + ) + } + } + + // https://datatracker.ietf.org/doc/html/rfc8252#section-7.1 + // + // > Following the requirements of Section 3.2 of [RFC3986], as there + // > is no naming authority for private-use URI scheme redirects, only + // > a single slash ("/") appears after the scheme component. + if ( + url.href.startsWith(`${url.protocol}//`) || + url.username || + url.password || + url.hostname || + url.port + ) { + throw new InvalidRedirectUriError( + `Private-Use URI Scheme must be in the form ${url.protocol}/`, + ) + } + + break + } + + default: + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4 + // + // > At a minimum, any private-use URI scheme that doesn't contain a + // > period character (".") SHOULD be rejected. + throw new InvalidRedirectUriError( + `Invalid redirect URI scheme "${url.protocol}"`, + ) + } + } + + if (isOAuthClientIdLoopback(clientId)) { + return this.validateLoopbackClientMetadata(clientId, metadata) + } else if (isOAuthClientIdDiscoverable(clientId)) { + return this.validateDiscoverableClientMetadata(clientId, metadata) + } else { + return metadata + } + } + + validateLoopbackClientMetadata( + clientId: OAuthClientIdLoopback, + metadata: OAuthClientMetadata, + ): OAuthClientMetadata { + if (metadata.client_uri) { + throw new InvalidClientMetadataError( + 'client_uri is not allowed for loopback clients', + ) + } + + if (metadata.application_type !== 'native') { + throw new InvalidClientMetadataError( + 'Loopback clients must have application_type "native"', + ) + } + + if ( + !ALLOW_LOOPBACK_CLIENT_REFRESH_TOKEN && + metadata.grant_types.includes('refresh_token') + ) { + throw new InvalidClientMetadataError( + 'Loopback clients are not allowed to use the "refresh_token" grant type', + ) + } + + for (const endpoint of OAUTH_AUTHENTICATED_ENDPOINT_NAMES) { + const method = + metadata[`${endpoint}_endpoint_auth_method`] || + metadata[`token_endpoint_auth_method`] + + if (method !== 'none') { + throw new InvalidClientMetadataError( + `Loopback clients are not allowed to use "${endpoint}_endpoint_auth_method" ${method}`, + ) + } + } + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + + if (url.protocol !== 'http:') { + throw new InvalidRedirectUriError( + `Loopback clients must use HTTP redirect URIs`, + ) + } + + if (!isLoopbackHost(url.hostname)) { + throw new InvalidRedirectUriError( + `Loopback clients must use loopback redirect URIs`, + ) + } + } + + return metadata + } + + validateDiscoverableClientMetadata( + clientId: OAuthClientIdDiscoverable, + metadata: OAuthClientMetadata, + ): OAuthClientMetadata { + if (!metadata.client_id) { + // https://drafts.aaronpk.com/draft-parecki-oauth-client-id-metadata-document/draft-parecki-oauth-client-id-metadata-document.html + throw new InvalidClientMetadataError( + `client_id is required for discoverable clients`, + ) + } + + const clientIdUrl = parseDiscoverableClientId(clientId) + + if (metadata.client_uri) { + // https://drafts.aaronpk.com/draft-parecki-oauth-client-id-metadata-document/draft-parecki-oauth-client-id-metadata-document.html + // + // The client_uri must be a parent of the client_id URL. This might be + // relaxed in the future. + + const clientUriUrl = new URL(metadata.client_uri) + + if (clientUriUrl.origin !== clientIdUrl.origin) { + throw new InvalidClientMetadataError( + `client_uri must have the same origin as the client_id`, + ) + } + + if (clientIdUrl.pathname !== clientUriUrl.pathname) { + if ( + !clientIdUrl.pathname.startsWith( + clientUriUrl.pathname.endsWith('/') + ? clientUriUrl.pathname + : `${clientUriUrl.pathname}/`, + ) + ) { + throw new InvalidClientMetadataError( + `client_uri must be a parent URL of the client_id`, + ) + } + } + } + + for (const endpoint of OAUTH_AUTHENTICATED_ENDPOINT_NAMES) { + const method = metadata[`${endpoint}_endpoint_auth_method`] + switch (method) { + case 'client_secret_post': + case 'client_secret_basic': + case 'client_secret_jwt': + throw new InvalidClientMetadataError( + `Client authentication method "${method}" is not allowed for discoverable clients`, + ) + } + } + + for (const redirectUri of metadata.redirect_uris) { + const url = parseRedirectUri(redirectUri) + + if (isPrivateUseUriScheme(url)) { + // https://drafts.aaronpk.com/draft-parecki-oauth-client-id-metadata-document/draft-parecki-oauth-client-id-metadata-document.html + // + // Fully qualified domain name (FQDN) of the client_id, in reverse + // order. This could be relaxed to allow same apex domain names, or + // parent domains, but for now we require an exact match. + const protocol = `${reverseDomain(clientIdUrl.hostname)}:` + if (url.protocol !== protocol) { + throw new InvalidRedirectUriError( + `Private-Use URI Scheme redirect URI, for discoverable client metadata, must be the fully qualified domain name (FQDN) of the client_id, in reverse order (${protocol})`, + ) + } + } + } + + return metadata + } +} + +function reverseDomain(domain: string) { + return domain.split('.').reverse().join('.') +} + +function isPrivateUseUriScheme(uri: URL) { + return uri.protocol.includes('.') +} + +function buildJsonGetRequest(uri: string, options?: GetCachedOptions) { + const headers = new Headers([['accept', 'application/json']]) + if (options?.noCache) headers.set('cache-control', 'no-cache') + return new Request(uri, { + headers, + signal: options?.signal, + redirect: 'error', + }) +} diff --git a/packages/oauth/oauth-provider/src/client/client-store.ts b/packages/oauth/oauth-provider/src/client/client-store.ts new file mode 100644 index 00000000000..fdf05d7d0ed --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-store.ts @@ -0,0 +1,38 @@ +import { OAuthClientMetadata } from '@atproto/oauth-types' + +import { Awaitable } from '../lib/util/type.js' +import { ClientId } from './client-id.js' + +// Export all types needed to implement the ClientStore interface +export * from './client-data.js' +export * from './client-id.js' +export type { Awaitable } + +export interface ClientStore { + findClient(clientId: ClientId): Awaitable +} + +export function isClientStore( + implementation: Record & Partial, +): implementation is Record & ClientStore { + return typeof implementation.findClient === 'function' +} + +export function ifClientStore( + implementation?: Record & Partial, +): ClientStore | undefined { + if (implementation && isClientStore(implementation)) { + return implementation + } + + return undefined +} + +export function asClientStore( + implementation?: Record & Partial, +): ClientStore { + const store = ifClientStore(implementation) + if (store) return store + + throw new Error('Invalid ClientStore implementation') +} diff --git a/packages/oauth/oauth-provider/src/client/client-utils.ts b/packages/oauth/oauth-provider/src/client/client-utils.ts new file mode 100644 index 00000000000..60f4c178d19 --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client-utils.ts @@ -0,0 +1,43 @@ +import { + OAuthClientIdDiscoverable, + OAuthClientIdLoopback, + parseOAuthLoopbackClientId, + parseOAuthDiscoverableClientId, +} from '@atproto/oauth-types' + +import { InvalidClientIdError } from '../errors/invalid-client-id-error.js' +import { InvalidRedirectUriError } from '../errors/invalid-redirect-uri-error.js' +import { isInternetHost } from '../lib/util/hostname.js' + +export function parseRedirectUri(redirectUri: string): URL { + try { + return new URL(redirectUri) + } catch (err) { + throw InvalidRedirectUriError.from(err) + } +} + +export function parseDiscoverableClientId( + clientId: OAuthClientIdDiscoverable, +): URL { + try { + const url = parseOAuthDiscoverableClientId(clientId) + + // Extra validation, prevent usage of invalid internet domain names. + if (!isInternetHost(url.hostname)) { + throw new InvalidClientIdError('ClientID is not a valid internet address') + } + + return url + } catch (err) { + throw InvalidClientIdError.from(err) + } +} + +export function parseLoopbackClientId(clientId: OAuthClientIdLoopback): URL { + try { + return parseOAuthLoopbackClientId(clientId) + } catch (err) { + throw InvalidClientIdError.from(err) + } +} diff --git a/packages/oauth/oauth-provider/src/client/client.ts b/packages/oauth/oauth-provider/src/client/client.ts new file mode 100644 index 00000000000..76fbf899020 --- /dev/null +++ b/packages/oauth/oauth-provider/src/client/client.ts @@ -0,0 +1,231 @@ +import { Jwks } from '@atproto/jwk' +import { + CLIENT_ASSERTION_TYPE_JWT_BEARER, + OAuthClientIdentification, + OAuthClientMetadata, + OAuthEndpointName, +} from '@atproto/oauth-types' +import { + UnsecuredJWT, + createLocalJWKSet, + createRemoteJWKSet, + jwtVerify, + type JWTPayload, + type JWTVerifyGetKey, + type JWTVerifyOptions, + type JWTVerifyResult, + type KeyLike, + type ResolvedKey, + type UnsecuredResult, +} from 'jose' +import { JOSEError } from 'jose/errors' + +import { CLIENT_ASSERTION_MAX_AGE, JAR_MAX_AGE } from '../constants.js' +import { InvalidClientError } from '../errors/invalid-client-error.js' +import { InvalidClientMetadataError } from '../errors/invalid-client-metadata-error.js' +import { InvalidRequestError } from '../errors/invalid-request-error.js' +import { ClientAuth, authJwkThumbprint } from './client-auth.js' +import { ClientId } from './client-id.js' +import { ClientInfo } from './client-info.js' + +export class Client { + /** + * @see {@link https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml#token-endpoint-auth-method} + */ + static readonly AUTH_METHODS_SUPPORTED = ['none', 'private_key_jwt'] as const + + private readonly keyGetter: JWTVerifyGetKey + + constructor( + public readonly id: ClientId, + public readonly metadata: OAuthClientMetadata, + public readonly jwks: undefined | Jwks = metadata.jwks, + public readonly info: ClientInfo, + ) { + // If the remote JWKS content is provided, we don't need to fetch it again. + this.keyGetter = + jwks || !metadata.jwks_uri + ? createLocalJWKSet(jwks || { keys: [] }) + : createRemoteJWKSet(new URL(metadata.jwks_uri), {}) + } + + public async decodeRequestObject(jar: string) { + try { + switch (this.metadata.request_object_signing_alg) { + case 'none': + return await this.jwtVerifyUnsecured(jar, { + maxTokenAge: JAR_MAX_AGE / 1000, + }) + case undefined: + // https://openid.net/specs/openid-connect-registration-1_0.html#rfc.section.2 + // > The default, if omitted, is that any algorithm supported by the OP + // > and the RP MAY be used. + return await this.jwtVerify(jar, { + maxTokenAge: JAR_MAX_AGE / 1000, + }) + default: + return await this.jwtVerify(jar, { + maxTokenAge: JAR_MAX_AGE / 1000, + algorithms: [this.metadata.request_object_signing_alg], + }) + } + } catch (err) { + const message = + err instanceof JOSEError + ? `Invalid "request" object: ${err.message}` + : `Invalid "request" object` + + throw new InvalidRequestError(message, err) + } + } + + protected async jwtVerifyUnsecured( + token: string, + options?: Omit, + ): Promise> { + return UnsecuredJWT.decode(token, { + ...options, + issuer: this.id, + }) + } + + protected async jwtVerify( + token: string, + options?: Omit, + ): Promise & ResolvedKey> { + return jwtVerify(token, this.keyGetter, { + ...options, + issuer: this.id, + }) + } + + protected getAuthMethod(endpoint: OAuthEndpointName) { + return ( + this.metadata[`${endpoint}_endpoint_auth_method`] || + this.metadata[`token_endpoint_auth_method`] + ) + } + + /** + * @see {@link https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1} + * @see {@link https://datatracker.ietf.org/doc/html/draft-ietf-oauth-jwt-bearer-11#section-3} + * @see {@link https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml#token-endpoint-auth-method} + */ + public async verifyCredentials( + input: OAuthClientIdentification, + endpoint: OAuthEndpointName, + checks: { + audience: string + }, + ): Promise<{ + clientAuth: ClientAuth + // for replay protection + nonce?: string + }> { + const method = this.getAuthMethod(endpoint) + + if (method === 'none') { + const clientAuth: ClientAuth = { method: 'none' } + return { clientAuth } + } + + if (method === 'private_key_jwt') { + if (!('client_assertion_type' in input) || !input.client_assertion_type) { + throw new InvalidRequestError( + `client_assertion_type required for "${method}"`, + ) + } else if (!input.client_assertion) { + throw new InvalidRequestError( + `client_assertion required for "${method}"`, + ) + } + + if (input.client_assertion_type === CLIENT_ASSERTION_TYPE_JWT_BEARER) { + const result = await this.jwtVerify<{ + jti: string + }>(input.client_assertion, { + audience: checks.audience, + subject: this.id, + maxTokenAge: CLIENT_ASSERTION_MAX_AGE / 1000, + }).catch((err) => { + if (err instanceof JOSEError) { + const msg = `Validation of "client_assertion" failed: ${err.message}` + throw new InvalidClientError(msg, err) + } + + throw err + }) + + if (!result.protectedHeader.kid) { + throw new InvalidClientError(`"kid" required in client_assertion`) + } + + if (!result.payload.jti) { + throw new InvalidClientError(`"jti" required in client_assertion`) + } + + const clientAuth: ClientAuth = { + method: CLIENT_ASSERTION_TYPE_JWT_BEARER, + jkt: await authJwkThumbprint(result.key), + alg: result.protectedHeader.alg, + kid: result.protectedHeader.kid, + } + + return { clientAuth, nonce: result.payload.jti } + } + + throw new InvalidClientError( + `Unsupported client_assertion_type "${input.client_assertion_type}"`, + ) + } + + // @ts-expect-error Ensure to keep Client.AUTH_METHODS_SUPPORTED in sync + // with the implementation of this function. + if (Client.AUTH_METHODS_SUPPORTED.includes(method)) { + throw new Error( + `verifyCredentials() should implement all of ${[ + Client.AUTH_METHODS_SUPPORTED, + ]}`, + ) + } + + throw new InvalidClientMetadataError( + `Unsupported ${endpoint}_endpoint_auth_method "${method}"`, + ) + } + + /** + * Ensures that a {@link ClientAuth} generated in the past is still valid wrt + * the current client metadata & jwks. This is used to invalidate tokens when + * the client stops advertising the key that it used to authenticate itself + * during the initial token request. + */ + public async validateClientAuth(clientAuth: ClientAuth): Promise { + if (clientAuth.method === 'none') { + return this.getAuthMethod('token') === 'none' + } + + if (clientAuth.method === CLIENT_ASSERTION_TYPE_JWT_BEARER) { + if (this.getAuthMethod('token') !== 'private_key_jwt') { + return false + } + try { + const key = await this.keyGetter( + { + kid: clientAuth.kid, + alg: clientAuth.alg, + }, + { payload: '', signature: '' }, + ) + const jtk = await authJwkThumbprint(key) + + return jtk === clientAuth.jkt + } catch (e) { + return false + } + } + + // @ts-expect-error + throw new Error(`Invalid method "${clientAuth.method}"`) + } +} diff --git a/packages/oauth/oauth-provider/src/constants.ts b/packages/oauth/oauth-provider/src/constants.ts new file mode 100644 index 00000000000..3804e790485 --- /dev/null +++ b/packages/oauth/oauth-provider/src/constants.ts @@ -0,0 +1,69 @@ +// The purpose of the prefix is to provide type safety + +export const DEVICE_ID_PREFIX = 'dev-' +export const DEVICE_ID_BYTES_LENGTH = 16 // 128 bits + +export const SESSION_ID_PREFIX = 'ses-' +export const SESSION_ID_BYTES_LENGTH = 16 // 128 bits - only valid if device id is valid + +export const REFRESH_TOKEN_PREFIX = 'ref-' +export const REFRESH_TOKEN_BYTES_LENGTH = 32 // 256 bits + +export const TOKEN_ID_PREFIX = 'tok-' +export const TOKEN_ID_BYTES_LENGTH = 16 // 128 bits - used as `jti` in JWTs (cannot be forged) + +export const REQUEST_ID_PREFIX = 'req-' +export const REQUEST_ID_BYTES_LENGTH = 16 // 128 bits + +export const CODE_PREFIX = 'cod-' +export const CODE_BYTES_LENGTH = 32 + +export const ALLOW_LOOPBACK_CLIENT_REFRESH_TOKEN = true + +const SECOND = 1e3 +const MINUTE = 60 * SECOND +const HOUR = 60 * MINUTE +const DAY = 24 * HOUR +const WEEK = 7 * DAY +const YEAR = 365.25 * DAY +const MONTH = YEAR / 12 + +/** 7 days */ +export const AUTHENTICATION_MAX_AGE = 7 * DAY + +/** 60 minutes */ +export const TOKEN_MAX_AGE = 60 * MINUTE + +/** 5 minutes */ +export const AUTHORIZATION_INACTIVITY_TIMEOUT = 5 * MINUTE + +/** 1 months */ +export const AUTHENTICATED_REFRESH_INACTIVITY_TIMEOUT = 1 * MONTH + +/** 2 days */ +export const UNAUTHENTICATED_REFRESH_INACTIVITY_TIMEOUT = 2 * DAY + +/** 1 week */ +export const UNAUTHENTICATED_REFRESH_LIFETIME = 1 * WEEK + +/** 1 year */ +export const AUTHENTICATED_REFRESH_LIFETIME = 1 * YEAR + +/** 5 minutes */ +export const PAR_EXPIRES_IN = 5 * MINUTE + +/** + * 59 seconds (should be less than a minute) + * + * @see {@link https://datatracker.ietf.org/doc/html/rfc9101#section-10.2} + */ +export const JAR_MAX_AGE = 59 * SECOND + +/** 1 minute */ +export const CLIENT_ASSERTION_MAX_AGE = 1 * MINUTE + +/** 3 minutes */ +export const DPOP_NONCE_MAX_AGE = 3 * MINUTE + +/** 5 seconds */ +export const SESSION_FIXATION_MAX_AGE = 5 * SECOND diff --git a/packages/oauth/oauth-provider/src/device/device-data.ts b/packages/oauth/oauth-provider/src/device/device-data.ts new file mode 100644 index 00000000000..864211f7ae1 --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/device-data.ts @@ -0,0 +1,11 @@ +import { z } from 'zod' + +import { deviceDetailsSchema } from './device-details.js' +import { sessionIdSchema } from './session-id.js' + +export const deviceDataSchema = deviceDetailsSchema.extend({ + sessionId: sessionIdSchema, + lastSeenAt: z.date(), +}) + +export type DeviceData = z.infer diff --git a/packages/oauth/oauth-provider/src/device/device-details.ts b/packages/oauth/oauth-provider/src/device/device-details.ts new file mode 100644 index 00000000000..8673736146b --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/device-details.ts @@ -0,0 +1,43 @@ +import { IncomingMessage } from 'node:http' + +import { z } from 'zod' + +export const deviceDetailsSchema = z.object({ + userAgent: z.string().nullable(), + ipAddress: z.string(), +}) +export type DeviceDetails = z.infer + +export function extractDeviceDetails( + req: IncomingMessage, + trustProxy: boolean, +): DeviceDetails { + const userAgent = req.headers['user-agent'] || null + const ipAddress = extractIpAddress(req, trustProxy) || null + + if (!ipAddress) { + throw new Error('Could not determine IP address') + } + + return { userAgent, ipAddress } +} + +export function extractIpAddress( + req: IncomingMessage, + trustProxy: boolean, +): string | undefined { + // Express app compatibility + if ('ip' in req && typeof req.ip === 'string') { + return req.ip + } + + if (trustProxy) { + const forwardedFor = req.headers['x-forwarded-for'] + if (typeof forwardedFor === 'string') { + const firstForward = forwardedFor.split(',')[0]!.trim() + if (firstForward) return firstForward + } + } + + return req.socket.remoteAddress +} diff --git a/packages/oauth/oauth-provider/src/device/device-id.ts b/packages/oauth/oauth-provider/src/device/device-id.ts new file mode 100644 index 00000000000..ced0c1a416d --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/device-id.ts @@ -0,0 +1,23 @@ +import { z } from 'zod' + +import { DEVICE_ID_BYTES_LENGTH, DEVICE_ID_PREFIX } from '../constants.js' +import { randomHexId } from '../lib/util/crypto.js' + +export const DEVICE_ID_LENGTH = + DEVICE_ID_PREFIX.length + DEVICE_ID_BYTES_LENGTH * 2 // hex encoding + +export const deviceIdSchema = z + .string() + .length(DEVICE_ID_LENGTH) + .refine( + (v): v is `${typeof DEVICE_ID_PREFIX}${string}` => + v.startsWith(DEVICE_ID_PREFIX), + { + message: `Invalid device ID format`, + }, + ) + +export type DeviceId = z.infer +export const generateDeviceId = async (): Promise => { + return `${DEVICE_ID_PREFIX}${await randomHexId(DEVICE_ID_BYTES_LENGTH)}` +} diff --git a/packages/oauth/oauth-provider/src/device/device-manager.ts b/packages/oauth/oauth-provider/src/device/device-manager.ts new file mode 100644 index 00000000000..c2d69a9f86e --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/device-manager.ts @@ -0,0 +1,287 @@ +import { IncomingMessage, ServerResponse } from 'node:http' + +import { serialize as serializeCookie } from 'cookie' +import type Keygrip from 'keygrip' +import { z } from 'zod' + +import { appendHeader, parseHttpCookies } from '../lib/http/index.js' + +import { SESSION_FIXATION_MAX_AGE } from '../constants.js' +import { DeviceData } from './device-data.js' +import { extractDeviceDetails } from './device-details.js' +import { DeviceId, deviceIdSchema, generateDeviceId } from './device-id.js' +import { DeviceStore } from './device-store.js' +import { generateSessionId, sessionIdSchema } from './session-id.js' + +export const DEFAULT_OPTIONS = { + /** + * Controls whether the IP address is read from the `X-Forwarded-For` header + * (if `true`), or from the `req.socket.remoteAddress` property (if `false`). + * + * @default true // (nowadays, most requests are proxied) + */ + trustProxy: true, + + /** + * Amount of time (in ms) after which session IDs will be rotated + * + * @default 300e3 // (5 minutes) + */ + rotationRate: 5 * 60e3, + + /** + * Cookie options + */ + cookie: { + keys: undefined as undefined | Keygrip, + + /** + * Name of the cookie used to identify the device + * + * @default 'session-id' + */ + device: 'device-id', + + /** + * Name of the cookie used to identify the session + * + * @default 'session-id' + */ + session: 'session-id', + + /** + * Url path for the cookie + * + * @default '/oauth/authorize' + */ + path: '/oauth/authorize', + + /** + * Amount of time (in ms) after which the session cookie will expire. + * If set to `null`, the cookie will be a session cookie (deleted when the + * browser is closed). + * + * @default 10 * 365.2 * 24 * 60 * 60e3 // 10 years (in ms) + */ + age: (10 * 365.2 * 24 * 60 * 60e3), + + /** + * Controls whether the cookie is only sent over HTTPS (if `true`), or also + * over HTTP (if `false`). This should **NOT** be set to `false` in + * production. + */ + secure: true, + + /** + * Controls whether the cookie is sent along with cross-site requests. + * + * @default 'lax' + */ + sameSite: 'lax' as 'lax' | 'strict', + }, +} + +export type DeviceDeviceManagerOptions = typeof DEFAULT_OPTIONS + +const cookieValueSchema = z.tuple([deviceIdSchema, sessionIdSchema]) +type CookieValue = z.infer + +/** + * This class provides an abstraction for keeping track of DEVICE sessions. It + * relies on a {@link DeviceStore} to persist session data and a cookie to + * identify the session. + */ +export class DeviceManager { + constructor( + private readonly store: DeviceStore, + private readonly options: DeviceDeviceManagerOptions = DEFAULT_OPTIONS, + ) {} + + public async load( + req: IncomingMessage, + res: ServerResponse, + ): Promise<{ deviceId: DeviceId }> { + const cookie = await this.getCookie(req) + if (cookie) { + return this.refresh(req, res, cookie.value, cookie.mustRotate) + } else { + return this.create(req, res) + } + } + + private async create( + req: IncomingMessage, + res: ServerResponse, + ): Promise<{ deviceId: DeviceId }> { + const { userAgent, ipAddress } = this.getDeviceDetails(req) + + const [deviceId, sessionId] = await Promise.all([ + generateDeviceId(), + generateSessionId(), + ] as const) + + await this.store.createDevice(deviceId, { + sessionId, + lastSeenAt: new Date(), + userAgent, + ipAddress, + }) + + this.setCookie(res, [deviceId, sessionId]) + + return { deviceId } + } + + private async refresh( + req: IncomingMessage, + res: ServerResponse, + [deviceId, sessionId]: CookieValue, + forceRotate = false, + ): Promise<{ deviceId: DeviceId }> { + const data = await this.store.readDevice(deviceId) + if (!data) return this.create(req, res) + + const lastSeenAt = new Date(data.lastSeenAt) + const age = Date.now() - lastSeenAt.getTime() + + if (sessionId !== data.sessionId) { + if (age <= SESSION_FIXATION_MAX_AGE) { + // The cookie was probably rotated by a concurrent request. Let's + // update the cookie with the new sessionId. + forceRotate = true + } else { + // Something's wrong. Let's create a new session. + await this.store.deleteDevice(deviceId) + return this.create(req, res) + } + } + + const details = this.getDeviceDetails(req) + + if ( + forceRotate || + details.ipAddress !== data.ipAddress || + details.userAgent !== data.userAgent || + age > this.options.rotationRate + ) { + await this.rotate(req, res, deviceId, { + ipAddress: details.ipAddress, + userAgent: details.userAgent || data.userAgent, + }) + } + + return { deviceId } + } + + public async rotate( + req: IncomingMessage, + res: ServerResponse, + deviceId: DeviceId, + data?: Partial>, + ): Promise { + const sessionId = await generateSessionId() + + await this.store.updateDevice(deviceId, { + ...data, + sessionId, + lastSeenAt: new Date(), + }) + + this.setCookie(res, [deviceId, sessionId]) + } + + private async getCookie( + req: IncomingMessage, + ): Promise<{ value: CookieValue; mustRotate: boolean } | null> { + const cookies = parseHttpCookies(req) + if (!cookies) return null + + const device = this.parseCookie( + cookies, + this.options.cookie.device, + deviceIdSchema, + ) + const session = this.parseCookie( + cookies, + this.options.cookie.session, + sessionIdSchema, + ) + + // Silently ignore invalid cookies + if (!device || !session) { + // If the device cookie is valid, let's cleanup the DB + if (device) await this.store.deleteDevice(device.value) + + return null + } + + return { + value: [device.value, session.value], + mustRotate: device.mustRotate || session.mustRotate, + } + } + + private parseCookie( + cookies: Record, + name: string, + schema: z.ZodType | z.ZodEffects, + ): null | { value: T; mustRotate: boolean } { + const result = schema.safeParse(cookies[name], { path: ['cookie', name] }) + if (!result.success) return null + + const value = result.data + + if (this.options.cookie.keys) { + const hash = cookies[`${name}:hash`] + if (!hash) return null + + const idx = this.options.cookie.keys.index(value, hash) + if (idx < 0) return null + + return { value, mustRotate: idx !== 0 } + } + + return { value, mustRotate: false } + } + + private setCookie(res: ServerResponse, cookieValue: null | CookieValue) { + this.writeCookie(res, this.options.cookie.device, cookieValue?.[0]) + this.writeCookie(res, this.options.cookie.session, cookieValue?.[1]) + } + + private writeCookie(res: ServerResponse, name: string, value?: string) { + const cookieOptions = { + maxAge: value + ? this.options.cookie.age == null + ? undefined + : this.options.cookie.age / 1000 + : 0, + httpOnly: true, + path: this.options.cookie.path, + secure: this.options.cookie.secure !== false, + sameSite: this.options.cookie.sameSite === 'lax' ? 'lax' : 'strict', + } as const + + appendHeader( + res, + 'Set-Cookie', + serializeCookie(name, value || '', cookieOptions), + ) + + if (this.options.cookie.keys) { + appendHeader( + res, + 'Set-Cookie', + serializeCookie( + `${name}:hash`, + value ? this.options.cookie.keys.sign(value) : '', + cookieOptions, + ), + ) + } + } + + private getDeviceDetails(req: IncomingMessage) { + return extractDeviceDetails(req, this.options.trustProxy) + } +} diff --git a/packages/oauth/oauth-provider/src/device/device-store.ts b/packages/oauth/oauth-provider/src/device/device-store.ts new file mode 100644 index 00000000000..a25c32382a6 --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/device-store.ts @@ -0,0 +1,35 @@ +import { Awaitable } from '../lib/util/type.js' +import { DeviceData } from './device-data.js' +import { DeviceId } from './device-id.js' + +// Export all types needed to implement the DeviceStore interface +export * from './device-id.js' +export * from './device-data.js' +export * from './session-id.js' + +export interface DeviceStore { + createDevice(deviceId: DeviceId, data: DeviceData): Awaitable + readDevice(deviceId: DeviceId): Awaitable + updateDevice(deviceId: DeviceId, data: Partial): Awaitable + deleteDevice(deviceId: DeviceId): Awaitable +} + +export function isDeviceStore( + implementation: Record & Partial, +): implementation is Record & DeviceStore { + return ( + typeof implementation.createDevice === 'function' && + typeof implementation.readDevice === 'function' && + typeof implementation.updateDevice === 'function' && + typeof implementation.deleteDevice === 'function' + ) +} + +export function asDeviceStore( + implementation?: Record & Partial, +): DeviceStore { + if (!implementation || !isDeviceStore(implementation)) { + throw new Error('Invalid DeviceStore implementation') + } + return implementation +} diff --git a/packages/oauth/oauth-provider/src/device/session-id.ts b/packages/oauth/oauth-provider/src/device/session-id.ts new file mode 100644 index 00000000000..e063ebdb50e --- /dev/null +++ b/packages/oauth/oauth-provider/src/device/session-id.ts @@ -0,0 +1,22 @@ +import { z } from 'zod' + +import { SESSION_ID_BYTES_LENGTH, SESSION_ID_PREFIX } from '../constants.js' +import { randomHexId } from '../lib/util/crypto.js' + +export const SESSION_ID_LENGTH = + SESSION_ID_PREFIX.length + SESSION_ID_BYTES_LENGTH * 2 // hex encoding + +export const sessionIdSchema = z + .string() + .length(SESSION_ID_LENGTH) + .refine( + (v): v is `${typeof SESSION_ID_PREFIX}${string}` => + v.startsWith(SESSION_ID_PREFIX), + { + message: `Invalid session ID format`, + }, + ) +export type SessionId = z.infer +export const generateSessionId = async (): Promise => { + return `${SESSION_ID_PREFIX}${await randomHexId(SESSION_ID_BYTES_LENGTH)}` +} diff --git a/packages/oauth/oauth-provider/src/dpop/dpop-manager.ts b/packages/oauth/oauth-provider/src/dpop/dpop-manager.ts new file mode 100644 index 00000000000..5e6d0d28865 --- /dev/null +++ b/packages/oauth/oauth-provider/src/dpop/dpop-manager.ts @@ -0,0 +1,147 @@ +import { createHash } from 'node:crypto' + +import { EmbeddedJWK, calculateJwkThumbprint, jwtVerify } from 'jose' +import { JOSEError } from 'jose/errors' + +import { DPOP_NONCE_MAX_AGE } from '../constants.js' +import { InvalidDpopProofError } from '../errors/invalid-dpop-proof-error.js' +import { UseDpopNonceError } from '../errors/use-dpop-nonce-error.js' +import { DpopNonce, DpopNonceInput } from './dpop-nonce.js' + +export { DpopNonce, type DpopNonceInput } +export type DpopManagerOptions = { + /** + * Set this to `false` to disable the use of nonces in DPoP proofs. Set this + * to a secret Uint8Array or hex encoded string to use a predictable seed for + * all nonces (typically useful when multiple instances are running). Leave + * undefined to generate a random seed at startup. + */ + dpopSecret?: false | DpopNonceInput + dpopStep?: number +} + +export class DpopManager { + protected readonly dpopNonce?: DpopNonce + + constructor({ dpopSecret, dpopStep }: DpopManagerOptions = {}) { + this.dpopNonce = + dpopSecret === false ? undefined : DpopNonce.from(dpopSecret, dpopStep) + } + + nextNonce(): string | undefined { + return this.dpopNonce?.next() + } + + /** + * @see {@link https://datatracker.ietf.org/doc/html/rfc9449#section-4.3} + */ + async checkProof( + proof: unknown, + htm: string, // HTTP Method + htu: string | URL, // HTTP URL + accessToken?: string, // Access Token + ) { + if (Array.isArray(proof) && proof.length === 1) { + proof = proof[0] + } + + if (!proof || typeof proof !== 'string') { + throw new InvalidDpopProofError('DPoP proof required') + } + + const { protectedHeader, payload } = await jwtVerify<{ + iat: number + exp: number + jti: string + }>(proof, EmbeddedJWK, { + typ: 'dpop+jwt', + maxTokenAge: 10, + clockTolerance: DPOP_NONCE_MAX_AGE / 1e3, + requiredClaims: ['iat', 'exp', 'jti'], + }).catch((err) => { + const message = + err instanceof JOSEError + ? `Invalid DPoP proof (${err.message})` + : 'Invalid DPoP proof' + throw new InvalidDpopProofError(message, err) + }) + + if (!payload.jti || typeof payload.jti !== 'string') { + throw new InvalidDpopProofError('Invalid or missing jti property') + } + + if (payload.exp - payload.iat > DPOP_NONCE_MAX_AGE / 3 / 1e3) { + throw new InvalidDpopProofError('DPoP proof validity too long') + } + + // Note rfc9110#section-9.1 states that the method name is case-sensitive + if (!htm || htm !== payload['htm']) { + throw new InvalidDpopProofError('DPoP htm mismatch') + } + + if ( + payload['nonce'] !== undefined && + typeof payload['nonce'] !== 'string' + ) { + throw new InvalidDpopProofError('DPoP nonce must be a string') + } + + if (!payload['nonce'] && this.dpopNonce) { + throw new UseDpopNonceError() + } + + if (payload['nonce'] && !this.dpopNonce?.check(payload['nonce'])) { + throw new UseDpopNonceError() + } + + const htuNorm = normalizeHtu(htu) + if (!htuNorm || htuNorm !== normalizeHtu(payload['htu'])) { + throw new InvalidDpopProofError('DPoP htu mismatch') + } + + if (accessToken) { + const athBuffer = createHash('sha256').update(accessToken).digest() + if (payload['ath'] !== athBuffer.toString('base64url')) { + throw new InvalidDpopProofError('DPoP ath mismatch') + } + } else if (payload['ath']) { + throw new InvalidDpopProofError('DPoP ath not allowed') + } + + try { + return { + protectedHeader, + payload, + jkt: await calculateJwkThumbprint(protectedHeader['jwk']!, 'sha256'), // EmbeddedJWK + } + } catch (err) { + const message = + err instanceof JOSEError ? err.message : 'Failed to calculate jkt' + throw new InvalidDpopProofError(message, err) + } + } +} + +/** + * @note + * > The htu claim matches the HTTP URI value for the HTTP request in which the + * > JWT was received, ignoring any query and fragment parts. + * + * > To reduce the likelihood of false negatives, servers SHOULD employ + * > syntax-based normalization (Section 6.2.2 of [RFC3986]) and scheme-based + * > normalization (Section 6.2.3 of [RFC3986]) before comparing the htu claim. + * @see {@link https://datatracker.ietf.org/doc/html/rfc9449#section-4.3 | RFC9449 section 4.3. Checking DPoP Proofs} + */ +function normalizeHtu(htu: unknown): string | null { + // Optimization + if (!htu) return null + + try { + const url = new URL(String(htu)) + url.hash = '' + url.search = '' + return url.href + } catch { + return null + } +} diff --git a/packages/oauth/oauth-provider/src/dpop/dpop-nonce.ts b/packages/oauth/oauth-provider/src/dpop/dpop-nonce.ts new file mode 100644 index 00000000000..4a74e8af294 --- /dev/null +++ b/packages/oauth/oauth-provider/src/dpop/dpop-nonce.ts @@ -0,0 +1,104 @@ +import { createHmac, randomBytes } from 'node:crypto' + +import { DPOP_NONCE_MAX_AGE } from '../constants.js' + +function numTo64bits(num: number) { + const arr = new Uint8Array(8) + arr[7] = (num = num | 0) & 0xff + arr[6] = (num >>= 8) & 0xff + arr[5] = (num >>= 8) & 0xff + arr[4] = (num >>= 8) & 0xff + arr[3] = (num >>= 8) & 0xff + arr[2] = (num >>= 8) & 0xff + arr[1] = (num >>= 8) & 0xff + arr[0] = (num >>= 8) & 0xff + return arr +} + +export type DpopNonceInput = string | Uint8Array | DpopNonce + +export class DpopNonce { + #secret: Uint8Array + #counter: number + + #prev: string + #now: string + #next: string + + constructor( + protected readonly secret: Uint8Array, + protected readonly step: number, + ) { + if (secret.length !== 32) throw new TypeError('Expected 32 bytes') + if (this.step < 0 || this.step > DPOP_NONCE_MAX_AGE / 3) { + throw new TypeError('Invalid step') + } + + this.#secret = Uint8Array.from(secret) + this.#counter = (Date.now() / step) | 0 + + this.#prev = this.compute(this.#counter - 1) + this.#now = this.compute(this.#counter) + this.#next = this.compute(this.#counter + 1) + } + + protected rotate() { + const counter = (Date.now() / this.step) | 0 + switch (counter - this.#counter) { + case 0: + // counter === this.#counter => nothing to do + return + case 1: + // Optimization: avoid recomputing #prev & #now + this.#prev = this.#now + this.#now = this.#next + this.#next = this.compute(counter + 1) + break + case 2: + // Optimization: avoid recomputing #prev + this.#prev = this.#next + this.#now = this.compute(counter) + this.#next = this.compute(counter + 1) + break + default: + // All nonces are outdated, so we recompute all of them + this.#prev = this.compute(counter - 1) + this.#now = this.compute(counter) + this.#next = this.compute(counter + 1) + break + } + this.#counter = counter + } + + protected compute(counter: number) { + return createHmac('sha256', this.#secret) + .update(numTo64bits(counter)) + .digest() + .toString('base64url') + } + + public next() { + this.rotate() + return this.#next + } + + public check(nonce: string) { + return this.#next === nonce || this.#now === nonce || this.#prev === nonce + } + + static from( + input: DpopNonceInput = randomBytes(32), + step = DPOP_NONCE_MAX_AGE / 3, + ): DpopNonce { + if (input instanceof DpopNonce) { + return input + } + if (input instanceof Uint8Array) { + return new DpopNonce(input, step) + } + if (typeof input === 'string') { + return new DpopNonce(Buffer.from(input, 'hex'), step) + } + return new DpopNonce(input, step) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/access-denied-error.ts b/packages/oauth/oauth-provider/src/errors/access-denied-error.ts new file mode 100644 index 00000000000..19fd3dadac8 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/access-denied-error.ts @@ -0,0 +1,26 @@ +import { OAuthAuthenticationRequestParameters } from '@atproto/oauth-types' +import { buildErrorPayload } from '../output/build-error-payload.js' +import { OAuthError } from './oauth-error.js' + +export class AccessDeniedError extends OAuthError { + constructor( + public readonly parameters: OAuthAuthenticationRequestParameters, + error_description: string, + error = 'access_denied', + cause?: unknown, + ) { + super(error, error_description, 400, cause) + } + + static from( + parameters: OAuthAuthenticationRequestParameters, + cause?: unknown, + ) { + if (cause && cause instanceof AccessDeniedError) { + return cause + } + + const { error, error_description } = buildErrorPayload(cause) + return new AccessDeniedError(parameters, error_description, error, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/account-selection-required-error.ts b/packages/oauth/oauth-provider/src/errors/account-selection-required-error.ts new file mode 100644 index 00000000000..f7bafb1d102 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/account-selection-required-error.ts @@ -0,0 +1,12 @@ +import { OAuthAuthenticationRequestParameters } from '@atproto/oauth-types' +import { AccessDeniedError } from './access-denied-error.js' + +export class AccountSelectionRequiredError extends AccessDeniedError { + constructor( + parameters: OAuthAuthenticationRequestParameters, + error_description = 'Account selection required', + cause?: unknown, + ) { + super(parameters, error_description, 'account_selection_required', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/consent-required-error.ts b/packages/oauth/oauth-provider/src/errors/consent-required-error.ts new file mode 100644 index 00000000000..8883e863b06 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/consent-required-error.ts @@ -0,0 +1,12 @@ +import { OAuthAuthenticationRequestParameters } from '@atproto/oauth-types' +import { AccessDeniedError } from './access-denied-error.js' + +export class ConsentRequiredError extends AccessDeniedError { + constructor( + parameters: OAuthAuthenticationRequestParameters, + error_description = 'User consent required', + cause?: unknown, + ) { + super(parameters, error_description, 'consent_required', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-authorization-details-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-authorization-details-error.ts new file mode 100644 index 00000000000..89d9b0732dc --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-authorization-details-error.ts @@ -0,0 +1,22 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc9396#section-14.6 | RFC 9396 - OAuth Dynamic Client Registration Metadata Registration Error} + * + * The AS MUST refuse to process any unknown authorization details type or + * authorization details not conforming to the respective type definition. The + * AS MUST abort processing and respond with an error + * invalid_authorization_details to the client if any of the following are true + * of the objects in the authorization_details structure: + * - contains an unknown authorization details type value, + * - is an object of known type but containing unknown fields, + * - contains fields of the wrong type for the authorization details type, + * - contains fields with invalid values for the authorization details type, or + * - is missing required fields for the authorization details type. + */ +export class InvalidAuthorizationDetailsError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_authorization_details', error_description, 400, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-client-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-client-error.ts new file mode 100644 index 00000000000..328d67b2175 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-client-error.ts @@ -0,0 +1,20 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 | RFC6749 - Issuing an Access Token } + * + * Client authentication failed (e.g., unknown client, no client authentication + * included, or unsupported authentication method). The authorization server MAY + * return an HTTP 401 (Unauthorized) status code to indicate which HTTP + * authentication schemes are supported. If the client attempted to + * authenticate via the "Authorization" request header field, the authorization + * server MUST respond with an HTTP 401 (Unauthorized) status code and include + * the "WWW-Authenticate" response header field matching the authentication + * scheme used by the client. + */ +export class InvalidClientError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_client', error_description, 400, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-client-id-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-client-id-error.ts new file mode 100644 index 00000000000..9a400406451 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-client-id-error.ts @@ -0,0 +1,20 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2 | RFC7591 - Client Registration Error Response} + * + * The value of one of the client metadata fields is invalid and the server has + * rejected this request. Note that an authorization server MAY choose to + * substitute a valid value for any requested parameter of a client's metadata. + */ +export class InvalidClientIdError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_client_id', error_description, 400, cause) + } + + static from(err: unknown): InvalidClientIdError { + if (err instanceof InvalidClientIdError) return err + if (err instanceof TypeError) return new InvalidClientIdError(err.message) + return new InvalidClientIdError('Invalid client identifier', err) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-client-metadata-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-client-metadata-error.ts new file mode 100644 index 00000000000..3ba9b204363 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-client-metadata-error.ts @@ -0,0 +1,19 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2 | RFC7591 - Client Registration Error Response} + * + * The value of one of the client metadata fields is invalid and the server has + * rejected this request. Note that an authorization server MAY choose to + * substitute a valid value for any requested parameter of a client's metadata. + */ +export class InvalidClientMetadataError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_client_metadata', error_description, 400, cause) + } + + static from(cause: unknown): InvalidClientMetadataError { + if (cause instanceof InvalidClientMetadataError) return cause + return new InvalidClientMetadataError('Invalid client configuration', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-dpop-key-binding-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-dpop-key-binding-error.ts new file mode 100644 index 00000000000..38cd8bac489 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-dpop-key-binding-error.ts @@ -0,0 +1,21 @@ +import { WWWAuthenticateError } from './www-authenticate-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 | RFC6750 - The WWW-Authenticate Response Header Field} + * + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc9449#name-the-dpop-authentication-sch | RFC9449 - The DPoP Authentication Scheme} + */ +export class InvalidDpopKeyBindingError extends WWWAuthenticateError { + constructor(cause?: unknown) { + const error = 'invalid_token' + const error_description = 'Invalid DPoP key binding' + super( + error, + error_description, + { DPoP: { error, error_description } }, + cause, + ) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-dpop-proof-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-dpop-proof-error.ts new file mode 100644 index 00000000000..d8826fd6a90 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-dpop-proof-error.ts @@ -0,0 +1,13 @@ +import { WWWAuthenticateError } from './www-authenticate-error.js' + +export class InvalidDpopProofError extends WWWAuthenticateError { + constructor(error_description: string, cause?: unknown) { + const error = 'invalid_dpop_proof' + super( + error, + error_description, + { DPoP: { error, error_description } }, + cause, + ) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-grant-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-grant-error.ts new file mode 100644 index 00000000000..a4e3838db38 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-grant-error.ts @@ -0,0 +1,16 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 | RFC6749 - Issuing an Access Token } + * + * The provided authorization grant (e.g., authorization code, resource owner + * credentials) or refresh token is invalid, expired, revoked, does not match + * the redirection URI used in the authorization request, or was issued to + * another client. + */ +export class InvalidGrantError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_grant', error_description, 400, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-parameters-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-parameters-error.ts new file mode 100644 index 00000000000..564e109eb06 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-parameters-error.ts @@ -0,0 +1,12 @@ +import { OAuthAuthenticationRequestParameters } from '@atproto/oauth-types' +import { AccessDeniedError } from './access-denied-error.js' + +export class InvalidParametersError extends AccessDeniedError { + constructor( + parameters: OAuthAuthenticationRequestParameters, + error_description: string, + cause?: unknown, + ) { + super(parameters, error_description, 'invalid_request', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-redirect-uri-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-redirect-uri-error.ts new file mode 100644 index 00000000000..bd10e4e7da3 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-redirect-uri-error.ts @@ -0,0 +1,17 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see {@link https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.2 | RFC7591} + * + * The value of one or more redirection URIs is invalid. + */ +export class InvalidRedirectUriError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_redirect_uri', error_description, 400, cause) + } + + static from(cause?: unknown): InvalidRedirectUriError { + if (cause instanceof InvalidRedirectUriError) return cause + return new InvalidRedirectUriError('Invalid redirect URI', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-request-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-request-error.ts new file mode 100644 index 00000000000..fd1dabbbf95 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-request-error.ts @@ -0,0 +1,30 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 | RFC6749 - Issuing an Access Token } + * + * The request is missing a required parameter, includes an unsupported + * parameter value (other than grant type), repeats a parameter, includes + * multiple credentials, utilizes more than one mechanism for authenticating the + * client, or is otherwise malformed. + * + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 | RFC6749 - Authorization Code Grant, Authorization Request} + * + * The request is missing a required parameter, includes an invalid parameter + * value, includes a parameter more than once, or is otherwise malformed. + * + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 | RFC6750 - The WWW-Authenticate Response Header Field } + * + * The request is missing a required parameter, includes an unsupported + * parameter or parameter value, repeats the same parameter, uses more than one + * method for including an access token, or is otherwise malformed. The resource + * server SHOULD respond with the HTTP 400 (Bad Request) status code. + */ +export class InvalidRequestError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('invalid_request', error_description, 400, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/invalid-token-error.ts b/packages/oauth/oauth-provider/src/errors/invalid-token-error.ts new file mode 100644 index 00000000000..91c973f3cdc --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/invalid-token-error.ts @@ -0,0 +1,59 @@ +import { JwtVerifyError } from '@atproto/jwk' +import { JOSEError } from 'jose/errors' +import { ZodError } from 'zod' + +import { OAuthError } from './oauth-error.js' +import { WWWAuthenticateError } from './www-authenticate-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 | RFC6750 - The WWW-Authenticate Response Header Field } + * + * The access token provided is expired, revoked, malformed, or invalid for + * other reasons. The resource SHOULD respond with the HTTP 401 (Unauthorized) + * status code. The client MAY request a new access token and retry the + * protected resource request. + */ +export class InvalidTokenError extends WWWAuthenticateError { + static from( + err: unknown, + tokenType: string, + fallbackMessage = 'Invalid token', + ): InvalidTokenError { + if (err instanceof InvalidTokenError) { + return err + } + + if (err instanceof OAuthError) { + return new InvalidTokenError(tokenType, err.error_description, err) + } + + if (err instanceof JOSEError) { + return new InvalidTokenError(tokenType, err.message, err) + } + + if (err instanceof JwtVerifyError) { + return new InvalidTokenError(tokenType, err.message, err) + } + + if (err instanceof ZodError) { + return new InvalidTokenError(tokenType, err.message, err) + } + + return new InvalidTokenError(tokenType, fallbackMessage, err) + } + + constructor( + readonly tokenType: string, + error_description: string, + cause?: unknown, + ) { + const error = 'invalid_token' + super( + error, + error_description, + { [tokenType]: { error, error_description } }, + cause, + ) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/login-required-error.ts b/packages/oauth/oauth-provider/src/errors/login-required-error.ts new file mode 100644 index 00000000000..4c2ea382efa --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/login-required-error.ts @@ -0,0 +1,12 @@ +import { OAuthAuthenticationRequestParameters } from '@atproto/oauth-types' +import { AccessDeniedError } from './access-denied-error.js' + +export class LoginRequiredError extends AccessDeniedError { + constructor( + parameters: OAuthAuthenticationRequestParameters, + error_description = 'Login is required', + cause?: unknown, + ) { + super(parameters, error_description, 'login_required', cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/oauth-error.ts b/packages/oauth/oauth-provider/src/errors/oauth-error.ts new file mode 100644 index 00000000000..9b14de44a5b --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/oauth-error.ts @@ -0,0 +1,28 @@ +export class OAuthError extends Error { + public expose: boolean + + constructor( + public readonly error: string, + public readonly error_description: string, + public readonly status = 400, + cause?: unknown, + ) { + super(error_description, { cause }) + + Error.captureStackTrace?.(this, this.constructor) + + this.name = this.constructor.name + this.expose = status < 500 + } + + get statusCode() { + return this.status + } + + toJSON() { + return { + error: this.error, + error_description: this.error_description, + } as const + } +} diff --git a/packages/oauth/oauth-provider/src/errors/unauthorized-client-error.ts b/packages/oauth/oauth-provider/src/errors/unauthorized-client-error.ts new file mode 100644 index 00000000000..6f28e5870ec --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/unauthorized-client-error.ts @@ -0,0 +1,20 @@ +import { OAuthError } from './oauth-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 | RFC6749 - Issuing an Access Token } + * + * The authenticated client is not authorized to use this authorization grant + * type. + * + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 | RFC6749 - Authorization Code Grant, Authorization Request} + * + * The client is not authorized to request an authorization code using this + * method. + */ +export class UnauthorizedClientError extends OAuthError { + constructor(error_description: string, cause?: unknown) { + super('unauthorized_client', error_description, 400, cause) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/use-dpop-nonce-error.ts b/packages/oauth/oauth-provider/src/errors/use-dpop-nonce-error.ts new file mode 100644 index 00000000000..034ace4825e --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/use-dpop-nonce-error.ts @@ -0,0 +1,32 @@ +import { OAuthError } from './oauth-error.js' +import { WWWAuthenticateError } from './www-authenticate-error.js' + +/** + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc9449#section-8 | RFC9449 - Section 8. Authorization Server-Provided Nonce} + */ +export class UseDpopNonceError extends OAuthError { + constructor( + error_description = 'Authorization server requires nonce in DPoP proof', + cause?: unknown, + ) { + super('use_dpop_nonce', error_description, 400, cause) + } + + /** + * Convert this error into an error meant to be used as "Resource + * Server-Provided Nonce" error. + * + * @see + * {@link https://datatracker.ietf.org/doc/html/rfc9449#section-9 | RFC9449 - Section 9. Resource Server-Provided Nonce} + */ + toWwwAuthenticateError(): WWWAuthenticateError { + const { error, error_description } = this + return new WWWAuthenticateError( + error, + error_description, + { DPoP: { error, error_description } }, + this, + ) + } +} diff --git a/packages/oauth/oauth-provider/src/errors/www-authenticate-error.ts b/packages/oauth/oauth-provider/src/errors/www-authenticate-error.ts new file mode 100644 index 00000000000..c781d6212e2 --- /dev/null +++ b/packages/oauth/oauth-provider/src/errors/www-authenticate-error.ts @@ -0,0 +1,65 @@ +import { VERIFY_ALGOS } from '../lib/util/crypto.js' + +import { OAuthError } from './oauth-error.js' + +export type WWWAuthenticateParams = Record +export type WWWAuthenticate = Record + +export class WWWAuthenticateError extends OAuthError { + public readonly wwwAuthenticate: WWWAuthenticate + + constructor( + error: string, + error_description: string, + wwwAuthenticate: WWWAuthenticate, + cause?: unknown, + ) { + super(error, error_description, 401, cause) + + this.wwwAuthenticate = + wwwAuthenticate['DPoP'] != null + ? { + ...wwwAuthenticate, + DPoP: { algs: VERIFY_ALGOS.join(' '), ...wwwAuthenticate['DPoP'] }, + } + : wwwAuthenticate + } + + get wwwAuthenticateHeader() { + return formatWWWAuthenticateHeader(this.wwwAuthenticate) + } +} + +function formatWWWAuthenticateHeader(wwwAuthenticate: WWWAuthenticate): string { + return Object.entries(wwwAuthenticate) + .filter(isWWWAuthenticateEntry) + .map(wwwAuthenticateEntryToString) + .join(', ') +} + +type WWWAuthenticateEntry = [type: string, params: WWWAuthenticateParams] +function isWWWAuthenticateEntry( + entry: [string, unknown], +): entry is WWWAuthenticateEntry { + const [, value] = entry + return value != null && typeof value === 'object' +} + +function wwwAuthenticateEntryToString([type, params]: WWWAuthenticateEntry) { + const paramsEnc = Object.entries(params) + .filter(isParamEntry) + .map(paramEntryToString) + + return paramsEnc.length ? `${type} ${paramsEnc.join(', ')}` : type +} + +type ParamEntry = [name: string, value: string] + +function isParamEntry(entry: [string, unknown]): entry is ParamEntry { + const [, value] = entry + return typeof value === 'string' && value !== '' && !value.includes('"') +} + +function paramEntryToString([name, value]: ParamEntry): string { + return `${name}="${value}"` +} diff --git a/packages/oauth/oauth-provider/src/index.ts b/packages/oauth/oauth-provider/src/index.ts new file mode 100644 index 00000000000..4a152bbb2e2 --- /dev/null +++ b/packages/oauth/oauth-provider/src/index.ts @@ -0,0 +1,15 @@ +// Avoid having to explicitly depend sub dependencies +export * from '@atproto-labs/fetch' +export * from '@atproto-labs/fetch-node' +export * from '@atproto/jwk' +export * from '@atproto/jwk-jose' +export * from '@atproto/oauth-types' + +export * from './constants.js' +export * from './oauth-client.js' +export * from './oauth-dpop.js' +export * from './oauth-errors.js' +export * from './oauth-hooks.js' +export * from './oauth-provider.js' +export * from './oauth-store.js' +export * from './oauth-verifier.js' diff --git a/packages/oauth/oauth-provider/src/lib/html/README.md b/packages/oauth/oauth-provider/src/lib/html/README.md new file mode 100644 index 00000000000..2020c57e13d --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/README.md @@ -0,0 +1,9 @@ +# Safe HTML generation and concatenation utility + +This library provides a safe way to generate and concatenate HTML strings. + +This code _could_ be used as a standalone library, but the Bluesky dev team does +not want to maintain it as such. As it is currently only used by the +`@atproto/oauth-provider` package, it is included here. Future development +should aim to keep this library independent of the rest of the +`@atproto/oauth-provider` package, so that it can be extracted and published. diff --git a/packages/oauth/oauth-provider/src/lib/html/build-document.ts b/packages/oauth/oauth-provider/src/lib/html/build-document.ts new file mode 100644 index 00000000000..f3ddfeb81b2 --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/build-document.ts @@ -0,0 +1,98 @@ +import { HtmlValue } from './escapers.js' +import { Html } from './html.js' +import { html } from './tags.js' + +export type AssetRef = { + url: string + sha256: string +} + +export type Attrs = Record +export type LinkAttrs = { href: string } & Attrs +export type MetaAttrs = + | { name: string; content: string } + | { 'http-equiv': string; content: string } + +const defaultViewport = html`` + +export type BuildDocumentOptions = { + htmlAttrs?: Attrs + base?: URL + meta?: readonly MetaAttrs[] + links?: readonly LinkAttrs[] + head?: HtmlValue + title?: HtmlValue + scripts?: readonly (Html | AssetRef)[] + styles?: readonly (Html | AssetRef)[] + body: HtmlValue + bodyAttrs?: Attrs +} + +export const buildDocument = ({ + htmlAttrs, + head, + title, + body, + bodyAttrs, + base, + meta, + links, + scripts, + styles, +}: BuildDocumentOptions) => html` + + + + ${title && html`${title}`} + ${base && html``} + ${meta?.some(isViewportMeta) ? null : defaultViewport} + ${meta?.map(metaToHtml)} + ${links?.map(linkToHtml)} + ${head} ${styles?.map(styleToHtml)} + + + ${body} ${scripts?.map(scriptToHtml)} + +` + +function isViewportMeta( + attrs: T, +): attrs is T & { name: 'viewport' } { + return 'name' in attrs && attrs.name === 'viewport' +} + +function* linkToHtml(attrs: LinkAttrs) { + yield html`` +} + +function* metaToHtml(attrs: MetaAttrs) { + yield html`` +} + +function* attrsToHtml(attrs?: Attrs) { + if (attrs) { + for (const [name, value] of Object.entries(attrs)) { + if (value == null) continue + else if (value === false) continue + else if (value === true) yield html` ${name}` + else yield html` ${name}="${value}"` + } + } +} + +function* scriptToHtml(script: Html | AssetRef) { + yield script instanceof Html + ? // prettier-ignore + html`` // hash validity requires no space around the content + : html`` +} + +function* styleToHtml(style: Html | AssetRef) { + yield style instanceof Html + ? // prettier-ignore + html`` // hash validity requires no space around the content + : html`` +} diff --git a/packages/oauth/oauth-provider/src/lib/html/escapers.ts b/packages/oauth/oauth-provider/src/lib/html/escapers.ts new file mode 100644 index 00000000000..ec3cb6c546f --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/escapers.ts @@ -0,0 +1,66 @@ +import { Html } from './html.js' +import { NestedIterable, stringReplacer } from './util.js' + +export function* javascriptEscaper(code: string) { + // "" can only appear in javascript strings, so we can safely escape + // the "<" without breaking the javascript. + yield* stringReplacer(code, '', '\\u003c/script>') +} + +export function* jsonEscaper(value: unknown) { + // https://redux.js.org/usage/server-rendering#security-considerations + const json = JSON.stringify(value) + if (json === undefined) throw new TypeError('Cannot serialize to JSON') + // "<" can only appear in JSON strings, so we can safely escape it without + // breaking the JSON. + yield* stringReplacer(json, '<', '\\u003c') +} + +export function* cssEscaper(css: string) { + yield* stringReplacer(css, '', '\\u003c/style>') +} + +export type HtmlVariable = Html | string | number | null | undefined +export type HtmlValue = NestedIterable + +export function* htmlEscaper( + htmlFragments: TemplateStringsArray, + values: readonly HtmlValue[], +): Generator { + for (let i = 0; i < htmlFragments.length; i++) { + yield htmlFragments[i]! + + const value = values[i] + if (value != null) yield* htmlVariableToFragments(value) + } +} + +function* htmlVariableToFragments( + value: HtmlValue, +): Generator { + if (value == null) { + return + } else if (typeof value === 'number') { + yield String(value) + } else if (typeof value === 'string') { + yield encode(value) + } else if (value instanceof Html) { + yield value + } else { + // Will throw if the value is not an iterable + for (const v of value) yield* htmlVariableToFragments(v) + } +} + +const specialCharRegExp = /[<>"'&]/g +const specialCharMap = new Map([ + ['<', '<'], + ['>', '>'], + ['"', '"'], + ["'", '''], + ['&', '&'], +]) +const specialCharMapGet = (c: string) => specialCharMap.get(c)! +function encode(value: string): string { + return value.replace(specialCharRegExp, specialCharMapGet) +} diff --git a/packages/oauth/oauth-provider/src/lib/html/html.ts b/packages/oauth/oauth-provider/src/lib/html/html.ts new file mode 100644 index 00000000000..c490d181dbd --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/html.ts @@ -0,0 +1,61 @@ +import { isString } from './util' + +const symbol = Symbol('Html.dangerouslyCreate') + +/** + * This class represents trusted HTML that can be safely embedded in a web page, + * or used as fragments to build a larger HTML document. + */ +export class Html { + #fragments: Iterable + + private constructor(fragments: Iterable, guard: symbol) { + if (guard !== symbol) { + // Force developers to use `Html.dangerouslyCreate` to create an Html + // instance, to make it clear that the content needs to be trusted. + throw new TypeError( + 'Use Html.dangerouslyCreate() to create an Html instance', + ) + } + + this.#fragments = fragments + } + + toString(): string { + // Lazily compute & join the fragments when they are used, to avoid + // unnecessary intermediate strings when concatenating multiple Html as + // fragments. + if ( + !Array.isArray(this.#fragments) || + this.#fragments.length > 1 || + !this.#fragments.every(isString) + ) { + // Will call `toString` recursively, as well as generating iterator + // results. + const fragment = Array.from(this.#fragments, String).join('') + this.#fragments = [fragment] // Cache result for future calls + return fragment + } + + return this.#fragments.join('') + } + + [Symbol.toPrimitive](hint): string { + switch (hint) { + case 'string': + case 'default': + return this.toString() + default: + throw new TypeError(`Cannot convert Html to a ${hint}`) + } + } + + *[Symbol.iterator](): IterableIterator { + // Using toString() here to use the optimized path for string concatenation + yield this.toString() + } + + static dangerouslyCreate(fragments: Iterable): Html { + return new Html(fragments, symbol) + } +} diff --git a/packages/oauth/oauth-provider/src/lib/html/index.ts b/packages/oauth/oauth-provider/src/lib/html/index.ts new file mode 100644 index 00000000000..bd308d8dd99 --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/index.ts @@ -0,0 +1,5 @@ +export * from './html.js' +export * from './tags.js' + +// Extra util +export * from './build-document.js' diff --git a/packages/oauth/oauth-provider/src/lib/html/tags.ts b/packages/oauth/oauth-provider/src/lib/html/tags.ts new file mode 100644 index 00000000000..bbdcd2027e8 --- /dev/null +++ b/packages/oauth/oauth-provider/src/lib/html/tags.ts @@ -0,0 +1,58 @@ +import { + HtmlValue, + cssEscaper, + htmlEscaper, + javascriptEscaper, + jsonEscaper, +} from './escapers.js' +import { Html } from './html.js' + +export { type HtmlValue } +export const html = ( + tpl: TemplateStringsArray, + ...val: readonly HtmlValue[] +) => + tpl.length === 1 && val.length === 0 + ? // Optimization for static HTML, avoid creating an iterable + Html.dangerouslyCreate(tpl) + : Html.dangerouslyCreate(htmlEscaper(tpl, val)) + +/** + * Escapes code to use as a JavaScript string inside a `