Skip to content

Commit

Permalink
nghttpx: Fix bug that causes 400 response after upgrade failure
Browse files Browse the repository at this point in the history
  • Loading branch information
tatsuhiro-t committed Feb 26, 2023
1 parent 1c62a2a commit bb024e3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
118 changes: 118 additions & 0 deletions integration-tests/nghttpx_http1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,76 @@ func TestH1H1ReqPhaseReturn(t *testing.T) {
}
}

// TestH1H1ReqPhaseReturnCONNECTMethod tests that mruby request phase
// hook resets llhttp HPE_PAUSED_UPGRADE.
func TestH1H1ReqPhaseReturnCONNECTMethod(t *testing.T) {
opts := options{
args: []string{"--mruby-file=" + testDir + "/req-return.rb"},
handler: func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("request should not be forwarded")
},
}
st := newServerTester(t, opts)
defer st.Close()

if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}

resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}

defer resp.Body.Close()

if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}

hdCheck := func() {
hdtests := []struct {
k, v string
}{
{"content-length", "20"},
{"from", "mruby"},
}

for _, tt := range hdtests {
if got, want := resp.Header.Get(tt.k), tt.v; got != want {
t.Errorf("%v = %v; want %v", tt.k, got, want)
}
}

if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}
}

hdCheck()

if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}

resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}

defer resp.Body.Close()

if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}

hdCheck()

if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}
}

// TestH1H1RespPhaseSetHeader tests mruby response phase hook modifies
// response header fields.
func TestH1H1RespPhaseSetHeader(t *testing.T) {
Expand Down Expand Up @@ -737,6 +807,54 @@ func TestH1H1POSTRequests(t *testing.T) {
}
}

// TestH1H1CONNECTMethodFailure tests that CONNECT method failure
// resets llhttp HPE_PAUSED_UPGRADE.
func TestH1H1CONNECTMethodFailure(t *testing.T) {
opts := options{
handler: func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("required-header") == "" {
w.WriteHeader(http.StatusNotFound)
}
},
}
st := newServerTester(t, opts)
defer st.Close()

if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}

resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}

defer resp.Body.Close()

if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}

if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}

if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\nrequired-header: foo\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}

resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}

defer resp.Body.Close()

if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("status: %v; want %v", got, want)
}
}

// // TestH1H2ConnectFailure tests that server handles the situation that
// // connection attempt to HTTP/2 backend failed.
// func TestH1H2ConnectFailure(t *testing.T) {
Expand Down
9 changes: 9 additions & 0 deletions src/shrpx_https_upstream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,15 @@ int HttpsUpstream::on_read() {
auto htperr = llhttp_execute(&htp_, reinterpret_cast<const char *>(rb->pos()),
rb->rleft());

if (htperr == HPE_PAUSED_UPGRADE &&
rb->pos() ==
reinterpret_cast<const uint8_t *>(llhttp_get_error_pos(&htp_))) {
llhttp_resume_after_upgrade(&htp_);

htperr = llhttp_execute(&htp_, reinterpret_cast<const char *>(rb->pos()),
rb->rleft());
}

auto nread =
htperr == HPE_OK
? rb->rleft()
Expand Down

0 comments on commit bb024e3

Please sign in to comment.