| package sqlmock |
| |
| import ( |
| "bytes" |
| "database/sql/driver" |
| "encoding/csv" |
| "fmt" |
| "io" |
| "strings" |
| ) |
| |
| const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ " |
| |
| // CSVColumnParser is a function which converts trimmed csv |
| // column string to a []byte representation. Currently |
| // transforms NULL to nil |
| var CSVColumnParser = func(s string) []byte { |
| switch { |
| case strings.ToLower(s) == "null": |
| return nil |
| } |
| return []byte(s) |
| } |
| |
| type rowSets struct { |
| sets []*Rows |
| pos int |
| ex *ExpectedQuery |
| raw [][]byte |
| } |
| |
| func (rs *rowSets) Columns() []string { |
| return rs.sets[rs.pos].cols |
| } |
| |
| func (rs *rowSets) Close() error { |
| rs.invalidateRaw() |
| rs.ex.rowsWereClosed = true |
| return rs.sets[rs.pos].closeErr |
| } |
| |
| // advances to next row |
| func (rs *rowSets) Next(dest []driver.Value) error { |
| r := rs.sets[rs.pos] |
| r.pos++ |
| rs.invalidateRaw() |
| if r.pos > len(r.rows) { |
| return io.EOF // per interface spec |
| } |
| |
| for i, col := range r.rows[r.pos-1] { |
| if b, ok := rawBytes(col); ok { |
| rs.raw = append(rs.raw, b) |
| dest[i] = b |
| continue |
| } |
| dest[i] = col |
| } |
| |
| return r.nextErr[r.pos-1] |
| } |
| |
| // transforms to debuggable printable string |
| func (rs *rowSets) String() string { |
| if rs.empty() { |
| return "with empty rows" |
| } |
| |
| msg := "should return rows:\n" |
| if len(rs.sets) == 1 { |
| for n, row := range rs.sets[0].rows { |
| msg += fmt.Sprintf(" row %d - %+v\n", n, row) |
| } |
| return strings.TrimSpace(msg) |
| } |
| for i, set := range rs.sets { |
| msg += fmt.Sprintf(" result set: %d\n", i) |
| for n, row := range set.rows { |
| msg += fmt.Sprintf(" row %d - %+v\n", n, row) |
| } |
| } |
| return strings.TrimSpace(msg) |
| } |
| |
| func (rs *rowSets) empty() bool { |
| for _, set := range rs.sets { |
| if len(set.rows) > 0 { |
| return false |
| } |
| } |
| return true |
| } |
| |
| func rawBytes(col driver.Value) (_ []byte, ok bool) { |
| val, ok := col.([]byte) |
| if !ok || len(val) == 0 { |
| return nil, false |
| } |
| // Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later |
| // This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close() |
| b := make([]byte, len(val)) |
| copy(b, val) |
| return b, true |
| } |
| |
| // Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close. |
| // If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes |
| func (rs *rowSets) invalidateRaw() { |
| // Replace the content of slices previously returned |
| b := []byte(invalidate) |
| for _, r := range rs.raw { |
| copy(r, bytes.Repeat(b, len(r)/len(b)+1)) |
| } |
| // Start with new slices for the next scan |
| rs.raw = nil |
| } |
| |
| // Rows is a mocked collection of rows to |
| // return for Query result |
| type Rows struct { |
| converter driver.ValueConverter |
| cols []string |
| rows [][]driver.Value |
| pos int |
| nextErr map[int]error |
| closeErr error |
| } |
| |
| // NewRows allows Rows to be created from a |
| // sql driver.Value slice or from the CSV string and |
| // to be used as sql driver.Rows. |
| // Use Sqlmock.NewRows instead if using a custom converter |
| func NewRows(columns []string) *Rows { |
| return &Rows{ |
| cols: columns, |
| nextErr: make(map[int]error), |
| converter: driver.DefaultParameterConverter, |
| } |
| } |
| |
| // CloseError allows to set an error |
| // which will be returned by rows.Close |
| // function. |
| // |
| // The close error will be triggered only in cases |
| // when rows.Next() EOF was not yet reached, that is |
| // a default sql library behavior |
| func (r *Rows) CloseError(err error) *Rows { |
| r.closeErr = err |
| return r |
| } |
| |
| // RowError allows to set an error |
| // which will be returned when a given |
| // row number is read |
| func (r *Rows) RowError(row int, err error) *Rows { |
| r.nextErr[row] = err |
| return r |
| } |
| |
| // AddRow composed from database driver.Value slice |
| // return the same instance to perform subsequent actions. |
| // Note that the number of values must match the number |
| // of columns |
| func (r *Rows) AddRow(values ...driver.Value) *Rows { |
| if len(values) != len(r.cols) { |
| panic("Expected number of values to match number of columns") |
| } |
| |
| row := make([]driver.Value, len(r.cols)) |
| for i, v := range values { |
| // Convert user-friendly values (such as int or driver.Valuer) |
| // to database/sql native value (driver.Value such as int64) |
| var err error |
| v, err = r.converter.ConvertValue(v) |
| if err != nil { |
| panic(fmt.Errorf( |
| "row #%d, column #%d (%q) type %T: %s", |
| len(r.rows)+1, i, r.cols[i], values[i], err, |
| )) |
| } |
| |
| row[i] = v |
| } |
| |
| r.rows = append(r.rows, row) |
| return r |
| } |
| |
| // FromCSVString build rows from csv string. |
| // return the same instance to perform subsequent actions. |
| // Note that the number of values must match the number |
| // of columns |
| func (r *Rows) FromCSVString(s string) *Rows { |
| res := strings.NewReader(strings.TrimSpace(s)) |
| csvReader := csv.NewReader(res) |
| |
| for { |
| res, err := csvReader.Read() |
| if err != nil || res == nil { |
| break |
| } |
| |
| row := make([]driver.Value, len(r.cols)) |
| for i, v := range res { |
| row[i] = CSVColumnParser(strings.TrimSpace(v)) |
| } |
| r.rows = append(r.rows, row) |
| } |
| return r |
| } |