Skip to content

Commit

Permalink
Fix go migration up
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Sep 29, 2016
1 parent 96680a8 commit f49670b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 144 deletions.
18 changes: 7 additions & 11 deletions down.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
package goose

import (
"database/sql"
"fmt"
)
import "database/sql"

func Down(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
if err != nil {
return err
}

previous, err := GetPreviousDBVersion(dir, current)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
if err != nil {
if err == ErrNoPreviousVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
}
return err
}
return err
}
migrations.Sort(false) // descending, Next will be Previous

previous, err := migrations.Next(current)
if err != nil {
return err
}

Expand Down
7 changes: 3 additions & 4 deletions goose.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ func checkVersionDuplicates(dir string) error {
return err
}

// Try sorting all migrations, so we get panic on any duplicates.
ms := migrationSorter(migrations)
ms.Sort(true)
ms.Sort(false)
// try both directions
migrations.Sort(false)
migrations.Sort(true)
return nil
}

Expand Down
169 changes: 51 additions & 118 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"sort"
Expand All @@ -18,7 +17,10 @@ import (
var (
ErrNoPreviousVersion = errors.New("no previous version found")
ErrNoNextVersion = errors.New("no next version found")
goMigrations []*Migration

MaxVersion = 9223372036854775807 // max(int64)

goMigrations []*Migration
)

type MigrationRecord struct {
Expand All @@ -36,18 +38,59 @@ type Migration struct {
Down func(*sql.Tx) error // Down go migration function
}

type migrationSorter []*Migration
type Migrations []*Migration

// helpers so we can use pkg sort
func (ms migrationSorter) Len() int { return len(ms) }
func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms migrationSorter) Less(i, j int) bool {
func (ms Migrations) Len() int { return len(ms) }
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms Migrations) Less(i, j int) bool {
if ms[i].Version == ms[j].Version {
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
}
return ms[i].Version < ms[j].Version
}

func (ms Migrations) Sort(up bool) {

// sort ascending or descending by version
if up {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}

// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}

func (ms Migrations) Last() (int64, error) {
if len(ms) == 0 {
return -1, ErrNoNextVersion
}

return ms[len(ms)-1].Version, nil
}

func (ms Migrations) Next(current int64) (int64, error) {
exceptLast := ms[:len(ms)-1]

for i, migration := range exceptLast {
if migration.Version == current {
return ms[i+1].Version, nil
}
}

return -1, ErrNoNextVersion
}

func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
v, _ := NumericComponent(filename)
Expand Down Expand Up @@ -77,7 +120,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
return nil
}

ms := migrationSorter(migrations)
ms := Migrations(migrations)
direction := current < target
ms.Sort(direction)

Expand Down Expand Up @@ -122,7 +165,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) {

// collect all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version
func CollectMigrations(dirpath string, current, target int64) (m []*Migration, err error) {
func CollectMigrations(dirpath string, current, target int64) (m Migrations, err error) {

// extract the numeric component of each migration,
// filter out any uninteresting files,
Expand Down Expand Up @@ -169,27 +212,6 @@ func versionFilter(v, current, target int64) bool {
return false
}

func (ms migrationSorter) Sort(direction bool) {

// sort ascending or descending by version
if direction {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}

// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}

// look for migration scripts with names in the form:
// XXX_descriptivename.ext
// where XXX specifies the version number
Expand Down Expand Up @@ -298,95 +320,6 @@ func GetDBVersion(db *sql.DB) (int64, error) {
return version, nil
}

func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) {

previous = -1
sawGivenVersion := false

filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {

if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > previous && v < version {
previous = v
}
if v == version {
sawGivenVersion = true
}
}
}

return nil
})

if previous == -1 {
if sawGivenVersion {
// the given version is (likely) valid but we didn't find
// anything before it.
// 'previous' must reflect that no migrations have been applied.
previous = 0
} else {
err = ErrNoPreviousVersion
}
}

return
}

func GetNextDBVersion(dirpath string, version int64) (next int64, err error) {

next = 9223372036854775807 // max(int64)

filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {

if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v < next && v > version {
next = v
}
}
}

return nil
})

if next == 9223372036854775807 {
next = version
err = ErrNoNextVersion
}

return
}

// helper to identify the most recent possible version
// within a folder of migration scripts
func GetMostRecentDBVersion(dirpath string) (version int64, err error) {

version = -1

filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if walkerr != nil {
return walkerr
}

if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > version {
version = v
}
}
}

return nil
})

if version == -1 {
err = errors.New("no valid version found")
}

return
}

func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {

if migrationType != "go" && migrationType != "sql" {
Expand Down
6 changes: 3 additions & 3 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func newMigration(v int64, src string) *Migration {

func TestMigrationMapSortUp(t *testing.T) {

ms := migrationSorter{}
ms := Migrations{}

// insert in any order
ms = append(ms, newMigration(20120000, "test"))
Expand All @@ -27,7 +27,7 @@ func TestMigrationMapSortUp(t *testing.T) {

func TestMigrationMapSortDown(t *testing.T) {

ms := migrationSorter{}
ms := Migrations{}

// insert in any order
ms = append(ms, newMigration(20120000, "test"))
Expand All @@ -42,7 +42,7 @@ func TestMigrationMapSortDown(t *testing.T) {
validateMigrationSort(t, ms, sorted)
}

func validateMigrationSort(t *testing.T, ms migrationSorter, sorted []int64) {
func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) {

for i, m := range ms {
if sorted[i] != m.Version {
Expand Down
8 changes: 7 additions & 1 deletion redo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ func Redo(db *sql.DB, dir string) error {
return err
}

previous, err := GetPreviousDBVersion(dir, current)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(false) // descending, Next will be Previous

previous, err := migrations.Next(current)
if err != nil {
return err
}
Expand Down
8 changes: 3 additions & 5 deletions status.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ func Status(db *sql.DB, dir string) error {
if err != nil {
return err
}

ms := migrationSorter(migrations)
ms.Sort(true)
migrations.Sort(true)

// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
Expand All @@ -26,8 +24,8 @@ func Status(db *sql.DB, dir string) error {
fmt.Println("goose: status")
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, m := range ms {
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
for _, migration := range migrations {
printMigrationStatus(db, migration.Version, filepath.Base(migration.Source))
}

return nil
Expand Down
16 changes: 14 additions & 2 deletions up.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import (
)

func Up(db *sql.DB, dir string) error {
target, err := GetMostRecentDBVersion(dir)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)

target, err := migrations.Last()
if err != nil {
return err
}
Expand All @@ -18,12 +24,18 @@ func Up(db *sql.DB, dir string) error {
}

func UpByOne(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)

current, err := GetDBVersion(db)
if err != nil {
return err
}

next, err := GetNextDBVersion(dir, current)
next, err := migrations.Next(current)
if err != nil {
if err == ErrNoNextVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
Expand Down

0 comments on commit f49670b

Please sign in to comment.