You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "context"
  6. "database/sql"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "runtime/debug"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/ergochat/ergo/irc/history"
  17. "github.com/ergochat/ergo/irc/logger"
  18. "github.com/ergochat/ergo/irc/utils"
  19. _ "github.com/go-sql-driver/mysql"
  20. )
  21. var (
  22. ErrDisallowed = errors.New("disallowed")
  23. )
  24. const (
  25. // maximum length in bytes of any message target (nickname or channel name) in its
  26. // canonicalized (i.e., casefolded) state:
  27. MaxTargetLength = 64
  28. // latest schema of the db
  29. latestDbSchema = "2"
  30. keySchemaVersion = "db.version"
  31. // minor version indicates rollback-safe upgrades, i.e.,
  32. // you can downgrade oragono and everything will work
  33. latestDbMinorVersion = "2"
  34. keySchemaMinorVersion = "db.minorversion"
  35. cleanupRowLimit = 50
  36. cleanupPauseTime = 10 * time.Minute
  37. // if we don't fill the pagination window due to exclusions,
  38. // retry with an expanded window at most this many times
  39. maxPaginationRetries = 3
  40. )
  41. type e struct{}
  42. type MySQL struct {
  43. timeout int64
  44. trackAccountMessages uint32
  45. db *sql.DB
  46. logger *logger.Manager
  47. insertHistory *sql.Stmt
  48. insertSequence *sql.Stmt
  49. insertConversation *sql.Stmt
  50. insertCorrespondent *sql.Stmt
  51. insertAccountMessage *sql.Stmt
  52. stateMutex sync.Mutex
  53. config Config
  54. wakeForgetter chan e
  55. }
  56. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  57. mysql.logger = logger
  58. mysql.wakeForgetter = make(chan e, 1)
  59. mysql.SetConfig(config)
  60. }
  61. func (mysql *MySQL) SetConfig(config Config) {
  62. atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
  63. var trackAccountMessages uint32
  64. if config.TrackAccountMessages {
  65. trackAccountMessages = 1
  66. }
  67. atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
  68. mysql.stateMutex.Lock()
  69. mysql.config = config
  70. mysql.stateMutex.Unlock()
  71. }
  72. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  73. mysql.stateMutex.Lock()
  74. expireTime = mysql.config.ExpireTime
  75. mysql.stateMutex.Unlock()
  76. return
  77. }
  78. func (m *MySQL) Open() (err error) {
  79. var address string
  80. if m.config.SocketPath != "" {
  81. address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
  82. } else if m.config.Port != 0 {
  83. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  84. }
  85. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  86. if err != nil {
  87. return err
  88. }
  89. if m.config.MaxConns != 0 {
  90. m.db.SetMaxOpenConns(m.config.MaxConns)
  91. m.db.SetMaxIdleConns(m.config.MaxConns)
  92. }
  93. if m.config.ConnMaxLifetime != 0 {
  94. m.db.SetConnMaxLifetime(m.config.ConnMaxLifetime)
  95. }
  96. err = m.fixSchemas()
  97. if err != nil {
  98. return err
  99. }
  100. err = m.prepareStatements()
  101. if err != nil {
  102. return err
  103. }
  104. go m.cleanupLoop()
  105. go m.forgetLoop()
  106. return nil
  107. }
  108. func (mysql *MySQL) fixSchemas() (err error) {
  109. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  110. key_name VARCHAR(32) primary key,
  111. value VARCHAR(32) NOT NULL
  112. ) CHARSET=ascii COLLATE=ascii_bin;`)
  113. if err != nil {
  114. return err
  115. }
  116. var schema string
  117. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  118. if err == sql.ErrNoRows {
  119. err = mysql.createTables()
  120. if err != nil {
  121. return
  122. }
  123. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  124. if err != nil {
  125. return
  126. }
  127. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  128. if err != nil {
  129. return
  130. }
  131. return
  132. } else if err == nil && schema != latestDbSchema {
  133. // TODO figure out what to do about schema changes
  134. return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
  135. } else if err != nil {
  136. return err
  137. }
  138. var minorVersion string
  139. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
  140. if err == sql.ErrNoRows {
  141. // XXX for now, the only minor version upgrade is the account tracking tables
  142. err = mysql.createComplianceTables()
  143. if err != nil {
  144. return
  145. }
  146. err = mysql.createCorrespondentsTable()
  147. if err != nil {
  148. return
  149. }
  150. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  151. if err != nil {
  152. return
  153. }
  154. } else if err == nil && minorVersion == "1" {
  155. // upgrade from 2.1 to 2.2: create the correspondents table
  156. err = mysql.createCorrespondentsTable()
  157. if err != nil {
  158. return
  159. }
  160. _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
  161. if err != nil {
  162. return
  163. }
  164. } else if err == nil && minorVersion != latestDbMinorVersion {
  165. // TODO: if minorVersion < latestDbMinorVersion, upgrade,
  166. // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
  167. }
  168. return
  169. }
  170. func (mysql *MySQL) createTables() (err error) {
  171. _, err = mysql.db.Exec(`CREATE TABLE history (
  172. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  173. data BLOB NOT NULL,
  174. msgid BINARY(16) NOT NULL,
  175. KEY (msgid(4))
  176. ) CHARSET=ascii COLLATE=ascii_bin;`)
  177. if err != nil {
  178. return err
  179. }
  180. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  181. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  182. target VARBINARY(%[1]d) NOT NULL,
  183. nanotime BIGINT UNSIGNED NOT NULL,
  184. KEY (target, nanotime)
  185. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  186. if err != nil {
  187. return err
  188. }
  189. /* XXX: this table used to be:
  190. CREATE TABLE sequence (
  191. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  192. target VARBINARY(%[1]d) NOT NULL,
  193. nanotime BIGINT UNSIGNED NOT NULL,
  194. history_id BIGINT NOT NULL,
  195. KEY (target, nanotime),
  196. KEY (history_id)
  197. ) CHARSET=ascii COLLATE=ascii_bin;
  198. Some users may still be using the old schema.
  199. */
  200. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  201. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  202. target VARBINARY(%[1]d) NOT NULL,
  203. correspondent VARBINARY(%[1]d) NOT NULL,
  204. nanotime BIGINT UNSIGNED NOT NULL,
  205. history_id BIGINT NOT NULL,
  206. KEY (target, correspondent, nanotime),
  207. KEY (history_id)
  208. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  209. if err != nil {
  210. return err
  211. }
  212. err = mysql.createCorrespondentsTable()
  213. if err != nil {
  214. return err
  215. }
  216. err = mysql.createComplianceTables()
  217. if err != nil {
  218. return err
  219. }
  220. return nil
  221. }
  222. func (mysql *MySQL) createCorrespondentsTable() (err error) {
  223. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
  224. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  225. target VARBINARY(%[1]d) NOT NULL,
  226. correspondent VARBINARY(%[1]d) NOT NULL,
  227. nanotime BIGINT UNSIGNED NOT NULL,
  228. UNIQUE KEY (target, correspondent),
  229. KEY (target, nanotime),
  230. KEY (nanotime)
  231. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  232. return
  233. }
  234. func (mysql *MySQL) createComplianceTables() (err error) {
  235. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
  236. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  237. account VARBINARY(%[1]d) NOT NULL,
  238. KEY (account, history_id)
  239. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  240. if err != nil {
  241. return err
  242. }
  243. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
  244. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  245. account VARBINARY(%[1]d) NOT NULL
  246. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  247. if err != nil {
  248. return err
  249. }
  250. return nil
  251. }
  252. func (mysql *MySQL) cleanupLoop() {
  253. defer func() {
  254. if r := recover(); r != nil {
  255. mysql.logger.Error("mysql",
  256. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  257. time.Sleep(cleanupPauseTime)
  258. go mysql.cleanupLoop()
  259. }
  260. }()
  261. for {
  262. expireTime := mysql.getExpireTime()
  263. if expireTime != 0 {
  264. for {
  265. startTime := time.Now()
  266. rowsDeleted, err := mysql.doCleanup(expireTime)
  267. elapsed := time.Now().Sub(startTime)
  268. mysql.logError("error during row cleanup", err)
  269. // keep going as long as we're accomplishing significant work
  270. // (don't busy-wait on small numbers of rows expiring):
  271. if rowsDeleted < (cleanupRowLimit / 10) {
  272. break
  273. }
  274. // crude backpressure mechanism: if the database is slow,
  275. // give it time to process other queries
  276. time.Sleep(elapsed)
  277. }
  278. }
  279. time.Sleep(cleanupPauseTime)
  280. }
  281. }
  282. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  283. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  284. defer cancel()
  285. ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
  286. if len(ids) == 0 {
  287. mysql.logger.Debug("mysql", "found no rows to clean up")
  288. return
  289. }
  290. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  291. if maxNanotime != 0 {
  292. mysql.deleteCorrespondents(ctx, maxNanotime)
  293. }
  294. return len(ids), mysql.deleteHistoryIDs(ctx, ids)
  295. }
  296. func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
  297. // can't use ? binding for a variable number of arguments, build the IN clause manually
  298. var inBuf strings.Builder
  299. inBuf.WriteByte('(')
  300. for i, id := range ids {
  301. if i != 0 {
  302. inBuf.WriteRune(',')
  303. }
  304. fmt.Fprintf(&inBuf, "%d", id)
  305. }
  306. inBuf.WriteRune(')')
  307. inClause := inBuf.String()
  308. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
  309. if err != nil {
  310. return
  311. }
  312. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
  313. if err != nil {
  314. return
  315. }
  316. if mysql.isTrackingAccountMessages() {
  317. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
  318. if err != nil {
  319. return
  320. }
  321. }
  322. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
  323. if err != nil {
  324. return
  325. }
  326. return
  327. }
  328. func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  329. rows, err := mysql.db.QueryContext(ctx, `
  330. SELECT history.id, sequence.nanotime, conversations.nanotime
  331. FROM history
  332. LEFT JOIN sequence ON history.id = sequence.history_id
  333. LEFT JOIN conversations on history.id = conversations.history_id
  334. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  335. if err != nil {
  336. return
  337. }
  338. defer rows.Close()
  339. idset := make(map[uint64]struct{}, cleanupRowLimit)
  340. threshold := time.Now().Add(-age).UnixNano()
  341. for rows.Next() {
  342. var id uint64
  343. var seqNano, convNano sql.NullInt64
  344. err = rows.Scan(&id, &seqNano, &convNano)
  345. if err != nil {
  346. return
  347. }
  348. nanotime := extractNanotime(seqNano, convNano)
  349. // returns 0 if not found; in that case the data is inconsistent
  350. // and we should delete the entry
  351. if nanotime < threshold {
  352. idset[id] = struct{}{}
  353. if nanotime > maxNanotime {
  354. maxNanotime = nanotime
  355. }
  356. }
  357. }
  358. ids = make([]uint64, len(idset))
  359. i := 0
  360. for id := range idset {
  361. ids[i] = id
  362. i++
  363. }
  364. return
  365. }
  366. func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
  367. result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
  368. if err != nil {
  369. mysql.logError("error deleting correspondents", err)
  370. } else {
  371. count, err := result.RowsAffected()
  372. if !mysql.logError("error deleting correspondents", err) {
  373. mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
  374. }
  375. }
  376. }
  377. // wait for forget queue items and process them one by one
  378. func (mysql *MySQL) forgetLoop() {
  379. defer func() {
  380. if r := recover(); r != nil {
  381. mysql.logger.Error("mysql",
  382. fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
  383. time.Sleep(cleanupPauseTime)
  384. go mysql.forgetLoop()
  385. }
  386. }()
  387. for {
  388. for {
  389. found, err := mysql.doForget()
  390. mysql.logError("error processing forget", err)
  391. if err != nil {
  392. time.Sleep(cleanupPauseTime)
  393. }
  394. if !found {
  395. break
  396. }
  397. }
  398. <-mysql.wakeForgetter
  399. }
  400. }
  401. // dequeue an item from the forget queue and process it
  402. func (mysql *MySQL) doForget() (found bool, err error) {
  403. id, account, err := func() (id int64, account string, err error) {
  404. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  405. defer cancel()
  406. row := mysql.db.QueryRowContext(ctx,
  407. `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
  408. err = row.Scan(&id, &account)
  409. if err == sql.ErrNoRows {
  410. return 0, "", nil
  411. }
  412. return
  413. }()
  414. if err != nil || account == "" {
  415. return false, err
  416. }
  417. found = true
  418. var count int
  419. for {
  420. start := time.Now()
  421. count, err = mysql.doForgetIteration(account)
  422. elapsed := time.Since(start)
  423. if err != nil {
  424. return true, err
  425. }
  426. if count == 0 {
  427. break
  428. }
  429. time.Sleep(elapsed)
  430. }
  431. mysql.logger.Debug("mysql", "forget complete for account", account)
  432. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  433. defer cancel()
  434. _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
  435. return
  436. }
  437. func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
  438. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  439. defer cancel()
  440. rows, err := mysql.db.QueryContext(ctx, `
  441. SELECT account_messages.history_id
  442. FROM account_messages
  443. WHERE account_messages.account = ?
  444. LIMIT ?;`, account, cleanupRowLimit)
  445. if err != nil {
  446. return
  447. }
  448. defer rows.Close()
  449. var ids []uint64
  450. for rows.Next() {
  451. var id uint64
  452. err = rows.Scan(&id)
  453. if err != nil {
  454. return
  455. }
  456. ids = append(ids, id)
  457. }
  458. if len(ids) == 0 {
  459. return
  460. }
  461. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
  462. err = mysql.deleteHistoryIDs(ctx, ids)
  463. return len(ids), err
  464. }
  465. func (mysql *MySQL) prepareStatements() (err error) {
  466. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  467. (data, msgid) VALUES (?, ?);`)
  468. if err != nil {
  469. return
  470. }
  471. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  472. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  473. if err != nil {
  474. return
  475. }
  476. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  477. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  478. if err != nil {
  479. return
  480. }
  481. mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
  482. (target, correspondent, nanotime) VALUES (?, ?, ?)
  483. ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
  484. if err != nil {
  485. return
  486. }
  487. mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
  488. (history_id, account) VALUES (?, ?);`)
  489. if err != nil {
  490. return
  491. }
  492. return
  493. }
  494. func (mysql *MySQL) getTimeout() time.Duration {
  495. return time.Duration(atomic.LoadInt64(&mysql.timeout))
  496. }
  497. func (mysql *MySQL) isTrackingAccountMessages() bool {
  498. return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
  499. }
  500. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  501. if err != nil {
  502. mysql.logger.Error("mysql", context, err.Error())
  503. return true
  504. }
  505. return false
  506. }
  507. func (mysql *MySQL) Forget(account string) {
  508. if mysql.db == nil || account == "" {
  509. return
  510. }
  511. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  512. defer cancel()
  513. _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
  514. if mysql.logError("can't insert into forget table", err) {
  515. return
  516. }
  517. // wake up the forget goroutine if it's blocked:
  518. select {
  519. case mysql.wakeForgetter <- e{}:
  520. default:
  521. }
  522. }
  523. func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
  524. if mysql.db == nil {
  525. return
  526. }
  527. if target == "" {
  528. return utils.ErrInvalidParams
  529. }
  530. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  531. defer cancel()
  532. id, err := mysql.insertBase(ctx, item)
  533. if err != nil {
  534. return
  535. }
  536. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  537. if err != nil {
  538. return
  539. }
  540. err = mysql.insertAccountMessageEntry(ctx, id, account)
  541. if err != nil {
  542. return
  543. }
  544. return
  545. }
  546. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  547. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  548. mysql.logError("could not insert sequence entry", err)
  549. return
  550. }
  551. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  552. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  553. mysql.logError("could not insert conversations entry", err)
  554. return
  555. }
  556. func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
  557. _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
  558. mysql.logError("could not insert conversations entry", err)
  559. return
  560. }
  561. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  562. value, err := marshalItem(&item)
  563. if mysql.logError("could not marshal item", err) {
  564. return
  565. }
  566. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  567. if mysql.logError("could not decode msgid", err) {
  568. return
  569. }
  570. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  571. if mysql.logError("could not insert item", err) {
  572. return
  573. }
  574. id, err = result.LastInsertId()
  575. if mysql.logError("could not insert item", err) {
  576. return
  577. }
  578. return
  579. }
  580. func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
  581. if account == "" || !mysql.isTrackingAccountMessages() {
  582. return
  583. }
  584. _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
  585. mysql.logError("could not insert account-message entry", err)
  586. return
  587. }
  588. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  589. if mysql.db == nil {
  590. return
  591. }
  592. if senderAccount == "" && recipientAccount == "" {
  593. return
  594. }
  595. if sender == "" || recipient == "" {
  596. return utils.ErrInvalidParams
  597. }
  598. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  599. defer cancel()
  600. id, err := mysql.insertBase(ctx, item)
  601. if err != nil {
  602. return
  603. }
  604. nanotime := item.Message.Time.UnixNano()
  605. if senderAccount != "" {
  606. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  607. if err != nil {
  608. return
  609. }
  610. err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
  611. if err != nil {
  612. return
  613. }
  614. }
  615. if recipientAccount != "" && sender != recipient {
  616. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  617. if err != nil {
  618. return
  619. }
  620. err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
  621. if err != nil {
  622. return
  623. }
  624. }
  625. err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
  626. if err != nil {
  627. return
  628. }
  629. return
  630. }
  631. // note that accountName is the unfolded name
  632. func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
  633. if mysql.db == nil {
  634. return nil
  635. }
  636. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  637. defer cancel()
  638. _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
  639. if err != nil {
  640. return
  641. }
  642. if accountName != "*" {
  643. var item history.Item
  644. err = unmarshalItem(data, &item)
  645. // delete if the entry is corrupt
  646. if err == nil && item.AccountName != accountName {
  647. return ErrDisallowed
  648. }
  649. }
  650. err = mysql.deleteHistoryIDs(ctx, []uint64{id})
  651. mysql.logError("couldn't delete msgid", err)
  652. return
  653. }
  654. func (mysql *MySQL) Export(account string, writer io.Writer) {
  655. if mysql.db == nil {
  656. return
  657. }
  658. var err error
  659. var lastSeen uint64
  660. for {
  661. rows := func() (count int) {
  662. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  663. defer cancel()
  664. rows, rowsErr := mysql.db.QueryContext(ctx, `
  665. SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
  666. INNER JOIN history ON history.id = account_messages.history_id
  667. INNER JOIN sequence ON account_messages.history_id = sequence.history_id
  668. WHERE account_messages.account = ? AND account_messages.history_id > ?
  669. LIMIT ?`, account, lastSeen, cleanupRowLimit)
  670. if rowsErr != nil {
  671. err = rowsErr
  672. return
  673. }
  674. defer rows.Close()
  675. for rows.Next() {
  676. var id uint64
  677. var blob, jsonBlob []byte
  678. var target string
  679. var item history.Item
  680. err = rows.Scan(&id, &blob, &target)
  681. if err != nil {
  682. return
  683. }
  684. err = unmarshalItem(blob, &item)
  685. if err != nil {
  686. return
  687. }
  688. item.CfCorrespondent = target
  689. jsonBlob, err = json.Marshal(item)
  690. if err != nil {
  691. return
  692. }
  693. count++
  694. if lastSeen < id {
  695. lastSeen = id
  696. }
  697. writer.Write(jsonBlob)
  698. writer.Write([]byte{'\n'})
  699. }
  700. return
  701. }()
  702. if rows == 0 || err != nil {
  703. break
  704. }
  705. }
  706. mysql.logError("could not export history", err)
  707. return
  708. }
  709. func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
  710. decoded, err := decodeMsgid(msgid)
  711. if err != nil {
  712. return
  713. }
  714. cols := `sequence.nanotime, conversations.nanotime`
  715. if includeData {
  716. cols = `sequence.nanotime, conversations.nanotime, history.id, history.data`
  717. }
  718. row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
  719. SELECT %s FROM history
  720. LEFT JOIN sequence ON history.id = sequence.history_id
  721. LEFT JOIN conversations ON history.id = conversations.history_id
  722. WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
  723. var nanoSeq, nanoConv sql.NullInt64
  724. if !includeData {
  725. err = row.Scan(&nanoSeq, &nanoConv)
  726. } else {
  727. err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
  728. }
  729. if err != sql.ErrNoRows {
  730. mysql.logError("could not resolve msgid to time", err)
  731. }
  732. if err != nil {
  733. return
  734. }
  735. nanotime := extractNanotime(nanoSeq, nanoConv)
  736. if nanotime == 0 {
  737. err = sql.ErrNoRows
  738. return
  739. }
  740. result = time.Unix(0, nanotime).UTC()
  741. return
  742. }
  743. func extractNanotime(seq, conv sql.NullInt64) (result int64) {
  744. if seq.Valid {
  745. return seq.Int64
  746. } else if conv.Valid {
  747. return conv.Int64
  748. }
  749. return
  750. }
  751. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  752. rows, err := mysql.db.QueryContext(ctx, query, args...)
  753. if mysql.logError("could not select history items", err) {
  754. return
  755. }
  756. defer rows.Close()
  757. for rows.Next() {
  758. var blob []byte
  759. var item history.Item
  760. err = rows.Scan(&blob)
  761. if mysql.logError("could not scan history item", err) {
  762. return
  763. }
  764. err = unmarshalItem(blob, &item)
  765. if mysql.logError("could not unmarshal history item", err) {
  766. return
  767. }
  768. results = append(results, item)
  769. }
  770. return
  771. }
  772. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  773. useSequence := correspondent == ""
  774. table := "sequence"
  775. if !useSequence {
  776. table = "conversations"
  777. }
  778. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  779. direction := "ASC"
  780. if !ascending {
  781. direction = "DESC"
  782. }
  783. var queryBuf strings.Builder
  784. args := make([]interface{}, 0, 6)
  785. fmt.Fprintf(&queryBuf,
  786. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  787. if useSequence {
  788. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  789. args = append(args, target)
  790. } else {
  791. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  792. args = append(args, target)
  793. args = append(args, correspondent)
  794. }
  795. if !after.IsZero() {
  796. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  797. args = append(args, after.UnixNano())
  798. }
  799. if !before.IsZero() {
  800. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  801. args = append(args, before.UnixNano())
  802. }
  803. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  804. args = append(args, limit)
  805. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  806. if err == nil && !ascending {
  807. history.Reverse(results)
  808. }
  809. return
  810. }
  811. func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
  812. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  813. direction := "ASC"
  814. if !ascending {
  815. direction = "DESC"
  816. }
  817. var queryBuf strings.Builder
  818. args := make([]interface{}, 0, 4)
  819. queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
  820. WHERE target = ?`)
  821. args = append(args, target)
  822. if !after.IsZero() {
  823. queryBuf.WriteString(" AND correspondents.nanotime > ?")
  824. args = append(args, after.UnixNano())
  825. }
  826. if !before.IsZero() {
  827. queryBuf.WriteString(" AND correspondents.nanotime < ?")
  828. args = append(args, before.UnixNano())
  829. }
  830. fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
  831. args = append(args, limit)
  832. query := queryBuf.String()
  833. rows, err := mysql.db.QueryContext(ctx, query, args...)
  834. if err != nil {
  835. return
  836. }
  837. defer rows.Close()
  838. var correspondent string
  839. var nanotime int64
  840. for rows.Next() {
  841. err = rows.Scan(&correspondent, &nanotime)
  842. if err != nil {
  843. return
  844. }
  845. results = append(results, history.TargetListing{
  846. CfName: correspondent,
  847. Time: time.Unix(0, nanotime),
  848. })
  849. }
  850. if !ascending {
  851. history.ReverseCorrespondents(results)
  852. }
  853. return
  854. }
  855. func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
  856. if mysql.db == nil {
  857. return
  858. }
  859. if len(cfchannels) == 0 {
  860. return
  861. }
  862. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  863. defer cancel()
  864. var queryBuf strings.Builder
  865. args := make([]interface{}, 0, len(results))
  866. // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
  867. // this should be a "loose index scan"
  868. queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
  869. WHERE sequence.target IN (`)
  870. for i, chname := range cfchannels {
  871. if i != 0 {
  872. queryBuf.WriteString(", ")
  873. }
  874. queryBuf.WriteByte('?')
  875. args = append(args, chname)
  876. }
  877. queryBuf.WriteString(") GROUP BY sequence.target;")
  878. rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
  879. if mysql.logError("could not query channel listings", err) {
  880. return
  881. }
  882. defer rows.Close()
  883. var target string
  884. var nanotime int64
  885. for rows.Next() {
  886. err = rows.Scan(&target, &nanotime)
  887. if mysql.logError("could not scan channel listings", err) {
  888. return
  889. }
  890. results = append(results, history.TargetListing{
  891. CfName: target,
  892. Time: time.Unix(0, nanotime),
  893. })
  894. }
  895. return
  896. }
  897. func (mysql *MySQL) Close() {
  898. // closing the database will close our prepared statements as well
  899. if mysql.db != nil {
  900. mysql.db.Close()
  901. }
  902. mysql.db = nil
  903. }
  904. // implements history.Sequence, emulating a single history buffer (for a channel,
  905. // a single user's DMs, or a DM conversation)
  906. type mySQLHistorySequence struct {
  907. mysql *MySQL
  908. target string
  909. correspondent string
  910. cutoff time.Time
  911. excludeFlags history.ExcludeFlags
  912. }
  913. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
  914. if s.excludeFlags == 0 {
  915. return s.baseBetween(start, end, limit)
  916. } else {
  917. return s.betweenWithRetries(start, end, limit)
  918. }
  919. }
  920. func (s *mySQLHistorySequence) baseBetween(start, end history.Selector, limit int) (results []history.Item, err error) {
  921. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  922. defer cancel()
  923. startTime := start.Time
  924. if start.Msgid != "" {
  925. startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
  926. if err != nil {
  927. return nil, err
  928. }
  929. }
  930. endTime := end.Time
  931. if end.Msgid != "" {
  932. endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
  933. if err != nil {
  934. return nil, err
  935. }
  936. }
  937. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  938. return results, err
  939. }
  940. func (s *mySQLHistorySequence) betweenWithRetries(start, end history.Selector, limit int) (results []history.Item, err error) {
  941. applyExclusions := func(currentResults []history.Item, excludeFlags history.ExcludeFlags, trueLimit int) (filteredResults []history.Item) {
  942. filteredResults = make([]history.Item, 0, len(currentResults))
  943. for _, item := range currentResults {
  944. if !item.IsExcluded(excludeFlags) {
  945. filteredResults = append(filteredResults, item)
  946. }
  947. if len(filteredResults) == trueLimit {
  948. break
  949. }
  950. }
  951. return
  952. }
  953. i := 1
  954. for {
  955. currentLimit := limit * i
  956. currentResults, err := s.baseBetween(start, end, currentLimit)
  957. if err != nil {
  958. return nil, err
  959. }
  960. results = applyExclusions(currentResults, s.excludeFlags, limit)
  961. // we're done in any of these three cases:
  962. // (1) we filled the window (2) we ran out of results on the backend (3) we can't retry anymore
  963. if len(results) == limit || len(currentResults) < currentLimit || i == maxPaginationRetries {
  964. return results, nil
  965. }
  966. i++
  967. }
  968. }
  969. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  970. // temporarily clear the exclude flags when running GenericAround, since we don't care about
  971. // the exactness of the paging window at all
  972. oldExcludeFlags := s.excludeFlags
  973. s.excludeFlags = 0
  974. defer func() {
  975. s.excludeFlags = oldExcludeFlags
  976. }()
  977. return history.GenericAround(s, start, limit)
  978. }
  979. func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
  980. ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
  981. defer cancel()
  982. // TODO accept msgids here?
  983. startTime := start.Time
  984. endTime := end.Time
  985. results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
  986. seq.mysql.logError("could not read correspondents", err)
  987. return
  988. }
  989. func (seq *mySQLHistorySequence) Cutoff() time.Time {
  990. return seq.cutoff
  991. }
  992. func (seq *mySQLHistorySequence) Ephemeral() bool {
  993. return false
  994. }
  995. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time, excludeFlags history.ExcludeFlags) history.Sequence {
  996. return &mySQLHistorySequence{
  997. target: target,
  998. correspondent: correspondent,
  999. mysql: mysql,
  1000. cutoff: cutoff,
  1001. excludeFlags: excludeFlags,
  1002. }
  1003. }