123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- // Copyright (c) 2020 Shivaram Lingamneni
- // released under the MIT license
-
- package mysql
-
- import (
- "bytes"
- "context"
- "database/sql"
- "fmt"
- "runtime/debug"
- "sync"
- "sync/atomic"
- "time"
-
- _ "github.com/go-sql-driver/mysql"
- "github.com/oragono/oragono/irc/history"
- "github.com/oragono/oragono/irc/logger"
- "github.com/oragono/oragono/irc/utils"
- )
-
- const (
- // maximum length in bytes of any message target (nickname or channel name) in its
- // canonicalized (i.e., casefolded) state:
- MaxTargetLength = 64
-
- // latest schema of the db
- latestDbSchema = "2"
- keySchemaVersion = "db.version"
- cleanupRowLimit = 50
- cleanupPauseTime = 10 * time.Minute
- )
-
- type MySQL struct {
- timeout int64
- db *sql.DB
- logger *logger.Manager
-
- insertHistory *sql.Stmt
- insertSequence *sql.Stmt
- insertConversation *sql.Stmt
-
- stateMutex sync.Mutex
- config Config
- }
-
- func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
- mysql.logger = logger
- mysql.SetConfig(config)
- }
-
- func (mysql *MySQL) SetConfig(config Config) {
- atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
- mysql.stateMutex.Lock()
- mysql.config = config
- mysql.stateMutex.Unlock()
- }
-
- func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
- mysql.stateMutex.Lock()
- expireTime = mysql.config.ExpireTime
- mysql.stateMutex.Unlock()
- return
- }
-
- func (m *MySQL) Open() (err error) {
- var address string
- if m.config.Port != 0 {
- address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
- }
-
- m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
- if err != nil {
- return err
- }
-
- err = m.fixSchemas()
- if err != nil {
- return err
- }
-
- err = m.prepareStatements()
- if err != nil {
- return err
- }
-
- go m.cleanupLoop()
-
- return nil
- }
-
- func (mysql *MySQL) fixSchemas() (err error) {
- _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
- key_name VARCHAR(32) primary key,
- value VARCHAR(32) NOT NULL
- ) CHARSET=ascii COLLATE=ascii_bin;`)
- if err != nil {
- return err
- }
-
- var schema string
- err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
- if err == sql.ErrNoRows {
- err = mysql.createTables()
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
- if err != nil {
- return
- }
- } else if err == nil && schema != latestDbSchema {
- // TODO figure out what to do about schema changes
- return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
- } else {
- return err
- }
-
- return nil
- }
-
- func (mysql *MySQL) createTables() (err error) {
- _, err = mysql.db.Exec(`CREATE TABLE history (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- data BLOB NOT NULL,
- msgid BINARY(16) NOT NULL,
- KEY (msgid(4))
- ) CHARSET=ascii COLLATE=ascii_bin;`)
- if err != nil {
- return err
- }
-
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- history_id BIGINT NOT NULL,
- KEY (target, nanotime),
- KEY (history_id)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
-
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- correspondent VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- history_id BIGINT NOT NULL,
- KEY (target, correspondent, nanotime),
- KEY (history_id)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
-
- return nil
- }
-
- func (mysql *MySQL) cleanupLoop() {
- defer func() {
- if r := recover(); r != nil {
- mysql.logger.Error("mysql",
- fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
- time.Sleep(cleanupPauseTime)
- go mysql.cleanupLoop()
- }
- }()
-
- for {
- expireTime := mysql.getExpireTime()
- if expireTime != 0 {
- for {
- startTime := time.Now()
- rowsDeleted, err := mysql.doCleanup(expireTime)
- elapsed := time.Now().Sub(startTime)
- mysql.logError("error during row cleanup", err)
- // keep going as long as we're accomplishing significant work
- // (don't busy-wait on small numbers of rows expiring):
- if rowsDeleted < (cleanupRowLimit / 10) {
- break
- }
- // crude backpressure mechanism: if the database is slow,
- // give it time to process other queries
- time.Sleep(elapsed)
- }
- }
- time.Sleep(cleanupPauseTime)
- }
- }
-
- func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
- ids, maxNanotime, err := mysql.selectCleanupIDs(age)
- if len(ids) == 0 {
- mysql.logger.Debug("mysql", "found no rows to clean up")
- return
- }
-
- mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
-
- // can't use ? binding for a variable number of arguments, build the IN clause manually
- var inBuf bytes.Buffer
- inBuf.WriteByte('(')
- for i, id := range ids {
- if i != 0 {
- inBuf.WriteRune(',')
- }
- fmt.Fprintf(&inBuf, "%d", id)
- }
- inBuf.WriteRune(')')
-
- _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
- if err != nil {
- return
- }
-
- count = len(ids)
- return
- }
-
- func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
- rows, err := mysql.db.Query(`
- SELECT history.id, sequence.nanotime
- FROM history
- LEFT JOIN sequence ON history.id = sequence.history_id
- ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
- if err != nil {
- return
- }
- defer rows.Close()
-
- // a history ID may have 0-2 rows in sequence: 1 for a channel entry,
- // 2 for a DM, 0 if the data is inconsistent. therefore, deduplicate
- // and delete anything that doesn't have a sequence entry:
- idset := make(map[uint64]struct{}, cleanupRowLimit)
- threshold := time.Now().Add(-age).UnixNano()
- for rows.Next() {
- var id uint64
- var nanotime sql.NullInt64
- err = rows.Scan(&id, &nanotime)
- if err != nil {
- return
- }
- if !nanotime.Valid || nanotime.Int64 < threshold {
- idset[id] = struct{}{}
- if nanotime.Valid && nanotime.Int64 > maxNanotime {
- maxNanotime = nanotime.Int64
- }
- }
- }
- ids = make([]uint64, len(idset))
- i := 0
- for id := range idset {
- ids[i] = id
- i++
- }
- return
- }
-
- func (mysql *MySQL) prepareStatements() (err error) {
- mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
- (data, msgid) VALUES (?, ?);`)
- if err != nil {
- return
- }
- mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
- (target, nanotime, history_id) VALUES (?, ?, ?);`)
- if err != nil {
- return
- }
- mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
- (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
- if err != nil {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) getTimeout() time.Duration {
- return time.Duration(atomic.LoadInt64(&mysql.timeout))
- }
-
- func (mysql *MySQL) logError(context string, err error) (quit bool) {
- if err != nil {
- mysql.logger.Error("mysql", context, err.Error())
- return true
- }
- return false
- }
-
- func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
- if mysql.db == nil {
- return
- }
-
- if target == "" {
- return utils.ErrInvalidParams
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- id, err := mysql.insertBase(ctx, item)
- if err != nil {
- return
- }
-
- err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
- return
- }
-
- func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
- _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
- mysql.logError("could not insert sequence entry", err)
- return
- }
-
- func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
- _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
- mysql.logError("could not insert conversations entry", err)
- return
- }
-
- func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
- value, err := marshalItem(&item)
- if mysql.logError("could not marshal item", err) {
- return
- }
-
- msgidBytes, err := decodeMsgid(item.Message.Msgid)
- if mysql.logError("could not decode msgid", err) {
- return
- }
-
- result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
- if mysql.logError("could not insert item", err) {
- return
- }
- id, err = result.LastInsertId()
- if mysql.logError("could not insert item", err) {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
- if mysql.db == nil {
- return
- }
-
- if senderAccount == "" && recipientAccount == "" {
- return
- }
-
- if sender == "" || recipient == "" {
- return utils.ErrInvalidParams
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- id, err := mysql.insertBase(ctx, item)
- if err != nil {
- return
- }
-
- nanotime := item.Message.Time.UnixNano()
-
- if senderAccount != "" {
- err = mysql.insertSequenceEntry(ctx, senderAccount, nanotime, id)
- if err != nil {
- return
- }
- err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
- if err != nil {
- return
- }
- }
-
- if recipientAccount != "" && sender != recipient {
- err = mysql.insertSequenceEntry(ctx, recipientAccount, nanotime, id)
- if err != nil {
- return
- }
- err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
- if err != nil {
- return
- }
- }
-
- return
- }
-
- func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
- // in theory, we could optimize out a roundtrip to the database by using a subquery instead:
- // sequence.nanotime > (
- // SELECT sequence.nanotime FROM sequence, history
- // WHERE sequence.history_id = history.id AND history.msgid = ?
- // LIMIT 1)
- // however, this doesn't handle the BETWEEN case with one or two msgids, where we
- // don't initially know whether the interval is going forwards or backwards. to simplify
- // the logic, resolve msgids to timestamps "manually" in all cases, using a separate query.
- decoded, err := decodeMsgid(msgid)
- if err != nil {
- return
- }
- row := mysql.db.QueryRowContext(ctx, `
- SELECT sequence.nanotime FROM sequence
- INNER JOIN history ON history.id = sequence.history_id
- WHERE history.msgid = ? LIMIT 1;`, decoded)
- var nanotime int64
- err = row.Scan(&nanotime)
- if mysql.logError("could not resolve msgid to time", err) {
- return
- }
- result = time.Unix(0, nanotime).UTC()
- return
- }
-
- func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
- rows, err := mysql.db.QueryContext(ctx, query, args...)
- if mysql.logError("could not select history items", err) {
- return
- }
-
- defer rows.Close()
-
- for rows.Next() {
- var blob []byte
- var item history.Item
- err = rows.Scan(&blob)
- if mysql.logError("could not scan history item", err) {
- return
- }
- err = unmarshalItem(blob, &item)
- if mysql.logError("could not unmarshal history item", err) {
- return
- }
- results = append(results, item)
- }
- return
- }
-
- func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
- useSequence := correspondent == ""
- table := "sequence"
- if !useSequence {
- table = "conversations"
- }
-
- after, before, ascending := history.MinMaxAsc(after, before, cutoff)
- direction := "ASC"
- if !ascending {
- direction = "DESC"
- }
-
- var queryBuf bytes.Buffer
-
- args := make([]interface{}, 0, 6)
- fmt.Fprintf(&queryBuf,
- "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
- if useSequence {
- fmt.Fprintf(&queryBuf, " sequence.target = ?")
- args = append(args, target)
- } else {
- fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
- args = append(args, target)
- args = append(args, correspondent)
- }
- if !after.IsZero() {
- fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
- args = append(args, after.UnixNano())
- }
- if !before.IsZero() {
- fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
- args = append(args, before.UnixNano())
- }
- fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
- args = append(args, limit)
-
- results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
- if err == nil && !ascending {
- history.Reverse(results)
- }
- return
- }
-
- func (mysql *MySQL) Close() {
- // closing the database will close our prepared statements as well
- if mysql.db != nil {
- mysql.db.Close()
- }
- mysql.db = nil
- }
-
- // implements history.Sequence, emulating a single history buffer (for a channel,
- // a single user's DMs, or a DM conversation)
- type mySQLHistorySequence struct {
- mysql *MySQL
- target string
- correspondent string
- cutoff time.Time
- }
-
- func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
- defer cancel()
-
- startTime := start.Time
- if start.Msgid != "" {
- startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
- if err != nil {
- return nil, false, err
- }
- }
- endTime := end.Time
- if end.Msgid != "" {
- endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
- if err != nil {
- return nil, false, err
- }
- }
-
- results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
- return results, (err == nil), err
- }
-
- func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
- return history.GenericAround(s, start, limit)
- }
-
- func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
- return &mySQLHistorySequence{
- target: target,
- correspondent: correspondent,
- mysql: mysql,
- cutoff: cutoff,
- }
- }
|