您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

history.go 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. // Copyright (c) 2020 Shivaram Lingamneni
  2. // released under the MIT license
  3. package mysql
  4. import (
  5. "bytes"
  6. "context"
  7. "database/sql"
  8. "fmt"
  9. "runtime/debug"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. _ "github.com/go-sql-driver/mysql"
  14. "github.com/oragono/oragono/irc/history"
  15. "github.com/oragono/oragono/irc/logger"
  16. "github.com/oragono/oragono/irc/utils"
  17. )
  18. const (
  19. // maximum length in bytes of any message target (nickname or channel name) in its
  20. // canonicalized (i.e., casefolded) state:
  21. MaxTargetLength = 64
  22. // latest schema of the db
  23. latestDbSchema = "2"
  24. keySchemaVersion = "db.version"
  25. cleanupRowLimit = 50
  26. cleanupPauseTime = 10 * time.Minute
  27. )
  28. type MySQL struct {
  29. timeout int64
  30. db *sql.DB
  31. logger *logger.Manager
  32. insertHistory *sql.Stmt
  33. insertSequence *sql.Stmt
  34. insertConversation *sql.Stmt
  35. stateMutex sync.Mutex
  36. config Config
  37. }
  38. func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
  39. mysql.logger = logger
  40. mysql.SetConfig(config)
  41. }
  42. func (mysql *MySQL) SetConfig(config Config) {
  43. atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
  44. mysql.stateMutex.Lock()
  45. mysql.config = config
  46. mysql.stateMutex.Unlock()
  47. }
  48. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  49. mysql.stateMutex.Lock()
  50. expireTime = mysql.config.ExpireTime
  51. mysql.stateMutex.Unlock()
  52. return
  53. }
  54. func (m *MySQL) Open() (err error) {
  55. var address string
  56. if m.config.Port != 0 {
  57. address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
  58. }
  59. m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
  60. if err != nil {
  61. return err
  62. }
  63. err = m.fixSchemas()
  64. if err != nil {
  65. return err
  66. }
  67. err = m.prepareStatements()
  68. if err != nil {
  69. return err
  70. }
  71. go m.cleanupLoop()
  72. return nil
  73. }
  74. func (mysql *MySQL) fixSchemas() (err error) {
  75. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  76. key_name VARCHAR(32) primary key,
  77. value VARCHAR(32) NOT NULL
  78. ) CHARSET=ascii COLLATE=ascii_bin;`)
  79. if err != nil {
  80. return err
  81. }
  82. var schema string
  83. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  84. if err == sql.ErrNoRows {
  85. err = mysql.createTables()
  86. if err != nil {
  87. return
  88. }
  89. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  90. if err != nil {
  91. return
  92. }
  93. } else if err == nil && schema != latestDbSchema {
  94. // TODO figure out what to do about schema changes
  95. return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
  96. } else {
  97. return err
  98. }
  99. return nil
  100. }
  101. func (mysql *MySQL) createTables() (err error) {
  102. _, err = mysql.db.Exec(`CREATE TABLE history (
  103. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  104. data BLOB NOT NULL,
  105. msgid BINARY(16) NOT NULL,
  106. KEY (msgid(4))
  107. ) CHARSET=ascii COLLATE=ascii_bin;`)
  108. if err != nil {
  109. return err
  110. }
  111. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
  112. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  113. target VARBINARY(%[1]d) NOT NULL,
  114. nanotime BIGINT UNSIGNED NOT NULL,
  115. history_id BIGINT NOT NULL,
  116. KEY (target, nanotime),
  117. KEY (history_id)
  118. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  119. if err != nil {
  120. return err
  121. }
  122. _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
  123. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  124. target VARBINARY(%[1]d) NOT NULL,
  125. correspondent VARBINARY(%[1]d) NOT NULL,
  126. nanotime BIGINT UNSIGNED NOT NULL,
  127. history_id BIGINT NOT NULL,
  128. KEY (target, correspondent, nanotime),
  129. KEY (history_id)
  130. ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
  131. if err != nil {
  132. return err
  133. }
  134. return nil
  135. }
  136. func (mysql *MySQL) cleanupLoop() {
  137. defer func() {
  138. if r := recover(); r != nil {
  139. mysql.logger.Error("mysql",
  140. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  141. time.Sleep(cleanupPauseTime)
  142. go mysql.cleanupLoop()
  143. }
  144. }()
  145. for {
  146. expireTime := mysql.getExpireTime()
  147. if expireTime != 0 {
  148. for {
  149. startTime := time.Now()
  150. rowsDeleted, err := mysql.doCleanup(expireTime)
  151. elapsed := time.Now().Sub(startTime)
  152. mysql.logError("error during row cleanup", err)
  153. // keep going as long as we're accomplishing significant work
  154. // (don't busy-wait on small numbers of rows expiring):
  155. if rowsDeleted < (cleanupRowLimit / 10) {
  156. break
  157. }
  158. // crude backpressure mechanism: if the database is slow,
  159. // give it time to process other queries
  160. time.Sleep(elapsed)
  161. }
  162. }
  163. time.Sleep(cleanupPauseTime)
  164. }
  165. }
  166. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  167. ids, maxNanotime, err := mysql.selectCleanupIDs(age)
  168. if len(ids) == 0 {
  169. mysql.logger.Debug("mysql", "found no rows to clean up")
  170. return
  171. }
  172. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  173. // can't use ? binding for a variable number of arguments, build the IN clause manually
  174. var inBuf bytes.Buffer
  175. inBuf.WriteByte('(')
  176. for i, id := range ids {
  177. if i != 0 {
  178. inBuf.WriteRune(',')
  179. }
  180. fmt.Fprintf(&inBuf, "%d", id)
  181. }
  182. inBuf.WriteRune(')')
  183. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
  184. if err != nil {
  185. return
  186. }
  187. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
  188. if err != nil {
  189. return
  190. }
  191. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
  192. if err != nil {
  193. return
  194. }
  195. count = len(ids)
  196. return
  197. }
  198. func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  199. rows, err := mysql.db.Query(`
  200. SELECT history.id, sequence.nanotime
  201. FROM history
  202. LEFT JOIN sequence ON history.id = sequence.history_id
  203. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  204. if err != nil {
  205. return
  206. }
  207. defer rows.Close()
  208. // a history ID may have 0-2 rows in sequence: 1 for a channel entry,
  209. // 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
  210. // and delete anything that doesn't have a sequence entry:
  211. idset := make(map[uint64]struct{}, cleanupRowLimit)
  212. threshold := time.Now().Add(-age).UnixNano()
  213. for rows.Next() {
  214. var id uint64
  215. var nanotime sql.NullInt64
  216. err = rows.Scan(&id, &nanotime)
  217. if err != nil {
  218. return
  219. }
  220. if !nanotime.Valid || nanotime.Int64 < threshold {
  221. idset[id] = struct{}{}
  222. if nanotime.Valid && nanotime.Int64 > maxNanotime {
  223. maxNanotime = nanotime.Int64
  224. }
  225. }
  226. }
  227. ids = make([]uint64, len(idset))
  228. i := 0
  229. for id := range idset {
  230. ids[i] = id
  231. i++
  232. }
  233. return
  234. }
  235. func (mysql *MySQL) prepareStatements() (err error) {
  236. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  237. (data, msgid) VALUES (?, ?);`)
  238. if err != nil {
  239. return
  240. }
  241. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  242. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  243. if err != nil {
  244. return
  245. }
  246. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  247. (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  248. if err != nil {
  249. return
  250. }
  251. return
  252. }
  253. func (mysql *MySQL) getTimeout() time.Duration {
  254. return time.Duration(atomic.LoadInt64(&mysql.timeout))
  255. }
  256. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  257. if err != nil {
  258. mysql.logger.Error("mysql", context, err.Error())
  259. return true
  260. }
  261. return false
  262. }
  263. func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
  264. if mysql.db == nil {
  265. return
  266. }
  267. if target == "" {
  268. return utils.ErrInvalidParams
  269. }
  270. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  271. defer cancel()
  272. id, err := mysql.insertBase(ctx, item)
  273. if err != nil {
  274. return
  275. }
  276. err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
  277. return
  278. }
  279. func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
  280. _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
  281. mysql.logError("could not insert sequence entry", err)
  282. return
  283. }
  284. func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
  285. _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
  286. mysql.logError("could not insert conversations entry", err)
  287. return
  288. }
  289. func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
  290. value, err := marshalItem(&item)
  291. if mysql.logError("could not marshal item", err) {
  292. return
  293. }
  294. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  295. if mysql.logError("could not decode msgid", err) {
  296. return
  297. }
  298. result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
  299. if mysql.logError("could not insert item", err) {
  300. return
  301. }
  302. id, err = result.LastInsertId()
  303. if mysql.logError("could not insert item", err) {
  304. return
  305. }
  306. return
  307. }
  308. func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
  309. if mysql.db == nil {
  310. return
  311. }
  312. if senderAccount == "" && recipientAccount == "" {
  313. return
  314. }
  315. if sender == "" || recipient == "" {
  316. return utils.ErrInvalidParams
  317. }
  318. ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
  319. defer cancel()
  320. id, err := mysql.insertBase(ctx, item)
  321. if err != nil {
  322. return
  323. }
  324. nanotime := item.Message.Time.UnixNano()
  325. if senderAccount != "" {
  326. err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id)
  327. if err != nil {
  328. return
  329. }
  330. err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
  331. if err != nil {
  332. return
  333. }
  334. }
  335. if recipientAccount != "" && sender != recipient {
  336. err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id)
  337. if err != nil {
  338. return
  339. }
  340. err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
  341. if err != nil {
  342. return
  343. }
  344. }
  345. return
  346. }
  347. func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
  348. // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
  349. // sequence.nanotime > (
  350. // SELECT sequence.nanotime FROM sequence, history
  351. // WHERE sequence.history_id = history.id AND history.msgid = ?
  352. // LIMIT 1)
  353. // however, this doesn't handle the BETWEEN case with one or two msgids, where we
  354. // don't initially know whether the interval is going forwards or backwards. to simplify
  355. // the logic, resolve msgids to timestamps "manually" in all cases, using a separate query.
  356. decoded, err := decodeMsgid(msgid)
  357. if err != nil {
  358. return
  359. }
  360. row := mysql.db.QueryRowContext(ctx, `
  361. SELECT sequence.nanotime FROM sequence
  362. INNER JOIN history ON history.id = sequence.history_id
  363. WHERE history.msgid = ? LIMIT 1;`, decoded)
  364. var nanotime int64
  365. err = row.Scan(&nanotime)
  366. if mysql.logError("could not resolve msgid to time", err) {
  367. return
  368. }
  369. result = time.Unix(0, nanotime).UTC()
  370. return
  371. }
  372. func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
  373. rows, err := mysql.db.QueryContext(ctx, query, args...)
  374. if mysql.logError("could not select history items", err) {
  375. return
  376. }
  377. defer rows.Close()
  378. for rows.Next() {
  379. var blob []byte
  380. var item history.Item
  381. err = rows.Scan(&blob)
  382. if mysql.logError("could not scan history item", err) {
  383. return
  384. }
  385. err = unmarshalItem(blob, &item)
  386. if mysql.logError("could not unmarshal history item", err) {
  387. return
  388. }
  389. results = append(results, item)
  390. }
  391. return
  392. }
  393. func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  394. useSequence := correspondent == ""
  395. table := "sequence"
  396. if !useSequence {
  397. table = "conversations"
  398. }
  399. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  400. direction := "ASC"
  401. if !ascending {
  402. direction = "DESC"
  403. }
  404. var queryBuf bytes.Buffer
  405. args := make([]interface{}, 0, 6)
  406. fmt.Fprintf(&queryBuf,
  407. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  408. if useSequence {
  409. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  410. args = append(args, target)
  411. } else {
  412. fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
  413. args = append(args, target)
  414. args = append(args, correspondent)
  415. }
  416. if !after.IsZero() {
  417. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  418. args = append(args, after.UnixNano())
  419. }
  420. if !before.IsZero() {
  421. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  422. args = append(args, before.UnixNano())
  423. }
  424. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  425. args = append(args, limit)
  426. results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
  427. if err == nil && !ascending {
  428. history.Reverse(results)
  429. }
  430. return
  431. }
  432. func (mysql *MySQL) Close() {
  433. // closing the database will close our prepared statements as well
  434. if mysql.db != nil {
  435. mysql.db.Close()
  436. }
  437. mysql.db = nil
  438. }
  439. // implements history.Sequence, emulating a single history buffer (for a channel,
  440. // a single user's DMs, or a DM conversation)
  441. type mySQLHistorySequence struct {
  442. mysql *MySQL
  443. target string
  444. correspondent string
  445. cutoff time.Time
  446. }
  447. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
  448. ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
  449. defer cancel()
  450. startTime := start.Time
  451. if start.Msgid != "" {
  452. startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
  453. if err != nil {
  454. return nil, false, err
  455. }
  456. }
  457. endTime := end.Time
  458. if end.Msgid != "" {
  459. endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
  460. if err != nil {
  461. return nil, false, err
  462. }
  463. }
  464. results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
  465. return results, (err == nil), err
  466. }
  467. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  468. return history.GenericAround(s, start, limit)
  469. }
  470. func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
  471. return &mySQLHistorySequence{
  472. target: target,
  473. correspondent: correspondent,
  474. mysql: mysql,
  475. cutoff: cutoff,
  476. }
  477. }