Skip to content

Commit

Permalink
Config: refactor config file loaders (thrasher-corp#577)
Browse files Browse the repository at this point in the history
* Config: fix don't create empty dir when resolving path

* Config: refactor config file loaders

* add a layer of abstraction so that config can be loaded from non-files
* use io.Reader / io.Writer abstraction to separate data operations from
file operations
* remove dryrun option from SaveConfig - now it always saves

* rename read and save methods to mention file operations

* log error when encryption prompt fails

* as the user didn't make a choice, we'd prompt again next time the file
is loaded
* add file.Writer tests
* skip permissions test for windows

* defer creating the writer on save to the last moment

* this avoids truncating file when there is error with password prompt
* add a test

* tests with StdIn cannot run in parallel
  • Loading branch information
Rots authored Nov 4, 2020
1 parent 80bc8c7 commit ee55ae5
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 121 deletions.
2 changes: 1 addition & 1 deletion cmd/exchange_template/exchange_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func makeExchange(exch *exchange) error {
}

configTestFile.Exchanges = append(configTestFile.Exchanges, newExchConfig)
err = configTestFile.SaveConfig(exchangeConfigPath, false)
err = configTestFile.SaveConfigToFile(exchangeConfigPath)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/exchange_template/exchange_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestNewExchange(t *testing.T) {
t.Fatalf("unable to remove exchange config for %s, manual removal required\n",
testExchangeName)
}
if err := cfg.SaveConfig(exchangeConfigPath, false); err != nil {
if err := cfg.SaveConfigToFile(exchangeConfigPath); err != nil {
t.Fatal(err)
}
}
13 changes: 13 additions & 0 deletions common/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ func Write(file string, data []byte) error {
return ioutil.WriteFile(file, data, 0770)
}

// Writer creates a writer to a file or returns an error if it fails. This
// func also ensures that all files are set to this permission (only rw access
// for the running user and the group the user is a member of)
func Writer(file string) (*os.File, error) {
basePath := filepath.Dir(file)
if !Exists(basePath) {
if err := os.MkdirAll(basePath, 0770); err != nil {
return nil, err
}
}
return os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0770)
}

// Move moves a file from a source path to a destination path
// This must be used across the codebase for compatibility with Docker volumes
// and Golang (fixes Invalid cross-device link when using os.Rename)
Expand Down
89 changes: 89 additions & 0 deletions common/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,92 @@ func TestWriteAsCSV(t *testing.T) {
}
}
}

func TestWriter(t *testing.T) {
type args struct {
file string
}
tmp, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmp)

testData := `data`

tests := []struct {
name string
args args
want *os.File
wantErr bool
}{
{
name: "invalid",
args: args{"//invalid-nofile\\"},
wantErr: true,
},
{
name: "empty",
args: args{""},
wantErr: true,
},
{
name: "relative newfile",
args: args{"newfile"},
},
{
name: "deep file",
args: args{filepath.Join(tmp, "new", "file", "multiple", "sub", "paths")},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
got, err := Writer(tt.args.file)
if err != nil {
if (err != nil) != tt.wantErr {
t.Errorf("Writer() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
defer os.Remove(got.Name())
fileInfo, err := os.Stat(got.Name())
if err != nil {
t.Fatal(err)
}
if !fileInfo.Mode().IsRegular() {
t.Fatalf("Writer() error = expected to get a file %s", got.Name())
}
_, err = got.WriteString(testData)
if err != nil {
t.Fatal(err)
}
err = got.Close()
if err != nil {
t.Fatal(err)
}
if data, err := ioutil.ReadFile(got.Name()); err != nil || string(data) != testData {
t.Errorf("Could not write the file, or contents were wrong: expected = %s, got =%s", testData, string(data))
}
})
}
}

func TestWriterNoPermissionFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skip file permissions")
}
temp, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(temp)
err = os.Chmod(temp, 0555)
if err != nil {
t.Fatal(err)
}
_, err = Writer(filepath.Join(temp, "path", "to", "somefile"))
if err == nil {
t.Error("Expected to fail when no permissions, but writer succeeded")
}
}
168 changes: 108 additions & 60 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package config

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strconv"
Expand Down Expand Up @@ -1545,94 +1547,133 @@ func migrateConfig(configFile, targetDir string) (string, error) {
return target, nil
}

