You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

history.go 28KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "context"
  6. "database/sql"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "runtime/debug"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/ergochat/ergo/irc/history"
  17. "github.com/ergochat/ergo/irc/logger"
  18. "github.com/ergochat/ergo/irc/utils"
  19. _ "github.com/go-sql-driver/mysql"
  20. )
  21. var (
  22. ErrDisallowed = errors.New("disallowed")
  23. )
  24. const (
  25. // maximum length in bytes of any message target (nickname or channel name) in its
  26. // canonicalized (i.e., casefolded) state:
  27. MaxTargetLength = 64
  28. // latest schema of the db
  29. latestDbSchema = "2"
  30. keySchemaVersion = "db.version"
  31. // minor version indicates rollback-safe upgrades, i.e.,
  32. // you can downgrade oragono and everything will work
  33. latestDbMinorVersion = "2"
  34. keySchemaMinorVersion = "db.minorversion"
  35. cleanupRowLimit = 50
  36. cleanupPauseTime = 10 * time.Minute
  37. )
  38. type e struct{}
  39. type MySQL struct {
  40. timeout int64
  41. trackAccountMessages uint32
  42. db *sql.DB
  43. logger *logger.Manager
  44. insertHistory *sql.Stmt
  45. insertSequence *sql.Stmt
  46. insertConversation *sql.Stmt
  47. insertCorrespondent *sql.Stmt
  48. insertAccountMessage *sql.Stmt
  49. stateMutex sync.Mutex
  50. config Config
  51. wakeForgetter chan e
  52. }
  53. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  54. mysql.logger = logger
  55. mysql.wakeForgetter = make(chan e, 1)
  56. mysql.SetConfig(config)
  57. }
  58. func (mysql *MySQL) SetConfig(config Config) {
  59. atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
  60. var trackAccountMessages uint32
  61. if config.TrackAccountMessages {
  62. trackAccountMessages = 1
  63. }
  64. atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
  65. mysql.stateMutex.Lock()
  66. mysql.config = config
  67. mysql.stateMutex.Unlock()
  68. }
  69. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  70. mysql.stateMutex.Lock()
  71. expireTime = mysql.config.ExpireTime
  72. mysql.stateMutex.Unlock()
  73. return
  74. }
  75. func (m *MySQL) Open() (err error) {
  76. var address string
  77. if m.config.SocketPath != "" {
  78. address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
  79. } else if m.config.Port != 0 {
  80. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  81. }
  82. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  83. if err != nil {
  84. return err
  85. }
  86. if m.config.MaxConns != 0 {
  87. m.db.SetMaxOpenConns(m.config.MaxConns)
  88. m.db.SetMaxIdleConns(m.config.MaxConns)
  89. }
  90. if m.config.ConnMaxLifetime != 0 {
  91. m.db.SetConnMaxLifetime(m.config.ConnMaxLifetime)
  92. }
  93. err = m.fixSchemas()
  94. if err != nil {
  95. return err
  96. }
  97. err = m.prepareStatements()
  98. if err != nil {
  99. return err
  100. }
  101. go m.cleanupLoop()
  102. go m.forgetLoop()
  103. return nil
  104. }
  105. func (mysql *MySQL) fixSchemas() (err error) {
  106. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  107. key_name VARCHAR(32) primary key,
  108. value VARCHAR(32) NOT NULL
  109. ) CHARSET=ascii COLLATE=ascii_bin;`)
  110. if err != nil {
  111. return err
  112. }
  113. var schema string
  114. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  115. if err == sql.ErrNoRows {
  116. err = mysql.createTables()
  117. if err != nil {
  118. return
  119. }
  120. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  121. if err != nil {
  122. return
  123. }
  124. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  125. if err != nil {
  126. return
  127. }
  128. return
  129. } else if err == nil && schema != latestDbSchema {
  130. // TODO figure out what to do about schema changes
  131. return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
  132. } else if err != nil {
  133. return err
  134. }
  135. var minorVersion string
  136. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
  137. if err == sql.ErrNoRows {
  138. // XXX for now, the only minor version upgrade is the account tracking tables
  139. err = mysql.createComplianceTables()
  140. if err != nil {
  141. return
  142. }
  143. err = mysql.createCorrespondentsTable()
  144. if err != nil {
  145. return
  146. }
  147. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  148. if err != nil {
  149. return
  150. }
  151. } else if err == nil && minorVersion == "1" {
  152. // upgrade from 2.1 to 2.2: create the correspondents table
  153. err = mysql.createCorrespondentsTable()
  154. if err != nil {
  155. return
  156. }
  157. _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
  158. if err != nil {
  159. return
  160. }
  161. } else if err == nil && minorVersion != latestDbMinorVersion {
  162. // TODO: if minorVersion < latestDbMinorVersion, upgrade,
  163. // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
  164. }
  165. return
  166. }
  167. func (mysql *MySQL) createTables() (err error) {
  168. _, err = mysql.db.Exec(`CREATE TABLE history (
  169. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  170. data BLOB NOT NULL,
  171. msgid BINARY(16) NOT NULL,
  172. KEY (msgid(4))
  173. ) CHARSET=ascii COLLATE=ascii_bin;`)
  174. if err != nil {
  175. return err
  176. }
  177. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  178. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  179. target VARBINARY(%[1]d) NOT NULL,
  180. nanotime BIGINT UNSIGNED NOT NULL,
  181. KEY (target, nanotime)
  182. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  183. if err != nil {
  184. return err
  185. }
  186. /* XXX: this table used to be:
  187. CREATE TABLE sequence (
  188. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  189. target VARBINARY(%[1]d) NOT NULL,
  190. nanotime BIGINT UNSIGNED NOT NULL,
  191. history_id BIGINT NOT NULL,
  192. KEY (target, nanotime),
  193. KEY (history_id)
  194. ) CHARSET=ascii COLLATE=ascii_bin;
  195. Some users may still be using the old schema.
  196. */
  197. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  198. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  199. target VARBINARY(%[1]d) NOT NULL,
  200. correspondent VARBINARY(%[1]d) NOT NULL,
  201. nanotime BIGINT UNSIGNED NOT NULL,
  202. history_id BIGINT NOT NULL,
  203. KEY (target, correspondent, nanotime),
  204. KEY (history_id)
  205. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  206. if err != nil {
  207. return err
  208. }
  209. err = mysql.createCorrespondentsTable()
  210. if err != nil {
  211. return err
  212. }
  213. err = mysql.createComplianceTables()
  214. if err != nil {
  215. return err
  216. }
  217. return nil
  218. }
  219. func (mysql *MySQL) createCorrespondentsTable() (err error) {
  220. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
  221. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  222. target VARBINARY(%[1]d) NOT NULL,
  223. correspondent VARBINARY(%[1]d) NOT NULL,
  224. nanotime BIGINT UNSIGNED NOT NULL,
  225. UNIQUE KEY (target, correspondent),
  226. KEY (target, nanotime),
  227. KEY (nanotime)
  228. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  229. return
  230. }
  231. func (mysql *MySQL) createComplianceTables() (err error) {
  232. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
  233. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  234. account VARBINARY(%[1]d) NOT NULL,
  235. KEY (account, history_id)
  236. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  237. if err != nil {
  238. return err
  239. }
  240. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
  241. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  242. account VARBINARY(%[1]d) NOT NULL
  243. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  244. if err != nil {
  245. return err
  246. }
  247. return nil
  248. }
  249. func (mysql *MySQL) cleanupLoop() {
  250. defer func() {
  251. if r := recover(); r != nil {
  252. mysql.logger.Error("mysql",
  253. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  254. time.Sleep(cleanupPauseTime)
  255. go mysql.cleanupLoop()
  256. }
  257. }()
  258. for {
  259. expireTime := mysql.getExpireTime()
  260. if expireTime != 0 {
  261. for {
  262. startTime := time.Now()
  263. rowsDeleted, err := mysql.doCleanup(expireTime)
  264. elapsed := time.Now().Sub(startTime)
  265. mysql.logError("error during row cleanup", err)
  266. // keep going as long as we're accomplishing significant work
  267. // (don't busy-wait on small numbers of rows expiring):
  268. if rowsDeleted < (cleanupRowLimit / 10) {
  269. break
  270. }
  271. // crude backpressure mechanism: if the database is slow,
  272. // give it time to process other queries
  273. time.Sleep(elapsed)
  274. }
  275. }
  276. time.Sleep(cleanupPauseTime)
  277. }
  278. }
  279. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  280. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  281. defer cancel()
  282. ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
  283. if len(ids) == 0 {
  284. mysql.logger.Debug("mysql", "found no rows to clean up")
  285. return
  286. }
  287. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  288. if maxNanotime != 0 {
  289. mysql.deleteCorrespondents(ctx, maxNanotime)
  290. }
  291. return len(ids), mysql.deleteHistoryIDs(ctx, ids)
  292. }
  293. func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
  294. // can't use ? binding for a variable number of arguments, build the IN clause manually
  295. var inBuf strings.Builder
  296. inBuf.WriteByte('(')
  297. for i, id := range ids {
  298. if i != 0 {
  299. inBuf.WriteRune(',')
  300. }
  301. fmt.Fprintf(&inBuf, "%d", id)
  302. }
  303. inBuf.WriteRune(')')
  304. inClause := inBuf.String()
  305. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
  306. if err != nil {
  307. return
  308. }
  309. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
  310. if err != nil {
  311. return
  312. }
  313. if mysql.isTrackingAccountMessages() {
  314. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
  315. if err != nil {
  316. return
  317. }
  318. }
  319. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
  320. if err != nil {
  321. return
  322. }
  323. return
  324. }
  325. func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  326. rows, err := mysql.db.QueryContext(ctx, `
  327. SELECT history.id, sequence.nanotime, conversations.nanotime
  328. FROM history
  329. LEFT JOIN sequence ON history.id = sequence.history_id
  330. LEFT JOIN conversations on history.id = conversations.history_id
  331. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  332. if err != nil {
  333. return
  334. }
  335. defer rows.Close()
  336. idset := make(map[uint64]struct{}, cleanupRowLimit)
  337. threshold := time.Now().Add(-age).UnixNano()
  338. for rows.Next() {
  339. var id uint64
  340. var seqNano, convNano sql.NullInt64
  341. err = rows.Scan(&id, &seqNano, &convNano)
  342. if err != nil {
  343. return
  344. }
  345. nanotime := extractNanotime(seqNano, convNano)
  346. // returns 0 if not found; in that case the data is inconsistent
  347. // and we should delete the entry
  348. if nanotime < threshold {
  349. idset[id] = struct{}{}
  350. if nanotime > maxNanotime {
  351. maxNanotime = nanotime
  352. }
  353. }
  354. }
  355. ids = make([]uint64, len(idset))
  356. i := 0
  357. for id := range idset {
  358. ids[i] = id
  359. i++
  360. }
  361. return
  362. }
  363. func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
  364. result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
  365. if err != nil {
  366. mysql.logError("error deleting correspondents", err)
  367. } else {
  368. count, err := result.RowsAffected()
  369. if !mysql.logError("error deleting correspondents", err) {
  370. mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
  371. }
  372. }
  373. }
  374. // wait for forget queue items and process them one by one
  375. func (mysql *MySQL) forgetLoop() {
  376. defer func() {
  377. if r := recover(); r != nil {
  378. mysql.logger.Error("mysql",
  379. fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
  380. time.Sleep(cleanupPauseTime)
  381. go mysql.forgetLoop()
  382. }
  383. }()
  384. for {
  385. for {
  386. found, err := mysql.doForget()
  387. mysql.logError("error processing forget", err)
  388. if err != nil {
  389. time.Sleep(cleanupPauseTime)
  390. }
  391. if !found {
  392. break
  393. }
  394. }
  395. <-mysql.wakeForgetter
  396. }
  397. }
  398. // dequeue an item from the forget queue and process it
  399. func (mysql *MySQL) doForget() (found bool, err error) {
  400. id, account, err := func() (id int64, account string, err error) {
  401. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  402. defer cancel()
  403. row := mysql.db.QueryRowContext(ctx,
  404. `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
  405. err = row.Scan(&id, &account)
  406. if err == sql.ErrNoRows {
  407. return 0, "", nil
  408. }
  409. return
  410. }()
  411. if err != nil || account == "" {
  412. return false, err
  413. }
  414. found = true
  415. var count int
  416. for {
  417. start := time.Now()
  418. count, err = mysql.doForgetIteration(account)
  419. elapsed := time.Since(start)
  420. if err != nil {
  421. return true, err
  422. }
  423. if count == 0 {
  424. break
  425. }
  426. time.Sleep(elapsed)
  427. }
  428. mysql.logger.Debug("mysql", "forget complete for account", account)
  429. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  430. defer cancel()
  431. _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
  432. return
  433. }
  434. func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
  435. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  436. defer cancel()
  437. rows, err := mysql.db.QueryContext(ctx, `
  438. SELECT account_messages.history_id
  439. FROM account_messages
  440. WHERE account_messages.account = ?
  441. LIMIT ?;`, account, cleanupRowLimit)
  442. if err != nil {
  443. return
  444. }
  445. defer rows.Close()
  446. var ids []uint64
  447. for rows.Next() {
  448. var id uint64
  449. err = rows.Scan(&id)
  450. if err != nil {
  451. return
  452. }
  453. ids = append(ids, id)
  454. }
  455. if len(ids) == 0 {
  456. return
  457. }
  458. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
  459. err = mysql.deleteHistoryIDs(ctx, ids)
  460. return len(ids), err
  461. }
  462. func (mysql *MySQL) prepareStatements() (err error) {
  463. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  464. (data, msgid) VALUES (?, ?);`)
  465. if err != nil {
  466. return
  467. }
  468. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  469. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  470. if err != nil {
  471. return
  472. }
  473. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  474. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  475. if err != nil {
  476. return
  477. }
  478. mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
  479. (target, correspondent, nanotime) VALUES (?, ?, ?)
  480. ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
  481. if err != nil {
  482. return
  483. }
  484. mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
  485. (history_id, account) VALUES (?, ?);`)
  486. if err != nil {
  487. return
  488. }
  489. return
  490. }
  491. func (mysql *MySQL) getTimeout() time.Duration {
  492. return time.Duration(atomic.LoadInt64(&mysql.timeout))
  493. }
  494. func (mysql *MySQL) isTrackingAccountMessages() bool {
  495. return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
  496. }
  497. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  498. if err != nil {
  499. mysql.logger.Error("mysql", context, err.Error())
  500. return true
  501. }
  502. return false
  503. }
  504. func (mysql *MySQL) Forget(account string) {
  505. if mysql.db == nil || account == "" {
  506. return
  507. }
  508. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  509. defer cancel()
  510. _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
  511. if mysql.logError("can't insert into forget table", err) {
  512. return
  513. }
  514. // wake up the forget goroutine if it's blocked:
  515. select {
  516. case mysql.wakeForgetter <- e{}:
  517. default:
  518. }
  519. }
  520. func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
  521. if mysql.db == nil {
  522. return
  523. }
  524. if target == "" {
  525. return utils.ErrInvalidParams
  526. }
  527. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  528. defer cancel()
  529. id, err := mysql.insertBase(ctx, item)
  530. if err != nil {
  531. return
  532. }
  533. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  534. if err != nil {
  535. return
  536. }
  537. err = mysql.insertAccountMessageEntry(ctx, id, account)
  538. if err != nil {
  539. return
  540. }
  541. return
  542. }
  543. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  544. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  545. mysql.logError("could not insert sequence entry", err)
  546. return
  547. }
  548. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  549. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  550. mysql.logError("could not insert conversations entry", err)
  551. return
  552. }
  553. func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
  554. _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
  555. mysql.logError("could not insert conversations entry", err)
  556. return
  557. }
  558. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  559. value, err := marshalItem(&item)
  560. if mysql.logError("could not marshal item", err) {
  561. return
  562. }
  563. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  564. if mysql.logError("could not decode msgid", err) {
  565. return
  566. }
  567. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  568. if mysql.logError("could not insert item", err) {
  569. return
  570. }
  571. id, err = result.LastInsertId()
  572. if mysql.logError("could not insert item", err) {
  573. return
  574. }
  575. return
  576. }
  577. func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
  578. if account == "" || !mysql.isTrackingAccountMessages() {
  579. return
  580. }
  581. _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
  582. mysql.logError("could not insert account-message entry", err)
  583. return
  584. }
  585. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  586. if mysql.db == nil {
  587. return
  588. }
  589. if senderAccount == "" && recipientAccount == "" {
  590. return
  591. }
  592. if sender == "" || recipient == "" {
  593. return utils.ErrInvalidParams
  594. }
  595. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  596. defer cancel()
  597. id, err := mysql.insertBase(ctx, item)
  598. if err != nil {
  599. return
  600. }
  601. nanotime := item.Message.Time.UnixNano()
  602. if senderAccount != "" {
  603. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  604. if err != nil {
  605. return
  606. }
  607. err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
  608. if err != nil {
  609. return
  610. }
  611. }
  612. if recipientAccount != "" && sender != recipient {
  613. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  614. if err != nil {
  615. return
  616. }
  617. err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
  618. if err != nil {
  619. return
  620. }
  621. }
  622. err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
  623. if err != nil {
  624. return
  625. }
  626. return
  627. }
  628. // note that accountName is the unfolded name
  629. func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
  630. if mysql.db == nil {
  631. return nil
  632. }
  633. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  634. defer cancel()
  635. _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
  636. if err != nil {
  637. return
  638. }
  639. if accountName != "*" {
  640. var item history.Item
  641. err = unmarshalItem(data, &item)
  642. // delete if the entry is corrupt
  643. if err == nil && item.AccountName != accountName {
  644. return ErrDisallowed
  645. }
  646. }
  647. err = mysql.deleteHistoryIDs(ctx, []uint64{id})
  648. mysql.logError("couldn't delete msgid", err)
  649. return
  650. }
  651. func (mysql *MySQL) Export(account string, writer io.Writer) {
  652. if mysql.db == nil {
  653. return
  654. }
  655. var err error
  656. var lastSeen uint64
  657. for {
  658. rows := func() (count int) {
  659. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  660. defer cancel()
  661. rows, rowsErr := mysql.db.QueryContext(ctx, `
  662. SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
  663. INNER JOIN history ON history.id = account_messages.history_id
  664. INNER JOIN sequence ON account_messages.history_id = sequence.history_id
  665. WHERE account_messages.account = ? AND account_messages.history_id > ?
  666. LIMIT ?`, account, lastSeen, cleanupRowLimit)
  667. if rowsErr != nil {
  668. err = rowsErr
  669. return
  670. }
  671. defer rows.Close()
  672. for rows.Next() {
  673. var id uint64
  674. var blob, jsonBlob []byte
  675. var target string
  676. var item history.Item
  677. err = rows.Scan(&id, &blob, &target)
  678. if err != nil {
  679. return
  680. }
  681. err = unmarshalItem(blob, &item)
  682. if err != nil {
  683. return
  684. }
  685. item.CfCorrespondent = target
  686. jsonBlob, err = json.Marshal(item)
  687. if err != nil {
  688. return
  689. }
  690. count++
  691. if lastSeen < id {
  692. lastSeen = id
  693. }
  694. writer.Write(jsonBlob)
  695. writer.Write([]byte{'\n'})
  696. }
  697. return
  698. }()
  699. if rows == 0 || err != nil {
  700. break
  701. }
  702. }
  703. mysql.logError("could not export history", err)
  704. return
  705. }
  706. func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
  707. decoded, err := decodeMsgid(msgid)
  708. if err != nil {
  709. return
  710. }
  711. cols := `sequence.nanotime, conversations.nanotime`
  712. if includeData {
  713. cols = `sequence.nanotime, conversations.nanotime, history.id, history.data`
  714. }
  715. row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
  716. SELECT %s FROM history
  717. LEFT JOIN sequence ON history.id = sequence.history_id
  718. LEFT JOIN conversations ON history.id = conversations.history_id
  719. WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
  720. var nanoSeq, nanoConv sql.NullInt64
  721. if !includeData {
  722. err = row.Scan(&nanoSeq, &nanoConv)
  723. } else {
  724. err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
  725. }
  726. if err != sql.ErrNoRows {
  727. mysql.logError("could not resolve msgid to time", err)
  728. }
  729. if err != nil {
  730. return
  731. }
  732. nanotime := extractNanotime(nanoSeq, nanoConv)
  733. if nanotime == 0 {
  734. err = sql.ErrNoRows
  735. return
  736. }
  737. result = time.Unix(0, nanotime).UTC()
  738. return
  739. }
  740. func extractNanotime(seq, conv sql.NullInt64) (result int64) {
  741. if seq.Valid {
  742. return seq.Int64
  743. } else if conv.Valid {
  744. return conv.Int64
  745. }
  746. return
  747. }
  748. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  749. rows, err := mysql.db.QueryContext(ctx, query, args...)
  750. if mysql.logError("could not select history items", err) {
  751. return
  752. }
  753. defer rows.Close()
  754. for rows.Next() {
  755. var blob []byte
  756. var item history.Item
  757. err = rows.Scan(&blob)
  758. if mysql.logError("could not scan history item", err) {
  759. return
  760. }
  761. err = unmarshalItem(blob, &item)
  762. if mysql.logError("could not unmarshal history item", err) {
  763. return
  764. }
  765. results = append(results, item)
  766. }
  767. return
  768. }
  769. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  770. useSequence := correspondent == ""
  771. table := "sequence"
  772. if !useSequence {
  773. table = "conversations"
  774. }
  775. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  776. direction := "ASC"
  777. if !ascending {
  778. direction = "DESC"
  779. }
  780. var queryBuf strings.Builder
  781. args := make([]interface{}, 0, 6)
  782. fmt.Fprintf(&queryBuf,
  783. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  784. if useSequence {
  785. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  786. args = append(args, target)
  787. } else {
  788. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  789. args = append(args, target)
  790. args = append(args, correspondent)
  791. }
  792. if !after.IsZero() {
  793. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  794. args = append(args, after.UnixNano())
  795. }
  796. if !before.IsZero() {
  797. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  798. args = append(args, before.UnixNano())
  799. }
  800. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  801. args = append(args, limit)
  802. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  803. if err == nil && !ascending {
  804. history.Reverse(results)
  805. }
  806. return
  807. }
  808. func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
  809. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  810. direction := "ASC"
  811. if !ascending {
  812. direction = "DESC"
  813. }
  814. var queryBuf strings.Builder
  815. args := make([]interface{}, 0, 4)
  816. queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
  817. WHERE target = ?`)
  818. args = append(args, target)
  819. if !after.IsZero() {
  820. queryBuf.WriteString(" AND correspondents.nanotime > ?")
  821. args = append(args, after.UnixNano())
  822. }
  823. if !before.IsZero() {
  824. queryBuf.WriteString(" AND correspondents.nanotime < ?")
  825. args = append(args, before.UnixNano())
  826. }
  827. fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
  828. args = append(args, limit)
  829. query := queryBuf.String()
  830. rows, err := mysql.db.QueryContext(ctx, query, args...)
  831. if err != nil {
  832. return
  833. }
  834. defer rows.Close()
  835. var correspondent string
  836. var nanotime int64
  837. for rows.Next() {
  838. err = rows.Scan(&correspondent, &nanotime)
  839. if err != nil {
  840. return
  841. }
  842. results = append(results, history.TargetListing{
  843. CfName: correspondent,
  844. Time: time.Unix(0, nanotime),
  845. })
  846. }
  847. if !ascending {
  848. history.ReverseCorrespondents(results)
  849. }
  850. return
  851. }
  852. func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
  853. if mysql.db == nil {
  854. return
  855. }
  856. if len(cfchannels) == 0 {
  857. return
  858. }
  859. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  860. defer cancel()
  861. var queryBuf strings.Builder
  862. args := make([]interface{}, 0, len(results))
  863. // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
  864. // this should be a "loose index scan"
  865. queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
  866. WHERE sequence.target IN (`)
  867. for i, chname := range cfchannels {
  868. if i != 0 {
  869. queryBuf.WriteString(", ")
  870. }
  871. queryBuf.WriteByte('?')
  872. args = append(args, chname)
  873. }
  874. queryBuf.WriteString(") GROUP BY sequence.target;")
  875. rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
  876. if mysql.logError("could not query channel listings", err) {
  877. return
  878. }
  879. defer rows.Close()
  880. var target string
  881. var nanotime int64
  882. for rows.Next() {
  883. err = rows.Scan(&target, &nanotime)
  884. if mysql.logError("could not scan channel listings", err) {
  885. return
  886. }
  887. results = append(results, history.TargetListing{
  888. CfName: target,
  889. Time: time.Unix(0, nanotime),
  890. })
  891. }
  892. return
  893. }
  894. func (mysql *MySQL) Close() {
  895. // closing the database will close our prepared statements as well
  896. if mysql.db != nil {
  897. mysql.db.Close()
  898. }
  899. mysql.db = nil
  900. }
  901. // implements history.Sequence, emulating a single history buffer (for a channel,
  902. // a single user's DMs, or a DM conversation)
  903. type mySQLHistorySequence struct {
  904. mysql *MySQL
  905. target string
  906. correspondent string
  907. cutoff time.Time
  908. }
  909. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
  910. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  911. defer cancel()
  912. startTime := start.Time
  913. if start.Msgid != "" {
  914. startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
  915. if err != nil {
  916. return nil, err
  917. }
  918. }
  919. endTime := end.Time
  920. if end.Msgid != "" {
  921. endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
  922. if err != nil {
  923. return nil, err
  924. }
  925. }
  926. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  927. return results, err
  928. }
  929. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  930. return history.GenericAround(s, start, limit)
  931. }
  932. func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
  933. ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
  934. defer cancel()
  935. // TODO accept msgids here?
  936. startTime := start.Time
  937. endTime := end.Time
  938. results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
  939. seq.mysql.logError("could not read correspondents", err)
  940. return
  941. }
  942. func (seq *mySQLHistorySequence) Cutoff() time.Time {
  943. return seq.cutoff
  944. }
  945. func (seq *mySQLHistorySequence) Ephemeral() bool {
  946. return false
  947. }
  948. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
  949. return &mySQLHistorySequence{
  950. target: target,
  951. correspondent: correspondent,
  952. mysql: mysql,
  953. cutoff: cutoff,
  954. }
  955. }