From cbc8666401ded1e1d4638ea218d307eff9c01ae2 Mon Sep 17 00:00:00 2001
From: hunsh <hunsh.cn@gmail.com>
Date: Wed, 19 Jan 2022 13:14:09 +0800
Subject: [PATCH] feat: Optimize 302 processing
 https://github.com/hunshcn/gh-proxy/issues/47

---
 app/main.py | 96 ++++++++++++++++++++++++++++++-----------------------
 index.js    | 51 ++++++++++++++++------------
 2 files changed, 84 insertions(+), 63 deletions(-)

diff --git a/app/main.py b/app/main.py
index f6149829..68046529 100644
--- a/app/main.py
+++ b/app/main.py
@@ -105,30 +105,36 @@ def generate():
     return chunks
 
 
+def check_url(u):
+    for exp in (exp1, exp2, exp3, exp4, exp5):
+        m = exp.match(u)
+        if m:
+            return m
+    return False
+
+
 @app.route('/<path:u>', methods=['GET', 'POST'])
-def proxy(u):
+def handler(u):
     u = u if u.startswith('http') else 'https://' + u
     if u.rfind('://', 3, 9) == -1:
         u = u.replace('s:/', 's://', 1)  # uwsgi会将//传递为/
     pass_by = False
-    for exp in (exp1, exp2, exp3, exp4, exp5):
-        m = exp.match(u)
-        if m:
-            m = tuple(m.groups())
-            if white_list:
-                for i in white_list:
-                    if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]:
-                        break
-                else:
-                    return Response('Forbidden by white list.', status=403)
-            for i in black_list:
-                if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]:
-                    return Response('Forbidden by black list.', status=403)
-            for i in pass_list:
+    m = check_url(u)
+    if m:
+        m = tuple(m.groups())
+        if white_list:
+            for i in white_list:
                 if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]:
-                    pass_by = True
                     break
-            break
+            else:
+                return Response('Forbidden by white list.', status=403)
+        for i in black_list:
+            if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]:
+                return Response('Forbidden by black list.', status=403)
+        for i in pass_list:
+            if m[:len(i)] == i or i[0] == '*' and len(m) == 2 and m[1] == i[1]:
+                pass_by = True
+                break
     else:
         return Response('Invalid input.', status=403)
 
@@ -146,31 +152,39 @@ def proxy(u):
     else:
         if exp2.match(u):
             u = u.replace('/blob/', '/raw/', 1)
-        headers = {}
-        r_headers = dict(request.headers)
-        if 'Host' in r_headers:
-            r_headers.pop('Host')
-        try:
-            url = u + request.url.replace(request.base_url, '', 1)
-            if url.startswith('https:/') and not url.startswith('https://'):
-                url = 'https://' + url[7:]
-            r = requests.request(method=request.method, url=url, data=request.data, headers=r_headers, stream=True)
-            headers = dict(r.headers)
-
-            if 'Content-length' in r.headers and int(r.headers['Content-length']) > size_limit:
-                return redirect(u + request.url.replace(request.base_url, '', 1))
-
-            def generate():
-                for chunk in iter_content(r, chunk_size=CHUNK_SIZE):
-                    yield chunk
-
-            return Response(generate(), headers=headers, status=r.status_code)
-        except Exception as e:
-            headers['content-type'] = 'text/html; charset=UTF-8'
-            return Response('server error ' + str(e), status=500, headers=headers)
-    # else:
-    #     return Response('Illegal input', status=403, mimetype='text/html; charset=UTF-8')
+        return proxy(u)
+
+
+def proxy(u, allow_redirects=False):
+    headers = {}
+    r_headers = dict(request.headers)
+    if 'Host' in r_headers:
+        r_headers.pop('Host')
+    try:
+        url = u + request.url.replace(request.base_url, '', 1)
+        if url.startswith('https:/') and not url.startswith('https://'):
+            url = 'https://' + url[7:]
+        r = requests.request(method=request.method, url=url, data=request.data, headers=r_headers, stream=True, allow_redirects=allow_redirects)
+        headers = dict(r.headers)
+
+        if 'Content-length' in r.headers and int(r.headers['Content-length']) > size_limit:
+            return redirect(u + request.url.replace(request.base_url, '', 1))
+
+        def generate():
+            for chunk in iter_content(r, chunk_size=CHUNK_SIZE):
+                yield chunk
 
