package cli import ( "errors" "fmt" "os" "path/filepath" "strconv" "strings" "time" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( errInvalidSequenceWidth = errors.New("Digits must be positive") errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") errInvalidTimeFormat = errors.New("Time format may not be empty") ) func nextSeqVersion(matches []string, seqDigits int) (string, error) { if seqDigits <= 0 { return "", errInvalidSequenceWidth } nextSeq := uint64(1) if len(matches) > 0 { filename := matches[len(matches)-1] matchSeqStr := filepath.Base(filename) idx := strings.Index(matchSeqStr, "_") if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit return "", fmt.Errorf("Malformed migration filename: %s", filename) } var err error matchSeqStr = matchSeqStr[0:idx] nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) if err != nil { return "", err } nextSeq++ } version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) if len(version) > seqDigits { return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) } return version, nil } func timeVersion(startTime time.Time, format string) (version string, err error) { switch format { case "": err = errInvalidTimeFormat case "unix": version = strconv.FormatInt(startTime.Unix(), 10) case "unixNano": version = strconv.FormatInt(startTime.UnixNano(), 10) default: version = startTime.Format(format) } return } // createCmd (meant to be called via a CLI command) creates a new migration func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error { if seq && format != defaultTimeFormat { return errIncompatibleSeqAndFormat } var version string var err error dir = filepath.Clean(dir) ext = "." + strings.TrimPrefix(ext, ".") if seq { matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) if err != nil { return err } version, err = nextSeqVersion(matches, seqDigits) if err != nil { return err } } else { version, err = timeVersion(startTime, format) if err != nil { return err } } versionGlob := filepath.Join(dir, version+"_*"+ext) matches, err := filepath.Glob(versionGlob) if err != nil { return err } if len(matches) > 0 { return fmt.Errorf("duplicate migration version: %s", version) } if err = os.MkdirAll(dir, os.ModePerm); err != nil { return err } for _, direction := range []string{"up", "down"} { basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) filename := filepath.Join(dir, basename) if err = createFile(filename); err != nil { return err } if print { absPath, _ := filepath.Abs(filename) log.Println(absPath) } } return nil } func createFile(filename string) error { // create exclusive (fails if file already exists) // os.Create() specifies 0666 as the FileMode, so we're doing the same f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) if err != nil { return err } return f.Close() } func gotoCmd(m *migrate.Migrate, v uint) error { if err := m.Migrate(v); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } return nil } func upCmd(m *migrate.Migrate, limit int) error { if limit >= 0 { if err := m.Steps(limit); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } else { if err := m.Up(); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } return nil } func downCmd(m *migrate.Migrate, limit int) error { if limit >= 0 { if err := m.Steps(-limit); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } else { if err := m.Down(); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } return nil } func dropCmd(m *migrate.Migrate) error { if err := m.Drop(); err != nil { return err } return nil } func forceCmd(m *migrate.Migrate, v int) error { if err := m.Force(v); err != nil { return err } return nil } func versionCmd(m *migrate.Migrate) error { v, dirty, err := m.Version() if err != nil { return err } if dirty { log.Printf("%v (dirty)\n", v) } else { log.Println(v) } return nil } // numDownMigrationsFromArgs returns an int for number of migrations to apply // and a bool indicating if we need a confirm before applying func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { if applyAll { if len(args) > 0 { return 0, false, errors.New("-all cannot be used with other arguments") } return -1, false, nil } switch len(args) { case 0: return -1, true, nil case 1: downValue := args[0] n, err := strconv.ParseUint(downValue, 10, 64) if err != nil { return 0, false, errors.New("can't read limit argument N") } return int(n), false, nil default: return 0, false, errors.New("too many arguments") } }