Merge pull request #340 from DATA-DOG/feature/316
Add WithTXOption expectation to ExpectBegin
diff --git a/expectations.go b/expectations.go
index 8a6cd44..28d11ac 100644
--- a/expectations.go
+++ b/expectations.go
@@ -1,6 +1,7 @@
package sqlmock
import (
+ "database/sql"
"database/sql/driver"
"fmt"
"strings"
@@ -53,7 +54,8 @@
// returned by *Sqlmock.ExpectBegin.
type ExpectedBegin struct {
commonExpectation
- delay time.Duration
+ delay time.Duration
+ txOpts *driver.TxOptions
}
// WillReturnError allows to set an error for *sql.DB.Begin action
@@ -65,6 +67,9 @@
// String returns string representation
func (e *ExpectedBegin) String() string {
msg := "ExpectedBegin => expecting database transaction Begin"
+ if e.txOpts != nil {
+ msg += fmt.Sprintf(", with tx options: %+v", e.txOpts)
+ }
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
@@ -78,6 +83,15 @@
return e
}
+// WithTxOptions allows to set transaction options for *sql.DB.Begin action
+func (e *ExpectedBegin) WithTxOptions(opts sql.TxOptions) *ExpectedBegin {
+ e.txOpts = &driver.TxOptions{
+ Isolation: driver.IsolationLevel(opts.Isolation),
+ ReadOnly: opts.ReadOnly,
+ }
+ return e
+}
+
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit.
type ExpectedCommit struct {
diff --git a/sqlmock.go b/sqlmock.go
index 3ee1256..a4f8f35 100644
--- a/sqlmock.go
+++ b/sqlmock.go
@@ -213,7 +213,7 @@
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) {
- ex, err := c.begin()
+ ex, err := c.begin(driver.TxOptions{})
if ex != nil {
time.Sleep(ex.delay)
}
@@ -224,7 +224,7 @@
return c, nil
}
-func (c *sqlmock) begin() (*ExpectedBegin, error) {
+func (c *sqlmock) begin(opts driver.TxOptions) (*ExpectedBegin, error) {
var expected *ExpectedBegin
var ok bool
var fulfilled int
@@ -252,9 +252,14 @@
}
return nil, fmt.Errorf(msg)
}
+ defer expected.Unlock()
+ if expected.txOpts != nil &&
+ expected.txOpts.Isolation != opts.Isolation &&
+ expected.txOpts.ReadOnly != opts.ReadOnly {
+ return nil, fmt.Errorf("expected transaction options do not match: %+v, got: %+v", expected.txOpts, opts)
+ }
expected.triggered = true
- expected.Unlock()
return expected, expected.err
}
diff --git a/sqlmock_go18.go b/sqlmock_go18.go
index f268900..9644958 100644
--- a/sqlmock_go18.go
+++ b/sqlmock_go18.go
@@ -1,3 +1,4 @@
+//go:build go1.8
// +build go1.8
package sqlmock
@@ -66,7 +67,7 @@
// Implement the "ConnBeginTx" interface
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
- ex, err := c.begin()
+ ex, err := c.begin(opts)
if ex != nil {
select {
case <-time.After(ex.delay):
diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go
index 6267f38..ddc7306 100644
--- a/sqlmock_go18_test.go
+++ b/sqlmock_go18_test.go
@@ -360,6 +360,66 @@
}
}
+func TestContextBeginWithTxOptions(t *testing.T) {
+ t.Parallel()
+ db, mock, err := New()
+ if err != nil {
+ t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ mock.ExpectBegin().WithTxOptions(sql.TxOptions{
+ Isolation: sql.LevelReadCommitted,
+ ReadOnly: true,
+ })
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ go func() {
+ time.Sleep(time.Millisecond * 10)
+ cancel()
+ }()
+
+ _, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: false})
+ if err != nil {
+ t.Errorf("error was not expected, but got: %v", err)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestContextBeginWithTxOptionsMismatch(t *testing.T) {
+ t.Parallel()
+ db, mock, err := New()
+ if err != nil {
+ t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ mock.ExpectBegin().WithTxOptions(sql.TxOptions{
+ Isolation: sql.LevelReadCommitted,
+ ReadOnly: true,
+ })
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ go func() {
+ time.Sleep(time.Millisecond * 10)
+ cancel()
+ }()
+
+ _, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: false})
+ if err == nil {
+ t.Error("error was expected, but there was none")
+ }
+
+ if err := mock.ExpectationsWereMet(); err == nil {
+ t.Errorf("was expecting an error, as the tx options did not match, but there wasn't one")
+ }
+}
+
func TestContextPrepareCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()