You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

connection.go 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "context"
  11. "database/sql"
  12. "database/sql/driver"
  13. "io"
  14. "net"
  15. "strconv"
  16. "strings"
  17. "time"
  18. )
  19. type mysqlConn struct {
  20. buf buffer
  21. netConn net.Conn
  22. rawConn net.Conn // underlying connection when netConn is TLS connection.
  23. affectedRows uint64
  24. insertId uint64
  25. cfg *Config
  26. maxAllowedPacket int
  27. maxWriteSize int
  28. writeTimeout time.Duration
  29. flags clientFlag
  30. status statusFlag
  31. sequence uint8
  32. parseTime bool
  33. reset bool // set when the Go SQL package calls ResetSession
  34. // for context support (Go 1.8+)
  35. watching bool
  36. watcher chan<- context.Context
  37. closech chan struct{}
  38. finished chan<- struct{}
  39. canceled atomicError // set non-nil if conn is canceled
  40. closed atomicBool // set when conn is closed, before closech is closed
  41. }
  42. // Handles parameters set in DSN after the connection is established
  43. func (mc *mysqlConn) handleParams() (err error) {
  44. for param, val := range mc.cfg.Params {
  45. switch param {
  46. // Charset
  47. case "charset":
  48. charsets := strings.Split(val, ",")
  49. for i := range charsets {
  50. // ignore errors here - a charset may not exist
  51. err = mc.exec("SET NAMES " + charsets[i])
  52. if err == nil {
  53. break
  54. }
  55. }
  56. if err != nil {
  57. return
  58. }
  59. // System Vars
  60. default:
  61. err = mc.exec("SET " + param + "=" + val + "")
  62. if err != nil {
  63. return
  64. }
  65. }
  66. }
  67. return
  68. }
  69. func (mc *mysqlConn) markBadConn(err error) error {
  70. if mc == nil {
  71. return err
  72. }
  73. if err != errBadConnNoWrite {
  74. return err
  75. }
  76. return driver.ErrBadConn
  77. }
  78. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  79. return mc.begin(false)
  80. }
  81. func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
  82. if mc.closed.IsSet() {
  83. errLog.Print(ErrInvalidConn)
  84. return nil, driver.ErrBadConn
  85. }
  86. var q string
  87. if readOnly {
  88. q = "START TRANSACTION READ ONLY"
  89. } else {
  90. q = "START TRANSACTION"
  91. }
  92. err := mc.exec(q)
  93. if err == nil {
  94. return &mysqlTx{mc}, err
  95. }
  96. return nil, mc.markBadConn(err)
  97. }
  98. func (mc *mysqlConn) Close() (err error) {
  99. // Makes Close idempotent
  100. if !mc.closed.IsSet() {
  101. err = mc.writeCommandPacket(comQuit)
  102. }
  103. mc.cleanup()
  104. return
  105. }
  106. // Closes the network connection and unsets internal variables. Do not call this
  107. // function after successfully authentication, call Close instead. This function
  108. // is called before auth or on auth failure because MySQL will have already
  109. // closed the network connection.
  110. func (mc *mysqlConn) cleanup() {
  111. if !mc.closed.TrySet(true) {
  112. return
  113. }
  114. // Makes cleanup idempotent
  115. close(mc.closech)
  116. if mc.netConn == nil {
  117. return
  118. }
  119. if err := mc.netConn.Close(); err != nil {
  120. errLog.Print(err)
  121. }
  122. }
  123. func (mc *mysqlConn) error() error {
  124. if mc.closed.IsSet() {
  125. if err := mc.canceled.Value(); err != nil {
  126. return err
  127. }
  128. return ErrInvalidConn
  129. }
  130. return nil
  131. }
  132. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  133. if mc.closed.IsSet() {
  134. errLog.Print(ErrInvalidConn)
  135. return nil, driver.ErrBadConn
  136. }
  137. // Send command
  138. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  139. if err != nil {
  140. // STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
  141. errLog.Print(err)
  142. return nil, driver.ErrBadConn
  143. }
  144. stmt := &mysqlStmt{
  145. mc: mc,
  146. }
  147. // Read Result
  148. columnCount, err := stmt.readPrepareResultPacket()
  149. if err == nil {
  150. if stmt.paramCount > 0 {
  151. if err = mc.readUntilEOF(); err != nil {
  152. return nil, err
  153. }
  154. }
  155. if columnCount > 0 {
  156. err = mc.readUntilEOF()
  157. }
  158. }
  159. return stmt, err
  160. }
  161. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  162. // Number of ? should be same to len(args)
  163. if strings.Count(query, "?") != len(args) {
  164. return "", driver.ErrSkip
  165. }
  166. buf, err := mc.buf.takeCompleteBuffer()
  167. if err != nil {
  168. // can not take the buffer. Something must be wrong with the connection
  169. errLog.Print(err)
  170. return "", ErrInvalidConn
  171. }
  172. buf = buf[:0]
  173. argPos := 0
  174. for i := 0; i < len(query); i++ {
  175. q := strings.IndexByte(query[i:], '?')
  176. if q == -1 {
  177. buf = append(buf, query[i:]...)
  178. break
  179. }
  180. buf = append(buf, query[i:i+q]...)
  181. i += q
  182. arg := args[argPos]
  183. argPos++
  184. if arg == nil {
  185. buf = append(buf, "NULL"...)
  186. continue
  187. }
  188. switch v := arg.(type) {
  189. case int64:
  190. buf = strconv.AppendInt(buf, v, 10)
  191. case uint64:
  192. // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
  193. buf = strconv.AppendUint(buf, v, 10)
  194. case float64:
  195. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  196. case bool:
  197. if v {
  198. buf = append(buf, '1')
  199. } else {
  200. buf = append(buf, '0')
  201. }
  202. case time.Time:
  203. if v.IsZero() {
  204. buf = append(buf, "'0000-00-00'"...)
  205. } else {
  206. v := v.In(mc.cfg.Loc)
  207. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  208. year := v.Year()
  209. year100 := year / 100
  210. year1 := year % 100
  211. month := v.Month()
  212. day := v.Day()
  213. hour := v.Hour()
  214. minute := v.Minute()
  215. second := v.Second()
  216. micro := v.Nanosecond() / 1000
  217. buf = append(buf, []byte{
  218. '\'',
  219. digits10[year100], digits01[year100],
  220. digits10[year1], digits01[year1],
  221. '-',
  222. digits10[month], digits01[month],
  223. '-',
  224. digits10[day], digits01[day],
  225. ' ',
  226. digits10[hour], digits01[hour],
  227. ':',
  228. digits10[minute], digits01[minute],
  229. ':',
  230. digits10[second], digits01[second],
  231. }...)
  232. if micro != 0 {
  233. micro10000 := micro / 10000
  234. micro100 := micro / 100 % 100
  235. micro1 := micro % 100
  236. buf = append(buf, []byte{
  237. '.',
  238. digits10[micro10000], digits01[micro10000],
  239. digits10[micro100], digits01[micro100],
  240. digits10[micro1], digits01[micro1],
  241. }...)
  242. }
  243. buf = append(buf, '\'')
  244. }
  245. case []byte:
  246. if v == nil {
  247. buf = append(buf, "NULL"...)
  248. } else {
  249. buf = append(buf, "_binary'"...)
  250. if mc.status&statusNoBackslashEscapes == 0 {
  251. buf = escapeBytesBackslash(buf, v)
  252. } else {
  253. buf = escapeBytesQuotes(buf, v)
  254. }
  255. buf = append(buf, '\'')
  256. }
  257. case string:
  258. buf = append(buf, '\'')
  259. if mc.status&statusNoBackslashEscapes == 0 {
  260. buf = escapeStringBackslash(buf, v)
  261. } else {
  262. buf = escapeStringQuotes(buf, v)
  263. }
  264. buf = append(buf, '\'')
  265. default:
  266. return "", driver.ErrSkip
  267. }
  268. if len(buf)+4 > mc.maxAllowedPacket {
  269. return "", driver.ErrSkip
  270. }
  271. }
  272. if argPos != len(args) {
  273. return "", driver.ErrSkip
  274. }
  275. return string(buf), nil
  276. }
  277. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  278. if mc.closed.IsSet() {
  279. errLog.Print(ErrInvalidConn)
  280. return nil, driver.ErrBadConn
  281. }
  282. if len(args) != 0 {
  283. if !mc.cfg.InterpolateParams {
  284. return nil, driver.ErrSkip
  285. }
  286. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  287. prepared, err := mc.interpolateParams(query, args)
  288. if err != nil {
  289. return nil, err
  290. }
  291. query = prepared
  292. }
  293. mc.affectedRows = 0
  294. mc.insertId = 0
  295. err := mc.exec(query)
  296. if err == nil {
  297. return &mysqlResult{
  298. affectedRows: int64(mc.affectedRows),
  299. insertId: int64(mc.insertId),
  300. }, err
  301. }
  302. return nil, mc.markBadConn(err)
  303. }
  304. // Internal function to execute commands
  305. func (mc *mysqlConn) exec(query string) error {
  306. // Send command
  307. if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
  308. return mc.markBadConn(err)
  309. }
  310. // Read Result
  311. resLen, err := mc.readResultSetHeaderPacket()
  312. if err != nil {
  313. return err
  314. }
  315. if resLen > 0 {
  316. // columns
  317. if err := mc.readUntilEOF(); err != nil {
  318. return err
  319. }
  320. // rows
  321. if err := mc.readUntilEOF(); err != nil {
  322. return err
  323. }
  324. }
  325. return mc.discardResults()
  326. }
  327. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  328. return mc.query(query, args)
  329. }
  330. func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
  331. if mc.closed.IsSet() {
  332. errLog.Print(ErrInvalidConn)
  333. return nil, driver.ErrBadConn
  334. }
  335. if len(args) != 0 {
  336. if !mc.cfg.InterpolateParams {
  337. return nil, driver.ErrSkip
  338. }
  339. // try client-side prepare to reduce roundtrip
  340. prepared, err := mc.interpolateParams(query, args)
  341. if err != nil {
  342. return nil, err
  343. }
  344. query = prepared
  345. }
  346. // Send command
  347. err := mc.writeCommandPacketStr(comQuery, query)
  348. if err == nil {
  349. // Read Result
  350. var resLen int
  351. resLen, err = mc.readResultSetHeaderPacket()
  352. if err == nil {
  353. rows := new(textRows)
  354. rows.mc = mc
  355. if resLen == 0 {
  356. rows.rs.done = true
  357. switch err := rows.NextResultSet(); err {
  358. case nil, io.EOF:
  359. return rows, nil
  360. default:
  361. return nil, err
  362. }
  363. }
  364. // Columns
  365. rows.rs.columns, err = mc.readColumns(resLen)
  366. return rows, err
  367. }
  368. }
  369. return nil, mc.markBadConn(err)
  370. }
  371. // Gets the value of the given MySQL System Variable
  372. // The returned byte slice is only valid until the next read
  373. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  374. // Send command
  375. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  376. return nil, err
  377. }
  378. // Read Result
  379. resLen, err := mc.readResultSetHeaderPacket()
  380. if err == nil {
  381. rows := new(textRows)
  382. rows.mc = mc
  383. rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  384. if resLen > 0 {
  385. // Columns
  386. if err := mc.readUntilEOF(); err != nil {
  387. return nil, err
  388. }
  389. }
  390. dest := make([]driver.Value, resLen)
  391. if err = rows.readRow(dest); err == nil {
  392. return dest[0].([]byte), mc.readUntilEOF()
  393. }
  394. }
  395. return nil, err
  396. }
  397. // finish is called when the query has canceled.
  398. func (mc *mysqlConn) cancel(err error) {
  399. mc.canceled.Set(err)
  400. mc.cleanup()
  401. }
  402. // finish is called when the query has succeeded.
  403. func (mc *mysqlConn) finish() {
  404. if !mc.watching || mc.finished == nil {
  405. return
  406. }
  407. select {
  408. case mc.finished <- struct{}{}:
  409. mc.watching = false
  410. case <-mc.closech:
  411. }
  412. }
  413. // Ping implements driver.Pinger interface
  414. func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
  415. if mc.closed.IsSet() {
  416. errLog.Print(ErrInvalidConn)
  417. return driver.ErrBadConn
  418. }
  419. if err = mc.watchCancel(ctx); err != nil {
  420. return
  421. }
  422. defer mc.finish()
  423. if err = mc.writeCommandPacket(comPing); err != nil {
  424. return mc.markBadConn(err)
  425. }
  426. return mc.readResultOK()
  427. }
  428. // BeginTx implements driver.ConnBeginTx interface
  429. func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  430. if err := mc.watchCancel(ctx); err != nil {
  431. return nil, err
  432. }
  433. defer mc.finish()
  434. if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
  435. level, err := mapIsolationLevel(opts.Isolation)
  436. if err != nil {
  437. return nil, err
  438. }
  439. err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
  440. if err != nil {
  441. return nil, err
  442. }
  443. }
  444. return mc.begin(opts.ReadOnly)
  445. }
  446. func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
  447. dargs, err := namedValueToValue(args)
  448. if err != nil {
  449. return nil, err
  450. }
  451. if err := mc.watchCancel(ctx); err != nil {
  452. return nil, err
  453. }
  454. rows, err := mc.query(query, dargs)
  455. if err != nil {
  456. mc.finish()
  457. return nil, err
  458. }
  459. rows.finish = mc.finish
  460. return rows, err
  461. }
  462. func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
  463. dargs, err := namedValueToValue(args)
  464. if err != nil {
  465. return nil, err
  466. }
  467. if err := mc.watchCancel(ctx); err != nil {
  468. return nil, err
  469. }
  470. defer mc.finish()
  471. return mc.Exec(query, dargs)
  472. }
  473. func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  474. if err := mc.watchCancel(ctx); err != nil {
  475. return nil, err
  476. }
  477. stmt, err := mc.Prepare(query)
  478. mc.finish()
  479. if err != nil {
  480. return nil, err
  481. }
  482. select {
  483. default:
  484. case <-ctx.Done():
  485. stmt.Close()
  486. return nil, ctx.Err()
  487. }
  488. return stmt, nil
  489. }
  490. func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  491. dargs, err := namedValueToValue(args)
  492. if err != nil {
  493. return nil, err
  494. }
  495. if err := stmt.mc.watchCancel(ctx); err != nil {
  496. return nil, err
  497. }
  498. rows, err := stmt.query(dargs)
  499. if err != nil {
  500. stmt.mc.finish()
  501. return nil, err
  502. }
  503. rows.finish = stmt.mc.finish
  504. return rows, err
  505. }
  506. func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  507. dargs, err := namedValueToValue(args)
  508. if err != nil {
  509. return nil, err
  510. }
  511. if err := stmt.mc.watchCancel(ctx); err != nil {
  512. return nil, err
  513. }
  514. defer stmt.mc.finish()
  515. return stmt.Exec(dargs)
  516. }
  517. func (mc *mysqlConn) watchCancel(ctx context.Context) error {
  518. if mc.watching {
  519. // Reach here if canceled,
  520. // so the connection is already invalid
  521. mc.cleanup()
  522. return nil
  523. }
  524. // When ctx is already cancelled, don't watch it.
  525. if err := ctx.Err(); err != nil {
  526. return err
  527. }
  528. // When ctx is not cancellable, don't watch it.
  529. if ctx.Done() == nil {
  530. return nil
  531. }
  532. // When watcher is not alive, can't watch it.
  533. if mc.watcher == nil {
  534. return nil
  535. }
  536. mc.watching = true
  537. mc.watcher <- ctx
  538. return nil
  539. }
  540. func (mc *mysqlConn) startWatcher() {
  541. watcher := make(chan context.Context, 1)
  542. mc.watcher = watcher
  543. finished := make(chan struct{})
  544. mc.finished = finished
  545. go func() {
  546. for {
  547. var ctx context.Context
  548. select {
  549. case ctx = <-watcher:
  550. case <-mc.closech:
  551. return
  552. }
  553. select {
  554. case <-ctx.Done():
  555. mc.cancel(ctx.Err())
  556. case <-finished:
  557. case <-mc.closech:
  558. return
  559. }
  560. }
  561. }()
  562. }
  563. func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
  564. nv.Value, err = converter{}.ConvertValue(nv.Value)
  565. return
  566. }
  567. // ResetSession implements driver.SessionResetter.
  568. // (From Go 1.10)
  569. func (mc *mysqlConn) ResetSession(ctx context.Context) error {
  570. if mc.closed.IsSet() {
  571. return driver.ErrBadConn
  572. }
  573. mc.reset = true
  574. return nil
  575. }