| package source |
| |
| import ( |
| "bytes" |
| "errors" |
| "flag" |
| "fmt" |
| "go/ast" |
| "go/format" |
| "go/parser" |
| "go/token" |
| "os" |
| "runtime" |
| "strings" |
| ) |
| |
| // IsUpdate is returns true if the -update flag is set. It indicates the user |
| // running the tests would like to update any golden values. |
| func IsUpdate() bool { |
| if Update { |
| return true |
| } |
| return flag.Lookup("update").Value.(flag.Getter).Get().(bool) |
| } |
| |
| // Update is a shim for testing, and for compatibility with the old -update-golden |
| // flag. |
| var Update bool |
| |
| func init() { |
| if f := flag.Lookup("update"); f != nil { |
| getter, ok := f.Value.(flag.Getter) |
| msg := "some other package defined an incompatible -update flag, expected a flag.Bool" |
| if !ok { |
| panic(msg) |
| } |
| if _, ok := getter.Get().(bool); !ok { |
| panic(msg) |
| } |
| return |
| } |
| flag.Bool("update", false, "update golden values") |
| } |
| |
| // ErrNotFound indicates that UpdateExpectedValue failed to find the |
| // variable to update, likely because it is not a package level variable. |
| var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value") |
| |
| // UpdateExpectedValue looks for a package-level variable with a name that |
| // starts with expected in the arguments to the caller. If the variable is |
| // found, the value of the variable will be updated to value of the other |
| // argument to the caller. |
| func UpdateExpectedValue(stackIndex int, x, y interface{}) error { |
| _, filename, line, ok := runtime.Caller(stackIndex + 1) |
| if !ok { |
| return errors.New("failed to get call stack") |
| } |
| debug("call stack position: %s:%d", filename, line) |
| |
| fileset := token.NewFileSet() |
| astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) |
| if err != nil { |
| return fmt.Errorf("failed to parse source file %s: %w", filename, err) |
| } |
| |
| expr, err := getCallExprArgs(fileset, astFile, line) |
| if err != nil { |
| return fmt.Errorf("call from %s:%d: %w", filename, line, err) |
| } |
| |
| if len(expr) < 3 { |
| debug("not enough arguments %d: %v", |
| len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}}) |
| return ErrNotFound |
| } |
| |
| argIndex, ident := getIdentForExpectedValueArg(expr) |
| if argIndex < 0 || ident == nil { |
| debug("no arguments started with the word 'expected': %v", |
| debugFormatNode{Node: &ast.CallExpr{Args: expr}}) |
| return ErrNotFound |
| } |
| |
| value := x |
| if argIndex == 1 { |
| value = y |
| } |
| |
| strValue, ok := value.(string) |
| if !ok { |
| debug("value must be type string, got %T", value) |
| return ErrNotFound |
| } |
| return UpdateVariable(filename, fileset, astFile, ident, strValue) |
| } |
| |
| // UpdateVariable writes to filename the contents of astFile with the value of |
| // the variable updated to value. |
| func UpdateVariable( |
| filename string, |
| fileset *token.FileSet, |
| astFile *ast.File, |
| ident *ast.Ident, |
| value string, |
| ) error { |
| obj := ident.Obj |
| if obj == nil { |
| return ErrNotFound |
| } |
| if obj.Kind != ast.Con && obj.Kind != ast.Var { |
| debug("can only update var and const, found %v", obj.Kind) |
| return ErrNotFound |
| } |
| |
| switch decl := obj.Decl.(type) { |
| case *ast.ValueSpec: |
| if len(decl.Names) != 1 { |
| debug("more than one name in ast.ValueSpec") |
| return ErrNotFound |
| } |
| |
| decl.Values[0] = &ast.BasicLit{ |
| Kind: token.STRING, |
| Value: "`" + value + "`", |
| } |
| |
| case *ast.AssignStmt: |
| if len(decl.Lhs) != 1 { |
| debug("more than one name in ast.AssignStmt") |
| return ErrNotFound |
| } |
| |
| decl.Rhs[0] = &ast.BasicLit{ |
| Kind: token.STRING, |
| Value: "`" + value + "`", |
| } |
| |
| default: |
| debug("can only update *ast.ValueSpec, found %T", obj.Decl) |
| return ErrNotFound |
| } |
| |
| var buf bytes.Buffer |
| if err := format.Node(&buf, fileset, astFile); err != nil { |
| return fmt.Errorf("failed to format file after update: %w", err) |
| } |
| |
| fh, err := os.Create(filename) |
| if err != nil { |
| return fmt.Errorf("failed to open file %v: %w", filename, err) |
| } |
| if _, err = fh.Write(buf.Bytes()); err != nil { |
| return fmt.Errorf("failed to write file %v: %w", filename, err) |
| } |
| if err := fh.Sync(); err != nil { |
| return fmt.Errorf("failed to sync file %v: %w", filename, err) |
| } |
| return nil |
| } |
| |
| func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) { |
| for i := 1; i < 3; i++ { |
| switch e := expr[i].(type) { |
| case *ast.Ident: |
| if strings.HasPrefix(strings.ToLower(e.Name), "expected") { |
| return i, e |
| } |
| } |
| } |
| return -1, nil |
| } |