Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

history.go 28KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "context"
  6. "database/sql"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "runtime/debug"
  12. "slices"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "time"
  17. "github.com/ergochat/ergo/irc/history"
  18. "github.com/ergochat/ergo/irc/logger"
  19. "github.com/ergochat/ergo/irc/utils"
  20. _ "github.com/go-sql-driver/mysql"
  21. )
  22. var (
  23. ErrDisallowed = errors.New("disallowed")
  24. )
  25. const (
  26. // maximum length in bytes of any message target (nickname or channel name) in its
  27. // canonicalized (i.e., casefolded) state:
  28. MaxTargetLength = 64
  29. // latest schema of the db
  30. latestDbSchema = "2"
  31. keySchemaVersion = "db.version"
  32. // minor version indicates rollback-safe upgrades, i.e.,
  33. // you can downgrade oragono and everything will work
  34. latestDbMinorVersion = "2"
  35. keySchemaMinorVersion = "db.minorversion"
  36. cleanupRowLimit = 50
  37. cleanupPauseTime = 10 * time.Minute
  38. )
  39. type e struct{}
  40. type MySQL struct {
  41. db *sql.DB
  42. logger *logger.Manager
  43. insertHistory *sql.Stmt
  44. insertSequence *sql.Stmt
  45. insertConversation *sql.Stmt
  46. insertCorrespondent *sql.Stmt
  47. insertAccountMessage *sql.Stmt
  48. stateMutex sync.Mutex
  49. config Config
  50. wakeForgetter chan e
  51. timeout atomic.Uint64
  52. trackAccountMessages atomic.Uint32
  53. }
  54. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  55. mysql.logger = logger
  56. mysql.wakeForgetter = make(chan e, 1)
  57. mysql.SetConfig(config)
  58. }
  59. func (mysql *MySQL) SetConfig(config Config) {
  60. mysql.timeout.Store(uint64(config.Timeout))
  61. var trackAccountMessages uint32
  62. if config.TrackAccountMessages {
  63. trackAccountMessages = 1
  64. }
  65. mysql.trackAccountMessages.Store(trackAccountMessages)
  66. mysql.stateMutex.Lock()
  67. mysql.config = config
  68. mysql.stateMutex.Unlock()
  69. }
  70. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  71. mysql.stateMutex.Lock()
  72. expireTime = mysql.config.ExpireTime
  73. mysql.stateMutex.Unlock()
  74. return
  75. }
  76. func (m *MySQL) Open() (err error) {
  77. var address string
  78. if m.config.SocketPath != "" {
  79. address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
  80. } else if m.config.Port != 0 {
  81. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  82. }
  83. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  84. if err != nil {
  85. return err
  86. }
  87. if m.config.MaxConns != 0 {
  88. m.db.SetMaxOpenConns(m.config.MaxConns)
  89. m.db.SetMaxIdleConns(m.config.MaxConns)
  90. }
  91. if m.config.ConnMaxLifetime != 0 {
  92. m.db.SetConnMaxLifetime(m.config.ConnMaxLifetime)
  93. }
  94. err = m.fixSchemas()
  95. if err != nil {
  96. return err
  97. }
  98. err = m.prepareStatements()
  99. if err != nil {
  100. return err
  101. }
  102. go m.cleanupLoop()
  103. go m.forgetLoop()
  104. return nil
  105. }
  106. func (mysql *MySQL) fixSchemas() (err error) {
  107. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  108. key_name VARCHAR(32) primary key,
  109. value VARCHAR(32) NOT NULL
  110. ) CHARSET=ascii COLLATE=ascii_bin;`)
  111. if err != nil {
  112. return err
  113. }
  114. var schema string
  115. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  116. if err == sql.ErrNoRows {
  117. err = mysql.createTables()
  118. if err != nil {
  119. return
  120. }
  121. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  122. if err != nil {
  123. return
  124. }
  125. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  126. if err != nil {
  127. return
  128. }
  129. return
  130. } else if err == nil && schema != latestDbSchema {
  131. // TODO figure out what to do about schema changes
  132. return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
  133. } else if err != nil {
  134. return err
  135. }
  136. var minorVersion string
  137. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
  138. if err == sql.ErrNoRows {
  139. // XXX for now, the only minor version upgrade is the account tracking tables
  140. err = mysql.createComplianceTables()
  141. if err != nil {
  142. return
  143. }
  144. err = mysql.createCorrespondentsTable()
  145. if err != nil {
  146. return
  147. }
  148. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
  149. if err != nil {
  150. return
  151. }
  152. } else if err == nil && minorVersion == "1" {
  153. // upgrade from 2.1 to 2.2: create the correspondents table
  154. err = mysql.createCorrespondentsTable()
  155. if err != nil {
  156. return
  157. }
  158. _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
  159. if err != nil {
  160. return
  161. }
  162. } else if err == nil && minorVersion != latestDbMinorVersion {
  163. // TODO: if minorVersion < latestDbMinorVersion, upgrade,
  164. // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
  165. }
  166. return
  167. }
  168. func (mysql *MySQL) createTables() (err error) {
  169. _, err = mysql.db.Exec(`CREATE TABLE history (
  170. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  171. data BLOB NOT NULL,
  172. msgid BINARY(16) NOT NULL,
  173. KEY (msgid(4))
  174. ) CHARSET=ascii COLLATE=ascii_bin;`)
  175. if err != nil {
  176. return err
  177. }
  178. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  179. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  180. target VARBINARY(%[1]d) NOT NULL,
  181. nanotime BIGINT UNSIGNED NOT NULL,
  182. KEY (target, nanotime)
  183. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  184. if err != nil {
  185. return err
  186. }
  187. /* XXX: this table used to be:
  188. CREATE TABLE sequence (
  189. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  190. target VARBINARY(%[1]d) NOT NULL,
  191. nanotime BIGINT UNSIGNED NOT NULL,
  192. history_id BIGINT NOT NULL,
  193. KEY (target, nanotime),
  194. KEY (history_id)
  195. ) CHARSET=ascii COLLATE=ascii_bin;
  196. Some users may still be using the old schema.
  197. */
  198. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  199. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  200. target VARBINARY(%[1]d) NOT NULL,
  201. correspondent VARBINARY(%[1]d) NOT NULL,
  202. nanotime BIGINT UNSIGNED NOT NULL,
  203. history_id BIGINT NOT NULL,
  204. KEY (target, correspondent, nanotime),
  205. KEY (history_id)
  206. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  207. if err != nil {
  208. return err
  209. }
  210. err = mysql.createCorrespondentsTable()
  211. if err != nil {
  212. return err
  213. }
  214. err = mysql.createComplianceTables()
  215. if err != nil {
  216. return err
  217. }
  218. return nil
  219. }
  220. func (mysql *MySQL) createCorrespondentsTable() (err error) {
  221. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
  222. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  223. target VARBINARY(%[1]d) NOT NULL,
  224. correspondent VARBINARY(%[1]d) NOT NULL,
  225. nanotime BIGINT UNSIGNED NOT NULL,
  226. UNIQUE KEY (target, correspondent),
  227. KEY (target, nanotime),
  228. KEY (nanotime)
  229. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  230. return
  231. }
  232. func (mysql *MySQL) createComplianceTables() (err error) {
  233. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
  234. history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
  235. account VARBINARY(%[1]d) NOT NULL,
  236. KEY (account, history_id)
  237. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  238. if err != nil {
  239. return err
  240. }
  241. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
  242. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  243. account VARBINARY(%[1]d) NOT NULL
  244. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  245. if err != nil {
  246. return err
  247. }
  248. return nil
  249. }
  250. func (mysql *MySQL) cleanupLoop() {
  251. defer func() {
  252. if r := recover(); r != nil {
  253. mysql.logger.Error("mysql",
  254. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  255. time.Sleep(cleanupPauseTime)
  256. go mysql.cleanupLoop()
  257. }
  258. }()
  259. for {
  260. expireTime := mysql.getExpireTime()
  261. if expireTime != 0 {
  262. for {
  263. startTime := time.Now()
  264. rowsDeleted, err := mysql.doCleanup(expireTime)
  265. elapsed := time.Now().Sub(startTime)
  266. mysql.logError("error during row cleanup", err)
  267. // keep going as long as we're accomplishing significant work
  268. // (don't busy-wait on small numbers of rows expiring):
  269. if rowsDeleted < (cleanupRowLimit / 10) {
  270. break
  271. }
  272. // crude backpressure mechanism: if the database is slow,
  273. // give it time to process other queries
  274. time.Sleep(elapsed)
  275. }
  276. }
  277. time.Sleep(cleanupPauseTime)
  278. }
  279. }
  280. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  281. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  282. defer cancel()
  283. ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
  284. if len(ids) == 0 {
  285. mysql.logger.Debug("mysql", "found no rows to clean up")
  286. return
  287. }
  288. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  289. if maxNanotime != 0 {
  290. mysql.deleteCorrespondents(ctx, maxNanotime)
  291. }
  292. return len(ids), mysql.deleteHistoryIDs(ctx, ids)
  293. }
  294. func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
  295. // can't use ? binding for a variable number of arguments, build the IN clause manually
  296. var inBuf strings.Builder
  297. inBuf.WriteByte('(')
  298. for i, id := range ids {
  299. if i != 0 {
  300. inBuf.WriteRune(',')
  301. }
  302. fmt.Fprintf(&inBuf, "%d", id)
  303. }
  304. inBuf.WriteRune(')')
  305. inClause := inBuf.String()
  306. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
  307. if err != nil {
  308. return
  309. }
  310. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
  311. if err != nil {
  312. return
  313. }
  314. if mysql.isTrackingAccountMessages() {
  315. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
  316. if err != nil {
  317. return
  318. }
  319. }
  320. _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
  321. if err != nil {
  322. return
  323. }
  324. return
  325. }
  326. func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  327. rows, err := mysql.db.QueryContext(ctx, `
  328. SELECT history.id, sequence.nanotime, conversations.nanotime
  329. FROM history
  330. LEFT JOIN sequence ON history.id = sequence.history_id
  331. LEFT JOIN conversations on history.id = conversations.history_id
  332. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  333. if err != nil {
  334. return
  335. }
  336. defer rows.Close()
  337. idset := make(map[uint64]struct{}, cleanupRowLimit)
  338. threshold := time.Now().Add(-age).UnixNano()
  339. for rows.Next() {
  340. var id uint64
  341. var seqNano, convNano sql.NullInt64
  342. err = rows.Scan(&id, &seqNano, &convNano)
  343. if err != nil {
  344. return
  345. }
  346. nanotime := extractNanotime(seqNano, convNano)
  347. // returns 0 if not found; in that case the data is inconsistent
  348. // and we should delete the entry
  349. if nanotime < threshold {
  350. idset[id] = struct{}{}
  351. if nanotime > maxNanotime {
  352. maxNanotime = nanotime
  353. }
  354. }
  355. }
  356. ids = make([]uint64, len(idset))
  357. i := 0
  358. for id := range idset {
  359. ids[i] = id
  360. i++
  361. }
  362. return
  363. }
  364. func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
  365. result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
  366. if err != nil {
  367. mysql.logError("error deleting correspondents", err)
  368. } else {
  369. count, err := result.RowsAffected()
  370. if !mysql.logError("error deleting correspondents", err) {
  371. mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
  372. }
  373. }
  374. }
  375. // wait for forget queue items and process them one by one
  376. func (mysql *MySQL) forgetLoop() {
  377. defer func() {
  378. if r := recover(); r != nil {
  379. mysql.logger.Error("mysql",
  380. fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
  381. time.Sleep(cleanupPauseTime)
  382. go mysql.forgetLoop()
  383. }
  384. }()
  385. for {
  386. for {
  387. found, err := mysql.doForget()
  388. mysql.logError("error processing forget", err)
  389. if err != nil {
  390. time.Sleep(cleanupPauseTime)
  391. }
  392. if !found {
  393. break
  394. }
  395. }
  396. <-mysql.wakeForgetter
  397. }
  398. }
  399. // dequeue an item from the forget queue and process it
  400. func (mysql *MySQL) doForget() (found bool, err error) {
  401. id, account, err := func() (id int64, account string, err error) {
  402. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  403. defer cancel()
  404. row := mysql.db.QueryRowContext(ctx,
  405. `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
  406. err = row.Scan(&id, &account)
  407. if err == sql.ErrNoRows {
  408. return 0, "", nil
  409. }
  410. return
  411. }()
  412. if err != nil || account == "" {
  413. return false, err
  414. }
  415. found = true
  416. var count int
  417. for {
  418. start := time.Now()
  419. count, err = mysql.doForgetIteration(account)
  420. elapsed := time.Since(start)
  421. if err != nil {
  422. return true, err
  423. }
  424. if count == 0 {
  425. break
  426. }
  427. time.Sleep(elapsed)
  428. }
  429. mysql.logger.Debug("mysql", "forget complete for account", account)
  430. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  431. defer cancel()
  432. _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
  433. return
  434. }
  435. func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
  436. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  437. defer cancel()
  438. rows, err := mysql.db.QueryContext(ctx, `
  439. SELECT account_messages.history_id
  440. FROM account_messages
  441. WHERE account_messages.account = ?
  442. LIMIT ?;`, account, cleanupRowLimit)
  443. if err != nil {
  444. return
  445. }
  446. defer rows.Close()
  447. var ids []uint64
  448. for rows.Next() {
  449. var id uint64
  450. err = rows.Scan(&id)
  451. if err != nil {
  452. return
  453. }
  454. ids = append(ids, id)
  455. }
  456. if len(ids) == 0 {
  457. return
  458. }
  459. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
  460. err = mysql.deleteHistoryIDs(ctx, ids)
  461. return len(ids), err
  462. }
  463. func (mysql *MySQL) prepareStatements() (err error) {
  464. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  465. (data, msgid) VALUES (?, ?);`)
  466. if err != nil {
  467. return
  468. }
  469. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  470. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  471. if err != nil {
  472. return
  473. }
  474. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  475. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  476. if err != nil {
  477. return
  478. }
  479. mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
  480. (target, correspondent, nanotime) VALUES (?, ?, ?)
  481. ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
  482. if err != nil {
  483. return
  484. }
  485. mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
  486. (history_id, account) VALUES (?, ?);`)
  487. if err != nil {
  488. return
  489. }
  490. return
  491. }
  492. func (mysql *MySQL) getTimeout() time.Duration {
  493. return time.Duration(mysql.timeout.Load())
  494. }
  495. func (mysql *MySQL) isTrackingAccountMessages() bool {
  496. return mysql.trackAccountMessages.Load() != 0
  497. }
  498. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  499. if err != nil {
  500. mysql.logger.Error("mysql", context, err.Error())
  501. return true
  502. }
  503. return false
  504. }
  505. func (mysql *MySQL) Forget(account string) {
  506. if mysql.db == nil || account == "" {
  507. return
  508. }
  509. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  510. defer cancel()
  511. _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
  512. if mysql.logError("can't insert into forget table", err) {
  513. return
  514. }
  515. // wake up the forget goroutine if it's blocked:
  516. select {
  517. case mysql.wakeForgetter <- e{}:
  518. default:
  519. }
  520. }
  521. func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
  522. if mysql.db == nil {
  523. return
  524. }
  525. if target == "" {
  526. return utils.ErrInvalidParams
  527. }
  528. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  529. defer cancel()
  530. id, err := mysql.insertBase(ctx, item)
  531. if err != nil {
  532. return
  533. }
  534. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  535. if err != nil {
  536. return
  537. }
  538. err = mysql.insertAccountMessageEntry(ctx, id, account)
  539. if err != nil {
  540. return
  541. }
  542. return
  543. }
  544. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  545. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  546. mysql.logError("could not insert sequence entry", err)
  547. return
  548. }
  549. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  550. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  551. mysql.logError("could not insert conversations entry", err)
  552. return
  553. }
  554. func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
  555. _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
  556. mysql.logError("could not insert conversations entry", err)
  557. return
  558. }
  559. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  560. value, err := marshalItem(&item)
  561. if mysql.logError("could not marshal item", err) {
  562. return
  563. }
  564. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  565. if mysql.logError("could not decode msgid", err) {
  566. return
  567. }
  568. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  569. if mysql.logError("could not insert item", err) {
  570. return
  571. }
  572. id, err = result.LastInsertId()
  573. if mysql.logError("could not insert item", err) {
  574. return
  575. }
  576. return
  577. }
  578. func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
  579. if account == "" || !mysql.isTrackingAccountMessages() {
  580. return
  581. }
  582. _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
  583. mysql.logError("could not insert account-message entry", err)
  584. return
  585. }
  586. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  587. if mysql.db == nil {
  588. return
  589. }
  590. if senderAccount == "" && recipientAccount == "" {
  591. return
  592. }
  593. if sender == "" || recipient == "" {
  594. return utils.ErrInvalidParams
  595. }
  596. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  597. defer cancel()
  598. id, err := mysql.insertBase(ctx, item)
  599. if err != nil {
  600. return
  601. }
  602. nanotime := item.Message.Time.UnixNano()
  603. if senderAccount != "" {
  604. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  605. if err != nil {
  606. return
  607. }
  608. err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
  609. if err != nil {
  610. return
  611. }
  612. }
  613. if recipientAccount != "" && sender != recipient {
  614. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  615. if err != nil {
  616. return
  617. }
  618. err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
  619. if err != nil {
  620. return
  621. }
  622. }
  623. err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
  624. if err != nil {
  625. return
  626. }
  627. return
  628. }
  629. // note that accountName is the unfolded name
  630. func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
  631. if mysql.db == nil {
  632. return nil
  633. }
  634. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  635. defer cancel()
  636. _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
  637. if err != nil {
  638. return
  639. }
  640. if accountName != "*" {
  641. var item history.Item
  642. err = unmarshalItem(data, &item)
  643. // delete if the entry is corrupt
  644. if err == nil && item.AccountName != accountName {
  645. return ErrDisallowed
  646. }
  647. }
  648. err = mysql.deleteHistoryIDs(ctx, []uint64{id})
  649. mysql.logError("couldn't delete msgid", err)
  650. return
  651. }
  652. func (mysql *MySQL) Export(account string, writer io.Writer) {
  653. if mysql.db == nil {
  654. return
  655. }
  656. var err error
  657. var lastSeen uint64
  658. for {
  659. rows := func() (count int) {
  660. ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
  661. defer cancel()
  662. rows, rowsErr := mysql.db.QueryContext(ctx, `
  663. SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
  664. INNER JOIN history ON history.id = account_messages.history_id
  665. INNER JOIN sequence ON account_messages.history_id = sequence.history_id
  666. WHERE account_messages.account = ? AND account_messages.history_id > ?
  667. LIMIT ?`, account, lastSeen, cleanupRowLimit)
  668. if rowsErr != nil {
  669. err = rowsErr
  670. return
  671. }
  672. defer rows.Close()
  673. for rows.Next() {
  674. var id uint64
  675. var blob, jsonBlob []byte
  676. var target string
  677. var item history.Item
  678. err = rows.Scan(&id, &blob, &target)
  679. if err != nil {
  680. return
  681. }
  682. err = unmarshalItem(blob, &item)
  683. if err != nil {
  684. return
  685. }
  686. item.CfCorrespondent = target
  687. jsonBlob, err = json.Marshal(item)
  688. if err != nil {
  689. return
  690. }
  691. count++
  692. if lastSeen < id {
  693. lastSeen = id
  694. }
  695. writer.Write(jsonBlob)
  696. writer.Write([]byte{'\n'})
  697. }
  698. return
  699. }()
  700. if rows == 0 || err != nil {
  701. break
  702. }
  703. }
  704. mysql.logError("could not export history", err)
  705. return
  706. }
  707. func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
  708. decoded, err := decodeMsgid(msgid)
  709. if err != nil {
  710. return
  711. }
  712. cols := `sequence.nanotime, conversations.nanotime`
  713. if includeData {
  714. cols = `sequence.nanotime, conversations.nanotime, history.id, history.data`
  715. }
  716. row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
  717. SELECT %s FROM history
  718. LEFT JOIN sequence ON history.id = sequence.history_id
  719. LEFT JOIN conversations ON history.id = conversations.history_id
  720. WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
  721. var nanoSeq, nanoConv sql.NullInt64
  722. if !includeData {
  723. err = row.Scan(&nanoSeq, &nanoConv)
  724. } else {
  725. err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
  726. }
  727. if err != sql.ErrNoRows {
  728. mysql.logError("could not resolve msgid to time", err)
  729. }
  730. if err != nil {
  731. return
  732. }
  733. nanotime := extractNanotime(nanoSeq, nanoConv)
  734. if nanotime == 0 {
  735. err = sql.ErrNoRows
  736. return
  737. }
  738. result = time.Unix(0, nanotime).UTC()
  739. return
  740. }
  741. func extractNanotime(seq, conv sql.NullInt64) (result int64) {
  742. if seq.Valid {
  743. return seq.Int64
  744. } else if conv.Valid {
  745. return conv.Int64
  746. }
  747. return
  748. }
  749. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  750. rows, err := mysql.db.QueryContext(ctx, query, args...)
  751. if mysql.logError("could not select history items", err) {
  752. return
  753. }
  754. defer rows.Close()
  755. for rows.Next() {
  756. var blob []byte
  757. var item history.Item
  758. err = rows.Scan(&blob)
  759. if mysql.logError("could not scan history item", err) {
  760. return
  761. }
  762. err = unmarshalItem(blob, &item)
  763. if mysql.logError("could not unmarshal history item", err) {
  764. return
  765. }
  766. results = append(results, item)
  767. }
  768. return
  769. }
  770. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  771. useSequence := correspondent == ""
  772. table := "sequence"
  773. if !useSequence {
  774. table = "conversations"
  775. }
  776. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  777. direction := "ASC"
  778. if !ascending {
  779. direction = "DESC"
  780. }
  781. var queryBuf strings.Builder
  782. args := make([]interface{}, 0, 6)
  783. fmt.Fprintf(&queryBuf,
  784. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  785. if useSequence {
  786. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  787. args = append(args, target)
  788. } else {
  789. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  790. args = append(args, target)
  791. args = append(args, correspondent)
  792. }
  793. if !after.IsZero() {
  794. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  795. args = append(args, after.UnixNano())
  796. }
  797. if !before.IsZero() {
  798. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  799. args = append(args, before.UnixNano())
  800. }
  801. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  802. args = append(args, limit)
  803. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  804. if err == nil && !ascending {
  805. slices.Reverse(results)
  806. }
  807. return
  808. }
  809. func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
  810. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  811. direction := "ASC"
  812. if !ascending {
  813. direction = "DESC"
  814. }
  815. var queryBuf strings.Builder
  816. args := make([]interface{}, 0, 4)
  817. queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
  818. WHERE target = ?`)
  819. args = append(args, target)
  820. if !after.IsZero() {
  821. queryBuf.WriteString(" AND correspondents.nanotime > ?")
  822. args = append(args, after.UnixNano())
  823. }
  824. if !before.IsZero() {
  825. queryBuf.WriteString(" AND correspondents.nanotime < ?")
  826. args = append(args, before.UnixNano())
  827. }
  828. fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
  829. args = append(args, limit)
  830. query := queryBuf.String()
  831. rows, err := mysql.db.QueryContext(ctx, query, args...)
  832. if err != nil {
  833. return
  834. }
  835. defer rows.Close()
  836. var correspondent string
  837. var nanotime int64
  838. for rows.Next() {
  839. err = rows.Scan(&correspondent, &nanotime)
  840. if err != nil {
  841. return
  842. }
  843. results = append(results, history.TargetListing{
  844. CfName: correspondent,
  845. Time: time.Unix(0, nanotime),
  846. })
  847. }
  848. if !ascending {
  849. slices.Reverse(results)
  850. }
  851. return
  852. }
  853. func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
  854. if mysql.db == nil {
  855. return
  856. }
  857. if len(cfchannels) == 0 {
  858. return
  859. }
  860. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  861. defer cancel()
  862. var queryBuf strings.Builder
  863. args := make([]interface{}, 0, len(results))
  864. // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
  865. // this should be a "loose index scan"
  866. queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
  867. WHERE sequence.target IN (`)
  868. for i, chname := range cfchannels {
  869. if i != 0 {
  870. queryBuf.WriteString(", ")
  871. }
  872. queryBuf.WriteByte('?')
  873. args = append(args, chname)
  874. }
  875. queryBuf.WriteString(") GROUP BY sequence.target;")
  876. rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
  877. if mysql.logError("could not query channel listings", err) {
  878. return
  879. }
  880. defer rows.Close()
  881. var target string
  882. var nanotime int64
  883. for rows.Next() {
  884. err = rows.Scan(&target, &nanotime)
  885. if mysql.logError("could not scan channel listings", err) {
  886. return
  887. }
  888. results = append(results, history.TargetListing{
  889. CfName: target,
  890. Time: time.Unix(0, nanotime),
  891. })
  892. }
  893. return
  894. }
  895. func (mysql *MySQL) Close() {
  896. // closing the database will close our prepared statements as well
  897. if mysql.db != nil {
  898. mysql.db.Close()
  899. }
  900. mysql.db = nil
  901. }
  902. // implements history.Sequence, emulating a single history buffer (for a channel,
  903. // a single user's DMs, or a DM conversation)
  904. type mySQLHistorySequence struct {
  905. mysql *MySQL
  906. target string
  907. correspondent string
  908. cutoff time.Time
  909. }
  910. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
  911. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  912. defer cancel()
  913. startTime := start.Time
  914. if start.Msgid != "" {
  915. startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
  916. if err != nil {
  917. if err == sql.ErrNoRows {
  918. return nil, nil
  919. } else {
  920. return nil, err
  921. }
  922. }
  923. }
  924. endTime := end.Time
  925. if end.Msgid != "" {
  926. endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
  927. if err != nil {
  928. if err == sql.ErrNoRows {
  929. return nil, nil
  930. } else {
  931. return nil, err
  932. }
  933. }
  934. }
  935. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  936. return results, err
  937. }
  938. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  939. return history.GenericAround(s, start, limit)
  940. }
  941. func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
  942. ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
  943. defer cancel()
  944. // TODO accept msgids here?
  945. startTime := start.Time
  946. endTime := end.Time
  947. results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
  948. seq.mysql.logError("could not read correspondents", err)
  949. return
  950. }
  951. func (seq *mySQLHistorySequence) Cutoff() time.Time {
  952. return seq.cutoff
  953. }
  954. func (seq *mySQLHistorySequence) Ephemeral() bool {
  955. return false
  956. }
  957. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
  958. return &mySQLHistorySequence{
  959. target: target,
  960. correspondent: correspondent,
  961. mysql: mysql,
  962. cutoff: cutoff,
  963. }
  964. }