blob: 6371f7da24bc42487c0b85d0bfe08ba47327d2cf [file] [log] [blame]
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MockGen generates mock implementations of Go interfaces.
package main
// TODO: This does not support recursive embedded interfaces.
// TODO: This does not support embedding package-local interfaces in a separate file.
import (
"flag"
"fmt"
"go/token"
"io"
"log"
"os"
"path"
"strconv"
"strings"
"unicode"
"github.com/golang/mock/mockgen/model"
)
const (
gomockImportPath = "github.com/golang/mock/gomock"
)
var (
source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
destination = flag.String("destination", "", "Output file; defaults to stdout.")
packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.")
selfPackage = flag.String("self_package", "", "If set, the package this mock will be part of.")
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
)
func main() {
flag.Usage = usage
flag.Parse()
var pkg *model.Package
var err error
if *source != "" {
pkg, err = ParseFile(*source)
} else {
if flag.NArg() != 2 {
log.Fatal("Expected exactly two arguments")
}
pkg, err = Reflect(flag.Arg(0), strings.Split(flag.Arg(1), ","))
}
if err != nil {
log.Fatalf("Loading input failed: %v", err)
}
if *debugParser {
pkg.Print(os.Stdout)
return
}
dst := os.Stdout
if len(*destination) > 0 {
f, err := os.Create(*destination)
if err != nil {
log.Fatalf("Failed opening destination file: %v", err)
}
defer f.Close()
dst = f
}
packageName := *packageOut
if packageName == "" {
// pkg.Name in reflect mode is the base name of the import path,
// which might have characters that are illegal to have in package names.
packageName = "mock_" + sanitize(pkg.Name)
}
g := generator{
w: dst,
}
if *source != "" {
g.filename = *source
} else {
g.srcPackage = flag.Arg(0)
g.srcInterfaces = flag.Arg(1)
}
if err := g.Generate(pkg, packageName); err != nil {
log.Fatalf("Failed generating mock: %v", err)
}
}
func usage() {
io.WriteString(os.Stderr, usageText)
flag.PrintDefaults()
}
const usageText = `mockgen has two modes of operation: source and reflect.
Source mode generates mock interfaces from a source file.
It is enabled by using the -source flag. Other flags that
may be useful in this mode are -imports and -aux_files.
Example:
mockgen -source=foo.go [other options]
Reflect mode generates mock interfaces by building a program
that uses reflection to understand interfaces. It is enabled
by passing two non-flag arguments: an import path, and a
comma-separated list of symbols.
Example:
mockgen database/sql/driver Conn,Driver
`
type generator struct {
w io.Writer
indent string
filename string // may be empty
srcPackage, srcInterfaces string // may be empty
packageMap map[string]string // map from import path to package name
}
func (g *generator) p(format string, args ...interface{}) {
fmt.Fprintf(g.w, g.indent+format+"\n", args...)
}
func (g *generator) in() {
g.indent += "\t"
}
func (g *generator) out() {
if len(g.indent) > 0 {
g.indent = g.indent[0 : len(g.indent)-1]
}
}
func removeDot(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s[0 : len(s)-1]
}
return s
}
// sanitize cleans up a string to make a suitable package name.
func sanitize(s string) string {
t := ""
for _, r := range s {
if t == "" {
if unicode.IsLetter(r) || r == '_' {
t += string(r)
continue
}
} else {
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
t += string(r)
continue
}
}
t += "_"
}
if t == "_" {
t = "x"
}
return t
}
func (g *generator) Generate(pkg *model.Package, pkgName string) error {
g.p("// Automatically generated by MockGen. DO NOT EDIT!")
if g.filename != "" {
g.p("// Source: %v", g.filename)
} else {
g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces)
}
g.p("")
// Get all required imports, and generate unique names for them all.
im := pkg.Imports()
im[gomockImportPath] = true
g.packageMap = make(map[string]string, len(im))
localNames := make(map[string]bool, len(im))
for pth := range im {
base := sanitize(path.Base(pth))
// Local names for an imported package can usually be the basename of the import path.
// A couple of situations don't permit that, such as duplicate local names
// (e.g. importing "html/template" and "text/template"), or where the basename is
// a keyword (e.g. "foo/case").
// try base0, base1, ...
pkgName := base
i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
pkgName = base + strconv.Itoa(i)
i++
}
g.packageMap[pth] = pkgName
localNames[pkgName] = true
}
g.p("package %v", pkgName)
g.p("")
g.p("import (")
g.in()
for path, pkg := range g.packageMap {
if path == *selfPackage {
continue
}
g.p("%v %q", pkg, path)
}
for _, path := range pkg.DotImports {
g.p(". %q", path)
}
g.out()
g.p(")")
for _, intf := range pkg.Interfaces {
if err := g.GenerateMockInterface(intf); err != nil {
return err
}
}
return nil
}
// The name of the mock type to use for the given interface identifier.
func mockName(typeName string) string {
return "Mock" + typeName
}
func (g *generator) GenerateMockInterface(intf *model.Interface) error {
mockType := mockName(intf.Name)
g.p("")
g.p("// Mock of %v interface", intf.Name)
g.p("type %v struct {", mockType)
g.in()
g.p("ctrl *gomock.Controller")
g.p("recorder *_%vRecorder", mockType)
g.out()
g.p("}")
g.p("")
g.p("// Recorder for %v (not exported)", mockType)
g.p("type _%vRecorder struct {", mockType)
g.in()
g.p("mock *%v", mockType)
g.out()
g.p("}")
g.p("")
// TODO: Re-enable this if we can import the interface reliably.
//g.p("// Verify that the mock satisfies the interface at compile time.")
//g.p("var _ %v = (*%v)(nil)", typeName, mockType)
//g.p("")
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
g.in()
g.p("mock := &%v{ctrl: ctrl}", mockType)
g.p("mock.recorder = &_%vRecorder{mock}", mockType)
g.p("return mock")
g.out()
g.p("}")
g.p("")
// XXX: possible name collision here if someone has EXPECT in their interface.
g.p("func (_m *%v) EXPECT() *_%vRecorder {", mockType, mockType)
g.in()
g.p("return _m.recorder")
g.out()
g.p("}")
g.GenerateMockMethods(mockType, intf, *selfPackage)
return nil
}
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
for _, m := range intf.Methods {
g.p("")
g.GenerateMockMethod(mockType, m, pkgOverride)
g.p("")
g.GenerateMockRecorderMethod(mockType, m)
}
}
// GenerateMockMethod generates a mock method implementation.
// If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
args := make([]string, len(m.In))
argNames := make([]string, len(m.In))
for i, p := range m.In {
name := p.Name
if name == "" {
name = fmt.Sprintf("_param%d", i)
}
ts := p.Type.String(g.packageMap, pkgOverride)
args[i] = name + " " + ts
argNames[i] = name
}
if m.Variadic != nil {
name := m.Variadic.Name
if name == "" {
name = fmt.Sprintf("_param%d", len(m.In))
}
ts := m.Variadic.Type.String(g.packageMap, pkgOverride)
args = append(args, name+" ..."+ts)
argNames = append(argNames, name)
}
argString := strings.Join(args, ", ")
rets := make([]string, len(m.Out))
for i, p := range m.Out {
rets[i] = p.Type.String(g.packageMap, pkgOverride)
}
retString := strings.Join(rets, ", ")
if len(rets) > 1 {
retString = "(" + retString + ")"
}
if retString != "" {
retString = " " + retString
}
g.p("func (_m *%v) %v(%v)%v {", mockType, m.Name, argString, retString)
g.in()
callArgs := strings.Join(argNames, ", ")
if callArgs != "" {
callArgs = ", " + callArgs
}
if m.Variadic != nil {
// Non-trivial. The generated code must build a []interface{},
// but the variadic argument may be any type.
g.p("_s := []interface{}{%s}", strings.Join(argNames[:len(argNames)-1], ", "))
g.p("for _, _x := range %s {", argNames[len(argNames)-1])
g.in()
g.p("_s = append(_s, _x)")
g.out()
g.p("}")
callArgs = ", _s..."
}
if len(m.Out) == 0 {
g.p(`_m.ctrl.Call(_m, "%v"%v)`, m.Name, callArgs)
} else {
g.p(`ret := _m.ctrl.Call(_m, "%v"%v)`, m.Name, callArgs)
// Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
// Happily, this coincides with the semantics we want here.
retNames := make([]string, len(rets))
for i, t := range rets {
retNames[i] = fmt.Sprintf("ret%d", i)
g.p("%s, _ := ret[%d].(%s)", retNames[i], i, t)
}
g.p("return " + strings.Join(retNames, ", "))
}
g.out()
g.p("}")
return nil
}
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
nargs := len(m.In)
args := make([]string, nargs)
for i := 0; i < nargs; i++ {
args[i] = "arg" + strconv.Itoa(i)
}
argString := strings.Join(args, ", ")
if nargs > 0 {
argString += " interface{}"
}
if m.Variadic != nil {
if nargs > 0 {
argString += ", "
}
argString += fmt.Sprintf("arg%d ...interface{}", nargs)
}
g.p("func (_mr *_%vRecorder) %v(%v) *gomock.Call {", mockType, m.Name, argString)
g.in()
callArgs := strings.Join(args, ", ")
if nargs > 0 {
callArgs = ", " + callArgs
}
if m.Variadic != nil {
if nargs == 0 {
// Easy: just use ... to push the arguments through.
callArgs = ", arg0..."
} else {
// Hard: create a temporary slice.
g.p("_s := append([]interface{}{%s}, arg%d...)", strings.Join(args, ", "), nargs)
callArgs = ", _s..."
}
}
g.p(`return _mr.mock.ctrl.RecordCall(_mr.mock, "%v"%v)`, m.Name, callArgs)
g.out()
g.p("}")
return nil
}