+        if 'Location' in r.headers:
+            _location = r.headers.get('Location')
+            if check_url(_location):
+                headers['Location'] = '/' + _location
+            else:
+                return proxy(_location, True)
+
+        return Response(generate(), headers=headers, status=r.status_code)
+    except Exception as e:
+        headers['content-type'] = 'text/html; charset=UTF-8'
+        return Response('server error ' + str(e), status=500, headers=headers)
 
 app.debug = True
 if __name__ == '__main__':
diff --git a/index.js b/index.js
index cad96dea..71338b51 100644
--- a/index.js
+++ b/index.js
@@ -22,6 +22,13 @@ const PREFLIGHT_INIT = {
     }),
 }
 
+
+const exp1 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:releases|archive|suites)\/.*$/i
+const exp2 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:blob|raw)\/.*$/i
+const exp3 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:info|git-).*$/i
+const exp4 = /^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+?\/.+$/i
+const exp5 = /^(?:https?:\/\/)?gist\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+$/i
+
 /**
  * @param {any} body
  * @param {number} status
@@ -52,6 +59,15 @@ addEventListener('fetch', e => {
 })
 
 
+function checkUrl(u) {
+    for (let i of [exp1, exp2, exp3, exp4, exp5, ]) {
+        if (u.search(i) === 0) {
+            return true
+        }
+    }
+    return false
+}
+
 /**
  * @param {FetchEvent} e
  */
@@ -65,18 +81,13 @@ async function fetchHandler(e) {
     }
     // cfworker 会把路径中的 `//` 合并成 `/`
     path = urlObj.href.substr(urlObj.origin.length + PREFIX.length).replace(/^https?:\/+/, 'https://')
-    const exp1 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:releases|archive|suites)\/.*$/i
-    const exp2 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:blob|raw)\/.*$/i
-    const exp3 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:info|git-).*$/i
-    const exp4 = /^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+?\/.+$/i
-    const exp5 = /^(?:https?:\/\/)?gist\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+$/i
     if (path.search(exp1) === 0 || path.search(exp5) === 0 || !Config.cnpmjs && (path.search(exp3) === 0 || path.search(exp4) === 0)) {
         return httpHandler(req, path)
     } else if (path.search(exp2) === 0) {
-        if (Config.jsdelivr){
+        if (Config.jsdelivr) {
             const newUrl = path.replace('/blob/', '@').replace(/^(?:https?:\/\/)?github\.com/, 'https://cdn.jsdelivr.net/gh')
             return Response.redirect(newUrl, 302)
-        }else{
+        } else {
             path = path.replace('/blob/', '/raw/')
             return httpHandler(req, path)
         }
@@ -106,8 +117,6 @@ function httpHandler(req, pathname) {
         return new Response(null, PREFLIGHT_INIT)
     }
 
-    let rawLen = ''
-
     const reqHdrNew = new Headers(reqHdrRaw)
 
     let urlStr = pathname
@@ -120,10 +129,10 @@ function httpHandler(req, pathname) {
     const reqInit = {
         method: req.method,
         headers: reqHdrNew,
-        redirect: 'follow',
+        redirect: 'manual',
         body: req.body
     }
-    return proxy(urlObj, reqInit, rawLen, 0)
+    return proxy(urlObj, reqInit)
 }
 
 
@@ -132,24 +141,22 @@ function httpHandler(req, pathname) {
  * @param {URL} urlObj
  * @param {RequestInit} reqInit
  */
-async function proxy(urlObj, reqInit, rawLen) {
+async function proxy(urlObj, reqInit) {
     const res = await fetch(urlObj.href, reqInit)
     const resHdrOld = res.headers
     const resHdrNew = new Headers(resHdrOld)
 
-    // verify
-    if (rawLen) {
-        const newLen = resHdrOld.get('content-length') || ''
-        const badLen = (rawLen !== newLen)
+    const status = res.status
 
-        if (badLen) {
-            return makeRes(res.body, 400, {
-                '--error': `bad len: ${newLen}, except: ${rawLen}`,
-                'access-control-expose-headers': '--error',
-            })
+    if (resHdrNew.has('location')) {
+        let _location = resHdrNew.get('location')
+        if (checkUrl(_location))
+            resHdrNew.set('location', PREFIX + _location)
+        else {
+            reqInit.redirect = 'follow'
+            return proxy(newUrl(_location), reqInit)
         }
     }
-    const status = res.status
     resHdrNew.set('access-control-expose-headers', '*')
     resHdrNew.set('access-control-allow-origin', '*')