// ReadConfig verifies and checks for encryption and verifies the unencrypted
// file contains JSON.
// Prompts for decryption key, if target file is encrypted
func (c *Config) ReadConfig(configPath string, dryrun bool) error {
// ReadConfigFromFile reads the configuration from the given file
// if target file is encrypted, prompts for encryption key
// Also - if not in dryrun mode - it checks if the configuration needs to be encrypted
// and stores the file as encrypted, if necessary (prompting for enryption key)
func (c *Config) ReadConfigFromFile(configPath string, dryrun bool) error {
defaultPath, _, err := GetFilePath(configPath)
if err != nil {
return err
}

fileData, err := ioutil.ReadFile(defaultPath)
confFile, err := os.Open(defaultPath)
if err != nil {
return err
}
defer confFile.Close()
result, wasEncrypted, err := ReadConfig(confFile, func() ([]byte, error) { return PromptForConfigKey(false) })
if err != nil {
return fmt.Errorf("error reading config %w", err)
}
// Override values in the current config
*c = *result

if !ConfirmECS(fileData) {
err = json.Unmarshal(fileData, c)
if dryrun || wasEncrypted || c.EncryptConfig == fileEncryptionDisabled {
return nil
}

if c.EncryptConfig == fileEncryptionPrompt {
confirm, err := promptForConfigEncryption()
if err != nil {
return err
log.Errorf(log.ConfigMgr, "The encryption prompt failed, ignoring for now, next time we will prompt again. Error: %s\n", err)
return nil
}
if confirm {
c.EncryptConfig = fileEncryptionEnabled
return c.SaveConfigToFile(defaultPath)
}

if c.EncryptConfig == fileEncryptionDisabled {
return nil
c.EncryptConfig = fileEncryptionDisabled
err = c.SaveConfigToFile(defaultPath)
if err != nil {
log.Errorf(log.ConfigMgr, "Cannot save config. Error: %s\n", err)
}
}
return nil
}

if c.EncryptConfig == fileEncryptionPrompt {
confirm, err := promptForConfigEncryption()
if err == nil {
if confirm {
c.EncryptConfig = fileEncryptionEnabled
return c.SaveConfig(defaultPath, dryrun)
}
// ReadConfig verifies and checks for encryption and loads the config from a JSON object.
// Prompts for decryption key, if target data is encrypted.
// Returns the loaded configuration and whether it was encrypted.
func ReadConfig(configReader io.Reader, keyProvider func() ([]byte, error)) (*Config, bool, error) {
reader := bufio.NewReader(configReader)

c.EncryptConfig = fileEncryptionDisabled
err := c.SaveConfig(configPath, dryrun)
if err != nil {
log.Errorf(log.ConfigMgr, "Cannot save config. Error: %s\n", err)
}
}
}
return nil
pref, err := reader.Peek(len(EncryptConfirmString))
if err != nil {
return nil, false, err
}

errCounter := 0
for {
if errCounter >= maxAuthFailures {
return errors.New("failed to decrypt config after 3 attempts")
}
key, err := PromptForConfigKey(false)
if err != nil {
log.Errorf(log.ConfigMgr, "PromptForConfigKey err: %s", err)
errCounter++
continue
}
if !ConfirmECS(pref) {
// Read unencrypted configuration
decoder := json.NewDecoder(reader)
c := &Config{}
err = decoder.Decode(c)
return c, false, err
}

conf, err := readEncryptedConfWithKey(reader, keyProvider)
return conf, true, err
}

var f []byte
f = append(f, fileData...)
data, err := c.decryptConfigData(f, key)
// readEncryptedConf reads encrypted configuration and requests key from provider
func readEncryptedConfWithKey(reader *bufio.Reader, keyProvider func() ([]byte, error)) (*Config, error) {
fileData, err := ioutil.ReadAll(reader)
if err != nil {
return nil, err
}
for errCounter := 0; errCounter < maxAuthFailures; errCounter++ {
key, err := keyProvider()
if err != nil {
log.Errorf(log.ConfigMgr, "decryptConfigData err: %s", err)
errCounter++
log.Errorf(log.ConfigMgr, "PromptForConfigKey err: %s", err)
continue
}

err = json.Unmarshal(data, c)
var c *Config
c, err = readEncryptedConf(bytes.NewReader(fileData), key)
if err != nil {
if errCounter < maxAuthFailures {
log.Error(log.ConfigMgr, "Invalid password.")
}
errCounter++
log.Error(log.ConfigMgr, "Could not decrypt and deserialise data with given key. Invalid password?", err)
continue
}
break
return c, nil
}
return nil
return nil, errors.New("failed to decrypt config after 3 attempts")
}

// SaveConfig saves your configuration to your desired path
// prompts for encryption key, if necessary
func (c *Config) SaveConfig(configPath string, dryrun bool) error {
if dryrun {
return nil
func readEncryptedConf(reader io.Reader, key []byte) (*Config, error) {
c := &Config{}
data, err := c.decryptConfigData(reader, key)
if err != nil {
return nil, err
}

err = json.Unmarshal(data, c)
return c, err
}

// SaveConfigToFile saves your configuration to your desired path as a JSON object.
// The function encrypts the data and prompts for encryption key, if necessary
func (c *Config) SaveConfigToFile(configPath string) error {
defaultPath, _, err := GetFilePath(configPath)
if err != nil {
return err
}
var writer *os.File
provider := func() (io.Writer, error) {
writer, err = file.Writer(defaultPath)
return writer, err
}
defer func() {
if writer != nil {
writer.Close()
}
}()
return c.Save(provider, func() ([]byte, error) { return PromptForConfigKey(true) })
}

// Save saves your configuration to the writer as a JSON object
// with encryption, if configured
// If there is an error when preparing the data to store, the writer is never requested
func (c *Config) Save(writerProvider func() (io.Writer, error), keyProvider func() ([]byte, error)) error {
payload, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
Expand All @@ -1642,7 +1683,7 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error {
// Ensure we have the key from session or from user
if len(c.sessionDK) == 0 {
var key []byte
key, err = PromptForConfigKey(true)
key, err = keyProvider()
if err != nil {
return err
}
Expand All @@ -1658,7 +1699,12 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error {
return err
}
}
return file.Write(defaultPath, payload)
configWriter, err := writerProvider()
if err != nil {
return err
}
_, err = io.Copy(configWriter, bytes.NewReader(payload))
return err
}

// CheckRemoteControlConfig checks to see if the old c.Webserver field is used
Expand Down Expand Up @@ -1759,7 +1805,7 @@ func (c *Config) CheckConfig() error {

// LoadConfig loads your configuration file into your configuration object
func (c *Config) LoadConfig(configPath string, dryrun bool) error {
err := c.ReadConfig(configPath, dryrun)
err := c.ReadConfigFromFile(configPath, dryrun)
if err != nil {
return fmt.Errorf(ErrFailureOpeningConfig, configPath, err)
}
Expand All @@ -1783,9 +1829,11 @@ func (c *Config) UpdateConfig(configPath string, newCfg *Config, dryrun bool) er
c.Webserver = newCfg.Webserver
c.Exchanges = newCfg.Exchanges

err = c.SaveConfig(configPath, dryrun)
if err != nil {
return err
if !dryrun {
err = c.SaveConfigToFile(configPath)
if err != nil {
return err
}
}

return c.LoadConfig(configPath, dryrun)
Expand Down
Loading

0 comments on commit ee55ae5

Please sign in to comment.