You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

history.go 28KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "context"
  6. "database/sql"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "runtime/debug"
  12. "slices"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "time"
  17. "github.com/ergochat/ergo/irc/history"
  18. "github.com/ergochat/ergo/irc/logger"
  19. "github.com/ergochat/ergo/irc/utils"
  20. _ "github.com/go-sql-driver/mysql"
  21. )
  22. var (
  23. ErrDisallowed = errors.New("disallowed")
  24. ErrDBIsNil = errors.New("db == nil")
  25. )
  26. const (
  27. // maximum length in bytes of any message target (nickname or channel name) in its
  28. // canonicalized (i.e., casefolded) state:
  29. MaxTargetLength = 64
  30. // latest schema of the db
  31. latestDbSchema = "2"
  32. keySchemaVersion = "db.version"
  33. // minor version indicates rollback-safe upgrades, i.e.,
  34. // you can downgrade oragono and everything will work
  35. latestDbMinorVersion = "2"
  36. keySchemaMinorVersion = "db.minorversion"
  37. cleanupRowLimit = 50
  38. cleanupPauseTime = 10 * time.Minute
  39. )
  40. type e struct{}
  41. type MySQL struct {
  42. db *sql.DB
  43. logger *logger.Manager
  44. insertHistory *sql.Stmt
  45. insertSequence *sql.Stmt
  46. insertConversation *sql.Stmt
  47. insertCorrespondent *sql.Stmt
  48. insertAccountMessage *sql.Stmt
  49. stateMutex sync.Mutex
  50. config Config
  51. wakeForgetter chan e
  52. timeout atomic.Uint64
  53. trackAccountMessages atomic.Uint32
  54. }
  55. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  56. mysql.logger = logger
  57. mysql.wakeForgetter = make(chan e, 1)
  58. mysql.SetConfig(config)
  59. }
  60. func (mysql *MySQL) SetConfig(config Config) {
  61. mysql.timeout.Store(uint64(config.Timeout))
  62. var trackAccountMessages uint32
  63. if config.TrackAccountMessages {
  64. trackAccountMessages = 1
  65. }
  66. mysql.trackAccountMessages.Store(trackAccountMessages)
  67. mysql.stateMutex.Lock()
  68. mysql.config = config
  69. mysql.stateMutex.Unlock()
  70. }
  71. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  72. mysql.stateMutex.Lock()
  73. expireTime = mysql.config.ExpireTime
  74. mysql.stateMutex.Unlock()
  75. return
  76. }
  77. func (m *MySQL) Open() (err error) {
  78. var address string
  79. if m.config.SocketPath != "" {
  80. address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
  81. } else if m.config.Port != 0 {
  82. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  83. }
  84. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  85. if err != nil {
  86. return err
  87. }
  88. if m.config.MaxConns != 0 {
  89. m.db.SetMaxOpenConns(m.config.MaxConns)
  90. m.db.SetMaxIdleConns(m.config.MaxConns)
  91. }
  92. if m.config.ConnMaxLifetime != 0 {
  93. m.db.SetConnMaxLifetime(m.config.ConnMaxLifetime)
  94. }
  95. err = m.fixSchemas()
  96. if err != nil {
  97. return err
  98. }
  99. err = m.prepareStatements()
  100. if err != nil {
  101. return err
  102. }
  103. go m.cleanupLoop()
  104. go m.forgetLoop()
  105. return nil
  106. }
  107. func (mysql *MySQL) fixSchemas() (err error) {
  108. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  109. key_name VARCHAR(32) primary key,
  110. value VARCHAR(32) NOT NULL
  111. ) CHARSET=ascii COLLATE=ascii_bin;`)
  112. if err != nil {
  113. return err
  114. }
  115. var schema string
  116. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  117. if err == sql.ErrNoRows {
  118. err = mysql.createTables()
  119. if err != nil {
  120. return
  121. }
  122. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  123. if err != nil {
  124. return
  125. }
  126. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  127. if err != nil {
  128. return
  129. }
  130. return
  131. } else if err == nil && schema != latestDbSchema {
  132. // TODO figure out what to do about schema changes
  133. return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
  134. } else if err != nil {
  135. return err
  136. }
  137. var minorVersion string
  138. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
  139. if err == sql.ErrNoRows {
  140. // XXX for now, the only minor version upgrade is the account tracking tables
  141. err = mysql.createComplianceTables()
  142. if err != nil {
  143. return
  144. }
  145. err = mysql.createCorrespondentsTable()
  146. if err != nil {
  147. return
  148. }
  149. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  150. if err != nil {
  151. return
  152. }
  153. } else if err == nil && minorVersion == "1" {
  154. // upgrade from 2.1 to 2.2: create the correspondents table
  155. err = mysql.createCorrespondentsTable()
  156. if err != nil {
  157. return
  158. }
  159. _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
  160. if err != nil {
  161. return
  162. }
  163. } else if err == nil && minorVersion != latestDbMinorVersion {
  164. // TODO: if minorVersion < latestDbMinorVersion, upgrade,
  165. // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
  166. }
  167. return
  168. }
  169. func (mysql *MySQL) createTables() (err error) {
  170. _, err = mysql.db.Exec(`CREATE TABLE history (
  171. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  172. data BLOB NOT NULL,
  173. msgid BINARY(16) NOT NULL,
  174. KEY (msgid(4))
  175. ) CHARSET=ascii COLLATE=ascii_bin;`)
  176. if err != nil {
  177. return err
  178. }
  179. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  180. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  181. target VARBINARY(%[1]d) NOT NULL,
  182. nanotime BIGINT UNSIGNED NOT NULL,
  183. KEY (target, nanotime)
  184. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  185. if err != nil {
  186. return err
  187. }
  188. /* XXX: this table used to be:
  189. CREATE TABLE sequence (
  190. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  191. target VARBINARY(%[1]d) NOT NULL,
  192. nanotime BIGINT UNSIGNED NOT NULL,
  193. history_id BIGINT NOT NULL,
  194. KEY (target, nanotime),
  195. KEY (history_id)
  196. ) CHARSET=ascii COLLATE=ascii_bin;
  197. Some users may still be using the old schema.
  198. */
  199. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  200. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  201. target VARBINARY(%[1]d) NOT NULL,
  202. correspondent VARBINARY(%[1]d) NOT NULL,
  203. nanotime BIGINT UNSIGNED NOT NULL,
  204. history_id BIGINT NOT NULL,
  205. KEY (target, correspondent, nanotime),
  206. KEY (history_id)
  207. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  208. if err != nil {
  209. return err
  210. }
  211. err = mysql.createCorrespondentsTable()
  212. if err != nil {
  213. return err
  214. }
  215. err = mysql.createComplianceTables()
  216. if err != nil {
  217. return err
  218. }
  219. return nil
  220. }
  221. func (mysql *MySQL) createCorrespondentsTable() (err error) {
  222. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
  223. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  224. target VARBINARY(%[1]d) NOT NULL,
  225. correspondent VARBINARY(%[1]d) NOT NULL,
  226. nanotime BIGINT UNSIGNED NOT NULL,
  227. UNIQUE KEY (target, correspondent),
  228. KEY (target, nanotime),
  229. KEY (nanotime)
  230. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  231. return
  232. }
  233. func (mysql *MySQL) createComplianceTables() (err error) {
  234. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
  235. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  236. account VARBINARY(%[1]d) NOT NULL,
  237. KEY (account, history_id)
  238. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  239. if err != nil {
  240. return err
  241. }
  242. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
  243. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  244. account VARBINARY(%[1]d) NOT NULL
  245. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  246. if err != nil {
  247. return err
  248. }
  249. return nil
  250. }
  251. func (mysql *MySQL) cleanupLoop() {
  252. defer func() {
  253. if r := recover(); r != nil {
  254. mysql.logger.Error("mysql",
  255. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  256. time.Sleep(cleanupPauseTime)
  257. go mysql.cleanupLoop()
  258. }
  259. }()
  260. for {
  261. expireTime := mysql.getExpireTime()
  262. if expireTime != 0 {
  263. for {
  264. startTime := time.Now()
  265. rowsDeleted, err := mysql.doCleanup(expireTime)
  266. elapsed := time.Now().Sub(startTime)
  267. mysql.logError("error during row cleanup", err)
  268. // keep going as long as we're accomplishing significant work
  269. // (don't busy-wait on small numbers of rows expiring):
  270. if rowsDeleted < (cleanupRowLimit / 10) {
  271. break
  272. }
  273. // crude backpressure mechanism: if the database is slow,
  274. // give it time to process other queries
  275. time.Sleep(elapsed)
  276. }
  277. }
  278. time.Sleep(cleanupPauseTime)
  279. }
  280. }
  281. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  282. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  283. defer cancel()
  284. ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
  285. if len(ids) == 0 {
  286. mysql.logger.Debug("mysql", "found no rows to clean up")
  287. return
  288. }
  289. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  290. if maxNanotime != 0 {
  291. mysql.deleteCorrespondents(ctx, maxNanotime)
  292. }
  293. return len(ids), mysql.deleteHistoryIDs(ctx, ids)
  294. }
  295. func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
  296. // can't use ? binding for a variable number of arguments, build the IN clause manually
  297. var inBuf strings.Builder
  298. inBuf.WriteByte('(')
  299. for i, id := range ids {
  300. if i != 0 {
  301. inBuf.WriteRune(',')
  302. }
  303. fmt.Fprintf(&inBuf, "%d", id)
  304. }
  305. inBuf.WriteRune(')')
  306. inClause := inBuf.String()
  307. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
  308. if err != nil {
  309. return
  310. }
  311. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
  312. if err != nil {
  313. return
  314. }
  315. if mysql.isTrackingAccountMessages() {
  316. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
  317. if err != nil {
  318. return
  319. }
  320. }
  321. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
  322. if err != nil {
  323. return
  324. }
  325. return
  326. }
  327. func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  328. rows, err := mysql.db.QueryContext(ctx, `
  329. SELECT history.id, sequence.nanotime, conversations.nanotime
  330. FROM history
  331. LEFT JOIN sequence ON history.id = sequence.history_id
  332. LEFT JOIN conversations on history.id = conversations.history_id
  333. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  334. if err != nil {
  335. return
  336. }
  337. defer rows.Close()
  338. idset := make(map[uint64]struct{}, cleanupRowLimit)
  339. threshold := time.Now().Add(-age).UnixNano()
  340. for rows.Next() {
  341. var id uint64
  342. var seqNano, convNano sql.NullInt64
  343. err = rows.Scan(&id, &seqNano, &convNano)
  344. if err != nil {
  345. return
  346. }
  347. nanotime := extractNanotime(seqNano, convNano)
  348. // returns 0 if not found; in that case the data is inconsistent
  349. // and we should delete the entry
  350. if nanotime < threshold {
  351. idset[id] = struct{}{}
  352. if nanotime > maxNanotime {
  353. maxNanotime = nanotime
  354. }
  355. }
  356. }
  357. ids = make([]uint64, len(idset))
  358. i := 0
  359. for id := range idset {
  360. ids[i] = id
  361. i++
  362. }
  363. return
  364. }
  365. func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
  366. result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
  367. if err != nil {
  368. mysql.logError("error deleting correspondents", err)
  369. } else {
  370. count, err := result.RowsAffected()
  371. if !mysql.logError("error deleting correspondents", err) {
  372. mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
  373. }
  374. }
  375. }
  376. // wait for forget queue items and process them one by one
  377. func (mysql *MySQL) forgetLoop() {
  378. defer func() {
  379. if r := recover(); r != nil {
  380. mysql.logger.Error("mysql",
  381. fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
  382. time.Sleep(cleanupPauseTime)
  383. go mysql.forgetLoop()
  384. }
  385. }()
  386. for {
  387. for {
  388. found, err := mysql.doForget()
  389. mysql.logError("error processing forget", err)
  390. if err != nil {
  391. time.Sleep(cleanupPauseTime)
  392. }
  393. if !found {
  394. break
  395. }
  396. }
  397. <-mysql.wakeForgetter
  398. }
  399. }
  400. // dequeue an item from the forget queue and process it
  401. func (mysql *MySQL) doForget() (found bool, err error) {
  402. id, account, err := func() (id int64, account string, err error) {
  403. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  404. defer cancel()
  405. row := mysql.db.QueryRowContext(ctx,
  406. `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
  407. err = row.Scan(&id, &account)
  408. if err == sql.ErrNoRows {
  409. return 0, "", nil
  410. }
  411. return
  412. }()
  413. if err != nil || account == "" {
  414. return false, err
  415. }
  416. found = true
  417. var count int
  418. for {
  419. start := time.Now()
  420. count, err = mysql.doForgetIteration(account)
  421. elapsed := time.Since(start)
  422. if err != nil {
  423. return true, err
  424. }
  425. if count == 0 {
  426. break
  427. }
  428. time.Sleep(elapsed)
  429. }
  430. mysql.logger.Debug("mysql", "forget complete for account", account)
  431. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  432. defer cancel()
  433. _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
  434. return
  435. }
  436. func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
  437. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  438. defer cancel()
  439. rows, err := mysql.db.QueryContext(ctx, `
  440. SELECT account_messages.history_id
  441. FROM account_messages
  442. WHERE account_messages.account = ?
  443. LIMIT ?;`, account, cleanupRowLimit)
  444. if err != nil {
  445. return
  446. }
  447. defer rows.Close()
  448. var ids []uint64
  449. for rows.Next() {
  450. var id uint64
  451. err = rows.Scan(&id)
  452. if err != nil {
  453. return
  454. }
  455. ids = append(ids, id)
  456. }
  457. if len(ids) == 0 {
  458. return
  459. }
  460. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
  461. err = mysql.deleteHistoryIDs(ctx, ids)
  462. return len(ids), err
  463. }
  464. func (mysql *MySQL) prepareStatements() (err error) {
  465. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  466. (data, msgid) VALUES (?, ?);`)
  467. if err != nil {
  468. return
  469. }
  470. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  471. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  472. if err != nil {
  473. return
  474. }
  475. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  476. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  477. if err != nil {
  478. return
  479. }
  480. mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
  481. (target, correspondent, nanotime) VALUES (?, ?, ?)
  482. ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
  483. if err != nil {
  484. return
  485. }
  486. mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
  487. (history_id, account) VALUES (?, ?);`)
  488. if err != nil {
  489. return
  490. }
  491. return
  492. }
  493. func (mysql *MySQL) getTimeout() time.Duration {
  494. return time.Duration(mysql.timeout.Load())
  495. }
  496. func (mysql *MySQL) isTrackingAccountMessages() bool {
  497. return mysql.trackAccountMessages.Load() != 0
  498. }
  499. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  500. if err != nil {
  501. mysql.logger.Error("mysql", context, err.Error())
  502. return true
  503. }
  504. return false
  505. }
  506. func (mysql *MySQL) Forget(account string) {
  507. if mysql.db == nil || account == "" {
  508. return
  509. }
  510. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  511. defer cancel()
  512. _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
  513. if mysql.logError("can't insert into forget table", err) {
  514. return
  515. }
  516. // wake up the forget goroutine if it's blocked:
  517. select {
  518. case mysql.wakeForgetter <- e{}:
  519. default:
  520. }
  521. }
  522. func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
  523. if mysql.db == nil {
  524. return
  525. }
  526. if target == "" {
  527. return utils.ErrInvalidParams
  528. }
  529. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  530. defer cancel()
  531. id, err := mysql.insertBase(ctx, item)
  532. if err != nil {
  533. return
  534. }
  535. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  536. if err != nil {
  537. return
  538. }
  539. err = mysql.insertAccountMessageEntry(ctx, id, account)
  540. if err != nil {
  541. return
  542. }
  543. return
  544. }
  545. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  546. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  547. mysql.logError("could not insert sequence entry", err)
  548. return
  549. }
  550. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  551. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  552. mysql.logError("could not insert conversations entry", err)
  553. return
  554. }
  555. func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
  556. _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
  557. mysql.logError("could not insert conversations entry", err)
  558. return
  559. }
  560. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  561. value, err := marshalItem(&item)
  562. if mysql.logError("could not marshal item", err) {
  563. return
  564. }
  565. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  566. if mysql.logError("could not decode msgid", err) {
  567. return
  568. }
  569. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  570. if mysql.logError("could not insert item", err) {
  571. return
  572. }
  573. id, err = result.LastInsertId()
  574. if mysql.logError("could not insert item", err) {
  575. return
  576. }
  577. return
  578. }
  579. func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
  580. if account == "" || !mysql.isTrackingAccountMessages() {
  581. return
  582. }
  583. _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
  584. mysql.logError("could not insert account-message entry", err)
  585. return
  586. }
  587. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  588. if mysql.db == nil {
  589. return
  590. }
  591. if senderAccount == "" && recipientAccount == "" {
  592. return
  593. }
  594. if sender == "" || recipient == "" {
  595. return utils.ErrInvalidParams
  596. }
  597. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  598. defer cancel()
  599. id, err := mysql.insertBase(ctx, item)
  600. if err != nil {
  601. return
  602. }
  603. nanotime := item.Message.Time.UnixNano()
  604. if senderAccount != "" {
  605. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  606. if err != nil {
  607. return
  608. }
  609. err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
  610. if err != nil {
  611. return
  612. }
  613. }
  614. if recipientAccount != "" && sender != recipient {
  615. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  616. if err != nil {
  617. return
  618. }
  619. err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
  620. if err != nil {
  621. return
  622. }
  623. }
  624. err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
  625. if err != nil {
  626. return
  627. }
  628. return
  629. }
  630. // note that accountName is the unfolded name
  631. func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
  632. if mysql.db == nil {
  633. return ErrDBIsNil
  634. }
  635. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  636. defer cancel()
  637. _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
  638. if err != nil {
  639. return
  640. }
  641. if accountName != "*" {
  642. var item history.Item
  643. err = unmarshalItem(data, &item)
  644. // delete if the entry is corrupt
  645. if err == nil && item.AccountName != accountName {
  646. return ErrDisallowed
  647. }
  648. }
  649. err = mysql.deleteHistoryIDs(ctx, []uint64{id})
  650. mysql.logError("couldn't delete msgid", err)
  651. return
  652. }
  653. func (mysql *MySQL) Export(account string, writer io.Writer) {
  654. if mysql.db == nil {
  655. return
  656. }
  657. var err error
  658. var lastSeen uint64
  659. for {
  660. rows := func() (count int) {
  661. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  662. defer cancel()
  663. rows, rowsErr := mysql.db.QueryContext(ctx, `
  664. SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
  665. INNER JOIN history ON history.id = account_messages.history_id
  666. INNER JOIN sequence ON account_messages.history_id = sequence.history_id
  667. WHERE account_messages.account = ? AND account_messages.history_id > ?
  668. LIMIT ?`, account, lastSeen, cleanupRowLimit)
  669. if rowsErr != nil {
  670. err = rowsErr
  671. return
  672. }
  673. defer rows.Close()
  674. for rows.Next() {
  675. var id uint64
  676. var blob, jsonBlob []byte
  677. var target string
  678. var item history.Item
  679. err = rows.Scan(&id, &blob, &target)
  680. if err != nil {
  681. return
  682. }
  683. err = unmarshalItem(blob, &item)
  684. if err != nil {
  685. return
  686. }
  687. item.CfCorrespondent = target
  688. jsonBlob, err = json.Marshal(item)
  689. if err != nil {
  690. return
  691. }
  692. count++
  693. if lastSeen < id {
  694. lastSeen = id
  695. }
  696. writer.Write(jsonBlob)
  697. writer.Write([]byte{'\n'})
  698. }
  699. return
  700. }()
  701. if rows == 0 || err != nil {
  702. break
  703. }
  704. }
  705. mysql.logError("could not export history", err)
  706. return
  707. }
  708. func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
  709. decoded, err := decodeMsgid(msgid)
  710. if err != nil {
  711. return
  712. }
  713. cols := `sequence.nanotime, conversations.nanotime`
  714. if includeData {
  715. cols = `sequence.nanotime, conversations.nanotime, history.id, history.data`
  716. }
  717. row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
  718. SELECT %s FROM history
  719. LEFT JOIN sequence ON history.id = sequence.history_id
  720. LEFT JOIN conversations ON history.id = conversations.history_id
  721. WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
  722. var nanoSeq, nanoConv sql.NullInt64
  723. if !includeData {
  724. err = row.Scan(&nanoSeq, &nanoConv)
  725. } else {
  726. err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
  727. }
  728. if err != sql.ErrNoRows {
  729. mysql.logError("could not resolve msgid to time", err)
  730. }
  731. if err != nil {
  732. return
  733. }
  734. nanotime := extractNanotime(nanoSeq, nanoConv)
  735. if nanotime == 0 {
  736. err = sql.ErrNoRows
  737. return
  738. }
  739. result = time.Unix(0, nanotime).UTC()
  740. return
  741. }
  742. func extractNanotime(seq, conv sql.NullInt64) (result int64) {
  743. if seq.Valid {
  744. return seq.Int64
  745. } else if conv.Valid {
  746. return conv.Int64
  747. }
  748. return
  749. }
  750. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  751. rows, err := mysql.db.QueryContext(ctx, query, args...)
  752. if mysql.logError("could not select history items", err) {
  753. return
  754. }
  755. defer rows.Close()
  756. for rows.Next() {
  757. var blob []byte
  758. var item history.Item
  759. err = rows.Scan(&blob)
  760. if mysql.logError("could not scan history item", err) {
  761. return
  762. }
  763. err = unmarshalItem(blob, &item)
  764. if mysql.logError("could not unmarshal history item", err) {
  765. return
  766. }
  767. results = append(results, item)
  768. }
  769. return
  770. }
  771. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  772. useSequence := correspondent == ""
  773. table := "sequence"
  774. if !useSequence {
  775. table = "conversations"
  776. }
  777. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  778. direction := "ASC"
  779. if !ascending {
  780. direction = "DESC"
  781. }
  782. var queryBuf strings.Builder
  783. args := make([]interface{}, 0, 6)
  784. fmt.Fprintf(&queryBuf,
  785. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  786. if useSequence {
  787. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  788. args = append(args, target)
  789. } else {
  790. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  791. args = append(args, target)
  792. args = append(args, correspondent)
  793. }
  794. if !after.IsZero() {
  795. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  796. args = append(args, after.UnixNano())
  797. }
  798. if !before.IsZero() {
  799. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  800. args = append(args, before.UnixNano())
  801. }
  802. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  803. args = append(args, limit)
  804. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  805. if err == nil && !ascending {
  806. slices.Reverse(results)
  807. }
  808. return
  809. }
  810. func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
  811. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  812. direction := "ASC"
  813. if !ascending {
  814. direction = "DESC"
  815. }
  816. var queryBuf strings.Builder
  817. args := make([]interface{}, 0, 4)
  818. queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
  819. WHERE target = ?`)
  820. args = append(args, target)
  821. if !after.IsZero() {
  822. queryBuf.WriteString(" AND correspondents.nanotime > ?")
  823. args = append(args, after.UnixNano())
  824. }
  825. if !before.IsZero() {
  826. queryBuf.WriteString(" AND correspondents.nanotime < ?")
  827. args = append(args, before.UnixNano())
  828. }
  829. fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
  830. args = append(args, limit)
  831. query := queryBuf.String()
  832. rows, err := mysql.db.QueryContext(ctx, query, args...)
  833. if err != nil {
  834. return
  835. }
  836. defer rows.Close()
  837. var correspondent string
  838. var nanotime int64
  839. for rows.Next() {
  840. err = rows.Scan(&correspondent, &nanotime)
  841. if err != nil {
  842. return
  843. }
  844. results = append(results, history.TargetListing{
  845. CfName: correspondent,
  846. Time: time.Unix(0, nanotime),
  847. })
  848. }
  849. if !ascending {
  850. slices.Reverse(results)
  851. }
  852. return
  853. }
  854. func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
  855. if mysql.db == nil {
  856. return
  857. }
  858. if len(cfchannels) == 0 {
  859. return
  860. }
  861. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  862. defer cancel()
  863. var queryBuf strings.Builder
  864. args := make([]interface{}, 0, len(results))
  865. // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
  866. // this should be a "loose index scan"
  867. queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
  868. WHERE sequence.target IN (`)
  869. for i, chname := range cfchannels {
  870. if i != 0 {
  871. queryBuf.WriteString(", ")
  872. }
  873. queryBuf.WriteByte('?')
  874. args = append(args, chname)
  875. }
  876. queryBuf.WriteString(") GROUP BY sequence.target;")
  877. rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
  878. if mysql.logError("could not query channel listings", err) {
  879. return
  880. }
  881. defer rows.Close()
  882. var target string
  883. var nanotime int64
  884. for rows.Next() {
  885. err = rows.Scan(&target, &nanotime)
  886. if mysql.logError("could not scan channel listings", err) {
  887. return
  888. }
  889. results = append(results, history.TargetListing{
  890. CfName: target,
  891. Time: time.Unix(0, nanotime),
  892. })
  893. }
  894. return
  895. }
  896. func (mysql *MySQL) Close() {
  897. // closing the database will close our prepared statements as well
  898. if mysql.db != nil {
  899. mysql.db.Close()
  900. }
  901. mysql.db = nil
  902. }
  903. // implements history.Sequence, emulating a single history buffer (for a channel,
  904. // a single user's DMs, or a DM conversation)
  905. type mySQLHistorySequence struct {
  906. mysql *MySQL
  907. target string
  908. correspondent string
  909. cutoff time.Time
  910. }
  911. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
  912. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  913. defer cancel()
  914. startTime := start.Time
  915. if start.Msgid != "" {
  916. startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
  917. if err != nil {
  918. if err == sql.ErrNoRows {
  919. return nil, nil
  920. } else {
  921. return nil, err
  922. }
  923. }
  924. }
  925. endTime := end.Time
  926. if end.Msgid != "" {
  927. endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
  928. if err != nil {
  929. if err == sql.ErrNoRows {
  930. return nil, nil
  931. } else {
  932. return nil, err
  933. }
  934. }
  935. }
  936. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  937. return results, err
  938. }
  939. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  940. return history.GenericAround(s, start, limit)
  941. }
  942. func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
  943. ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
  944. defer cancel()
  945. // TODO accept msgids here?
  946. startTime := start.Time
  947. endTime := end.Time
  948. results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
  949. seq.mysql.logError("could not read correspondents", err)
  950. return
  951. }
  952. func (seq *mySQLHistorySequence) Cutoff() time.Time {
  953. return seq.cutoff
  954. }
  955. func (seq *mySQLHistorySequence) Ephemeral() bool {
  956. return false
  957. }
  958. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
  959. return &mySQLHistorySequence{
  960. target: target,
  961. correspondent: correspondent,
  962. mysql: mysql,
  963. cutoff: cutoff,
  964. }
  965. }