Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

history.go 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "bytes"
  6. "context"
  7. "database/sql"
  8. "fmt"
  9. "runtime/debug"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. _ "github.com/go-sql-driver/mysql"
  14. "github.com/oragono/oragono/irc/history"
  15. "github.com/oragono/oragono/irc/logger"
  16. "github.com/oragono/oragono/irc/utils"
  17. )
  18. const (
  19. // maximum length in bytes of any message target (nickname or channel name) in its
  20. // canonicalized (i.e., casefolded) state:
  21. MaxTargetLength = 64
  22. // latest schema of the db
  23. latestDbSchema = "1"
  24. keySchemaVersion = "db.version"
  25. cleanupRowLimit = 50
  26. cleanupPauseTime = 10 * time.Minute
  27. )
  28. type MySQL struct {
  29. timeout int64
  30. db *sql.DB
  31. logger *logger.Manager
  32. insertHistory *sql.Stmt
  33. insertSequence *sql.Stmt
  34. insertConversation *sql.Stmt
  35. stateMutex sync.Mutex
  36. config Config
  37. }
  38. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  39. mysql.logger = logger
  40. mysql.SetConfig(config)
  41. }
  42. func (mysql *MySQL) SetConfig(config Config) {
  43. atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
  44. mysql.stateMutex.Lock()
  45. mysql.config = config
  46. mysql.stateMutex.Unlock()
  47. }
  48. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  49. mysql.stateMutex.Lock()
  50. expireTime = mysql.config.ExpireTime
  51. mysql.stateMutex.Unlock()
  52. return
  53. }
  54. func (m *MySQL) Open() (err error) {
  55. var address string
  56. if m.config.Port != 0 {
  57. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  58. }
  59. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  60. if err != nil {
  61. return err
  62. }
  63. err = m.fixSchemas()
  64. if err != nil {
  65. return err
  66. }
  67. err = m.prepareStatements()
  68. if err != nil {
  69. return err
  70. }
  71. go m.cleanupLoop()
  72. return nil
  73. }
  74. func (mysql *MySQL) fixSchemas() (err error) {
  75. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  76. key_name VARCHAR(32) primary key,
  77. value VARCHAR(32) NOT NULL
  78. ) CHARSET=ascii COLLATE=ascii_bin;`)
  79. if err != nil {
  80. return err
  81. }
  82. var schema string
  83. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  84. if err == sql.ErrNoRows {
  85. err = mysql.createTables()
  86. if err != nil {
  87. return
  88. }
  89. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  90. if err != nil {
  91. return
  92. }
  93. } else if err == nil && schema != latestDbSchema {
  94. // TODO figure out what to do about schema changes
  95. return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
  96. } else {
  97. return err
  98. }
  99. return nil
  100. }
  101. func (mysql *MySQL) createTables() (err error) {
  102. _, err = mysql.db.Exec(`CREATE TABLE history (
  103. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  104. data BLOB NOT NULL,
  105. msgid BINARY(16) NOT NULL,
  106. KEY (msgid(4))
  107. ) CHARSET=ascii COLLATE=ascii_bin;`)
  108. if err != nil {
  109. return err
  110. }
  111. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  112. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  113. target VARBINARY(%[1]d) NOT NULL,
  114. nanotime BIGINT UNSIGNED NOT NULL,
  115. history_id BIGINT NOT NULL,
  116. KEY (target, nanotime),
  117. KEY (history_id)
  118. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  119. if err != nil {
  120. return err
  121. }
  122. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  123. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  124. lower_target VARBINARY(%[1]d) NOT NULL,
  125. upper_target VARBINARY(%[1]d) NOT NULL,
  126. nanotime BIGINT UNSIGNED NOT NULL,
  127. history_id BIGINT NOT NULL,
  128. KEY (lower_target, upper_target, nanotime),
  129. KEY (history_id)
  130. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  131. if err != nil {
  132. return err
  133. }
  134. return nil
  135. }
  136. func (mysql *MySQL) cleanupLoop() {
  137. defer func() {
  138. if r := recover(); r != nil {
  139. mysql.logger.Error("mysql",
  140. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  141. time.Sleep(cleanupPauseTime)
  142. go mysql.cleanupLoop()
  143. }
  144. }()
  145. for {
  146. expireTime := mysql.getExpireTime()
  147. if expireTime != 0 {
  148. for {
  149. startTime := time.Now()
  150. rowsDeleted, err := mysql.doCleanup(expireTime)
  151. elapsed := time.Now().Sub(startTime)
  152. mysql.logError("error during row cleanup", err)
  153. // keep going as long as we're accomplishing significant work
  154. // (don't busy-wait on small numbers of rows expiring):
  155. if rowsDeleted < (cleanupRowLimit / 10) {
  156. break
  157. }
  158. // crude backpressure mechanism: if the database is slow,
  159. // give it time to process other queries
  160. time.Sleep(elapsed)
  161. }
  162. }
  163. time.Sleep(cleanupPauseTime)
  164. }
  165. }
  166. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  167. ids, maxNanotime, err := mysql.selectCleanupIDs(age)
  168. if len(ids) == 0 {
  169. mysql.logger.Debug("mysql", "found no rows to clean up")
  170. return
  171. }
  172. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  173. // can't use ? binding for a variable number of arguments, build the IN clause manually
  174. var inBuf bytes.Buffer
  175. inBuf.WriteByte('(')
  176. for i, id := range ids {
  177. if i != 0 {
  178. inBuf.WriteRune(',')
  179. }
  180. fmt.Fprintf(&inBuf, "%d", id)
  181. }
  182. inBuf.WriteRune(')')
  183. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
  184. if err != nil {
  185. return
  186. }
  187. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
  188. if err != nil {
  189. return
  190. }
  191. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
  192. if err != nil {
  193. return
  194. }
  195. count = len(ids)
  196. return
  197. }
  198. func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  199. rows, err := mysql.db.Query(`
  200. SELECT history.id, sequence.nanotime
  201. FROM history
  202. LEFT JOIN sequence ON history.id = sequence.history_id
  203. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  204. if err != nil {
  205. return
  206. }
  207. defer rows.Close()
  208. // a history ID may have 0-2 rows in sequence: 1 for a channel entry,
  209. // 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
  210. // and delete anything that doesn't have a sequence entry:
  211. idset := make(map[uint64]struct{}, cleanupRowLimit)
  212. threshold := time.Now().Add(-age).UnixNano()
  213. for rows.Next() {
  214. var id uint64
  215. var nanotime sql.NullInt64
  216. err = rows.Scan(&id, &nanotime)
  217. if err != nil {
  218. return
  219. }
  220. if !nanotime.Valid || nanotime.Int64 < threshold {
  221. idset[id] = struct{}{}
  222. if nanotime.Valid && nanotime.Int64 > maxNanotime {
  223. maxNanotime = nanotime.Int64
  224. }
  225. }
  226. }
  227. ids = make([]uint64, len(idset))
  228. i := 0
  229. for id := range idset {
  230. ids[i] = id
  231. i++
  232. }
  233. return
  234. }
  235. func (mysql *MySQL) prepareStatements() (err error) {
  236. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  237. (data, msgid) VALUES (?, ?);`)
  238. if err != nil {
  239. return
  240. }
  241. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  242. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  243. if err != nil {
  244. return
  245. }
  246. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  247. (lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  248. if err != nil {
  249. return
  250. }
  251. return
  252. }
  253. func (mysql *MySQL) getTimeout() time.Duration {
  254. return time.Duration(atomic.LoadInt64(&mysql.timeout))
  255. }
  256. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  257. if err != nil {
  258. mysql.logger.Error("mysql", context, err.Error())
  259. return true
  260. }
  261. return false
  262. }
  263. func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
  264. if mysql.db == nil {
  265. return
  266. }
  267. if target == "" {
  268. return utils.ErrInvalidParams
  269. }
  270. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  271. defer cancel()
  272. id, err := mysql.insertBase(ctx, item)
  273. if err != nil {
  274. return
  275. }
  276. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
  277. return
  278. }
  279. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
  280. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
  281. mysql.logError("could not insert sequence entry", err)
  282. return
  283. }
  284. func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
  285. lower, higher := stringMinMax(sender, recipient)
  286. _, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
  287. mysql.logError("could not insert conversations entry", err)
  288. return
  289. }
  290. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  291. value, err := marshalItem(&item)
  292. if mysql.logError("could not marshal item", err) {
  293. return
  294. }
  295. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  296. if mysql.logError("could not decode msgid", err) {
  297. return
  298. }
  299. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  300. if mysql.logError("could not insert item", err) {
  301. return
  302. }
  303. id, err = result.LastInsertId()
  304. if mysql.logError("could not insert item", err) {
  305. return
  306. }
  307. return
  308. }
  309. func stringMinMax(first, second string) (min, max string) {
  310. if first < second {
  311. return first, second
  312. } else {
  313. return second, first
  314. }
  315. }
  316. func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) {
  317. if mysql.db == nil {
  318. return
  319. }
  320. if !(senderPersistent || recipientPersistent) {
  321. return
  322. }
  323. if sender == "" || recipient == "" {
  324. return utils.ErrInvalidParams
  325. }
  326. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  327. defer cancel()
  328. id, err := mysql.insertBase(ctx, item)
  329. if err != nil {
  330. return
  331. }
  332. if senderPersistent {
  333. mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
  334. if err != nil {
  335. return
  336. }
  337. }
  338. if recipientPersistent && sender != recipient {
  339. err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
  340. if err != nil {
  341. return
  342. }
  343. }
  344. err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
  345. return
  346. }
  347. func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
  348. // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
  349. // sequence.nanotime > (
  350. // SELECT sequence.nanotime FROM sequence, history
  351. // WHERE sequence.history_id = history.id AND history.msgid = ?
  352. // LIMIT 1)
  353. // however, this doesn't handle the BETWEEN case with one or two msgids, where we
  354. // don't initially know whether the interval is going forwards or backwards. to simplify
  355. // the logic, resolve msgids to timestamps "manually" in all cases, using a separate query.
  356. decoded, err := decodeMsgid(msgid)
  357. if err != nil {
  358. return
  359. }
  360. row := mysql.db.QueryRowContext(ctx, `
  361. SELECT sequence.nanotime FROM sequence
  362. INNER JOIN history ON history.id = sequence.history_id
  363. WHERE history.msgid = ? LIMIT 1;`, decoded)
  364. var nanotime int64
  365. err = row.Scan(&nanotime)
  366. if mysql.logError("could not resolve msgid to time", err) {
  367. return
  368. }
  369. result = time.Unix(0, nanotime).UTC()
  370. return
  371. }
  372. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  373. rows, err := mysql.db.QueryContext(ctx, query, args...)
  374. if mysql.logError("could not select history items", err) {
  375. return
  376. }
  377. defer rows.Close()
  378. for rows.Next() {
  379. var blob []byte
  380. var item history.Item
  381. err = rows.Scan(&blob)
  382. if mysql.logError("could not scan history item", err) {
  383. return
  384. }
  385. err = unmarshalItem(blob, &item)
  386. if mysql.logError("could not unmarshal history item", err) {
  387. return
  388. }
  389. results = append(results, item)
  390. }
  391. return
  392. }
  393. func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  394. useSequence := true
  395. var lowerTarget, upperTarget string
  396. if sender != "" {
  397. lowerTarget, upperTarget = stringMinMax(sender, recipient)
  398. useSequence = false
  399. }
  400. table := "sequence"
  401. if !useSequence {
  402. table = "conversations"
  403. }
  404. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  405. direction := "ASC"
  406. if !ascending {
  407. direction = "DESC"
  408. }
  409. var queryBuf bytes.Buffer
  410. args := make([]interface{}, 0, 6)
  411. fmt.Fprintf(&queryBuf,
  412. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  413. if useSequence {
  414. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  415. args = append(args, recipient)
  416. } else {
  417. fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?")
  418. args = append(args, lowerTarget)
  419. args = append(args, upperTarget)
  420. }
  421. if !after.IsZero() {
  422. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  423. args = append(args, after.UnixNano())
  424. }
  425. if !before.IsZero() {
  426. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  427. args = append(args, before.UnixNano())
  428. }
  429. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  430. args = append(args, limit)
  431. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  432. if err == nil && !ascending {
  433. history.Reverse(results)
  434. }
  435. return
  436. }
  437. func (mysql *MySQL) Close() {
  438. // closing the database will close our prepared statements as well
  439. if mysql.db != nil {
  440. mysql.db.Close()
  441. }
  442. mysql.db = nil
  443. }
  444. // implements history.Sequence, emulating a single history buffer (for a channel,
  445. // a single user's DMs, or a DM conversation)
  446. type mySQLHistorySequence struct {
  447. mysql *MySQL
  448. sender string
  449. recipient string
  450. cutoff time.Time
  451. }
  452. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
  453. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  454. defer cancel()
  455. startTime := start.Time
  456. if start.Msgid != "" {
  457. startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
  458. if err != nil {
  459. return nil, false, err
  460. }
  461. }
  462. endTime := end.Time
  463. if end.Msgid != "" {
  464. endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
  465. if err != nil {
  466. return nil, false, err
  467. }
  468. }
  469. results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
  470. return results, (err == nil), err
  471. }
  472. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  473. return history.GenericAround(s, start, limit)
  474. }
  475. func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence {
  476. return &mySQLHistorySequence{
  477. sender: sender,
  478. recipient: recipient,
  479. mysql: mysql,
  480. cutoff: cutoff,
  481. }
  482. }