Skip to content

Commit

Permalink
Merge pull request DavidHuie#17 from veqryn/master
Browse files Browse the repository at this point in the history
add a logging interface, allow a user-specified logger
  • Loading branch information
DavidHuie committed Apr 8, 2016
2 parents eb6ae24 + 2de185a commit 1b90a57
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 24 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
*.test
*.test

/.idea
/gomigrate.iml
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ and a directory to migration files, create a migrator:
migrator, _ := gomigrate.NewMigrator(db, gomigrate.Postgres{}, "./migrations")
```

You may also specify a specific logger to use, such as logrus:

```go
migrator, _ := gomigrate.NewMigratorWithLogger(db, gomigrate.Postgres{}, "./migrations", logrus.New())
```

To migrate the database, run:

```go
Expand Down
61 changes: 38 additions & 23 deletions gomigrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
)
Expand All @@ -32,6 +33,14 @@ type Migrator struct {
MigrationsPath string
dbAdapter Migratable
migrations map[uint64]*Migration
logger Logger
}

type Logger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
Println(v ...interface{})
Fatalf(format string, v ...interface{})
}

// Returns true if the migration table already exists.
Expand All @@ -40,45 +49,51 @@ func (m *Migrator) MigrationTableExists() (bool, error) {
var tableName string
err := row.Scan(&tableName)
if err == sql.ErrNoRows {
log.Print("Migrations table not found")
m.logger.Print("Migrations table not found")
return false, nil
}
if err != nil {
log.Printf("Error checking for migration table: %v", err)
m.logger.Printf("Error checking for migration table: %v", err)
return false, err
}
log.Print("Migrations table found")
m.logger.Print("Migrations table found")
return true, nil
}

// Creates the migrations table if it doesn't exist.
func (m *Migrator) CreateMigrationsTable() error {
_, err := m.DB.Exec(m.dbAdapter.CreateMigrationTableSql())
if err != nil {
log.Fatalf("Error creating migrations table: %v", err)
m.logger.Fatalf("Error creating migrations table: %v", err)
}

log.Printf("Created migrations table: %s", migrationTableName)
m.logger.Printf("Created migrations table: %s", migrationTableName)

return nil
}

// Returns a new migrator.
func NewMigrator(db *sql.DB, adapter Migratable, migrationsPath string) (*Migrator, error) {
return NewMigratorWithLogger(db, adapter, migrationsPath, log.New(os.Stderr, "[gomigrate] ", log.LstdFlags))
}

// Returns a new migrator with the specified logger.
func NewMigratorWithLogger(db *sql.DB, adapter Migratable, migrationsPath string, logger Logger) (*Migrator, error) {
// Normalize the migrations path.
path := []byte(migrationsPath)
pathLength := len(path)
if path[pathLength-1] != '/' {
path = append(path, '/')
}

log.Printf("Migrations path: %s", path)
logger.Printf("Migrations path: %s", path)

migrator := Migrator{
db,
string(path),
adapter,
make(map[uint64]*Migration),
logger,
}

// Create the migrations table if it doesn't exist.
Expand Down Expand Up @@ -109,17 +124,17 @@ func (m *Migrator) fetchMigrations() error {

matches, err := filepath.Glob(string(pathGlob))
if err != nil {
log.Fatalf("Error while globbing migrations: %v", err)
m.logger.Fatalf("Error while globbing migrations: %v", err)
}

for _, match := range matches {
num, migrationType, name, err := parseMigrationPath(match)
if err != nil {
log.Printf("Invalid migration file found: %s", match)
m.logger.Printf("Invalid migration file found: %s", match)
continue
}

log.Printf("Migration file found: %s", match)
m.logger.Printf("Migration file found: %s", match)

migration, ok := m.migrations[num]
if !ok {
Expand All @@ -140,12 +155,12 @@ func (m *Migrator) fetchMigrations() error {
if path == "" {
path = migration.DownPath
}
log.Printf("Invalid migration pair for path: %s", path)
m.logger.Printf("Invalid migration pair for path: %s", path)
return InvalidMigrationPair
}
}

log.Printf("Migrations file pairs found: %v", len(m.migrations))
m.logger.Printf("Migrations file pairs found: %v", len(m.migrations))

return nil
}
Expand All @@ -161,7 +176,7 @@ func (m *Migrator) getMigrationStatuses() error {
continue
}
if err != nil {
log.Printf(
m.logger.Printf(
"Error getting migration status for %s: %v",
migration.Name,
err,
Expand Down Expand Up @@ -205,16 +220,16 @@ func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) err
return InvalidMigrationType
}

log.Printf("Applying migration: %s", path)
m.logger.Printf("Applying migration: %s", path)

sql, err := ioutil.ReadFile(path)
if err != nil {
log.Printf("Error reading migration: %s", path)
m.logger.Printf("Error reading migration: %s", path)
return err
}
transaction, err := m.DB.Begin()
if err != nil {
log.Printf("Error opening transaction: %v", err)
m.logger.Printf("Error opening transaction: %v", err)
return err
}

Expand All @@ -225,23 +240,23 @@ func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) err
for _, cmd := range commands {
result, err := transaction.Exec(cmd)
if err != nil {
log.Printf("Error executing migration: %v", err)
m.logger.Printf("Error executing migration: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
m.logger.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
}
if result != nil {
if rowsAffected, err := result.RowsAffected(); err != nil {
log.Printf("Error getting rows affected: %v", err)
m.logger.Printf("Error getting rows affected: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
m.logger.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
} else {
log.Printf("Rows affected: %v", rowsAffected)
m.logger.Printf("Rows affected: %v", rowsAffected)
}
}
}
Expand All @@ -259,17 +274,17 @@ func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) err
)
}
if err != nil {
log.Printf("Error logging migration: %v", err)
m.logger.Printf("Error logging migration: %v", err)
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
log.Printf("Error rolling back transaction: %v", rollbackErr)
m.logger.Printf("Error rolling back transaction: %v", rollbackErr)
return rollbackErr
}
return err
}

// Commit and update the struct status.
if err := transaction.Commit(); err != nil {
log.Printf("Error commiting transaction: %v", err)
m.logger.Printf("Error commiting transaction: %v", err)
return err
}
if mType == upMigration {
Expand Down

0 comments on commit 1b90a57

Please sign in to comment.