Skip to content

Commit

Permalink
Config/Engine: Add data directory to config.json (thrasher-corp#549)
Browse files Browse the repository at this point in the history
* add data directory to config.json

* fix quality check issues

* adjust data directory only when explicitly set

* unexport ValidateSettings
* process flags earlier so they can also be used when loading config
* fix test depends on flags

* rename config.DataDir to DataDirectory

* also don't omit in JSON if empty

* datadir flag induces dry run

* log warning
* enable parallel for sub tests
* leave data dir empty in example config

* remove parallel for loadConfigWithSettings

* create a new config object instead of using a shared one

* remove a test that potentially reads user file

* rename test methods to MixedCaps

* clean up test dir after engine tests

* use global config variable
  • Loading branch information
Rots authored Sep 17, 2020
1 parent 25466b2 commit a67b5cf
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 16 deletions.
17 changes: 14 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ func (c *Config) CheckLoggerConfig() error {
log.GlobalLogConfig = &c.Logging
log.RWM.Unlock()

logPath := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "logs")
logPath := c.GetDataPath("logs")
err := common.CreateDir(logPath)
if err != nil {
return err
Expand All @@ -1325,7 +1325,7 @@ func (c *Config) checkGCTScriptConfig() error {
c.GCTScript.MaxVirtualMachines = gctscript.DefaultMaxVirtualMachines
}

scriptPath := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "scripts")
scriptPath := c.GetDataPath("scripts")
err := common.CreateDir(scriptPath)
if err != nil {
return err
Expand Down Expand Up @@ -1362,7 +1362,7 @@ func (c *Config) checkDatabaseConfig() error {
}

if c.Database.Driver == database.DBSQLite || c.Database.Driver == database.DBSQLite3 {
databaseDir := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "database")
databaseDir := c.GetDataPath("database")
err := common.CreateDir(databaseDir)
if err != nil {
return err
Expand Down Expand Up @@ -1845,3 +1845,14 @@ func (c *Config) AssetTypeEnabled(a asset.Item, exch string) (bool, error) {
}
return true, nil
}

// GetDataPath gets the data path for the given subpath
func (c *Config) GetDataPath(elem ...string) string {
var baseDir string
if c.DataDirectory != "" {
baseDir = c.DataDirectory
} else {
baseDir = common.GetDefaultDataDir(runtime.GOOS)
}
return filepath.Join(append([]string{baseDir}, elem...)...)
}
42 changes: 42 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package config

import (
"path/filepath"
"runtime"
"strings"
"testing"

Expand Down Expand Up @@ -2034,3 +2036,43 @@ func TestRemoveExchange(t *testing.T) {
t.Fatal("exchange shouldn't exist")
}
}

func TestGetDataPath(t *testing.T) {
tests := []struct {
name string
dir string
elem []string
want string
}{
{
name: "empty",
dir: "",
elem: []string{},
want: common.GetDefaultDataDir(runtime.GOOS),
},
{
name: "empty a b",
dir: "",
elem: []string{"a", "b"},
want: filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "a", "b"),
},
{
name: "target",
dir: "target",
elem: []string{"a", "b"},
want: filepath.Join("target", "a", "b"),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := &Config{
DataDirectory: tt.dir,
}
if got := c.GetDataPath(tt.elem...); got != tt.want {
t.Errorf("Config.GetDataPath() = %v, want %v", got, tt.want)
}
})
}
}
1 change: 1 addition & 0 deletions config/config_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ var (
// Exchanges
type Config struct {
Name string `json:"name"`
DataDirectory string `json:"dataDirectory"`
EncryptConfig int `json:"encryptConfig"`
GlobalHTTPTimeout time.Duration `json:"globalHTTPTimeout"`
Database database.Config `json:"database"`
Expand Down
1 change: 1 addition & 0 deletions config_example.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"name": "Skynet",
"dataDirectory": "",
"encryptConfig": 0,
"globalHTTPTimeout": 15000000000,
"database": {
Expand Down
47 changes: 34 additions & 13 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,13 @@ func NewFromSettings(settings *Settings) (*Engine, error) {
if settings == nil {
return nil, errors.New("engine: settings is nil")
}
// collect flags
flag.Visit(func(f *flag.Flag) { flagSet[f.Name] = true })

var b Engine
b.Config = &config.Cfg
filePath, err := config.GetFilePath(settings.ConfigFile)
if err != nil {
return nil, err
}
var err error

log.Printf("Loading config file %s..\n", filePath)
err = b.Config.LoadConfig(filePath, settings.EnableDryRun)
b.Config, err = loadConfigWithSettings(settings)
if err != nil {
return nil, fmt.Errorf("failed to load config. Err: %s", err)
}
Expand All @@ -95,23 +92,47 @@ func NewFromSettings(settings *Settings) (*Engine, error) {
gctlog.Infoln(gctlog.Global, "Logger initialised.")
}

b.Settings.ConfigFile = filePath
b.Settings.DataDir = settings.DataDir
b.Settings.ConfigFile = settings.ConfigFile
b.Settings.DataDir = b.Config.GetDataPath()
b.Settings.CheckParamInteraction = settings.CheckParamInteraction

err = utils.AdjustGoMaxProcs(settings.GoMaxProcs)
if err != nil {
return nil, fmt.Errorf("unable to adjust runtime GOMAXPROCS value. Err: %s", err)
}

ValidateSettings(&b, settings)
validateSettings(&b, settings)
return &b, nil
}

// ValidateSettings validates and sets all bot settings
func ValidateSettings(b *Engine, s *Settings) {
flag.Visit(func(f *flag.Flag) { flagSet[f.Name] = true })
// loadConfigWithSettings creates configuration based on the provided settings
func loadConfigWithSettings(settings *Settings) (*config.Config, error) {
filePath, err := config.GetFilePath(settings.ConfigFile)
if err != nil {
return nil, err
}
log.Printf("Loading config file %s..\n", filePath)

conf := &config.Cfg
err = conf.ReadConfig(filePath, settings.EnableDryRun)
if err != nil {
return nil, fmt.Errorf(config.ErrFailureOpeningConfig, filePath, err)
}
// Apply overrides from settings
if flagSet["datadir"] {
// warn if dryrun isn't enabled
if !settings.EnableDryRun {
log.Println("Command line argument '-datadir' induces dry run mode.")
}
settings.EnableDryRun = true
conf.DataDirectory = settings.DataDir
}

return conf, conf.CheckConfig()
}

// validateSettings validates and sets all bot settings
func validateSettings(b *Engine, s *Settings) {
b.Settings.Verbose = s.Verbose
b.Settings.EnableDryRun = s.EnableDryRun
b.Settings.EnableAllExchanges = s.EnableAllExchanges
Expand Down
73 changes: 73 additions & 0 deletions engine/engine_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package engine

import (
"os"
"testing"

"github.com/thrasher-corp/gocryptotrader/config"
)

func TestLoadConfigWithSettings(t *testing.T) {
empty := ""
somePath := "somePath"
// Clean up after the tests
defer os.RemoveAll(somePath)
tests := []struct {
name string
flags []string
settings *Settings
want *string
wantErr bool
}{
{
name: "invalid file",
settings: &Settings{
ConfigFile: "nonExistent.json",
},
wantErr: true,
},
{
name: "test file",
settings: &Settings{
ConfigFile: config.TestFile,
EnableDryRun: true,
},
want: &empty,
wantErr: false,
},
{
name: "data dir in settings overrides config data dir",
flags: []string{"datadir"},
settings: &Settings{
ConfigFile: config.TestFile,
DataDir: somePath,
EnableDryRun: true,
},
want: &somePath,
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
// prepare the 'flags'
flagSet = make(map[string]bool)
for _, v := range tt.flags {
flagSet[v] = true
}
// Run the test
got, err := loadConfigWithSettings(tt.settings)
if (err != nil) != tt.wantErr {
t.Errorf("loadConfigWithSettings() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil || tt.want != nil {
if (got == nil && tt.want != nil) || (got != nil && tt.want == nil) {
t.Errorf("loadConfigWithSettings() = is nil %v, want nil %v", got == nil, tt.want == nil)
} else if got.DataDirectory != *tt.want {
t.Errorf("loadConfigWithSettings() = %v, want %v", got.DataDirectory, *tt.want)
}
}
})
}
}

0 comments on commit a67b5cf

Please sign in to comment.