Skip to content

Commit

Permalink
Basic x-fowarded-* headers microsoft#13 (microsoft#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher authored May 1, 2020
1 parent 561879f commit ce6b064
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
17 changes: 12 additions & 5 deletions samples/ReverseProxy.Sample/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ReverseProxy.Core.Middleware;

namespace Microsoft.ReverseProxy.Sample
{
Expand Down Expand Up @@ -47,15 +48,21 @@ public void Configure(IApplicationBuilder app)
endpoints.MapControllers();
endpoints.MapReverseProxy(proxyPipeline =>
{
proxyPipeline.UseProxyLoadBalancing();
// Customize the request before forwarding
// Custom endpoint selection
proxyPipeline.Use((context, next) =>
{
var connection = context.Connection;
context.Request.Headers.AppendCommaSeparatedValues("X-Forwarded-For",
new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort).ToString());
var someCriteria = false; // MeetsCriteria(context);
if (someCriteria)
{
var availableEndpointsFeature = context.Features.Get<IAvailableBackendEndpointsFeature>();
var endpoint = availableEndpointsFeature.Endpoints[0]; // PickEndpoint(availableEndpointsFeature.Endpoints);
// Load balancing will no-op if we've already reduced the list of available endpoints to 1.
availableEndpointsFeature.Endpoints = new[] { endpoint };
}

return next();
});
proxyPipeline.UseProxyLoadBalancing();
});
});
}
Expand Down
22 changes: 18 additions & 4 deletions src/ReverseProxy.Core/Service/Proxy/HttpProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.HttpOverrides;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -163,7 +164,7 @@ private async Task NormalProxyAsync(

// :::::::::::::::::::::::::::::::::::::::::::::
// :: Step 3: Copy request headers Downstream --► Proxy --► Upstream
CopyHeadersToUpstream(context.Request.Headers, upstreamRequest);
CopyHeadersToUpstream(context, upstreamRequest);

// :::::::::::::::::::::::::::::::::::::::::::::
// :: Step 4: Send the outgoing request using HttpClient
Expand All @@ -183,6 +184,7 @@ private async Task NormalProxyAsync(
// cause us to wait forever in step 9, so fail fast here.
if (bodyToUpstreamContent != null && !bodyToUpstreamContent.Started)
{
// TODO: bodyToUpstreamContent is never null. HttpClient might would not need to read the body in some scenarios, such as an early auth failure with Expect: 100-continue.
throw new InvalidOperationException("Proxying the downstream request body to the upstream server hasn't started. This is a coding defect.");
}

Expand Down Expand Up @@ -280,7 +282,7 @@ private async Task UpgradableProxyAsync(

// :::::::::::::::::::::::::::::::::::::::::::::
// :: Step 2: Copy request headers Downstream --► Proxy --► Upstream
CopyHeadersToUpstream(context.Request.Headers, upstreamRequest);
CopyHeadersToUpstream(context, upstreamRequest);

// :::::::::::::::::::::::::::::::::::::::::::::
// :: Step 3: Send the outgoing request using HttpMessageInvoker
Expand Down Expand Up @@ -342,6 +344,7 @@ private async Task UpgradableProxyAsync(
private StreamCopyHttpContent SetupCopyBodyUpstream(Stream source, HttpRequestMessage upstreamRequest, in ProxyTelemetryContext proxyTelemetryContext, bool isStreamingRequest, CancellationToken cancellation)
{
StreamCopyHttpContent contentToUpstream = null;
// TODO: the request body is never null.
if (source != null)
{
////this.logger.LogInformation($" Setting up downstream --> Proxy --> upstream body proxying");
Expand Down Expand Up @@ -373,9 +376,9 @@ private StreamCopyHttpContent SetupCopyBodyUpstream(Stream source, HttpRequestMe
return contentToUpstream;
}

private void CopyHeadersToUpstream(IHeaderDictionary source, HttpRequestMessage destination)
private void CopyHeadersToUpstream(HttpContext context, HttpRequestMessage destination)
{
foreach (var header in source)
foreach (var header in context.Request.Headers)
{
var headerValueCount = header.Value.Count;
if (headerValueCount == 0)
Expand Down Expand Up @@ -417,6 +420,17 @@ private void CopyHeadersToUpstream(IHeaderDictionary source, HttpRequestMessage
}
}
}

// Add common forwarders
// TODO: these need to be customizable
// https://github.com/microsoft/reverse-proxy/issues/13
// https://github.com/microsoft/reverse-proxy/issues/21
destination.Headers.TryAddWithoutValidation(ForwardedHeadersDefaults.XForwardedProtoHeaderName, context.Request.Scheme);
destination.Headers.TryAddWithoutValidation(ForwardedHeadersDefaults.XForwardedHostHeaderName, context.Request.Host.ToString());
if (context.Connection.RemoteIpAddress != null)
{
destination.Headers.TryAddWithoutValidation(ForwardedHeadersDefaults.XForwardedForHeaderName, context.Connection.RemoteIpAddress.ToString());
}
}

private void CopyHeadersToDownstream(HttpResponseMessage source, IHeaderDictionary destination)
Expand Down
81 changes: 75 additions & 6 deletions test/ReverseProxy.Core.Tests/Service/Proxy/HttpProxyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
Expand Down Expand Up @@ -39,14 +40,15 @@ public async Task ProxyAsync_NormalRequest_Works()
// Arrange
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "POST";
httpContext.Request.Scheme = "https";
httpContext.Request.Host = new HostString("example.com");
httpContext.Request.Scheme = "http";
httpContext.Request.Host = new HostString("example.com:3456");
httpContext.Request.Path = "/api/test";
httpContext.Request.QueryString = new QueryString("?a=b&c=d");
httpContext.Request.Headers.Add(":host", "example.com");
httpContext.Request.Headers.Add(":authority", "example.com:3456");
httpContext.Request.Headers.Add("x-ms-request-test", "request");
httpContext.Request.Headers.Add("Content-Language", "requestLanguage");
httpContext.Request.Body = StringToStream("request content");
httpContext.Connection.RemoteIpAddress = IPAddress.Loopback;

var proxyResponseStream = new MemoryStream();
httpContext.Response.Body = proxyResponseStream;
Expand All @@ -62,6 +64,11 @@ public async Task ProxyAsync_NormalRequest_Works()
Assert.Equal(HttpMethod.Post, request.Method);
Assert.Equal(targetUri, request.RequestUri);
Assert.Contains("request", request.Headers.GetValues("x-ms-request-test"));
Assert.Null(request.Headers.Host);
Assert.False(request.Headers.TryGetValues(":authority", out var value));
Assert.Equal("127.0.0.1", request.Headers.GetValues("x-forwarded-for").Single());
Assert.Equal("example.com:3456", request.Headers.GetValues("x-forwarded-host").Single());
Assert.Equal("http", request.Headers.GetValues("x-forwarded-proto").Single());

Assert.NotNull(request.Content);
Assert.Contains("requestLanguage", request.Content.Headers.GetValues("Content-Language"));
Expand Down Expand Up @@ -104,19 +111,76 @@ public async Task ProxyAsync_NormalRequest_Works()
Assert.Equal("response content", proxyResponseText);
}

[Fact]
public async Task ProxyAsync_NormalRequestWithExistingForwarders_Appends()
{
// Arrange
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "GET";
httpContext.Request.Scheme = "http";
httpContext.Request.Host = new HostString("example.com:3456");
httpContext.Request.Path = "/api/test";
httpContext.Request.QueryString = new QueryString("?a=b&c=d");
httpContext.Request.Headers.Add(":authority", "example.com:3456");
httpContext.Request.Headers.Add("x-forwarded-for", "::1");
httpContext.Request.Headers.Add("x-forwarded-proto", "https");
httpContext.Request.Headers.Add("x-forwarded-host", "some.other.host:4567");
httpContext.Connection.RemoteIpAddress = IPAddress.Loopback;

var proxyResponseStream = new MemoryStream();
httpContext.Response.Body = proxyResponseStream;

var targetUri = new Uri("https://localhost:123/a/b/api/test");
var sut = Create<HttpProxy>();
var client = MockHttpHandler.CreateClient(
async (HttpRequestMessage request, CancellationToken cancellationToken) =>
{
await Task.Yield();

Assert.Equal(new Version(2, 0), request.Version);
Assert.Equal(HttpMethod.Get, request.Method);
Assert.Equal(targetUri, request.RequestUri);
Assert.Equal(new[] { "::1", "127.0.0.1" }, request.Headers.GetValues("x-forwarded-for"));
Assert.Equal(new[] { "https", "http" }, request.Headers.GetValues("x-forwarded-proto"));
Assert.Equal(new[] { "some.other.host:4567", "example.com:3456" }, request.Headers.GetValues("x-forwarded-host"));
Assert.Null(request.Headers.Host);
Assert.False(request.Headers.TryGetValues(":authority", out var value));

// The proxy throws if the request body is not read.
await request.Content.CopyToAsync(Stream.Null);

var response = new HttpResponseMessage((HttpStatusCode)234);
return response;
});
var factoryMock = new Mock<IProxyHttpClientFactory>();
factoryMock.Setup(f => f.CreateNormalClient()).Returns(client);

var proxyTelemetryContext = new ProxyTelemetryContext(
backendId: "be1",
routeId: "rt1",
endpointId: "ep1");

// Act
await sut.ProxyAsync(httpContext, targetUri, factoryMock.Object, proxyTelemetryContext, CancellationToken.None, CancellationToken.None);

// Assert
Assert.Equal(234, httpContext.Response.StatusCode);
}

// Tests proxying an upgradable request.
[Fact]
public async Task ProxyAsync_UpgradableRequest_Works()
{
// Arrange
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "GET";
httpContext.Request.Scheme = "https";
httpContext.Request.Host = new HostString("example.com");
httpContext.Request.Scheme = "http";
httpContext.Request.Host = new HostString("example.com:3456");
httpContext.Request.Path = "/api/test";
httpContext.Request.QueryString = new QueryString("?a=b&c=d");
httpContext.Request.Headers.Add(":host", "example.com");
httpContext.Request.Headers.Add(":authority", "example.com:3456");
httpContext.Request.Headers.Add("x-ms-request-test", "request");
httpContext.Connection.RemoteIpAddress = IPAddress.Loopback;

var downstreamStream = new DuplexStream(
readStream: StringToStream("request content"),
Expand All @@ -139,6 +203,11 @@ public async Task ProxyAsync_UpgradableRequest_Works()
Assert.Equal(HttpMethod.Get, request.Method);
Assert.Equal(targetUri, request.RequestUri);
Assert.Contains("request", request.Headers.GetValues("x-ms-request-test"));
Assert.Null(request.Headers.Host);
Assert.False(request.Headers.TryGetValues(":authority", out var value));
Assert.Equal("127.0.0.1", request.Headers.GetValues("x-forwarded-for").Single());
Assert.Equal("example.com:3456", request.Headers.GetValues("x-forwarded-host").Single());
Assert.Equal("http", request.Headers.GetValues("x-forwarded-proto").Single());

Assert.Null(request.Content);

Expand Down

0 comments on commit ce6b064

Please sign in to comment.