Merge pull request #429 from snigle/reflectx
fix reflectx dominant field
diff --git a/.gitignore b/.gitignore
index 529841c..b2be23c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,7 @@
# Folders
_obj
_test
+.idea
# Architecture specific extensions/prefixes
*.[568vq]
diff --git a/.travis.yml b/.travis.yml
index 6bc68d6..d728152 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -18,9 +18,9 @@
# go versions to test
go:
- - "1.8"
- - "1.9"
- "1.10.x"
+ - "1.11.x"
+ - "1.12.x"
# run tests w/ coverage
script:
diff --git a/README.md b/README.md
index c0db7f7..0d71592 100644
--- a/README.md
+++ b/README.md
@@ -15,28 +15,34 @@
* `Get` and `Select` to go quickly from query to struct/slice
In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx),
-there is also some [standard documentation](http://jmoiron.github.io/sqlx/) that
+there is also some [user documentation](http://jmoiron.github.io/sqlx/) that
explains how to use `database/sql` along with sqlx.
## Recent Changes
-* sqlx/types.JsonText has been renamed to JSONText to follow Go naming conventions.
+1.3.0:
-This breaks backwards compatibility, but it's in a way that is trivially fixable
-(`s/JsonText/JSONText/g`). The `types` package is both experimental and not in
-active development currently.
+* `sqlx.DB.Connx(context.Context) *sqlx.Conn`
+* `sqlx.BindDriver(driverName, bindType)`
+* support for `[]map[string]interface{}` to do "batch" insertions
+* allocation & perf improvements for `sqlx.In`
-* Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905).
+DB.Connx returns an `sqlx.Conn`, which is an `sql.Conn`-alike consistent with
+sqlx's wrapping of other types.
+
+`BindDriver` allows users to control the bindvars that sqlx will use for drivers,
+and add new drivers at runtime. This results in a very slight performance hit
+when resolving the driver into a bind type (~40ns per call), but it allows users
+to specify what bindtype their driver uses even when sqlx has not been updated
+to know about it by default.
### Backwards Compatibility
-There is no Go1-like promise of absolute stability, but I take the issue seriously
-and will maintain the library in a compatible state unless vital bugs prevent me
-from doing so. Since [#59](https://github.com/jmoiron/sqlx/issues/59) and
-[#60](https://github.com/jmoiron/sqlx/issues/60) necessitated breaking behavior,
-a wider API cleanup was done at the time of fixing. It's possible this will happen
-in future; if it does, a git tag will be provided for users requiring the old
-behavior to continue to use it until such a time as they can migrate.
+Compatibility with the most recent two versions of Go is a requirement for any
+new changes. Compatibility beyond that is not guaranteed.
+
+Versioning is done with Go modules. Breaking changes (eg. removing deprecated API)
+will get major version number bumps.
## install
@@ -100,7 +106,7 @@
}
func main() {
- // this Pings the database trying to connect, panics on error
+ // this Pings the database trying to connect
// use sqlx.Open() for sql.Open() semantics
db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable")
if err != nil {
@@ -180,6 +186,28 @@
// as the name -> db mapping, so struct fields are lowercased and the `db` tag
// is taken into consideration.
rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason)
+
+
+ // batch insert
+
+ // batch insert with structs
+ personStructs := []Person{
+ {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"},
+ {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"},
+ {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"},
+ }
+
+ _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
+ VALUES (:first_name, :last_name, :email)`, personStructs)
+
+ // batch insert with maps
+ personMaps := []map[string]interface{}{
+ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"},
+ {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"},
+ {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"},
+ }
+
+ _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
+ VALUES (:first_name, :last_name, :email)`, personMaps)
}
```
-
diff --git a/bind.go b/bind.go
index 0fdc443..e521503 100644
--- a/bind.go
+++ b/bind.go
@@ -2,10 +2,12 @@
import (
"bytes"
+ "database/sql/driver"
"errors"
"reflect"
"strconv"
"strings"
+ "sync"
"github.com/jmoiron/sqlx/reflectx"
)
@@ -16,21 +18,39 @@
QUESTION
DOLLAR
NAMED
+ AT
)
+var defaultBinds = map[int][]string{
+ DOLLAR: []string{"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres"},
+ QUESTION: []string{"mysql", "sqlite3", "nrmysql", "nrsqlite3"},
+ NAMED: []string{"oci8", "ora", "goracle"},
+ AT: []string{"sqlserver"},
+}
+
+var binds sync.Map
+
+func init() {
+ for bind, drivers := range defaultBinds {
+ for _, driver := range drivers {
+ BindDriver(driver, bind)
+ }
+ }
+
+}
+
// BindType returns the bindtype for a given database given a drivername.
func BindType(driverName string) int {
- switch driverName {
- case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
- return DOLLAR
- case "mysql":
- return QUESTION
- case "sqlite3":
- return QUESTION
- case "oci8", "ora", "goracle":
- return NAMED
+ itype, ok := binds.Load(driverName)
+ if !ok {
+ return UNKNOWN
}
- return UNKNOWN
+ return itype.(int)
+}
+
+// BindDriver sets the BindType for driverName to bindType.
+func BindDriver(driverName string, bindType int) {
+ binds.Store(driverName, bindType)
}
// FIXME: this should be able to be tolerant of escaped ?'s in queries without
@@ -56,6 +76,8 @@
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
+ case AT:
+ rqb = append(rqb, '@', 'p')
}
j++
@@ -92,6 +114,28 @@
return rqb.String()
}
+func asSliceForIn(i interface{}) (v reflect.Value, ok bool) {
+ if i == nil {
+ return reflect.Value{}, false
+ }
+
+ v = reflect.ValueOf(i)
+ t := reflectx.Deref(v.Type())
+
+ // Only expand slices
+ if t.Kind() != reflect.Slice {
+ return reflect.Value{}, false
+ }
+
+ // []byte is a driver.Value type so it should not be expanded
+ if t == reflect.TypeOf([]byte{}) {
+ return reflect.Value{}, false
+
+ }
+
+ return v, true
+}
+
// In expands slice values in args, returning the modified query string
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
@@ -107,14 +151,25 @@
var flatArgsCount int
var anySlices bool
- meta := make([]argMeta, len(args))
+ var stackMeta [32]argMeta
+
+ var meta []argMeta
+ if len(args) <= len(stackMeta) {
+ meta = stackMeta[:len(args)]
+ } else {
+ meta = make([]argMeta, len(args))
+ }
for i, arg := range args {
- v := reflect.ValueOf(arg)
- t := reflectx.Deref(v.Type())
+ if a, ok := arg.(driver.Valuer); ok {
+ var err error
+ arg, err = a.Value()
+ if err != nil {
+ return "", nil, err
+ }
+ }
- // []byte is a driver.Value type so it should not be expanded
- if t.Kind() == reflect.Slice && t != reflect.TypeOf([]byte{}) {
+ if v, ok := asSliceForIn(arg); ok {
meta[i].length = v.Len()
meta[i].v = v
@@ -137,7 +192,9 @@
}
newArgs := make([]interface{}, 0, flatArgsCount)
- buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount))
+
+ var buf strings.Builder
+ buf.Grow(len(query) + len(", ?")*flatArgsCount)
var arg, offset int
@@ -192,11 +249,11 @@
args = append(args, val...)
case []int:
for i := range val {
- args = append(args, val[i])
+ args = append(args, &val[i])
}
case []string:
for i := range val {
- args = append(args, val[i])
+ args = append(args, &val[i])
}
default:
for si := 0; si < vlen; si++ {
diff --git a/bind_test.go b/bind_test.go
new file mode 100644
index 0000000..dfa590e
--- /dev/null
+++ b/bind_test.go
@@ -0,0 +1,79 @@
+package sqlx
+
+import (
+ "math/rand"
+ "testing"
+)
+
+func oldBindType(driverName string) int {
+ switch driverName {
+ case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql":
+ return DOLLAR
+ case "mysql":
+ return QUESTION
+ case "sqlite3":
+ return QUESTION
+ case "oci8", "ora", "goracle", "godror":
+ return NAMED
+ case "sqlserver":
+ return AT
+ }
+ return UNKNOWN
+}
+
+/*
+sync.Map implementation:
+
+goos: linux
+goarch: amd64
+pkg: github.com/jmoiron/sqlx
+BenchmarkBindSpeed/old-4 100000000 11.0 ns/op
+BenchmarkBindSpeed/new-4 24575726 50.8 ns/op
+
+
+async.Value map implementation:
+
+goos: linux
+goarch: amd64
+pkg: github.com/jmoiron/sqlx
+BenchmarkBindSpeed/old-4 100000000 11.0 ns/op
+BenchmarkBindSpeed/new-4 42535839 27.5 ns/op
+*/
+
+func BenchmarkBindSpeed(b *testing.B) {
+ testDrivers := []string{
+ "postgres", "pgx", "mysql", "sqlite3", "ora", "sqlserver",
+ }
+
+ b.Run("old", func(b *testing.B) {
+ b.StopTimer()
+ var seq []int
+ for i := 0; i < b.N; i++ {
+ seq = append(seq, rand.Intn(len(testDrivers)))
+ }
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ s := oldBindType(testDrivers[seq[i]])
+ if s == UNKNOWN {
+ b.Error("unknown driver")
+ }
+ }
+
+ })
+
+ b.Run("new", func(b *testing.B) {
+ b.StopTimer()
+ var seq []int
+ for i := 0; i < b.N; i++ {
+ seq = append(seq, rand.Intn(len(testDrivers)))
+ }
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ s := BindType(testDrivers[seq[i]])
+ if s == UNKNOWN {
+ b.Error("unknown driver")
+ }
+ }
+
+ })
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..53ef975
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,9 @@
+module github.com/jmoiron/sqlx
+
+go 1.10
+
+require (
+ github.com/go-sql-driver/mysql v1.5.0
+ github.com/lib/pq v1.2.0
+ github.com/mattn/go-sqlite3 v1.14.6
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..4db3f25
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,6 @@
+github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
+github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
+github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
+github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
+github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
+github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
diff --git a/named.go b/named.go
index 69eb954..276ed56 100644
--- a/named.go
+++ b/named.go
@@ -12,10 +12,12 @@
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
+ "bytes"
"database/sql"
"errors"
"fmt"
"reflect"
+ "regexp"
"strconv"
"unicode"
@@ -144,8 +146,22 @@
}, nil
}
+// convertMapStringInterface attempts to convert v to map[string]interface{}.
+// Unlike v.(map[string]interface{}), this function works on named types that
+// are convertible to map[string]interface{} as well.
+func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) {
+ var m map[string]interface{}
+ mtype := reflect.TypeOf(m)
+ t := reflect.TypeOf(v)
+ if !t.ConvertibleTo(mtype) {
+ return nil, false
+ }
+ return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true
+
+}
+
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
- if maparg, ok := arg.(map[string]interface{}); ok {
+ if maparg, ok := convertMapStringInterface(arg); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg, m)
@@ -200,7 +216,7 @@
return "", []interface{}{}, err
}
- arglist, err := bindArgs(names, arg, m)
+ arglist, err := bindAnyArgs(names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
@@ -208,6 +224,56 @@
return bound, arglist, nil
}
+var valueBracketReg = regexp.MustCompile(`\([^(]*.[^(]\)$`)
+
+func fixBound(bound string, loop int) string {
+ loc := valueBracketReg.FindStringIndex(bound)
+ if len(loc) != 2 {
+ return bound
+ }
+ var buffer bytes.Buffer
+
+ buffer.WriteString(bound[0:loc[1]])
+ for i := 0; i < loop-1; i++ {
+ buffer.WriteString(",")
+ buffer.WriteString(bound[loc[0]:loc[1]])
+ }
+ buffer.WriteString(bound[loc[1]:])
+ return buffer.String()
+}
+
+// bindArray binds a named parameter query with fields from an array or slice of
+// structs argument.
+func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
+ // do the initial binding with QUESTION; if bindType is not question,
+ // we can rebind it at the end.
+ bound, names, err := compileNamedQuery([]byte(query), QUESTION)
+ if err != nil {
+ return "", []interface{}{}, err
+ }
+ arrayValue := reflect.ValueOf(arg)
+ arrayLen := arrayValue.Len()
+ if arrayLen == 0 {
+ return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg)
+ }
+ var arglist = make([]interface{}, 0, len(names)*arrayLen)
+ for i := 0; i < arrayLen; i++ {
+ elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m)
+ if err != nil {
+ return "", []interface{}{}, err
+ }
+ arglist = append(arglist, elemArglist...)
+ }
+ if arrayLen > 1 {
+ bound = fixBound(bound, arrayLen)
+ }
+ // adjust binding type if we weren't on question
+ if bindType != QUESTION {
+ bound = Rebind(bindType, bound)
+ }
+ return bound, arglist, nil
+}
+
// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
@@ -259,6 +325,10 @@
}
inName = true
name = []byte{}
+ } else if inName && i > 0 && b == '=' && len(name) == 0 {
+ rebound = append(rebound, ':', '=')
+ inName = false
+ continue
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
// append the byte to the name if we are in a name and not on the last byte
@@ -287,6 +357,12 @@
rebound = append(rebound, byte(b))
}
currentVar++
+ case AT:
+ rebound = append(rebound, '@', 'p')
+ for _, b := range strconv.Itoa(currentVar) {
+ rebound = append(rebound, byte(b))
+ }
+ currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
@@ -317,10 +393,20 @@
}
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
- if maparg, ok := arg.(map[string]interface{}); ok {
- return bindMap(bindType, query, maparg)
+ t := reflect.TypeOf(arg)
+ k := t.Kind()
+ switch {
+ case k == reflect.Map && t.Key().Kind() == reflect.String:
+ m, ok := convertMapStringInterface(arg)
+ if !ok {
+ return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg)
+ }
+ return bindMap(bindType, query, m)
+ case k == reflect.Array || k == reflect.Slice:
+ return bindArray(bindType, query, arg, m)
+ default:
+ return bindStruct(bindType, query, arg, m)
}
- return bindStruct(bindType, query, arg, m)
}
// NamedQuery binds a named query and then runs Query on the result using the
@@ -336,7 +422,7 @@
// NamedExec uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
-// or the query excution itself.
+// or the query execution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
diff --git a/named_context.go b/named_context.go
index 9405007..07ad216 100644
--- a/named_context.go
+++ b/named_context.go
@@ -122,7 +122,7 @@
// NamedExecContext uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
-// or the query excution itself.
+// or the query execution itself.
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
diff --git a/named_context_test.go b/named_context_test.go
index 87e94ac..fd1d851 100644
--- a/named_context_test.go
+++ b/named_context_test.go
@@ -9,7 +9,7 @@
)
func TestNamedContextQueries(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
test := Test{t}
var ns *NamedStmt
diff --git a/named_test.go b/named_test.go
index d3459a8..24b725b 100644
--- a/named_test.go
+++ b/named_test.go
@@ -2,19 +2,21 @@
import (
"database/sql"
+ "fmt"
"testing"
)
func TestCompileQuery(t *testing.T) {
table := []struct {
- Q, R, D, N string
- V []string
+ Q, R, D, T, N string
+ V []string
}{
// basic test for named parameters, invalid char ',' terminating
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
+ T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
V: []string{"name", "age", "first", "last"},
},
@@ -23,6 +25,7 @@
Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT * FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`,
+ T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
@@ -30,6 +33,7 @@
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
+ T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
@@ -37,9 +41,18 @@
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
+ T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
+ {
+ Q: `SELECT @name := "name", :age, :first, :last`,
+ R: `SELECT @name := "name", ?, ?, ?`,
+ D: `SELECT @name := "name", $1, $2, $3`,
+ N: `SELECT @name := "name", :age, :first, :last`,
+ T: `SELECT @name := "name", @p1, @p2, @p3`,
+ V: []string{"age", "first", "last"},
+ },
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
* We could certainly iterate by Rune instead, though it's a great deal slower,
* it's probably the RightWay(tm)
@@ -74,6 +87,11 @@
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd)
}
+ qt, _, _ := compileNamedQuery([]byte(test.Q), AT)
+ if qt != test.T {
+ t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt)
+ }
+
qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
if qq != test.N {
t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
@@ -101,8 +119,18 @@
}
}
+func TestEscapedColons(t *testing.T) {
+ t.Skip("not sure it is possible to support this in general case without an SQL parser")
+ var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND
+ (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id`
+ _, _, err := compileNamedQuery([]byte(qs), DOLLAR)
+ if err != nil {
+ t.Error("Didn't handle colons correctly when inside a string")
+ }
+}
+
func TestNamedQueries(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
test := Test{t}
var ns *NamedStmt
@@ -167,6 +195,49 @@
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
}
+ // test struct batch inserts
+ sls := []Person{
+ {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"},
+ {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"},
+ {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"},
+ }
+
+ insert := fmt.Sprintf("INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)", now)
+ _, err = db.NamedExec(insert, sls)
+ test.Error(err)
+
+ // test map batch inserts
+ slsMap := []map[string]interface{}{
+ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"},
+ {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"},
+ {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"},
+ }
+
+ _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
+ VALUES (:first_name, :last_name, :email)`, slsMap)
+ test.Error(err)
+
+ type A map[string]interface{}
+
+ typedMap := []A{
+ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"},
+ {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"},
+ {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"},
+ }
+
+ _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
+ VALUES (:first_name, :last_name, :email)`, typedMap)
+ test.Error(err)
+
+ for _, p := range sls {
+ dest := Person{}
+ err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=?"), p.Email)
+ test.Error(err)
+ if dest.Email != p.Email {
+ t.Errorf("expected %s, got %s", p.Email, dest.Email)
+ }
+ }
+
// test Exec
ns, err = db.PrepareNamed(`
INSERT INTO person (first_name, last_name, email)
diff --git a/reflectx/reflect.go b/reflectx/reflect.go
index 05c3abb..0b10994 100644
--- a/reflectx/reflect.go
+++ b/reflectx/reflect.go
@@ -269,9 +269,7 @@
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
- for p, n := range is {
- x[p] = n
- }
+ copy(x, is)
x[len(x)-1] = i
return x
}
diff --git a/sqlx.go b/sqlx.go
index 4385c3f..112ef70 100644
--- a/sqlx.go
+++ b/sqlx.go
@@ -64,11 +64,7 @@
// it's not important that we use the right mapper for this particular object,
// we're only concerned on how many exported fields this struct has
- m := mapper()
- if len(m.TypeMap(t).Index) == 0 {
- return true
- }
- return false
+ return len(mapper().TypeMap(t).Index) == 0
}
// ColScanner is an interface used by MapScan and SliceScan
@@ -149,15 +145,15 @@
}
func mapperFor(i interface{}) *reflectx.Mapper {
- switch i.(type) {
+ switch i := i.(type) {
case DB:
- return i.(DB).Mapper
+ return i.Mapper
case *DB:
- return i.(*DB).Mapper
+ return i.Mapper
case Tx:
- return i.(Tx).Mapper
+ return i.Mapper
case *Tx:
- return i.(*Tx).Mapper
+ return i.Mapper
default:
return mapper()
}
@@ -380,6 +376,14 @@
return prepareNamed(db, query)
}
+// Conn is a wrapper around sql.Conn with extra functionality
+type Conn struct {
+ *sql.Conn
+ driverName string
+ unsafe bool
+ Mapper *reflectx.Mapper
+}
+
// Tx is an sqlx wrapper around sql.Tx with extra functionality
type Tx struct {
*sql.Tx
@@ -471,8 +475,6 @@
s = v.Stmt
case *Stmt:
s = v.Stmt
- case sql.Stmt:
- s = &v
case *sql.Stmt:
s = v
default:
diff --git a/sqlx_context.go b/sqlx_context.go
index d58ff33..7aa4dd0 100644
--- a/sqlx_context.go
+++ b/sqlx_context.go
@@ -208,6 +208,74 @@
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
+// Connx returns an *sqlx.Conn instead of an *sql.Conn.
+func (db *DB) Connx(ctx context.Context) (*Conn, error) {
+ conn, err := db.DB.Conn(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil
+}
+
+// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
+// *sql.Tx.
+//
+// The provided context is used until the transaction is committed or rolled
+// back. If the context is canceled, the sql package will roll back the
+// transaction. Tx.Commit will return an error if the context provided to
+// BeginxContext is canceled.
+func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
+ tx, err := c.Conn.BeginTx(ctx, opts)
+ if err != nil {
+ return nil, err
+ }
+ return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err
+}
+
+// SelectContext using this Conn.
+// Any placeholder parameters are replaced with supplied args.
+func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ return SelectContext(ctx, c, dest, query, args...)
+}
+
+// GetContext using this Conn.
+// Any placeholder parameters are replaced with supplied args.
+// An error is returned if the result set is empty.
+func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ return GetContext(ctx, c, dest, query, args...)
+}
+
+// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
+//
+// The provided context is used for the preparation of the statement, not for
+// the execution of the statement.
+func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
+ return PreparexContext(ctx, c, query)
+}
+
+// QueryxContext queries the database and returns an *sqlx.Rows.
+// Any placeholder parameters are replaced with supplied args.
+func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ r, err := c.Conn.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+ return &Rows{Rows: r, unsafe: c.unsafe, Mapper: c.Mapper}, err
+}
+
+// QueryRowxContext queries the database and returns an *sqlx.Row.
+// Any placeholder parameters are replaced with supplied args.
+func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := c.Conn.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper}
+}
+
+// Rebind a query within a Conn's bindvar type.
+func (c *Conn) Rebind(query string) string {
+ return Rebind(BindType(c.driverName), query)
+}
+
// StmtxContext returns a version of the prepared statement which runs within a
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt {
@@ -217,8 +285,6 @@
s = v.Stmt
case *Stmt:
s = v.Stmt
- case sql.Stmt:
- s = &v
case *sql.Stmt:
s = v
default:
diff --git a/sqlx_context_test.go b/sqlx_context_test.go
index 85e112b..e49ab8b 100644
--- a/sqlx_context_test.go
+++ b/sqlx_context_test.go
@@ -42,7 +42,7 @@
}
func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) {
- runner := func(ctx context.Context, db *DB, t *testing.T, create, drop string) {
+ runner := func(ctx context.Context, db *DB, t *testing.T, create, drop, now string) {
defer func() {
MultiExecContext(ctx, db, drop)
}()
@@ -52,16 +52,16 @@
}
if TestPostgres {
- create, drop := schema.Postgres()
- runner(ctx, pgdb, t, create, drop)
+ create, drop, now := schema.Postgres()
+ runner(ctx, pgdb, t, create, drop, now)
}
if TestSqlite {
- create, drop := schema.Sqlite3()
- runner(ctx, sldb, t, create, drop)
+ create, drop, now := schema.Sqlite3()
+ runner(ctx, sldb, t, create, drop, now)
}
if TestMysql {
- create, drop := schema.MySQL()
- runner(ctx, mysqldb, t, create, drop)
+ create, drop, now := schema.MySQL()
+ runner(ctx, mysqldb, t, create, drop, now)
}
}
@@ -984,7 +984,7 @@
person := &Person{}
err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist")
if err == nil {
- t.Fatal("Should have got an error for Get on non-existant row.")
+ t.Fatal("Should have got an error for Get on non-existent row.")
}
// lets test prepared statements some more
@@ -1342,3 +1342,85 @@
}
})
}
+
+func TestConn(t *testing.T) {
+ var schema = Schema{
+ create: `
+ CREATE TABLE tt_conn (
+ id integer,
+ value text NULL DEFAULT NULL
+ );`,
+ drop: "drop table tt_conn;",
+ }
+
+ RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) {
+ conn, err := db.Connx(ctx)
+ defer conn.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = conn.ExecContext(ctx, conn.Rebind(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`), 1, "a", 2, "b")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type s struct {
+ ID int `db:"id"`
+ Value string `db:"value"`
+ }
+
+ v := []s{}
+
+ err = conn.SelectContext(ctx, &v, "SELECT * FROM tt_conn ORDER BY id ASC")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if v[0].ID != 1 {
+ t.Errorf("Expecting ID of 1, got %d", v[0].ID)
+ }
+
+ v1 := s{}
+ err = conn.GetContext(ctx, &v1, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"), 1)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ if v1.ID != 1 {
+ t.Errorf("Expecting to get back 1, but got %v\n", v1.ID)
+ }
+
+ stmt, err := conn.PreparexContext(ctx, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ v1 = s{}
+ tx, err := conn.BeginTxx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tstmt := tx.Stmtx(stmt)
+ row := tstmt.QueryRowx(1)
+ err = row.StructScan(&v1)
+ if err != nil {
+ t.Error(err)
+ }
+ tx.Commit()
+ if v1.ID != 1 {
+ t.Errorf("Expecting to get back 1, but got %v\n", v1.ID)
+ }
+
+ rows, err := conn.QueryxContext(ctx, "SELECT * FROM tt_conn")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for rows.Next() {
+ err = rows.StructScan(&v1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ })
+}
diff --git a/sqlx_test.go b/sqlx_test.go
index e26c980..5ea754a 100644
--- a/sqlx_test.go
+++ b/sqlx_test.go
@@ -98,16 +98,16 @@
drop string
}
-func (s Schema) Postgres() (string, string) {
- return s.create, s.drop
+func (s Schema) Postgres() (string, string, string) {
+ return s.create, s.drop, `now()`
}
-func (s Schema) MySQL() (string, string) {
- return strings.Replace(s.create, `"`, "`", -1), s.drop
+func (s Schema) MySQL() (string, string, string) {
+ return strings.Replace(s.create, `"`, "`", -1), s.drop, `now()`
}
-func (s Schema) Sqlite3() (string, string) {
- return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop
+func (s Schema) Sqlite3() (string, string, string) {
+ return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop, `CURRENT_TIMESTAMP`
}
var defaultSchema = Schema{
@@ -218,27 +218,27 @@
}
}
-func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T)) {
- runner := func(db *DB, t *testing.T, create, drop string) {
+func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T, now string)) {
+ runner := func(db *DB, t *testing.T, create, drop, now string) {
defer func() {
MultiExec(db, drop)
}()
MultiExec(db, create)
- test(db, t)
+ test(db, t, now)
}
if TestPostgres {
- create, drop := schema.Postgres()
- runner(pgdb, t, create, drop)
+ create, drop, now := schema.Postgres()
+ runner(pgdb, t, create, drop, now)
}
if TestSqlite {
- create, drop := schema.Sqlite3()
- runner(sldb, t, create, drop)
+ create, drop, now := schema.Sqlite3()
+ runner(sldb, t, create, drop, now)
}
if TestMysql {
- create, drop := schema.MySQL()
- runner(mysqldb, t, create, drop)
+ create, drop, now := schema.MySQL()
+ runner(mysqldb, t, create, drop, now)
}
}
@@ -263,7 +263,7 @@
// Test a new backwards compatible feature, that missing scan destinations
// will silently scan into sql.RawText rather than failing/panicing
func TestMissingNames(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
type PersonPlus struct {
FirstName string `db:"first_name"`
@@ -383,7 +383,7 @@
type Loop2 struct{ Loop1 }
type Loop3 struct{ Loop2 }
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
peopleAndPlaces := []PersonPlace{}
err := db.Select(
@@ -476,7 +476,7 @@
}
type Boss Employee
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
var employees []struct {
@@ -512,7 +512,7 @@
}
type Boss Employee
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
var employees []struct {
@@ -544,7 +544,7 @@
}
func TestSelectSliceMapTime(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
rows, err := db.Queryx("SELECT * FROM person")
if err != nil {
@@ -573,7 +573,7 @@
}
func TestNilReceiver(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
var p *Person
err := db.Get(p, "SELECT * FROM person LIMIT 1")
@@ -619,7 +619,7 @@
`,
}
- RunWithSchema(schema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
type Person struct {
FirstName sql.NullString `db:"first_name"`
LastName sql.NullString `db:"last_name"`
@@ -829,7 +829,7 @@
drop: "drop table tt;",
}
- RunWithSchema(schema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
type TT struct {
ID int
Value *string
@@ -874,7 +874,7 @@
drop: `drop table kv;`,
}
- RunWithSchema(schema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
type WrongTypes struct {
K int
V string
@@ -898,11 +898,22 @@
})
}
+func TestMultiInsert(t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
+ loadDefaultFixture(db, t)
+ q := db.Rebind(`INSERT INTO employees (name, id) VALUES (?, ?), (?, ?);`)
+ db.MustExec(q,
+ "Name1", 400,
+ "name2", 500,
+ )
+ })
+}
+
// FIXME: this function is kinda big but it slows things down to be constantly
// loading and reloading the schema..
func TestUsage(t *testing.T) {
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
slicemembers := []SliceMember{}
err := db.Select(&slicemembers, "SELECT * FROM place ORDER BY telcode ASC")
@@ -1158,7 +1169,7 @@
person := &Person{}
err = db.Get(person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist")
if err == nil {
- t.Fatal("Should have got an error for Get on non-existant row.")
+ t.Fatal("Should have got an error for Get on non-existent row.")
}
// lets test prepared statements some more
@@ -1320,6 +1331,17 @@
t.Errorf("q2 failed")
}
+ s1 = Rebind(AT, q1)
+ s2 = Rebind(AT, q2)
+
+ if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)` {
+ t.Errorf("q1 failed")
+ }
+
+ if s2 != `INSERT INTO foo (a, b, c) VALUES (@p1, @p2, "foo"), ("Hi", @p3, @p4)` {
+ t.Errorf("q2 failed")
+ }
+
s1 = Rebind(NAMED, q1)
s2 = Rebind(NAMED, q2)
@@ -1410,7 +1432,7 @@
drop: `drop table message;`,
}
- RunWithSchema(schema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
messages := []Message{
{"Hello, World", PropertyMap{"one": "1", "two": "2"}},
{"Thanks, Joy", PropertyMap{"pull": "request"}},
@@ -1454,7 +1476,7 @@
type Var struct{ Raw json.RawMessage }
type Var2 struct{ Raw []byte }
type Var3 struct{ Raw mybyte }
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
var err error
var v, q Var
if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil {
@@ -1500,6 +1522,9 @@
{"SELECT * FROM foo WHERE x = ? AND y in (?)",
[]interface{}{[]byte("foo"), []int{0, 5, 3}},
4},
+ {"SELECT * FROM foo WHERE x = ? AND y IN (?)",
+ []interface{}{sql.NullString{Valid: false}, []string{"a", "b"}},
+ 3},
}
for _, test := range tests {
q, a, err := In(test.q, test.args...)
@@ -1551,7 +1576,7 @@
t.Error("Expected an error, but got nil.")
}
}
- RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) {
loadDefaultFixture(db, t)
//tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1")
//tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852")
@@ -1675,7 +1700,7 @@
drop: `drop table x;`,
}
- RunWithSchema(schema, t, func(db *DB, t *testing.T) {
+ RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
type t1 struct {
K *string
}
@@ -1724,6 +1749,33 @@
}
}
+func TestBindNamedMapper(t *testing.T) {
+ type A map[string]interface{}
+ m := reflectx.NewMapperFunc("db", NameMapper)
+ query, args, err := bindNamedMapper(DOLLAR, `select :x`, A{
+ "x": "X!",
+ }, m)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got := fmt.Sprintf("%s %s", query, args)
+ want := `select $1 [X!]`
+ if got != want {
+ t.Errorf("\ngot: %q\nwant: %q", got, want)
+ }
+
+ _, _, err = bindNamedMapper(DOLLAR, `select :x`, map[string]string{
+ "x": "X!",
+ }, m)
+ if err == nil {
+ t.Fatal("err is nil")
+ }
+ if !strings.Contains(err.Error(), "unsupported map type") {
+ t.Errorf("wrong error: %s", err)
+ }
+}
+
func BenchmarkBindMap(b *testing.B) {
b.StopTimer()
q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`
diff --git a/types/types.go b/types/types.go
index 7b014c1..e824750 100644
--- a/types/types.go
+++ b/types/types.go
@@ -30,11 +30,11 @@
// the wire and storing the raw result in the GzippedText.
func (g *GzippedText) Scan(src interface{}) error {
var source []byte
- switch src.(type) {
+ switch src := src.(type) {
case string:
- source = []byte(src.(string))
+ source = []byte(src)
case []byte:
- source = src.([]byte)
+ source = src
default:
return errors.New("Incompatible type for GzippedText")
}