blob: 47d3796ca705b6d7ead2cfd59a13ec2de9010631 [file] [log] [blame]
package sqlmock
import (
"fmt"
"regexp"
"strings"
)
var re = regexp.MustCompile("\\s+")
// strip out new lines and trim spaces
func stripQuery(q string) (s string) {
return strings.TrimSpace(re.ReplaceAllString(q, " "))
}
// QueryMatcher is an SQL query string matcher interface,
// which can be used to customize validation of SQL query strings.
// As an example, external library could be used to build
// and validate SQL ast, columns selected.
//
// sqlmock can be customized to implement a different QueryMatcher
// configured through an option when sqlmock.New or sqlmock.NewWithDSN
// is called, default QueryMatcher is QueryMatcherRegexp.
type QueryMatcher interface {
// Match expected SQL query string without whitespace to
// actual SQL.
Match(expectedSQL, actualSQL string) error
}
// QueryMatcherFunc type is an adapter to allow the use of
// ordinary functions as QueryMatcher. If f is a function
// with the appropriate signature, QueryMatcherFunc(f) is a
// QueryMatcher that calls f.
type QueryMatcherFunc func(expectedSQL, actualSQL string) error
// Match implements the QueryMatcher
func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error {
return f(expectedSQL, actualSQL)
}
// QueryMatcherRegexp is the default SQL query matcher
// used by sqlmock. It parses expectedSQL to a regular
// expression and attempts to match actualSQL.
var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
re, err := regexp.Compile(expect)
if err != nil {
return err
}
if !re.MatchString(actual) {
return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String())
}
return nil
})
// QueryMatcherEqual is the SQL query matcher
// which simply tries a case sensitive match of
// expected and actual SQL strings without whitespace.
var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
if actual != expect {
return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect)
}
return nil
})