You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

history.go 23KB


  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "bytes"
  6. "context"
  7. "database/sql"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "runtime/debug"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. _ "github.com/go-sql-driver/mysql"
  17. "github.com/oragono/oragono/irc/history"
  18. "github.com/oragono/oragono/irc/logger"
  19. "github.com/oragono/oragono/irc/utils"
  20. )
  21. var (
  22. ErrDisallowed = errors.New("disallowed")
  23. )
  24. const (
  25. // maximum length in bytes of any message target (nickname or channel name) in its
  26. // canonicalized (i.e., casefolded) state:
  27. MaxTargetLength = 64
  28. // latest schema of the db
  29. latestDbSchema = "2"
  30. keySchemaVersion = "db.version"
  31. // minor version indicates rollback-safe upgrades, i.e.,
  32. // you can downgrade oragono and everything will work
  33. latestDbMinorVersion = "1"
  34. keySchemaMinorVersion = "db.minorversion"
  35. cleanupRowLimit = 50
  36. cleanupPauseTime = 10 * time.Minute
  37. )
  38. type e struct{}
  39. type MySQL struct {
  40. timeout int64
  41. trackAccountMessages uint32
  42. db *sql.DB
  43. logger *logger.Manager
  44. insertHistory *sql.Stmt
  45. insertSequence *sql.Stmt
  46. insertConversation *sql.Stmt
  47. insertAccountMessage *sql.Stmt
  48. stateMutex sync.Mutex
  49. config Config
  50. wakeForgetter chan e
  51. }
  52. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  53. mysql.logger = logger
  54. mysql.wakeForgetter = make(chan e, 1)
  55. mysql.SetConfig(config)
  56. }
  57. func (mysql *MySQL) SetConfig(config Config) {
  58. atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
  59. var trackAccountMessages uint32
  60. if config.TrackAccountMessages {
  61. trackAccountMessages = 1
  62. }
  63. atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
  64. mysql.stateMutex.Lock()
  65. mysql.config = config
  66. mysql.stateMutex.Unlock()
  67. }
  68. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  69. mysql.stateMutex.Lock()
  70. expireTime = mysql.config.ExpireTime
  71. mysql.stateMutex.Unlock()
  72. return
  73. }
  74. func (m *MySQL) Open() (err error) {
  75. var address string
  76. if m.config.SocketPath != "" {
  77. address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
  78. } else if m.config.Port != 0 {
  79. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  80. }
  81. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  82. if err != nil {
  83. return err
  84. }
  85. err = m.fixSchemas()
  86. if err != nil {
  87. return err
  88. }
  89. err = m.prepareStatements()
  90. if err != nil {
  91. return err
  92. }
  93. go m.cleanupLoop()
  94. go m.forgetLoop()
  95. return nil
  96. }
  97. func (mysql *MySQL) fixSchemas() (err error) {
  98. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  99. key_name VARCHAR(32) primary key,
  100. value VARCHAR(32) NOT NULL
  101. ) CHARSET=ascii COLLATE=ascii_bin;`)
  102. if err != nil {
  103. return err
  104. }
  105. var schema string
  106. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  107. if err == sql.ErrNoRows {
  108. err = mysql.createTables()
  109. if err != nil {
  110. return
  111. }
  112. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  113. if err != nil {
  114. return
  115. }
  116. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  117. if err != nil {
  118. return
  119. }
  120. return
  121. } else if err == nil && schema != latestDbSchema {
  122. // TODO figure out what to do about schema changes
  123. return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
  124. } else if err != nil {
  125. return err
  126. }
  127. var minorVersion string
  128. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
  129. if err == sql.ErrNoRows {
  130. // XXX for now, the only minor version upgrade is the account tracking tables
  131. err = mysql.createComplianceTables()
  132. if err != nil {
  133. return
  134. }
  135. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  136. if err != nil {
  137. return
  138. }
  139. } else if err == nil && minorVersion != latestDbMinorVersion {
  140. // TODO: if minorVersion < latestDbMinorVersion, upgrade,
  141. // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
  142. }
  143. return
  144. }
  145. func (mysql *MySQL) createTables() (err error) {
  146. _, err = mysql.db.Exec(`CREATE TABLE history (
  147. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  148. data BLOB NOT NULL,
  149. msgid BINARY(16) NOT NULL,
  150. KEY (msgid(4))
  151. ) CHARSET=ascii COLLATE=ascii_bin;`)
  152. if err != nil {
  153. return err
  154. }
  155. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  156. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  157. target VARBINARY(%[1]d) NOT NULL,
  158. nanotime BIGINT UNSIGNED NOT NULL,
  159. history_id BIGINT NOT NULL,
  160. KEY (target, nanotime),
  161. KEY (history_id)
  162. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  163. if err != nil {
  164. return err
  165. }
  166. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  167. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  168. target VARBINARY(%[1]d) NOT NULL,
  169. correspondent VARBINARY(%[1]d) NOT NULL,
  170. nanotime BIGINT UNSIGNED NOT NULL,
  171. history_id BIGINT NOT NULL,
  172. KEY (target, correspondent, nanotime),
  173. KEY (history_id)
  174. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  175. if err != nil {
  176. return err
  177. }
  178. err = mysql.createComplianceTables()
  179. if err != nil {
  180. return err
  181. }
  182. return nil
  183. }
  184. func (mysql *MySQL) createComplianceTables() (err error) {
  185. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
  186. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  187. account VARBINARY(%[1]d) NOT NULL,
  188. KEY (account, history_id)
  189. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  190. if err != nil {
  191. return err
  192. }
  193. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
  194. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  195. account VARBINARY(%[1]d) NOT NULL
  196. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  197. if err != nil {
  198. return err
  199. }
  200. return nil
  201. }
  202. func (mysql *MySQL) cleanupLoop() {
  203. defer func() {
  204. if r := recover(); r != nil {
  205. mysql.logger.Error("mysql",
  206. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  207. time.Sleep(cleanupPauseTime)
  208. go mysql.cleanupLoop()
  209. }
  210. }()
  211. for {
  212. expireTime := mysql.getExpireTime()
  213. if expireTime != 0 {
  214. for {
  215. startTime := time.Now()
  216. rowsDeleted, err := mysql.doCleanup(expireTime)
  217. elapsed := time.Now().Sub(startTime)
  218. mysql.logError("error during row cleanup", err)
  219. // keep going as long as we're accomplishing significant work
  220. // (don't busy-wait on small numbers of rows expiring):
  221. if rowsDeleted < (cleanupRowLimit / 10) {
  222. break
  223. }
  224. // crude backpressure mechanism: if the database is slow,
  225. // give it time to process other queries
  226. time.Sleep(elapsed)
  227. }
  228. }
  229. time.Sleep(cleanupPauseTime)
  230. }
  231. }
  232. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  233. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  234. defer cancel()
  235. ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
  236. if len(ids) == 0 {
  237. mysql.logger.Debug("mysql", "found no rows to clean up")
  238. return
  239. }
  240. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  241. return len(ids), mysql.deleteHistoryIDs(ctx, ids)
  242. }
  243. func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
  244. // can't use ? binding for a variable number of arguments, build the IN clause manually
  245. var inBuf bytes.Buffer
  246. inBuf.WriteByte('(')
  247. for i, id := range ids {
  248. if i != 0 {
  249. inBuf.WriteRune(',')
  250. }
  251. fmt.Fprintf(&inBuf, "%d", id)
  252. }
  253. inBuf.WriteRune(')')
  254. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
  255. if err != nil {
  256. return
  257. }
  258. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
  259. if err != nil {
  260. return
  261. }
  262. if mysql.isTrackingAccountMessages() {
  263. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes()))
  264. if err != nil {
  265. return
  266. }
  267. }
  268. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
  269. if err != nil {
  270. return
  271. }
  272. return
  273. }
  274. func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  275. rows, err := mysql.db.QueryContext(ctx, `
  276. SELECT history.id, sequence.nanotime
  277. FROM history
  278. LEFT JOIN sequence ON history.id = sequence.history_id
  279. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  280. if err != nil {
  281. return
  282. }
  283. defer rows.Close()
  284. // a history ID may have 0-2 rows in sequence: 1 for a channel entry,
  285. // 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
  286. // and delete anything that doesn't have a sequence entry:
  287. idset := make(map[uint64]struct{}, cleanupRowLimit)
  288. threshold := time.Now().Add(-age).UnixNano()
  289. for rows.Next() {
  290. var id uint64
  291. var nanotime sql.NullInt64
  292. err = rows.Scan(&id, &nanotime)
  293. if err != nil {
  294. return
  295. }
  296. if !nanotime.Valid || nanotime.Int64 < threshold {
  297. idset[id] = struct{}{}
  298. if nanotime.Valid && nanotime.Int64 > maxNanotime {
  299. maxNanotime = nanotime.Int64
  300. }
  301. }
  302. }
  303. ids = make([]uint64, len(idset))
  304. i := 0
  305. for id := range idset {
  306. ids[i] = id
  307. i++
  308. }
  309. return
  310. }
  311. // wait for forget queue items and process them one by one
  312. func (mysql *MySQL) forgetLoop() {
  313. defer func() {
  314. if r := recover(); r != nil {
  315. mysql.logger.Error("mysql",
  316. fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
  317. time.Sleep(cleanupPauseTime)
  318. go mysql.forgetLoop()
  319. }
  320. }()
  321. for {
  322. for {
  323. found, err := mysql.doForget()
  324. mysql.logError("error processing forget", err)
  325. if err != nil {
  326. time.Sleep(cleanupPauseTime)
  327. }
  328. if !found {
  329. break
  330. }
  331. }
  332. <-mysql.wakeForgetter
  333. }
  334. }
  335. // dequeue an item from the forget queue and process it
  336. func (mysql *MySQL) doForget() (found bool, err error) {
  337. id, account, err := func() (id int64, account string, err error) {
  338. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  339. defer cancel()
  340. row := mysql.db.QueryRowContext(ctx,
  341. `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
  342. err = row.Scan(&id, &account)
  343. if err == sql.ErrNoRows {
  344. return 0, "", nil
  345. }
  346. return
  347. }()
  348. if err != nil || account == "" {
  349. return false, err
  350. }
  351. found = true
  352. var count int
  353. for {
  354. start := time.Now()
  355. count, err = mysql.doForgetIteration(account)
  356. elapsed := time.Since(start)
  357. if err != nil {
  358. return true, err
  359. }
  360. if count == 0 {
  361. break
  362. }
  363. time.Sleep(elapsed)
  364. }
  365. mysql.logger.Debug("mysql", "forget complete for account", account)
  366. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  367. defer cancel()
  368. _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
  369. return
  370. }
  371. func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
  372. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  373. defer cancel()
  374. rows, err := mysql.db.QueryContext(ctx, `
  375. SELECT account_messages.history_id
  376. FROM account_messages
  377. WHERE account_messages.account = ?
  378. LIMIT ?;`, account, cleanupRowLimit)
  379. if err != nil {
  380. return
  381. }
  382. defer rows.Close()
  383. var ids []uint64
  384. for rows.Next() {
  385. var id uint64
  386. err = rows.Scan(&id)
  387. if err != nil {
  388. return
  389. }
  390. ids = append(ids, id)
  391. }
  392. if len(ids) == 0 {
  393. return
  394. }
  395. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
  396. err = mysql.deleteHistoryIDs(ctx, ids)
  397. return len(ids), err
  398. }
  399. func (mysql *MySQL) prepareStatements() (err error) {
  400. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  401. (data, msgid) VALUES (?, ?);`)
  402. if err != nil {
  403. return
  404. }
  405. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  406. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  407. if err != nil {
  408. return
  409. }
  410. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  411. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  412. if err != nil {
  413. return
  414. }
  415. mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
  416. (history_id, account) VALUES (?, ?);`)
  417. if err != nil {
  418. return
  419. }
  420. return
  421. }
  422. func (mysql *MySQL) getTimeout() time.Duration {
  423. return time.Duration(atomic.LoadInt64(&mysql.timeout))
  424. }
  425. func (mysql *MySQL) isTrackingAccountMessages() bool {
  426. return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
  427. }
  428. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  429. if err != nil {
  430. mysql.logger.Error("mysql", context, err.Error())
  431. return true
  432. }
  433. return false
  434. }
  435. func (mysql *MySQL) Forget(account string) {
  436. if mysql.db == nil || account == "" {
  437. return
  438. }
  439. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  440. defer cancel()
  441. _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
  442. if mysql.logError("can't insert into forget table", err) {
  443. return
  444. }
  445. // wake up the forget goroutine if it's blocked:
  446. select {
  447. case mysql.wakeForgetter <- e{}:
  448. default:
  449. }
  450. }
  451. func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
  452. if mysql.db == nil {
  453. return
  454. }
  455. if target == "" {
  456. return utils.ErrInvalidParams
  457. }
  458. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  459. defer cancel()
  460. id, err := mysql.insertBase(ctx, item)
  461. if err != nil {
  462. return
  463. }
  464. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  465. if err != nil {
  466. return
  467. }
  468. err = mysql.insertAccountMessageEntry(ctx, id, account)
  469. if err != nil {
  470. return
  471. }
  472. return
  473. }
  474. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  475. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  476. mysql.logError("could not insert sequence entry", err)
  477. return
  478. }
  479. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  480. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  481. mysql.logError("could not insert conversations entry", err)
  482. return
  483. }
  484. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  485. value, err := marshalItem(&item)
  486. if mysql.logError("could not marshal item", err) {
  487. return
  488. }
  489. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  490. if mysql.logError("could not decode msgid", err) {
  491. return
  492. }
  493. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  494. if mysql.logError("could not insert item", err) {
  495. return
  496. }
  497. id, err = result.LastInsertId()
  498. if mysql.logError("could not insert item", err) {
  499. return
  500. }
  501. return
  502. }
  503. func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
  504. if account == "" || !mysql.isTrackingAccountMessages() {
  505. return
  506. }
  507. _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
  508. mysql.logError("could not insert account-message entry", err)
  509. return
  510. }
  511. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  512. if mysql.db == nil {
  513. return
  514. }
  515. if senderAccount == "" && recipientAccount == "" {
  516. return
  517. }
  518. if sender == "" || recipient == "" {
  519. return utils.ErrInvalidParams
  520. }
  521. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  522. defer cancel()
  523. id, err := mysql.insertBase(ctx, item)
  524. if err != nil {
  525. return
  526. }
  527. nanotime := item.Message.Time.UnixNano()
  528. if senderAccount != "" {
  529. err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id)
  530. if err != nil {
  531. return
  532. }
  533. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  534. if err != nil {
  535. return
  536. }
  537. }
  538. if recipientAccount != "" && sender != recipient {
  539. err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id)
  540. if err != nil {
  541. return
  542. }
  543. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  544. if err != nil {
  545. return
  546. }
  547. }
  548. err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
  549. if err != nil {
  550. return
  551. }
  552. return
  553. }
  554. // note that accountName is the unfolded name
  555. func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
  556. if mysql.db == nil {
  557. return nil
  558. }
  559. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  560. defer cancel()
  561. _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
  562. if err != nil {
  563. return
  564. }
  565. if accountName != "*" {
  566. var item history.Item
  567. err = unmarshalItem(data, &item)
  568. // delete if the entry is corrupt
  569. if err == nil && item.AccountName != accountName {
  570. return ErrDisallowed
  571. }
  572. }
  573. err = mysql.deleteHistoryIDs(ctx, []uint64{id})
  574. mysql.logError("couldn't delete msgid", err)
  575. return
  576. }
  577. func (mysql *MySQL) Export(account string, writer io.Writer) {
  578. if mysql.db == nil {
  579. return
  580. }
  581. var err error
  582. var lastSeen uint64
  583. for {
  584. rows := func() (count int) {
  585. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  586. defer cancel()
  587. rows, rowsErr := mysql.db.QueryContext(ctx, `
  588. SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
  589. INNER JOIN history ON history.id = account_messages.history_id
  590. INNER JOIN sequence ON account_messages.history_id = sequence.history_id
  591. WHERE account_messages.account = ? AND account_messages.history_id > ?
  592. LIMIT ?`, account, lastSeen, cleanupRowLimit)
  593. if rowsErr != nil {
  594. err = rowsErr
  595. return
  596. }
  597. defer rows.Close()
  598. for rows.Next() {
  599. var id uint64
  600. var blob, jsonBlob []byte
  601. var target string
  602. var item history.Item
  603. err = rows.Scan(&id, &blob, &target)
  604. if err != nil {
  605. return
  606. }
  607. err = unmarshalItem(blob, &item)
  608. if err != nil {
  609. return
  610. }
  611. item.CfCorrespondent = target
  612. jsonBlob, err = json.Marshal(item)
  613. if err != nil {
  614. return
  615. }
  616. count++
  617. if lastSeen < id {
  618. lastSeen = id
  619. }
  620. writer.Write(jsonBlob)
  621. writer.Write([]byte{'\n'})
  622. }
  623. return
  624. }()
  625. if rows == 0 || err != nil {
  626. break
  627. }
  628. }
  629. mysql.logError("could not export history", err)
  630. return
  631. }
  632. func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
  633. // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
  634. // sequence.nanotime > (
  635. // SELECT sequence.nanotime FROM sequence, history
  636. // WHERE sequence.history_id = history.id AND history.msgid = ?
  637. // LIMIT 1)
  638. // however, this doesn't handle the BETWEEN case with one or two msgids, where we
  639. // don't initially know whether the interval is going forwards or backwards. to simplify
  640. // the logic, resolve msgids to timestamps "manually" in all cases, using a separate query.
  641. decoded, err := decodeMsgid(msgid)
  642. if err != nil {
  643. return
  644. }
  645. cols := `sequence.nanotime`
  646. if includeData {
  647. cols = `sequence.nanotime, sequence.history_id, history.data`
  648. }
  649. row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
  650. SELECT %s FROM sequence
  651. INNER JOIN history ON history.id = sequence.history_id
  652. WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
  653. var nanotime int64
  654. if !includeData {
  655. err = row.Scan(&nanotime)
  656. } else {
  657. err = row.Scan(&nanotime, &id, &data)
  658. }
  659. if err != sql.ErrNoRows {
  660. mysql.logError("could not resolve msgid to time", err)
  661. }
  662. if err != nil {
  663. return
  664. }
  665. result = time.Unix(0, nanotime).UTC()
  666. return
  667. }
  668. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  669. rows, err := mysql.db.QueryContext(ctx, query, args...)
  670. if mysql.logError("could not select history items", err) {
  671. return
  672. }
  673. defer rows.Close()
  674. for rows.Next() {
  675. var blob []byte
  676. var item history.Item
  677. err = rows.Scan(&blob)
  678. if mysql.logError("could not scan history item", err) {
  679. return
  680. }
  681. err = unmarshalItem(blob, &item)
  682. if mysql.logError("could not unmarshal history item", err) {
  683. return
  684. }
  685. results = append(results, item)
  686. }
  687. return
  688. }
  689. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  690. useSequence := correspondent == ""
  691. table := "sequence"
  692. if !useSequence {
  693. table = "conversations"
  694. }
  695. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  696. direction := "ASC"
  697. if !ascending {
  698. direction = "DESC"
  699. }
  700. var queryBuf bytes.Buffer
  701. args := make([]interface{}, 0, 6)
  702. fmt.Fprintf(&queryBuf,
  703. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  704. if useSequence {
  705. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  706. args = append(args, target)
  707. } else {
  708. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  709. args = append(args, target)
  710. args = append(args, correspondent)
  711. }
  712. if !after.IsZero() {
  713. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  714. args = append(args, after.UnixNano())
  715. }
  716. if !before.IsZero() {
  717. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  718. args = append(args, before.UnixNano())
  719. }
  720. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  721. args = append(args, limit)
  722. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  723. if err == nil && !ascending {
  724. history.Reverse(results)
  725. }
  726. return
  727. }
  728. func (mysql *MySQL) Close() {
  729. // closing the database will close our prepared statements as well
  730. if mysql.db != nil {
  731. mysql.db.Close()
  732. }
  733. mysql.db = nil
  734. }
  735. // implements history.Sequence, emulating a single history buffer (for a channel,
  736. // a single user's DMs, or a DM conversation)
  737. type mySQLHistorySequence struct {
  738. mysql *MySQL
  739. target string
  740. correspondent string
  741. cutoff time.Time
  742. }
  743. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
  744. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  745. defer cancel()
  746. startTime := start.Time
  747. if start.Msgid != "" {
  748. startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
  749. if err != nil {
  750. return nil, false, err
  751. }
  752. }
  753. endTime := end.Time
  754. if end.Msgid != "" {
  755. endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
  756. if err != nil {
  757. return nil, false, err
  758. }
  759. }
  760. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  761. return results, (err == nil), err
  762. }
  763. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  764. return history.GenericAround(s, start, limit)
  765. }
  766. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
  767. return &mySQLHistorySequence{
  768. target: target,
  769. correspondent: correspondent,
  770. mysql: mysql,
  771. cutoff: cutoff,
  772. }
  773. }