| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package |
| // |
| // Copyright 2012 Julien Schmidt. All rights reserved. |
| // http://www.julienschmidt.com |
| // |
| // This Source Code Form is subject to the terms of the Mozilla Public |
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| // You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| package mysql |
| |
| import ( |
| "crypto/sha1" |
| "crypto/tls" |
| "database/sql/driver" |
| "encoding/binary" |
| "fmt" |
| "io" |
| "log" |
| "os" |
| "regexp" |
| "strings" |
| "time" |
| ) |
| |
| // NullTime represents a time.Time that may be NULL. |
| // NullTime implements the Scanner interface so |
| // it can be used as a scan destination: |
| // |
| // var nt NullTime |
| // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) |
| // ... |
| // if nt.Valid { |
| // // use nt.Time |
| // } else { |
| // // NULL value |
| // } |
| // |
| // This NullTime implementation is not driver-specific |
| type NullTime struct { |
| Time time.Time |
| Valid bool // Valid is true if Time is not NULL |
| } |
| |
| // Scan implements the Scanner interface. |
| // The value type must be time.Time or string / []byte (formatted time-string), |
| // otherwise Scan fails. |
| func (nt *NullTime) Scan(value interface{}) (err error) { |
| if value == nil { |
| nt.Time, nt.Valid = time.Time{}, false |
| return |
| } |
| |
| switch v := value.(type) { |
| case time.Time: |
| nt.Time, nt.Valid = v, true |
| return |
| case []byte: |
| nt.Time, err = parseDateTime(string(v), time.UTC) |
| nt.Valid = (err == nil) |
| return |
| case string: |
| nt.Time, err = parseDateTime(v, time.UTC) |
| nt.Valid = (err == nil) |
| return |
| } |
| |
| nt.Valid = false |
| return fmt.Errorf("Can't convert %T to time.Time", value) |
| } |
| |
| // Value implements the driver Valuer interface. |
| func (nt NullTime) Value() (driver.Value, error) { |
| if !nt.Valid { |
| return nil, nil |
| } |
| return nt.Time, nil |
| } |
| |
| var tlsConfigMap map[string]*tls.Config |
| |
| // Registers a custom tls.Config to be used with sql.Open. |
| // Use the key as a value in the DSN where tls=value. |
| // |
| // rootCertPool := x509.NewCertPool() |
| // { |
| // pem, err := ioutil.ReadFile("/path/ca-cert.pem") |
| // if err != nil { |
| // log.Fatal(err) |
| // } |
| // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { |
| // log.Fatal("Failed to append PEM.") |
| // } |
| // } |
| // clientCert := make([]tls.Certificate, 0, 1) |
| // { |
| // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") |
| // if err != nil { |
| // log.Fatal(err) |
| // } |
| // clientCert = append(clientCert, certs) |
| // } |
| // mysql.RegisterTLSConfig("custom", tls.Config{ |
| // RootCAs: rootCertPool, |
| // Certificates: clientCert, |
| // }) |
| // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") |
| // |
| func RegisterTLSConfig(key string, config *tls.Config) { |
| if tlsConfigMap == nil { |
| tlsConfigMap = make(map[string]*tls.Config) |
| } |
| tlsConfigMap[key] = config |
| } |
| |
| // Removes tls.Config associated with key. |
| func DeregisterTLSConfig(key string) { |
| if tlsConfigMap == nil { |
| return |
| } |
| delete(tlsConfigMap, key) |
| } |
| |
| // Logger |
| var ( |
| errLog *log.Logger |
| ) |
| |
| func init() { |
| errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) |
| |
| dsnPattern = regexp.MustCompile( |
| `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] |
| `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] |
| `\/(?P<dbname>.*?)` + // /dbname |
| `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN] |
| } |
| |
| // Data Source Name Parser |
| var dsnPattern *regexp.Regexp |
| |
| func parseDSN(dsn string) (cfg *config, err error) { |
| cfg = new(config) |
| cfg.params = make(map[string]string) |
| |
| matches := dsnPattern.FindStringSubmatch(dsn) |
| names := dsnPattern.SubexpNames() |
| |
| for i, match := range matches { |
| switch names[i] { |
| case "user": |
| cfg.user = match |
| case "passwd": |
| cfg.passwd = match |
| case "net": |
| cfg.net = match |
| case "addr": |
| cfg.addr = match |
| case "dbname": |
| cfg.dbname = match |
| case "params": |
| for _, v := range strings.Split(match, "&") { |
| param := strings.SplitN(v, "=", 2) |
| if len(param) != 2 { |
| continue |
| } |
| |
| // cfg params |
| switch value := param[1]; param[0] { |
| |
| // Disable INFILE whitelist / enable all files |
| case "allowAllFiles": |
| cfg.allowAllFiles = readBool(value) |
| |
| // Switch "rowsAffected" mode |
| case "clientFoundRows": |
| cfg.clientFoundRows = readBool(value) |
| |
| // Time Location |
| case "loc": |
| cfg.loc, err = time.LoadLocation(value) |
| if err != nil { |
| return |
| } |
| |
| // Dial Timeout |
| case "timeout": |
| cfg.timeout, err = time.ParseDuration(value) |
| if err != nil { |
| return |
| } |
| |
| // TLS-Encryption |
| case "tls": |
| if readBool(value) { |
| cfg.tls = &tls.Config{} |
| } else if strings.ToLower(value) == "skip-verify" { |
| cfg.tls = &tls.Config{InsecureSkipVerify: true} |
| // TODO: Check for Boolean false |
| } else if tlsConfig, ok := tlsConfigMap[value]; ok { |
| cfg.tls = tlsConfig |
| } |
| |
| default: |
| cfg.params[param[0]] = value |
| } |
| } |
| } |
| } |
| |
| // Set default network if empty |
| if cfg.net == "" { |
| cfg.net = "tcp" |
| } |
| |
| // Set default adress if empty |
| if cfg.addr == "" { |
| cfg.addr = "127.0.0.1:3306" |
| } |
| |
| // Set default location if not set |
| if cfg.loc == nil { |
| cfg.loc = time.UTC |
| } |
| |
| return |
| } |
| |
| // Encrypt password using 4.1+ method |
| // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later |
| func scramblePassword(scramble, password []byte) []byte { |
| if len(password) == 0 { |
| return nil |
| } |
| |
| // stage1Hash = SHA1(password) |
| crypt := sha1.New() |
| crypt.Write(password) |
| stage1 := crypt.Sum(nil) |
| |
| // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) |
| // inner Hash |
| crypt.Reset() |
| crypt.Write(stage1) |
| hash := crypt.Sum(nil) |
| |
| // outer Hash |
| crypt.Reset() |
| crypt.Write(scramble) |
| crypt.Write(hash) |
| scramble = crypt.Sum(nil) |
| |
| // token = scrambleHash XOR stage1Hash |
| for i := range scramble { |
| scramble[i] ^= stage1[i] |
| } |
| return scramble |
| } |
| |
| func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { |
| switch len(str) { |
| case 10: // YYYY-MM-DD |
| if str == "0000-00-00" { |
| return |
| } |
| t, err = time.Parse(timeFormat[:10], str) |
| case 19: // YYYY-MM-DD HH:MM:SS |
| if str == "0000-00-00 00:00:00" { |
| return |
| } |
| t, err = time.Parse(timeFormat, str) |
| default: |
| err = fmt.Errorf("Invalid Time-String: %s", str) |
| return |
| } |
| |
| // Adjust location |
| if err == nil && loc != time.UTC { |
| y, mo, d := t.Date() |
| h, mi, s := t.Clock() |
| t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil |
| } |
| |
| return |
| } |
| |
| func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { |
| switch num { |
| case 0: |
| return time.Time{}, nil |
| case 4: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| 0, 0, 0, 0, |
| loc, |
| ), nil |
| case 7: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| int(data[4]), // hour |
| int(data[5]), // minutes |
| int(data[6]), // seconds |
| 0, |
| loc, |
| ), nil |
| case 11: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| int(data[4]), // hour |
| int(data[5]), // minutes |
| int(data[6]), // seconds |
| int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds |
| loc, |
| ), nil |
| } |
| return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) |
| } |
| |
| func formatBinaryDate(num uint64, data []byte) (driver.Value, error) { |
| switch num { |
| case 0: |
| return []byte("0000-00-00"), nil |
| case 4: |
| return []byte(fmt.Sprintf( |
| "%04d-%02d-%02d", |
| binary.LittleEndian.Uint16(data[:2]), |
| data[2], |
| data[3], |
| )), nil |
| } |
| return nil, fmt.Errorf("Invalid DATE-packet length %d", num) |
| } |
| |
| func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { |
| switch num { |
| case 0: |
| return []byte("0000-00-00 00:00:00"), nil |
| case 4: |
| return []byte(fmt.Sprintf( |
| "%04d-%02d-%02d 00:00:00", |
| binary.LittleEndian.Uint16(data[:2]), |
| data[2], |
| data[3], |
| )), nil |
| case 7: |
| return []byte(fmt.Sprintf( |
| "%04d-%02d-%02d %02d:%02d:%02d", |
| binary.LittleEndian.Uint16(data[:2]), |
| data[2], |
| data[3], |
| data[4], |
| data[5], |
| data[6], |
| )), nil |
| case 11: |
| return []byte(fmt.Sprintf( |
| "%04d-%02d-%02d %02d:%02d:%02d.%06d", |
| binary.LittleEndian.Uint16(data[:2]), |
| data[2], |
| data[3], |
| data[4], |
| data[5], |
| data[6], |
| binary.LittleEndian.Uint32(data[7:11]), |
| )), nil |
| } |
| return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) |
| } |
| |
| func readBool(value string) bool { |
| switch strings.ToLower(value) { |
| case "true": |
| return true |
| case "1": |
| return true |
| } |
| return false |
| } |
| |
| /****************************************************************************** |
| * Convert from and to bytes * |
| ******************************************************************************/ |
| |
| func uint64ToBytes(n uint64) []byte { |
| return []byte{ |
| byte(n), |
| byte(n >> 8), |
| byte(n >> 16), |
| byte(n >> 24), |
| byte(n >> 32), |
| byte(n >> 40), |
| byte(n >> 48), |
| byte(n >> 56), |
| } |
| } |
| |
| func uint64ToString(n uint64) []byte { |
| var a [20]byte |
| i := 20 |
| |
| // U+0030 = 0 |
| // ... |
| // U+0039 = 9 |
| |
| var q uint64 |
| for n >= 10 { |
| i-- |
| q = n / 10 |
| a[i] = uint8(n-q*10) + 0x30 |
| n = q |
| } |
| |
| i-- |
| a[i] = uint8(n) + 0x30 |
| |
| return a[i:] |
| } |
| |
| // treats string value as unsigned integer representation |
| func stringToInt(b []byte) int { |
| val := 0 |
| for i := range b { |
| val *= 10 |
| val += int(b[i] - 0x30) |
| } |
| return val |
| } |
| |
| func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { |
| // Get length |
| num, isNull, n := readLengthEncodedInteger(b) |
| if num < 1 { |
| return nil, isNull, n, nil |
| } |
| |
| n += int(num) |
| |
| // Check data length |
| if len(b) >= n { |
| return b[n-int(num) : n], false, n, nil |
| } |
| return nil, false, n, io.EOF |
| } |
| |
| func skipLengthEnodedString(b []byte) (int, error) { |
| // Get length |
| num, _, n := readLengthEncodedInteger(b) |
| if num < 1 { |
| return n, nil |
| } |
| |
| n += int(num) |
| |
| // Check data length |
| if len(b) >= n { |
| return n, nil |
| } |
| return n, io.EOF |
| } |
| |
| func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) { |
| switch b[0] { |
| |
| // 251: NULL |
| case 0xfb: |
| n = 1 |
| isNull = true |
| return |
| |
| // 252: value of following 2 |
| case 0xfc: |
| num = uint64(b[1]) | uint64(b[2])<<8 |
| n = 3 |
| return |
| |
| // 253: value of following 3 |
| case 0xfd: |
| num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
| n = 4 |
| return |
| |
| // 254: value of following 8 |
| case 0xfe: |
| num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | |
| uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | |
| uint64(b[7])<<48 | uint64(b[8])<<54 |
| n = 9 |
| return |
| } |
| |
| // 0-250: value of first byte |
| num = uint64(b[0]) |
| n = 1 |
| return |
| } |
| |
| func lengthEncodedIntegerToBytes(n uint64) []byte { |
| switch { |
| case n <= 250: |
| return []byte{byte(n)} |
| |
| case n <= 0xffff: |
| return []byte{0xfc, byte(n), byte(n >> 8)} |
| |
| case n <= 0xffffff: |
| return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} |
| } |
| return nil |
| } |