Skip to content

Commit

Permalink
Restrict metadata headers in error propagation (connectrpc#711)
Browse files Browse the repository at this point in the history
This PR addresses issues when propagating errors from a client back to a
handler. On the client side connect errors will contain all response
headers: transport (`Content-Type`, `Content-Length`, etc), protocol and
application headers. These could break the transport when trying to
re-encode the error or leak sensitive information between services. For
any wire errors (errors decoded from a client response) we now disable
meta propagation. For other errors we now also restrict the headers
propagated.
  • Loading branch information
emcfarlane authored Mar 19, 2024
1 parent fbcf0ff commit 7b3b344
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 10 deletions.
132 changes: 132 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"connectrpc.com/connect/internal/memhttp/memhttptest"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/wrapperspb"
)

const errorMessage = "oh no"
Expand Down Expand Up @@ -542,6 +543,137 @@ func TestConcurrentStreams(t *testing.T) {
done.Wait()
}

func TestErrorHeaderPropagation(t *testing.T) {
t.Parallel()
newError := func(testname string, isWire bool) *connect.Error {
err := connect.NewError(connect.CodeInvalidArgument, errors.New(testname))
if isWire {
err = connect.NewWireError(connect.CodeInvalidArgument, errors.New(testname))
}
msgDetail := &wrapperspb.StringValue{Value: "server details"}
errDetail, derr := connect.NewErrorDetail(msgDetail)
if assert.Nil(t, derr) {
err.AddDetail(errDetail)
}
err.Meta().Set("Content-Length", "1337")
err.Meta().Set("Content-Type", "application/xml")
err.Meta().Set("Accept-Encoding", "bogus")
err.Meta().Set("Date", "Thu, 01 Jan 1970 00:00:00 GMT")
err.Meta().Set("Grpc-Status", "0")
// Set custom headers.
err.Meta().Set("X-Test", testname)
err.Meta()["x-test-case"] = []string{testname}
return err
}
pingServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
return nil, newError(request.Header().Get("X-Test"), request.Header().Get("X-Test-Is-Wire") == "true")
},
cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error {
return newError(stream.RequestHeader().Get("X-Test"), stream.RequestHeader().Get("X-Test-Is-Wire") == "true")
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)

assertError := func(t *testing.T, err error, allowCustomHeaders bool) {
t.Helper()
var connectErr *connect.Error
if !assert.True(t, errors.As(err, &connectErr)) {
return
}
assert.Equal(t, connectErr.Code(), connect.CodeInvalidArgument)
assert.Equal(t, connectErr.Message(), t.Name())
details := connectErr.Details()
if assert.Equal(t, len(details), 1) {
detailMsg, err := details[0].Value()
if !assert.Nil(t, err) {
return
}
serverDetails, ok := detailMsg.(*wrapperspb.StringValue)
if !assert.True(t, ok) {
return
}
assert.Equal(t, serverDetails.Value, "server details")
}
meta := connectErr.Meta()
assert.NotEqual(t, meta.Values("Content-Length"), []string{"1337"})
assert.NotEqual(t, meta.Values("Accept-Encoding"), []string{"bogus"})
assert.NotEqual(t, meta.Values("Content-Type"), []string{"application/xml"})
assert.NotEqual(t, meta.Values("Content-Length"), []string{"1337"})
assert.NotEqual(t, meta.Values("Date"), []string{"Thu, 01 Jan 1970 00:00:00 GMT"})
if allowCustomHeaders {
assert.Equal(t, meta.Values("x-test-case"), []string{t.Name()})
assert.Equal(t, meta.Values("X-Test"), []string{t.Name()})
} else {
assert.Equal(t, meta.Values("x-test-case"), []string(nil))
assert.Equal(t, meta.Values("X-Test"), []string(nil))
}
}
testServices := func(t *testing.T, client pingv1connect.PingServiceClient) {
t.Helper()
t.Run("unary", func(t *testing.T) {
request := connect.NewRequest(&pingv1.PingRequest{})
request.Header().Set("X-Test", t.Name())
_, err := client.Ping(context.Background(), request)
if !assert.NotNil(t, err) {
return
}
assertError(t, err, true /* allowCustomHeaders */)
t.Run("wire", func(t *testing.T) {
request := connect.NewRequest(&pingv1.PingRequest{})
request.Header().Set("X-Test", t.Name())
request.Header().Set("X-Test-Is-Wire", "true")
_, err := client.Ping(context.Background(), request)
if !assert.NotNil(t, err) {
return
}
assertError(t, err, false /* allowCustomHeaders */)
})
})
t.Run("bidi", func(t *testing.T) {
stream := client.CumSum(context.Background())
stream.RequestHeader().Set("X-Test", t.Name())
if err := stream.Send(nil); err != nil {
t.Fatal(err)
}
_, err := stream.Receive()
if !assert.NotNil(t, err) {
return
}
assertError(t, err, true /* allowCustomHeaders */)
t.Run("wire", func(t *testing.T) {
stream := client.CumSum(context.Background())
stream.RequestHeader().Set("X-Test", t.Name())
stream.RequestHeader().Set("X-Test-Is-Wire", "true")
if err := stream.Send(nil); err != nil {
t.Fatal(err)
}
_, err := stream.Receive()
if !assert.NotNil(t, err) {
return
}
})
})
}
t.Run("connect", func(t *testing.T) {
t.Parallel()
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL())
testServices(t, client)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC())
testServices(t, client)
})
t.Run("grpc-web", func(t *testing.T) {
t.Parallel()
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb())
testServices(t, client)
})
}

