blob: a133f03e54f8f612dbaf553463aca05bee1e4269 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package main
import (
"bufio"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
"unicode"
"github.com/mattes/migrate"
"github.com/mattes/migrate/database/mysql"
_ "github.com/mattes/migrate/source/file"
"go.chromium.org/luci/common/logging"
)
// inputStream is mocked in tests.
var inputStream io.Reader = os.Stdin
// migrateLogger adapts our logger to the one used by 'migrate' package.
type migrateLogger struct {
ctx context.Context
}
func (m migrateLogger) Printf(f string, v ...interface{}) { logging.Infof(m.ctx, "migrate: "+f, v...) }
func (m migrateLogger) Verbose() bool { return false }
// DefaultMigrationsPath returns a path to a directory with migration files.
func DefaultMigrationsPath() string {
p, err := filepath.Abs("migrations")
if err != nil {
panic(err)
}
return p
}
// CreateEmptyMigration asks user for a migration title and creates two empty
// appropriately named *.sql files (for 'up' and 'down' migrations).
func CreateEmptyMigration(migrationsPath string) error {
if err := os.MkdirAll(migrationsPath, 0777); err != nil {
return err
}
// Read a line and convert to snake case.
fmt.Printf("Enter a name for the new migration:\n> ")
reader := bufio.NewReader(inputStream)
name, _ := reader.ReadString('\n')
name = strings.Map(func(r rune) rune {
if unicode.IsSpace(r) || r == '-' {
return '_'
}
return unicode.ToLower(r)
}, strings.TrimSpace(name))
// Grab new migration number and put last migration name into the file, so
// that if multiple CLs with same sequence number are committed, there'll be
// a merge conflict in 'last_version' file.
seq, err := bumpSequenceFile(filepath.Join(migrationsPath, "last_version"), name)
if err != nil {
return err
}
base := filepath.Join(migrationsPath, fmt.Sprintf("%03d_%s", seq, name))
files := []string{base + ".up.sql", base + ".down.sql"}
for _, f := range files {
fd, err := os.Create(f)
if err != nil {
return err
}
fd.Close()
fmt.Printf("Created %s\n", f)
}
fmt.Println(
"Populate these files with SQL statements to migrate schema up (for roll-forwards)\n" +
"and down (for roll-backs). Test locally that migrations apply in both directions!")
return nil
}
// bumpSequenceFile non-atomically increments the integer in given file and
// returns its new value.
//
// If the file doesn't exist, it is created and the integer is set to 1.
func bumpSequenceFile(path, migration string) (seq uint64, err error) {
switch buf, err := ioutil.ReadFile(path); {
case os.IsNotExist(err):
seq = 0
case err != nil:
return 0, err
default:
fields := strings.Fields(string(buf))
if len(fields) == 0 {
return 0, fmt.Errorf("malformed last_version file %q", path)
}
seq, err = strconv.ParseUint(fields[0], 10, 32)
if err != nil {
return 0, err
}
}
seq++
str := fmt.Sprintf("%d %s\n", seq, migration)
if err := ioutil.WriteFile(path, []byte(str), 0644); err != nil {
return 0, err
}
return seq, nil
}
// WithMigrate sets up instance of migrate.Migrate and calls 'body'.
func WithMigrate(ctx context.Context, migrationsPath string, conf *DBConfig, socket string, body func(m *migrate.Migrate) error) error {
db, err := OpenDB(ctx, socket, conf, false)
if err != nil {
return err
}
defer db.Close()
driver, err := mysql.WithInstance(db, &mysql.Config{})
if err != nil {
return err
}
defer driver.Close()
m, err := migrate.NewWithDatabaseInstance("file://"+filepath.ToSlash(migrationsPath), "mysql", driver)
if err != nil {
return err
}
defer m.Close()
m.Log = migrateLogger{ctx}
return body(m)
}
// ReportVersion logs current schema version (as fetched from the DB itself).
func ReportVersion(ctx context.Context, m *migrate.Migrate) {
switch ver, dirty, err := m.Version(); {
case err == migrate.ErrNilVersion:
logging.Infof(ctx, "Current version: none")
case err == nil && !dirty:
logging.Infof(ctx, "Current version: %d", ver)
case err == nil && dirty:
logging.Warningf(ctx, "Current version: %d (dirty!)", ver)
default:
fmt.Println(err)
logging.Errorf(ctx, "Current version: unknown (%s)", err)
}
}