You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

history.go 13KB


  1. package mysql
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "runtime/debug"
  7. "strconv"
  8. "sync"
  9. "time"
  10. _ "github.com/go-sql-driver/mysql"
  11. "github.com/oragono/oragono/irc/history"
  12. "github.com/oragono/oragono/irc/logger"
  13. "github.com/oragono/oragono/irc/utils"
  14. )
  15. const (
  16. // latest schema of the db
  17. latestDbSchema = "1"
  18. keySchemaVersion = "db.version"
  19. cleanupRowLimit = 50
  20. cleanupPauseTime = 10 * time.Minute
  21. )
  22. type MySQL struct {
  23. db *sql.DB
  24. logger *logger.Manager
  25. insertHistory *sql.Stmt
  26. insertSequence *sql.Stmt
  27. insertConversation *sql.Stmt
  28. stateMutex sync.Mutex
  29. expireTime time.Duration
  30. }
  31. func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
  32. mysql.logger = logger
  33. mysql.expireTime = expireTime
  34. }
  35. func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
  36. mysql.stateMutex.Lock()
  37. mysql.expireTime = expireTime
  38. mysql.stateMutex.Unlock()
  39. }
  40. func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
  41. mysql.stateMutex.Lock()
  42. expireTime = mysql.expireTime
  43. mysql.stateMutex.Unlock()
  44. return
  45. }
  46. func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
  47. // TODO: timeouts!
  48. var address string
  49. if port != 0 {
  50. address = fmt.Sprintf("tcp(%s:%d)", host, port)
  51. }
  52. mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
  53. if err != nil {
  54. return err
  55. }
  56. err = mysql.fixSchemas()
  57. if err != nil {
  58. return err
  59. }
  60. err = mysql.prepareStatements()
  61. if err != nil {
  62. return err
  63. }
  64. go mysql.cleanupLoop()
  65. return nil
  66. }
  67. func (mysql *MySQL) fixSchemas() (err error) {
  68. _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
  69. key_name VARCHAR(32) primary key,
  70. value VARCHAR(32) NOT NULL
  71. ) CHARSET=ascii COLLATE=ascii_bin;`)
  72. if err != nil {
  73. return err
  74. }
  75. var schema string
  76. err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
  77. if err == sql.ErrNoRows {
  78. err = mysql.createTables()
  79. if err != nil {
  80. return
  81. }
  82. _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
  83. if err != nil {
  84. return
  85. }
  86. } else if err == nil && schema != latestDbSchema {
  87. // TODO figure out what to do about schema changes
  88. return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
  89. } else {
  90. return err
  91. }
  92. return nil
  93. }
  94. func (mysql *MySQL) createTables() (err error) {
  95. _, err = mysql.db.Exec(`CREATE TABLE history (
  96. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  97. data BLOB NOT NULL,
  98. msgid BINARY(16) NOT NULL,
  99. KEY (msgid(4))
  100. ) CHARSET=ascii COLLATE=ascii_bin;`)
  101. if err != nil {
  102. return err
  103. }
  104. _, err = mysql.db.Exec(`CREATE TABLE sequence (
  105. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  106. target VARBINARY(64) NOT NULL,
  107. nanotime BIGINT UNSIGNED NOT NULL,
  108. history_id BIGINT NOT NULL,
  109. KEY (target, nanotime),
  110. KEY (history_id)
  111. ) CHARSET=ascii COLLATE=ascii_bin;`)
  112. if err != nil {
  113. return err
  114. }
  115. _, err = mysql.db.Exec(`CREATE TABLE conversations (
  116. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
  117. lower_target VARBINARY(64) NOT NULL,
  118. upper_target VARBINARY(64) NOT NULL,
  119. nanotime BIGINT UNSIGNED NOT NULL,
  120. history_id BIGINT NOT NULL,
  121. KEY (lower_target, upper_target, nanotime),
  122. KEY (history_id)
  123. ) CHARSET=ascii COLLATE=ascii_bin;`)
  124. if err != nil {
  125. return err
  126. }
  127. return nil
  128. }
  129. func (mysql *MySQL) cleanupLoop() {
  130. defer func() {
  131. if r := recover(); r != nil {
  132. mysql.logger.Error("mysql",
  133. fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
  134. time.Sleep(cleanupPauseTime)
  135. go mysql.cleanupLoop()
  136. }
  137. }()
  138. for {
  139. expireTime := mysql.getExpireTime()
  140. if expireTime != 0 {
  141. for {
  142. startTime := time.Now()
  143. rowsDeleted, err := mysql.doCleanup(expireTime)
  144. elapsed := time.Now().Sub(startTime)
  145. mysql.logError("error during row cleanup", err)
  146. // keep going as long as we're accomplishing significant work
  147. // (don't busy-wait on small numbers of rows expiring):
  148. if rowsDeleted < (cleanupRowLimit / 10) {
  149. break
  150. }
  151. // crude backpressure mechanism: if the database is slow,
  152. // give it time to process other queries
  153. time.Sleep(elapsed)
  154. }
  155. }
  156. time.Sleep(cleanupPauseTime)
  157. }
  158. }
  159. func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
  160. ids, maxNanotime, err := mysql.selectCleanupIDs(age)
  161. if len(ids) == 0 {
  162. mysql.logger.Debug("mysql", "found no rows to clean up")
  163. return
  164. }
  165. mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
  166. // can't use ? binding for a variable number of arguments, build the IN clause manually
  167. var inBuf bytes.Buffer
  168. inBuf.WriteByte('(')
  169. for i, id := range ids {
  170. if i != 0 {
  171. inBuf.WriteRune(',')
  172. }
  173. inBuf.WriteString(strconv.FormatInt(int64(id), 10))
  174. }
  175. inBuf.WriteRune(')')
  176. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
  177. if err != nil {
  178. return
  179. }
  180. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
  181. if err != nil {
  182. return
  183. }
  184. _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
  185. if err != nil {
  186. return
  187. }
  188. count = len(ids)
  189. return
  190. }
  191. func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
  192. rows, err := mysql.db.Query(`
  193. SELECT history.id, sequence.nanotime
  194. FROM history
  195. LEFT JOIN sequence ON history.id = sequence.history_id
  196. ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
  197. if err != nil {
  198. return
  199. }
  200. defer rows.Close()
  201. // a history ID may have 0-2 rows in sequence: 1 for a channel entry,
  202. // 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
  203. // and delete anything that doesn't have a sequence entry:
  204. idset := make(map[uint64]struct{}, cleanupRowLimit)
  205. threshold := time.Now().Add(-age).UnixNano()
  206. for rows.Next() {
  207. var id uint64
  208. var nanotime sql.NullInt64
  209. err = rows.Scan(&id, &nanotime)
  210. if err != nil {
  211. return
  212. }
  213. if !nanotime.Valid || nanotime.Int64 < threshold {
  214. idset[id] = struct{}{}
  215. if nanotime.Valid && nanotime.Int64 > maxNanotime {
  216. maxNanotime = nanotime.Int64
  217. }
  218. }
  219. }
  220. ids = make([]uint64, len(idset))
  221. i := 0
  222. for id := range idset {
  223. ids[i] = id
  224. i++
  225. }
  226. return
  227. }
  228. func (mysql *MySQL) prepareStatements() (err error) {
  229. mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
  230. (data, msgid) VALUES (?, ?);`)
  231. if err != nil {
  232. return
  233. }
  234. mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
  235. (target, nanotime, history_id) VALUES (?, ?, ?);`)
  236. if err != nil {
  237. return
  238. }
  239. mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
  240. (lower_target, upper_target, nanotime, history_id) VALUES (?, ?, ?, ?);`)
  241. if err != nil {
  242. return
  243. }
  244. return
  245. }
  246. func (mysql *MySQL) logError(context string, err error) (quit bool) {
  247. if err != nil {
  248. mysql.logger.Error("mysql", context, err.Error())
  249. return true
  250. }
  251. return false
  252. }
  253. func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
  254. if mysql.db == nil {
  255. return
  256. }
  257. if target == "" {
  258. return utils.ErrInvalidParams
  259. }
  260. id, err := mysql.insertBase(item)
  261. if err != nil {
  262. return
  263. }
  264. err = mysql.insertSequenceEntry(target, item.Message.Time, id)
  265. return
  266. }
  267. func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
  268. _, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
  269. mysql.logError("could not insert sequence entry", err)
  270. return
  271. }
  272. func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
  273. lower, higher := stringMinMax(sender, recipient)
  274. _, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
  275. mysql.logError("could not insert conversations entry", err)
  276. return
  277. }
  278. func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
  279. value, err := marshalItem(&item)
  280. if mysql.logError("could not marshal item", err) {
  281. return
  282. }
  283. msgidBytes, err := decodeMsgid(item.Message.Msgid)
  284. if mysql.logError("could not decode msgid", err) {
  285. return
  286. }
  287. result, err := mysql.insertHistory.Exec(value, msgidBytes)
  288. if mysql.logError("could not insert item", err) {
  289. return
  290. }
  291. id, err = result.LastInsertId()
  292. if mysql.logError("could not insert item", err) {
  293. return
  294. }
  295. return
  296. }
  297. func stringMinMax(first, second string) (min, max string) {
  298. if first < second {
  299. return first, second
  300. } else {
  301. return second, first
  302. }
  303. }
  304. func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent, recipientPersistent bool, item history.Item) (err error) {
  305. if mysql.db == nil {
  306. return
  307. }
  308. if !(senderPersistent || recipientPersistent) {
  309. return
  310. }
  311. if sender == "" || recipient == "" {
  312. return utils.ErrInvalidParams
  313. }
  314. id, err := mysql.insertBase(item)
  315. if err != nil {
  316. return
  317. }
  318. if senderPersistent {
  319. mysql.insertSequenceEntry(sender, item.Message.Time, id)
  320. if err != nil {
  321. return
  322. }
  323. }
  324. if recipientPersistent && sender != recipient {
  325. err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
  326. if err != nil {
  327. return
  328. }
  329. }
  330. err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
  331. return
  332. }
  333. func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
  334. // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
  335. // sequence.nanotime > (
  336. // SELECT sequence.nanotime FROM sequence, history
  337. // WHERE sequence.history_id = history.id AND history.msgid = ?
  338. // LIMIT 1)
  339. // however, this doesn't handle the BETWEEN case with one or two msgids, where we
  340. // don't initially know whether the interval is going forwards or backwards. to simplify
  341. // the logic, resolve msgids to timestamps "manually" in all cases, using a separate query.
  342. decoded, err := decodeMsgid(msgid)
  343. if err != nil {
  344. return
  345. }
  346. row := mysql.db.QueryRow(`
  347. SELECT sequence.nanotime FROM sequence
  348. INNER JOIN history ON history.id = sequence.history_id
  349. WHERE history.msgid = ? LIMIT 1;`, decoded)
  350. var nanotime int64
  351. err = row.Scan(&nanotime)
  352. if mysql.logError("could not resolve msgid to time", err) {
  353. return
  354. }
  355. result = time.Unix(0, nanotime)
  356. return
  357. }
  358. func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
  359. rows, err := mysql.db.Query(query, args...)
  360. if mysql.logError("could not select history items", err) {
  361. return
  362. }
  363. defer rows.Close()
  364. for rows.Next() {
  365. var blob []byte
  366. var item history.Item
  367. err = rows.Scan(&blob)
  368. if mysql.logError("could not scan history item", err) {
  369. return
  370. }
  371. err = unmarshalItem(blob, &item)
  372. if mysql.logError("could not unmarshal history item", err) {
  373. return
  374. }
  375. results = append(results, item)
  376. }
  377. return
  378. }
  379. func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
  380. useSequence := true
  381. var lowerTarget, upperTarget string
  382. if sender != "" {
  383. lowerTarget, upperTarget = stringMinMax(sender, recipient)
  384. useSequence = false
  385. }
  386. table := "sequence"
  387. if !useSequence {
  388. table = "conversations"
  389. }
  390. after, before, ascending := history.MinMaxAsc(after, before, cutoff)
  391. direction := "ASC"
  392. if !ascending {
  393. direction = "DESC"
  394. }
  395. var queryBuf bytes.Buffer
  396. args := make([]interface{}, 0, 6)
  397. fmt.Fprintf(&queryBuf,
  398. "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
  399. if useSequence {
  400. fmt.Fprintf(&queryBuf, " sequence.target = ?")
  401. args = append(args, recipient)
  402. } else {
  403. fmt.Fprintf(&queryBuf, " conversations.lower_target = ? AND conversations.upper_target = ?")
  404. args = append(args, lowerTarget)
  405. args = append(args, upperTarget)
  406. }
  407. if !after.IsZero() {
  408. fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
  409. args = append(args, after.UnixNano())
  410. }
  411. if !before.IsZero() {
  412. fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
  413. args = append(args, before.UnixNano())
  414. }
  415. fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
  416. args = append(args, limit)
  417. results, err = mysql.selectItems(queryBuf.String(), args...)
  418. if err == nil && !ascending {
  419. history.Reverse(results)
  420. }
  421. return
  422. }
  423. func (mysql *MySQL) Close() {
  424. // closing the database will close our prepared statements as well
  425. if mysql.db != nil {
  426. mysql.db.Close()
  427. }
  428. mysql.db = nil
  429. }
  430. // implements history.Sequence, emulating a single history buffer (for a channel,
  431. // a single user's DMs, or a DM conversation)
  432. type mySQLHistorySequence struct {
  433. mysql *MySQL
  434. sender string
  435. recipient string
  436. cutoff time.Time
  437. }
  438. func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
  439. startTime := start.Time
  440. if start.Msgid != "" {
  441. startTime, err = s.mysql.msgidToTime(start.Msgid)
  442. if err != nil {
  443. return nil, false, err
  444. }
  445. }
  446. endTime := end.Time
  447. if end.Msgid != "" {
  448. endTime, err = s.mysql.msgidToTime(end.Msgid)
  449. if err != nil {
  450. return nil, false, err
  451. }
  452. }
  453. results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
  454. return results, (err == nil), err
  455. }
  456. func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
  457. return history.GenericAround(s, start, limit)
  458. }
  459. func (mysql *MySQL) MakeSequence(sender, recipient string, cutoff time.Time) history.Sequence {
  460. return &mySQLHistorySequence{
  461. sender: sender,
  462. recipient: recipient,
  463. mysql: mysql,
  464. cutoff: cutoff,
  465. }
  466. }