func TestHeaderBasic(t *testing.T) {
t.Parallel()
const (
Expand Down
9 changes: 9 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ func NewWireError(c Code, underlying error) *Error {
// Clients may find this useful when deciding how to propagate errors. For
// example, an RPC-to-HTTP proxy might expose a server-sent CodeUnknown as an
// HTTP 500 but a client-synthesized CodeUnknown as a 503.
//
// Handlers will strip [Error.Meta] headers propagated from wire errors to avoid
// leaking response headers. To propagate headers recreate the error as a
// non-wire error.
func IsWireError(err error) bool {
se := new(Error)
if !errors.As(err, &se) {
Expand Down Expand Up @@ -229,6 +233,11 @@ func (e *Error) AddDetail(d *ErrorDetail) {
// or a block of in-body metadata, depending on the protocol in use and whether
// or not the handler has already written messages to the stream.
//
// Protocol-specific headers and trailers may be removed to avoid breaking
// protocol semantics. For example, Content-Length and Content-Type headers
// won't be propagated. See the documentation for each protocol for more
// datails.
//
// When clients receive errors, the metadata contains the union of the HTTP
// headers and the protocol-specific trailers (either HTTP trailers or in-body
// metadata).
Expand Down
4 changes: 2 additions & 2 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request,
}

func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error {
if connectErr, ok := asError(err); ok {
mergeHeaders(response.Header(), connectErr.meta)
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(response.Header(), connectErr.meta)
}
response.WriteHeader(connectCodeToHTTP(CodeOf(err)))
data, marshalErr := json.Marshal(newConnectWireError(err))
Expand Down
40 changes: 40 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,46 @@ func mergeHeaders(into, from http.Header) {
}
}

// mergeMetdataHeaders merges the metadata headers from the "from" header into
// the "into" header. It skips over non metadata headers that should not be
// propagated from the server to the client.
func mergeMetadataHeaders(into, from http.Header) {
for key, vals := range from {
if len(vals) == 0 {
// For response trailers, net/http will pre-populate entries
// with nil values based on the "Trailer" header. But if there
// are no actual values for those keys, we skip them.
continue
}
switch http.CanonicalHeaderKey(key) {
case headerContentType,
headerContentLength,
headerContentEncoding,
headerHost,
headerUserAgent,
headerTrailer,
headerDate:
// HTTP headers.
case connectUnaryHeaderAcceptCompression,
connectUnaryTrailerPrefix,
connectStreamingHeaderCompression,
connectStreamingHeaderAcceptCompression,
connectHeaderTimeout,
connectHeaderProtocolVersion:
// Connect headers.
case grpcHeaderCompression,
grpcHeaderAcceptCompression,
grpcHeaderTimeout,
grpcHeaderStatus,
grpcHeaderMessage,
grpcHeaderDetails:
// gRPC headers.
default:
into[key] = append(into[key], vals...)
}
}
}

// getHeaderCanonical is a shortcut for Header.Get() which
// bypasses the CanonicalMIMEHeaderKey operation when we
// know the key is already in canonical form.
Expand Down
1 change: 1 addition & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const (
headerHost = "Host"
headerUserAgent = "User-Agent"
headerTrailer = "Trailer"
headerDate = "Date"

discardLimit = 1024 * 1024 * 4 // 4MiB
)
Expand Down
13 changes: 7 additions & 6 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,12 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
cc.compressionPools.CommaSeparatedNames(),
)
}
cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression)
if response.StatusCode != http.StatusOK {
unmarshaler := connectUnaryUnmarshaler{
ctx: cc.unmarshaler.ctx,
reader: response.Body,
compressionPool: cc.compressionPools.Get(compression),
compressionPool: cc.unmarshaler.compressionPool,
bufferPool: cc.bufferPool,
}
var wireErr connectWireError
Expand All @@ -559,7 +561,6 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
mergeHeaders(serverErr.meta, cc.responseTrailer)
return serverErr
}
cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression)
return nil
}

Expand Down Expand Up @@ -765,8 +766,8 @@ func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) {
header[headerVary] = append(header[headerVary], connectUnaryHeaderAcceptCompression)
}
if err != nil {
if connectErr, ok := asError(err); ok {
mergeHeaders(header, connectErr.meta)
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(header, connectErr.meta)
}
}
for k, v := range hc.responseTrailer {
Expand Down Expand Up @@ -850,8 +851,8 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea
end := &connectEndStreamMessage{Trailer: trailer}
if err != nil {
end.Error = newConnectWireError(err)
if connectErr, ok := asError(err); ok {
mergeHeaders(end.Trailer, connectErr.meta)
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(end.Trailer, connectErr.meta)
}
}
data, marshalErr := json.Marshal(end)
Expand Down
4 changes: 2 additions & 2 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,8 @@ func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) {
)
return
}
if connectErr, ok := asError(err); ok {
mergeHeaders(trailer, connectErr.meta)
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(trailer, connectErr.meta)
}
setHeaderCanonical(trailer, grpcHeaderStatus, code)
setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(status.GetMessage()))
Expand Down

0 comments on commit 7b3b344

Please sign in to comment.