Skip to content

Commit

Permalink
engine/gRPC proxy: Fix mux regression and add test coverage (thrasher…
Browse files Browse the repository at this point in the history
…-corp#1456)

* engine/gRPC proxy: Fix mux regression and enhance test coverage

* Use a temp dir for TLS creds and add credentials test tables

* Update GetRPCEndpoints grpcProxyName ListenAddr field

* Log unauthorised access attempts
  • Loading branch information
thrasher- authored Feb 5, 2024
1 parent e0c6e11 commit d57fefb
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 6 deletions.
2 changes: 1 addition & 1 deletion engine/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (bot *Engine) GetRPCEndpoints() (map[string]RPCEndpoint, error) {
},
grpcProxyName: {
Started: bot.Settings.EnableGRPCProxy,
ListenAddr: "http://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
ListenAddr: "https://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
},
DeprecatedName: {
Started: bot.Settings.EnableDeprecatedRPC,
Expand Down
27 changes: 22 additions & 5 deletions engine/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,14 @@ func StartRPCServer(engine *Engine) {

// StartRPCRESTProxy starts a gRPC proxy
func (s *RPCServer) StartRPCRESTProxy() {
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on http://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on https://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)

targetDir := utils.GetTLSDir(s.Settings.DataDir)
creds, err := credentials.NewClientTLSFromFile(filepath.Join(targetDir, "cert.pem"), "")
certFile := filepath.Join(targetDir, "cert.pem")
keyFile := filepath.Join(targetDir, "key.pem")
creds, err := credentials.NewClientTLSFromFile(certFile, "")
if err != nil {
log.Errorf(log.GRPCSys, "Unabled to start gRPC proxy. Err: %s\n", err)
log.Errorf(log.GRPCSys, "Unable to start gRPC proxy. Err: %s\n", err)
return
}

Expand All @@ -200,16 +202,31 @@ func (s *RPCServer) StartRPCRESTProxy() {
Addr: s.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
ReadHeaderTimeout: time.Minute,
ReadTimeout: time.Minute,
Handler: s.authClient(mux),
}

if err = server.ListenAndServe(); err != nil {
log.Errorf(log.GRPCSys, "GRPC proxy failed to server: %s\n", err)
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
log.Errorf(log.GRPCSys, "gRPC proxy server failed to serve: %s\n", err)
return
}
}()

log.Debugln(log.GRPCSys, "gRPC proxy server started!")
}

func (s *RPCServer) authClient(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username != s.Config.RemoteControl.Username || password != s.Config.RemoteControl.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`)
http.Error(w, "Access denied", http.StatusUnauthorized)
log.Warnf(log.GRPCSys, "gRPC proxy server unauthorised access attempt. IP: %s Path: %s\n", r.RemoteAddr, r.URL.Path)
return
}
handler.ServeHTTP(w, r)
})
}

// GetInfo returns info about the current GoCryptoTrader session
func (s *RPCServer) GetInfo(_ context.Context, _ *gctrpc.GetInfoRequest) (*gctrpc.GetInfoResponse, error) {
rpcEndpoints, err := s.getRPCEndpoints()
Expand Down
151 changes: 151 additions & 0 deletions engine/rpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@ package engine

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"testing"
Expand All @@ -16,6 +24,7 @@ import (
"github.com/gofrs/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/convert"
"github.com/thrasher-corp/gocryptotrader/common/key"
Expand Down Expand Up @@ -4139,3 +4148,145 @@ func TestGetOpenInterest(t *testing.T) {
_, err = s.GetOpenInterest(context.Background(), req)
assert.NoError(t, err)
}

func TestStartRPCRESTProxy(t *testing.T) {
t.Parallel()

tempDir := filepath.Join(os.TempDir(), "gct-grpc-proxy-test")
tempDirTLS := filepath.Join(tempDir, "tls")

t.Cleanup(func() {
assert.NoErrorf(t, os.RemoveAll(tempDir), "RemoveAll should not error, manual directory deletion required for TempDir: %s", tempDir)
})

if !assert.NoError(t, genCert(tempDirTLS), "genCert should not error") {
t.FailNow()
}

gRPCPort := rand.Intn(65535-42069) + 42069 //nolint:gosec // Don't require crypto/rand usage here
gRPCProxyPort := gRPCPort + 1

e := &Engine{
Config: &config.Config{
RemoteControl: config.RemoteControlConfig{
Username: "bobmarley",
Password: "Sup3rdup3rS3cr3t",
GRPC: config.GRPCConfig{
Enabled: true,
ListenAddress: "localhost:" + strconv.Itoa(gRPCPort),
GRPCProxyListenAddress: "localhost:" + strconv.Itoa(gRPCProxyPort),
},
},
},
Settings: Settings{
DataDir: tempDir,
CoreSettings: CoreSettings{EnableGRPCProxy: true},
},
}

fakeTime := time.Now().Add(-time.Hour)
e.uptime = fakeTime

StartRPCServer(e)

// Give the proxy time to start
time.Sleep(time.Millisecond * 500)

certFile := filepath.Join(tempDirTLS, "cert.pem")
caCert, err := os.ReadFile(certFile)
require.NoError(t, err, "ReadFile should not error")
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
require.True(t, ok, "AppendCertsFromPEM should return true")
client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}}}

for _, creds := range []struct {
testDescription string
username string
password string
}{
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
{"Valid username but invalid password", "bobmarley", "wrongpass"},
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
} {
creds := creds
t.Run(creds.testDescription, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://localhost:"+strconv.Itoa(gRPCProxyPort)+"/v1/getinfo", http.NoBody)
require.NoError(t, err, "NewRequestWithContext should not error")
req.SetBasicAuth(creds.username, creds.password)
resp, err := client.Do(req)
require.NoError(t, err, "Do should not error")
defer resp.Body.Close()

if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
var info gctrpc.GetInfoResponse
err = json.NewDecoder(resp.Body).Decode(&info)
require.NoError(t, err, "Decode should not error")

uptimeDuration, err := time.ParseDuration(info.Uptime)
require.NoError(t, err, "ParseDuration should not error")
assert.InDelta(t, time.Since(fakeTime).Seconds(), uptimeDuration.Seconds(), 1.0, "Uptime should be within 1 second of the expected duration")
} else {
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err, "ReadAll should not error")
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "HTTP status code should be 401")
assert.Equal(t, "Access denied\n", string(respBody), "Response body should be 'Access denied\n'")
}
})
}
}

func TestRPCProxyAuthClient(t *testing.T) {
t.Parallel()

s := new(RPCServer)
s.Engine = &Engine{
Config: &config.Config{
RemoteControl: config.RemoteControlConfig{
Username: "bobmarley",
Password: "Sup3rdup3rS3cr3t",
},
},
}

dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("MEOW"))
require.NoError(t, err, "Write should not error")
})

handler := s.authClient(dummyHandler)

for _, creds := range []struct {
testDescription string
username string
password string
}{
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
{"Valid username but invalid password", "bobmarley", "wrongpass"},
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
} {
creds := creds
t.Run(creds.testDescription, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", http.NoBody)
require.NoError(t, err, "NewRequestWithContext should not error")
req.SetBasicAuth(creds.username, creds.password)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
assert.Equal(t, http.StatusOK, rr.Code, "HTTP status code should be 200")
assert.Equal(t, "MEOW", rr.Body.String(), "Response body should be 'MEOW'")
} else {
assert.Equal(t, http.StatusUnauthorized, rr.Code, "HTTP status code should be 401")
assert.Equal(t, "Access denied\n", rr.Body.String(), "Response body should be 'Access denied\n'")
}
})
}
}

0 comments on commit d57fefb

Please sign in to comment.