| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package |
| // |
| // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. |
| // |
| // This Source Code Form is subject to the terms of the Mozilla Public |
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| // You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| package mysql |
| |
| import ( |
| "database/sql/driver" |
| "io" |
| "net" |
| "strconv" |
| "strings" |
| "time" |
| ) |
| |
| type mysqlConn struct { |
| buf buffer |
| netConn net.Conn |
| affectedRows uint64 |
| insertId uint64 |
| cfg *Config |
| maxAllowedPacket int |
| maxWriteSize int |
| writeTimeout time.Duration |
| flags clientFlag |
| status statusFlag |
| sequence uint8 |
| parseTime bool |
| strict bool |
| } |
| |
| // Handles parameters set in DSN after the connection is established |
| func (mc *mysqlConn) handleParams() (err error) { |
| for param, val := range mc.cfg.Params { |
| switch param { |
| // Charset |
| case "charset": |
| charsets := strings.Split(val, ",") |
| for i := range charsets { |
| // ignore errors here - a charset may not exist |
| err = mc.exec("SET NAMES " + charsets[i]) |
| if err == nil { |
| break |
| } |
| } |
| if err != nil { |
| return |
| } |
| |
| // System Vars |
| default: |
| err = mc.exec("SET " + param + "=" + val + "") |
| if err != nil { |
| return |
| } |
| } |
| } |
| |
| return |
| } |
| |
| func (mc *mysqlConn) Begin() (driver.Tx, error) { |
| if mc.netConn == nil { |
| errLog.Print(ErrInvalidConn) |
| return nil, driver.ErrBadConn |
| } |
| err := mc.exec("START TRANSACTION") |
| if err == nil { |
| return &mysqlTx{mc}, err |
| } |
| |
| return nil, err |
| } |
| |
| func (mc *mysqlConn) Close() (err error) { |
| // Makes Close idempotent |
| if mc.netConn != nil { |
| err = mc.writeCommandPacket(comQuit) |
| } |
| |
| mc.cleanup() |
| |
| return |
| } |
| |
| // Closes the network connection and unsets internal variables. Do not call this |
| // function after successfully authentication, call Close instead. This function |
| // is called before auth or on auth failure because MySQL will have already |
| // closed the network connection. |
| func (mc *mysqlConn) cleanup() { |
| // Makes cleanup idempotent |
| if mc.netConn != nil { |
| if err := mc.netConn.Close(); err != nil { |
| errLog.Print(err) |
| } |
| mc.netConn = nil |
| } |
| mc.cfg = nil |
| mc.buf.nc = nil |
| } |
| |
| func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { |
| if mc.netConn == nil { |
| errLog.Print(ErrInvalidConn) |
| return nil, driver.ErrBadConn |
| } |
| // Send command |
| err := mc.writeCommandPacketStr(comStmtPrepare, query) |
| if err != nil { |
| return nil, err |
| } |
| |
| stmt := &mysqlStmt{ |
| mc: mc, |
| } |
| |
| // Read Result |
| columnCount, err := stmt.readPrepareResultPacket() |
| if err == nil { |
| if stmt.paramCount > 0 { |
| if err = mc.readUntilEOF(); err != nil { |
| return nil, err |
| } |
| } |
| |
| if columnCount > 0 { |
| err = mc.readUntilEOF() |
| } |
| } |
| |
| return stmt, err |
| } |
| |
| func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { |
| // Number of ? should be same to len(args) |
| if strings.Count(query, "?") != len(args) { |
| return "", driver.ErrSkip |
| } |
| |
| buf := mc.buf.takeCompleteBuffer() |
| if buf == nil { |
| // can not take the buffer. Something must be wrong with the connection |
| errLog.Print(ErrBusyBuffer) |
| return "", driver.ErrBadConn |
| } |
| buf = buf[:0] |
| argPos := 0 |
| |
| for i := 0; i < len(query); i++ { |
| q := strings.IndexByte(query[i:], '?') |
| if q == -1 { |
| buf = append(buf, query[i:]...) |
| break |
| } |
| buf = append(buf, query[i:i+q]...) |
| i += q |
| |
| arg := args[argPos] |
| argPos++ |
| |
| if arg == nil { |
| buf = append(buf, "NULL"...) |
| continue |
| } |
| |
| switch v := arg.(type) { |
| case int64: |
| buf = strconv.AppendInt(buf, v, 10) |
| case float64: |
| buf = strconv.AppendFloat(buf, v, 'g', -1, 64) |
| case bool: |
| if v { |
| buf = append(buf, '1') |
| } else { |
| buf = append(buf, '0') |
| } |
| case time.Time: |
| if v.IsZero() { |
| buf = append(buf, "'0000-00-00'"...) |
| } else { |
| v := v.In(mc.cfg.Loc) |
| v = v.Add(time.Nanosecond * 500) // To round under microsecond |
| year := v.Year() |
| year100 := year / 100 |
| year1 := year % 100 |
| month := v.Month() |
| day := v.Day() |
| hour := v.Hour() |
| minute := v.Minute() |
| second := v.Second() |
| micro := v.Nanosecond() / 1000 |
| |
| buf = append(buf, []byte{ |
| '\'', |
| digits10[year100], digits01[year100], |
| digits10[year1], digits01[year1], |
| '-', |
| digits10[month], digits01[month], |
| '-', |
| digits10[day], digits01[day], |
| ' ', |
| digits10[hour], digits01[hour], |
| ':', |
| digits10[minute], digits01[minute], |
| ':', |
| digits10[second], digits01[second], |
| }...) |
| |
| if micro != 0 { |
| micro10000 := micro / 10000 |
| micro100 := micro / 100 % 100 |
| micro1 := micro % 100 |
| buf = append(buf, []byte{ |
| '.', |
| digits10[micro10000], digits01[micro10000], |
| digits10[micro100], digits01[micro100], |
| digits10[micro1], digits01[micro1], |
| }...) |
| } |
| buf = append(buf, '\'') |
| } |
| case []byte: |
| if v == nil { |
| buf = append(buf, "NULL"...) |
| } else { |
| buf = append(buf, "_binary'"...) |
| if mc.status&statusNoBackslashEscapes == 0 { |
| buf = escapeBytesBackslash(buf, v) |
| } else { |
| buf = escapeBytesQuotes(buf, v) |
| } |
| buf = append(buf, '\'') |
| } |
| case string: |
| buf = append(buf, '\'') |
| if mc.status&statusNoBackslashEscapes == 0 { |
| buf = escapeStringBackslash(buf, v) |
| } else { |
| buf = escapeStringQuotes(buf, v) |
| } |
| buf = append(buf, '\'') |
| default: |
| return "", driver.ErrSkip |
| } |
| |
| if len(buf)+4 > mc.maxAllowedPacket { |
| return "", driver.ErrSkip |
| } |
| } |
| if argPos != len(args) { |
| return "", driver.ErrSkip |
| } |
| return string(buf), nil |
| } |
| |
| func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { |
| if mc.netConn == nil { |
| errLog.Print(ErrInvalidConn) |
| return nil, driver.ErrBadConn |
| } |
| if len(args) != 0 { |
| if !mc.cfg.InterpolateParams { |
| return nil, driver.ErrSkip |
| } |
| // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement |
| prepared, err := mc.interpolateParams(query, args) |
| if err != nil { |
| return nil, err |
| } |
| query = prepared |
| } |
| mc.affectedRows = 0 |
| mc.insertId = 0 |
| |
| err := mc.exec(query) |
| if err == nil { |
| return &mysqlResult{ |
| affectedRows: int64(mc.affectedRows), |
| insertId: int64(mc.insertId), |
| }, err |
| } |
| return nil, err |
| } |
| |
| // Internal function to execute commands |
| func (mc *mysqlConn) exec(query string) error { |
| // Send command |
| if err := mc.writeCommandPacketStr(comQuery, query); err != nil { |
| return err |
| } |
| |
| // Read Result |
| resLen, err := mc.readResultSetHeaderPacket() |
| if err != nil { |
| return err |
| } |
| |
| if resLen > 0 { |
| // columns |
| if err := mc.readUntilEOF(); err != nil { |
| return err |
| } |
| |
| // rows |
| if err := mc.readUntilEOF(); err != nil { |
| return err |
| } |
| } |
| |
| return mc.discardResults() |
| } |
| |
| func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { |
| if mc.netConn == nil { |
| errLog.Print(ErrInvalidConn) |
| return nil, driver.ErrBadConn |
| } |
| if len(args) != 0 { |
| if !mc.cfg.InterpolateParams { |
| return nil, driver.ErrSkip |
| } |
| // try client-side prepare to reduce roundtrip |
| prepared, err := mc.interpolateParams(query, args) |
| if err != nil { |
| return nil, err |
| } |
| query = prepared |
| } |
| // Send command |
| err := mc.writeCommandPacketStr(comQuery, query) |
| if err == nil { |
| // Read Result |
| var resLen int |
| resLen, err = mc.readResultSetHeaderPacket() |
| if err == nil { |
| rows := new(textRows) |
| rows.mc = mc |
| |
| if resLen == 0 { |
| rows.rs.done = true |
| |
| switch err := rows.NextResultSet(); err { |
| case nil, io.EOF: |
| return rows, nil |
| default: |
| return nil, err |
| } |
| } |
| // Columns |
| rows.rs.columns, err = mc.readColumns(resLen) |
| return rows, err |
| } |
| } |
| return nil, err |
| } |
| |
| // Gets the value of the given MySQL System Variable |
| // The returned byte slice is only valid until the next read |
| func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { |
| // Send command |
| if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { |
| return nil, err |
| } |
| |
| // Read Result |
| resLen, err := mc.readResultSetHeaderPacket() |
| if err == nil { |
| rows := new(textRows) |
| rows.mc = mc |
| rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} |
| |
| if resLen > 0 { |
| // Columns |
| if err := mc.readUntilEOF(); err != nil { |
| return nil, err |
| } |
| } |
| |
| dest := make([]driver.Value, resLen) |
| if err = rows.readRow(dest); err == nil { |
| return dest[0].([]byte), mc.readUntilEOF() |
| } |
| } |
| return nil, err |
| } |