From cbc8666401ded1e1d4638ea218d307eff9c01ae2 Mon Sep 17 00:00:00 2001 From: hunsh <hunsh.cn@gmail.com> Date: Wed, 19 Jan 2022 13:14:09 +0800 Subject: [PATCH] feat: Optimize 302 processing https://github.com/hunshcn/gh-proxy/issues/47 --- app/main.py | 96 ++++++++++++++++++++++++++++++----------------------- index.js | 51 ++++++++++++++++------------ 2 files changed, 84 insertions(+), 63 deletions(-) diff --git a/app/main.py b/app/main.py index f6149829..68046529 100644 --- a/app/main.py +++ b/app/main.py @@ -105,30 +105,36 @@ def generate(): return chunks +def check_url(u): + for exp in (exp1, exp2, exp3, exp4, exp5): + m = exp.match(u) + if m: + return m + return False + + @app.route('/<path:u>', methods=['GET', 'POST']) -def proxy(u): +def handler(u): u = u if u.startswith('http') else 'https://' + u if u.rfind('://', 3, 9) == -1: u = u.replace('s:/', 's://', 1) # uwsgi会将//传递为/ pass_by = False - for exp in (exp1, exp2, exp3, exp4, exp5): - m = exp.match(u) - if m: - m = tuple(m.groups()) - if white_list: - for i in white_list: - if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]: - break - else: - return Response('Forbidden by white list.', status=403) - for i in black_list: - if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]: - return Response('Forbidden by black list.', status=403) - for i in pass_list: + m = check_url(u) + if m: + m = tuple(m.groups()) + if white_list: + for i in white_list: if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]: - pass_by = True break - break + else: + return Response('Forbidden by white list.', status=403) + for i in black_list: + if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]: + return Response('Forbidden by black list.', status=403) + for i in pass_list: + if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]: + pass_by = True + break else: return Response('Invalid input.', status=403) @@ -146,31 +152,39 @@ def proxy(u): else: if exp2.match(u): u = u.replace('/blob/', '/raw/', 1) - headers = {} - r_headers = dict(request.headers) - if 'Host' in r_headers: - r_headers.pop('Host') - try: - url = u + request.url.replace(request.base_url, '', 1) - if url.startswith('https:/') and not url.startswith('https://'): - url = 'https://' + url[7:] - r = requests.request(method=request.method, url=url, data=request.data, headers=r_headers, stream=True) - headers = dict(r.headers) - - if 'Content-length' in r.headers and int(r.headers['Content-length']) > size_limit: - return redirect(u + request.url.replace(request.base_url, '', 1)) - - def generate(): - for chunk in iter_content(r, chunk_size=CHUNK_SIZE): - yield chunk - - return Response(generate(), headers=headers, status=r.status_code) - except Exception as e: - headers['content-type'] = 'text/html; charset=UTF-8' - return Response('server error ' + str(e), status=500, headers=headers) - # else: - # return Response('Illegal input', status=403, mimetype='text/html; charset=UTF-8') + return proxy(u) + + +def proxy(u, allow_redirects=False): + headers = {} + r_headers = dict(request.headers) + if 'Host' in r_headers: + r_headers.pop('Host') + try: + url = u + request.url.replace(request.base_url, '', 1) + if url.startswith('https:/') and not url.startswith('https://'): + url = 'https://' + url[7:] + r = requests.request(method=request.method, url=url, data=request.data, headers=r_headers, stream=True, allow_redirects=allow_redirects) + headers = dict(r.headers) + + if 'Content-length' in r.headers and int(r.headers['Content-length']) > size_limit: + return redirect(u + request.url.replace(request.base_url, '', 1)) + + def generate(): + for chunk in iter_content(r, chunk_size=CHUNK_SIZE): + yield chunk + if 'Location' in r.headers: + _location = r.headers.get('Location') + if check_url(_location): + headers['Location'] = '/' + _location + else: + return proxy(_location, True) + + return Response(generate(), headers=headers, status=r.status_code) + except Exception as e: + headers['content-type'] = 'text/html; charset=UTF-8' + return Response('server error ' + str(e), status=500, headers=headers) app.debug = True if __name__ == '__main__': diff --git a/index.js b/index.js index cad96dea..71338b51 100644 --- a/index.js +++ b/index.js @@ -22,6 +22,13 @@ const PREFLIGHT_INIT = { }), } + +const exp1 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:releases|archive|suites)\/.*$/i +const exp2 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:blob|raw)\/.*$/i +const exp3 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:info|git-).*$/i +const exp4 = /^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+?\/.+$/i +const exp5 = /^(?:https?:\/\/)?gist\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+$/i + /** * @param {any} body * @param {number} status @@ -52,6 +59,15 @@ addEventListener('fetch', e => { }) +function checkUrl(u) { + for (let i of [exp1, exp2, exp3, exp4, exp5, ]) { + if (u.search(i) === 0) { + return true + } + } + return false +} + /** * @param {FetchEvent} e */ @@ -65,18 +81,13 @@ async function fetchHandler(e) { } // cfworker 会把路径中的 `//` 合并成 `/` path = urlObj.href.substr(urlObj.origin.length + PREFIX.length).replace(/^https?:\/+/, 'https://') - const exp1 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:releases|archive|suites)\/.*$/i - const exp2 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:blob|raw)\/.*$/i - const exp3 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:info|git-).*$/i - const exp4 = /^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+?\/.+$/i - const exp5 = /^(?:https?:\/\/)?gist\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+$/i if (path.search(exp1) === 0 || path.search(exp5) === 0 || !Config.cnpmjs && (path.search(exp3) === 0 || path.search(exp4) === 0)) { return httpHandler(req, path) } else if (path.search(exp2) === 0) { - if (Config.jsdelivr){ + if (Config.jsdelivr) { const newUrl = path.replace('/blob/', '@').replace(/^(?:https?:\/\/)?github\.com/, 'https://cdn.jsdelivr.net/gh') return Response.redirect(newUrl, 302) - }else{ + } else { path = path.replace('/blob/', '/raw/') return httpHandler(req, path) } @@ -106,8 +117,6 @@ function httpHandler(req, pathname) { return new Response(null, PREFLIGHT_INIT) } - let rawLen = '' - const reqHdrNew = new Headers(reqHdrRaw) let urlStr = pathname @@ -120,10 +129,10 @@ function httpHandler(req, pathname) { const reqInit = { method: req.method, headers: reqHdrNew, - redirect: 'follow', + redirect: 'manual', body: req.body } - return proxy(urlObj, reqInit, rawLen, 0) + return proxy(urlObj, reqInit) } @@ -132,24 +141,22 @@ function httpHandler(req, pathname) { * @param {URL} urlObj * @param {RequestInit} reqInit */ -async function proxy(urlObj, reqInit, rawLen) { +async function proxy(urlObj, reqInit) { const res = await fetch(urlObj.href, reqInit) const resHdrOld = res.headers const resHdrNew = new Headers(resHdrOld) - // verify - if (rawLen) { - const newLen = resHdrOld.get('content-length') || '' - const badLen = (rawLen !== newLen) + const status = res.status - if (badLen) { - return makeRes(res.body, 400, { - '--error': `bad len: ${newLen}, except: ${rawLen}`, - 'access-control-expose-headers': '--error', - }) + if (resHdrNew.has('location')) { + let _location = resHdrNew.get('location') + if (checkUrl(_location)) + resHdrNew.set('location', PREFIX + _location) + else { + reqInit.redirect = 'follow' + return proxy(newUrl(_location), reqInit) } } - const status = res.status resHdrNew.set('access-control-expose-headers', '*') resHdrNew.set('access-control-allow-origin', '*')