Skip to content

Commit

Permalink
credentials: Fix issue with external package usage (thrasher-corp#1250)
Browse files Browse the repository at this point in the history
* credentials: Fix issue with external package usage

* Add shazberterino's suggestion

* credentials: Revert and expand coverage
  • Loading branch information
thrasher- authored Jul 4, 2023
1 parent 5af9b65 commit 1388b17
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 68 deletions.
13 changes: 7 additions & 6 deletions exchanges/account/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ var (

// Credentials define parameters that allow for an authenticated request.
type Credentials struct {
Key string
Secret string
ClientID string // TODO: Implement with exchange orders functionality
PEMKey string
SubAccount string
OneTimePassword string
Key string
Secret string
ClientID string // TODO: Implement with exchange orders functionality
PEMKey string
SubAccount string
OneTimePassword string
SecretBase64Decoded bool
// TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities.
}

Expand Down
8 changes: 6 additions & 2 deletions exchanges/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ func (b *Base) VerifyAPICredentials(creds *account.Credentials) error {
return fmt.Errorf("%s %w", b.Name, errRequiresAPIClientID)
}

if b.API.CredentialsValidator.RequiresBase64DecodeSecret && !b.LoadedByConfig {
_, err := crypto.Base64Decode(creds.Secret)
if b.API.CredentialsValidator.RequiresBase64DecodeSecret && !creds.SecretBase64Decoded {
decodedResult, err := crypto.Base64Decode(creds.Secret)
if err != nil {
return fmt.Errorf("%s API secret %w: %s", b.Name, errBase64DecodeFailure, err)
}
creds.Secret = string(decodedResult)
creds.SecretBase64Decoded = true
}

return nil
}

Expand Down Expand Up @@ -218,6 +221,7 @@ func (b *Base) SetCredentials(apiKey, apiSecret, clientID, subaccount, pemKey, o
return
}
b.API.credentials.Secret = string(result)
b.API.credentials.SecretBase64Decoded = true
} else {
b.API.credentials.Secret = apiSecret
}
Expand Down
214 changes: 154 additions & 60 deletions exchanges/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,35 @@ func TestGetCredentials(t *testing.T) {
t.Fatalf("received: %v but expected: %v", err, errRequiresAPISecret)
}

b.API.CredentialsValidator.RequiresBase64DecodeSecret = true
ctx = account.DeployCredentialsToContext(context.Background(), &account.Credentials{
Key: "meow",
Secret: "invalidb64",
})
if _, err = b.GetCredentials(ctx); !errors.Is(err, errBase64DecodeFailure) {
t.Fatalf("received: %v but expected: %v", err, errBase64DecodeFailure)
}

const expectedBase64DecodedOutput = "hello world"
ctx = account.DeployCredentialsToContext(context.Background(), &account.Credentials{
Key: "meow",
Secret: "aGVsbG8gd29ybGQ=",
})
creds, err := b.GetCredentials(ctx)
if !errors.Is(err, nil) {
t.Fatalf("received: %v but expected: %v", err, nil)
}
if creds.Secret != expectedBase64DecodedOutput {
t.Fatalf("received: %v but expected: %v", creds.Secret, expectedBase64DecodedOutput)
}

ctx = context.WithValue(context.Background(), account.ContextCredentialsFlag, "pewpew")
_, err = b.GetCredentials(ctx)
if !errors.Is(err, errContextCredentialsFailure) {
t.Fatalf("received: %v but expected: %v", err, errContextCredentialsFailure)
}

b.API.CredentialsValidator.RequiresBase64DecodeSecret = false
fullCred := &account.Credentials{
Key: "superkey",
Secret: "supersecret",
Expand All @@ -47,7 +70,7 @@ func TestGetCredentials(t *testing.T) {
}

ctx = account.DeployCredentialsToContext(context.Background(), fullCred)
creds, err := b.GetCredentials(ctx)
creds, err = b.GetCredentials(ctx)
if !errors.Is(err, nil) {
t.Fatalf("received: %v but expected: %v", err, nil)
}
Expand Down Expand Up @@ -131,9 +154,13 @@ func TestVerifyAPICredentials(t *testing.T) {
RequiresSecret bool
RequiresClientID bool
RequiresBase64DecodeSecret bool
UseSetCredentials bool
CheckBase64DecodedOutput bool
Expected error
}

const expectedBase64DecodedOutput = "hello world"

testCases := []tester{
// Empty credentials
{Expected: ErrCredentialsAreEmpty},
Expand All @@ -152,31 +179,45 @@ func TestVerifyAPICredentials(t *testing.T) {
// test requires base64 decode secret
{RequiresBase64DecodeSecret: true, RequiresSecret: true, Expected: errRequiresAPISecret, Key: "bruh"},
{RequiresBase64DecodeSecret: true, Secret: "%%", Expected: errBase64DecodeFailure},
{RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ="},
{RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ=", CheckBase64DecodedOutput: true},
{RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ=", UseSetCredentials: true, CheckBase64DecodedOutput: true},
}

setupBase := func(tData *tester) *Base {
b := &Base{}
b.API.SetKey(tData.Key)
b.API.SetSecret(tData.Secret)
b.API.SetClientID(tData.ClientID)
b.API.SetPEMKey(tData.PEMKey)
b.API.CredentialsValidator.RequiresKey = tData.RequiresKey
b.API.CredentialsValidator.RequiresSecret = tData.RequiresSecret
b.API.CredentialsValidator.RequiresPEM = tData.RequiresPEM
b.API.CredentialsValidator.RequiresClientID = tData.RequiresClientID
b.API.CredentialsValidator.RequiresBase64DecodeSecret = tData.RequiresBase64DecodeSecret
b := &Base{
API: API{
CredentialsValidator: CredentialsValidator{
RequiresKey: tData.RequiresKey,
RequiresSecret: tData.RequiresSecret,
RequiresClientID: tData.RequiresClientID,
RequiresPEM: tData.RequiresPEM,
RequiresBase64DecodeSecret: tData.RequiresBase64DecodeSecret,
},
},
}
if tData.UseSetCredentials {
b.SetCredentials(tData.Key, tData.Secret, tData.ClientID, "", tData.PEMKey, "")
} else {
b.API.SetKey(tData.Key)
b.API.SetSecret(tData.Secret)
b.API.SetClientID(tData.ClientID)
b.API.SetPEMKey(tData.PEMKey)
}
return b
}

for x := range testCases {
testData := &testCases[x]
x := x
for x, tc := range testCases {
x, tc := x, tc
t.Run("", func(t *testing.T) {
t.Parallel()
b := setupBase(testData)
if err := b.VerifyAPICredentials(b.API.credentials); !errors.Is(err, testData.Expected) {
t.Errorf("Test %d: expected: %v: got %v", x+1, testData.Expected, err)
b := setupBase(&tc)
if err := b.VerifyAPICredentials(b.API.credentials); !errors.Is(err, tc.Expected) {
t.Errorf("Test %d: expected: %v: got %v", x+1, tc.Expected, err)
}
if tc.CheckBase64DecodedOutput {
if b.API.credentials.Secret != expectedBase64DecodedOutput {
t.Errorf("Test %d: expected: %v: got %v", x+1, expectedBase64DecodedOutput, b.API.credentials.Secret)
}
}
})
}
Expand All @@ -185,50 +226,103 @@ func TestVerifyAPICredentials(t *testing.T) {
func TestCheckCredentials(t *testing.T) {
t.Parallel()

b := Base{
SkipAuthCheck: true,
API: API{credentials: &account.Credentials{}},
}

// Test SkipAuthCheck
err := b.CheckCredentials(&account.Credentials{}, false)
if !errors.Is(err, nil) {
t.Errorf("received '%v' expected '%v'", err, nil)
}

// Test credentials failure
b.SkipAuthCheck = false
b.API.CredentialsValidator.RequiresKey = true
b.API.credentials.OneTimePassword = "wow"
err = b.CheckCredentials(b.API.credentials, false)
if !errors.Is(err, errRequiresAPIKey) {
t.Errorf("received '%v' expected '%v'", err, errRequiresAPIKey)
}
b.API.credentials.OneTimePassword = ""

// Test bot usage with authenticated API support disabled, but with
// valid credentials
b.LoadedByConfig = true
b.API.credentials.Key = "k3y"
err = b.CheckCredentials(b.API.credentials, false)
if !errors.Is(err, ErrAuthenticationSupportNotEnabled) {
t.Errorf("received '%v' expected '%v'", err, ErrAuthenticationSupportNotEnabled)
}

// Test enabled authenticated API support and loaded by config
// but invalid credentials
b.API.AuthenticatedSupport = true
b.API.credentials.Key = ""
err = b.CheckCredentials(b.API.credentials, false)
if !errors.Is(err, ErrCredentialsAreEmpty) {
t.Errorf("received '%v' expected '%v'", err, ErrCredentialsAreEmpty)
testCases := []struct {
name string
base *Base
checkBase64Output bool
expectedErr error
}{
{
name: "Test SkipAuthCheck",
base: &Base{
SkipAuthCheck: true,
API: API{credentials: &account.Credentials{}},
},
expectedErr: nil,
},
{
name: "Test credentials failure",
base: &Base{
API: API{
CredentialsValidator: CredentialsValidator{RequiresKey: true},
credentials: &account.Credentials{OneTimePassword: "wow"},
},
},
expectedErr: errRequiresAPIKey,
},
{
name: "Test exchange usage with authenticated API support disabled, but with valid credentials",
base: &Base{
LoadedByConfig: true,
API: API{
CredentialsValidator: CredentialsValidator{RequiresKey: true},
credentials: &account.Credentials{Key: "k3y"},
},
},
expectedErr: ErrAuthenticationSupportNotEnabled,
},
{
name: "Test enabled authenticated API support and loaded by config but invalid credentials",
base: &Base{
LoadedByConfig: true,
API: API{
AuthenticatedSupport: true,
CredentialsValidator: CredentialsValidator{RequiresKey: true},
credentials: &account.Credentials{},
},
},
expectedErr: ErrCredentialsAreEmpty,
},
{
name: "Test base64 decoded invalid credentials",
base: &Base{
API: API{
CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true},
credentials: &account.Credentials{Secret: "invalid"},
},
},
expectedErr: errBase64DecodeFailure,
},
{
name: "Test base64 decoded valid credentials",
base: &Base{
API: API{
CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true},
credentials: &account.Credentials{Secret: "aGVsbG8gd29ybGQ="},
},
},
checkBase64Output: true,
expectedErr: nil,
},
{
name: "Test valid credentials",
base: &Base{
API: API{
AuthenticatedSupport: true,
CredentialsValidator: CredentialsValidator{RequiresKey: true},
credentials: &account.Credentials{Key: "k3y"},
},
},
expectedErr: nil,
},
}

// Finally a valid one
b.API.credentials.Key = "k3y"
err = b.CheckCredentials(b.API.credentials, false)
if !errors.Is(err, nil) {
t.Errorf("received '%v' expected '%v'", err, nil)
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if err := tc.base.CheckCredentials(tc.base.API.credentials, false); !errors.Is(err, tc.expectedErr) {
t.Errorf("%s: received '%v' but expected '%v'", tc.name, err, tc.expectedErr)
}
if tc.checkBase64Output {
if tc.base.API.credentials.SecretBase64Decoded != true {
t.Errorf("%s: expected secret to be base64 decoded", tc.name)
}
if tc.base.API.credentials.Secret != "hello world" {
t.Errorf("%s: expected %q but received %q", "hello world", tc.name, tc.base.API.credentials.Secret)
}
}
})
}
}

Expand Down

0 comments on commit 1388b17

Please sign in to comment.