Adding feature to allow repeatable expectations
diff --git a/sqlmock.go b/sqlmock.go
index fa7f624..e013164 100644
--- a/sqlmock.go
+++ b/sqlmock.go
@@ -73,13 +73,26 @@
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool)
+
+ // AllowRepeatedExpectationMatching gives an option whether or not to
+ // allow expectations to be matched more than once.
+ //
+ // By default it is set to - false.
+ //
+ // This option may be turned on anytime during tests. As soon
+ // as it is switched to true, expectations will be allowed to match
+ // regardless of it has been previously matched against.
+ //
+ // When setting this true, consider if you will need to set MatchExpectationsInOrder(false)
+ AllowRepeatedExpectationMatching(bool)
}
type sqlmock struct {
- ordered bool
- dsn string
- opened int
- drv *mockDriver
+ ordered bool
+ repeatable bool
+ dsn string
+ opened int
+ drv *mockDriver
expected []expectation
}
@@ -102,6 +115,10 @@
c.ordered = b
}
+func (c *sqlmock) AllowRepeatedExpectationMatching(b bool) {
+ c.repeatable = b
+}
+
// Close a mock database driver connection. It may or may not
// be called depending on the sircumstances, but if it is called
// there must be an *ExpectedClose expectation satisfied.
@@ -121,9 +138,11 @@
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
- next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ next.Unlock()
+ continue
+ }
}
if expected, ok = next.(*ExpectedClose); ok {
@@ -185,7 +204,9 @@
if next.fulfilled() {
next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ continue
+ }
}
if expected, ok = next.(*ExpectedBegin); ok {
@@ -246,7 +267,9 @@
if next.fulfilled() {
next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ continue
+ }
}
if c.ordered {
@@ -322,9 +345,12 @@
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
- next.Unlock()
fulfilled++
- continue
+
+ if !c.repeatable {
+ next.Unlock()
+ continue
+ }
}
if c.ordered {
@@ -401,9 +427,11 @@
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
- next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ next.Unlock()
+ continue
+ }
}
if c.ordered {
@@ -448,6 +476,14 @@
if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
+
+ // reset rows for next use if allowed
+ if rs, ok := expected.rows.(*rowSets); ok {
+ rs.pos = 0
+ for _, set := range rs.sets {
+ set.pos = 0
+ }
+ }
return expected, nil
}
@@ -481,7 +517,9 @@
if next.fulfilled() {
next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ continue
+ }
}
if expected, ok = next.(*ExpectedCommit); ok {
@@ -516,7 +554,9 @@
if next.fulfilled() {
next.Unlock()
fulfilled++
- continue
+ if !c.repeatable {
+ continue
+ }
}
if expected, ok = next.(*ExpectedRollback); ok {
diff --git a/sqlmock_test.go b/sqlmock_test.go
index 9c48d3d..51ce90f 100644
--- a/sqlmock_test.go
+++ b/sqlmock_test.go
@@ -354,6 +354,57 @@
}
}
+func TestPreparedQueryMultipleExecutions(t *testing.T) {
+ t.Parallel()
+ db, mock, err := New()
+ if err != nil {
+ t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer db.Close()
+
+ mock.MatchExpectationsInOrder(false)
+ mock.AllowRepeatedExpectationMatching(true)
+
+ rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
+ mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").ExpectQuery().
+ WithArgs(5).
+ WillReturnRows(rs)
+
+ stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
+ if err != nil {
+ t.Errorf("error '%s' was not expected while creating a prepared statement", err)
+ }
+
+ stmt2, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
+ if err != nil {
+ t.Errorf("error '%s' was not expected while creating a prepared statement", err)
+ }
+
+ var id1, id2 int
+ var title1, title2 string
+ err = stmt.QueryRow(5).Scan(&id1, &title1)
+ if err != nil {
+ t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
+ }
+
+ err = stmt2.QueryRow(5).Scan(&id2, &title2)
+ if err != nil {
+ t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
+ }
+
+ if id1 != 5 || id2 != 5 {
+ t.Errorf("expected mocked id to be 5, but got %d instead", id2)
+ }
+
+ if title1 != "hello world" || title2 != "hello world" {
+ t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title2)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expections: %s", err)
+ }
+}
+
func TestUnorderedPreparedQueryExecutions(t *testing.T) {
t.Parallel()
db, mock, err := New()