1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093 |
- // Copyright (c) 2020 Shivaram Lingamneni
- // released under the MIT license
-
- package mysql
-
- import (
- "context"
- "database/sql"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "runtime/debug"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/ergochat/ergo/irc/history"
- "github.com/ergochat/ergo/irc/logger"
- "github.com/ergochat/ergo/irc/utils"
- _ "github.com/go-sql-driver/mysql"
- )
-
- var (
- ErrDisallowed = errors.New("disallowed")
- )
-
- const (
- // maximum length in bytes of any message target (nickname or channel name) in its
- // canonicalized (i.e., casefolded) state:
- MaxTargetLength = 64
-
- // latest schema of the db
- latestDbSchema = "2"
- keySchemaVersion = "db.version"
- // minor version indicates rollback-safe upgrades, i.e.,
- // you can downgrade oragono and everything will work
- latestDbMinorVersion = "2"
- keySchemaMinorVersion = "db.minorversion"
- cleanupRowLimit = 50
- cleanupPauseTime = 10 * time.Minute
- )
-
- type e struct{}
-
- type MySQL struct {
- timeout int64
- trackAccountMessages uint32
- db *sql.DB
- logger *logger.Manager
-
- insertHistory *sql.Stmt
- insertSequence *sql.Stmt
- insertConversation *sql.Stmt
- insertCorrespondent *sql.Stmt
- insertAccountMessage *sql.Stmt
-
- stateMutex sync.Mutex
- config Config
-
- wakeForgetter chan e
- }
-
- func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
- mysql.logger = logger
- mysql.wakeForgetter = make(chan e, 1)
- mysql.SetConfig(config)
- }
-
- func (mysql *MySQL) SetConfig(config Config) {
- atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
- var trackAccountMessages uint32
- if config.TrackAccountMessages {
- trackAccountMessages = 1
- }
- atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
- mysql.stateMutex.Lock()
- mysql.config = config
- mysql.stateMutex.Unlock()
- }
-
- func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
- mysql.stateMutex.Lock()
- expireTime = mysql.config.ExpireTime
- mysql.stateMutex.Unlock()
- return
- }
-
- func (m *MySQL) Open() (err error) {
- var address string
- if m.config.SocketPath != "" {
- address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
- } else if m.config.Port != 0 {
- address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
- }
-
- m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
- if err != nil {
- return err
- }
-
- if m.config.MaxConns != 0 {
- m.db.SetMaxOpenConns(m.config.MaxConns)
- m.db.SetMaxIdleConns(m.config.MaxConns)
- }
- if m.config.ConnMaxLifetime != 0 {
- m.db.SetConnMaxLifetime(m.config.ConnMaxLifetime)
- }
-
- err = m.fixSchemas()
- if err != nil {
- return err
- }
-
- err = m.prepareStatements()
- if err != nil {
- return err
- }
-
- go m.cleanupLoop()
- go m.forgetLoop()
-
- return nil
- }
-
- func (mysql *MySQL) fixSchemas() (err error) {
- _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
- key_name VARCHAR(32) primary key,
- value VARCHAR(32) NOT NULL
- ) CHARSET=ascii COLLATE=ascii_bin;`)
- if err != nil {
- return err
- }
-
- var schema string
- err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
- if err == sql.ErrNoRows {
- err = mysql.createTables()
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
- if err != nil {
- return
- }
- return
- } else if err == nil && schema != latestDbSchema {
- // TODO figure out what to do about schema changes
- return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
- } else if err != nil {
- return err
- }
-
- var minorVersion string
- err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
- if err == sql.ErrNoRows {
- // XXX for now, the only minor version upgrade is the account tracking tables
- err = mysql.createComplianceTables()
- if err != nil {
- return
- }
- err = mysql.createCorrespondentsTable()
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
- if err != nil {
- return
- }
- } else if err == nil && minorVersion == "1" {
- // upgrade from 2.1 to 2.2: create the correspondents table
- err = mysql.createCorrespondentsTable()
- if err != nil {
- return
- }
- _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
- if err != nil {
- return
- }
- } else if err == nil && minorVersion != latestDbMinorVersion {
- // TODO: if minorVersion < latestDbMinorVersion, upgrade,
- // if latestDbMinorVersion < minorVersion, ignore because backwards compatible
- }
- return
- }
-
- func (mysql *MySQL) createTables() (err error) {
- _, err = mysql.db.Exec(`CREATE TABLE history (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- data BLOB NOT NULL,
- msgid BINARY(16) NOT NULL,
- KEY (msgid(4))
- ) CHARSET=ascii COLLATE=ascii_bin;`)
- if err != nil {
- return err
- }
-
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE sequence (
- history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- KEY (target, nanotime)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
- /* XXX: this table used to be:
- CREATE TABLE sequence (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- history_id BIGINT NOT NULL,
- KEY (target, nanotime),
- KEY (history_id)
- ) CHARSET=ascii COLLATE=ascii_bin;
- Some users may still be using the old schema.
- */
-
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE conversations (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- correspondent VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- history_id BIGINT NOT NULL,
- KEY (target, correspondent, nanotime),
- KEY (history_id)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
-
- err = mysql.createCorrespondentsTable()
- if err != nil {
- return err
- }
-
- err = mysql.createComplianceTables()
- if err != nil {
- return err
- }
-
- return nil
- }
-
- func (mysql *MySQL) createCorrespondentsTable() (err error) {
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- target VARBINARY(%[1]d) NOT NULL,
- correspondent VARBINARY(%[1]d) NOT NULL,
- nanotime BIGINT UNSIGNED NOT NULL,
- UNIQUE KEY (target, correspondent),
- KEY (target, nanotime),
- KEY (nanotime)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- return
- }
-
- func (mysql *MySQL) createComplianceTables() (err error) {
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
- history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
- account VARBINARY(%[1]d) NOT NULL,
- KEY (account, history_id)
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
-
- _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
- account VARBINARY(%[1]d) NOT NULL
- ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
- if err != nil {
- return err
- }
-
- return nil
- }
-
- func (mysql *MySQL) cleanupLoop() {
- defer func() {
- if r := recover(); r != nil {
- mysql.logger.Error("mysql",
- fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack()))
- time.Sleep(cleanupPauseTime)
- go mysql.cleanupLoop()
- }
- }()
-
- for {
- expireTime := mysql.getExpireTime()
- if expireTime != 0 {
- for {
- startTime := time.Now()
- rowsDeleted, err := mysql.doCleanup(expireTime)
- elapsed := time.Now().Sub(startTime)
- mysql.logError("error during row cleanup", err)
- // keep going as long as we're accomplishing significant work
- // (don't busy-wait on small numbers of rows expiring):
- if rowsDeleted < (cleanupRowLimit / 10) {
- break
- }
- // crude backpressure mechanism: if the database is slow,
- // give it time to process other queries
- time.Sleep(elapsed)
- }
- }
- time.Sleep(cleanupPauseTime)
- }
- }
-
- func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
- defer cancel()
-
- ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
- if len(ids) == 0 {
- mysql.logger.Debug("mysql", "found no rows to clean up")
- return
- }
-
- mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
-
- if maxNanotime != 0 {
- mysql.deleteCorrespondents(ctx, maxNanotime)
- }
-
- return len(ids), mysql.deleteHistoryIDs(ctx, ids)
- }
-
- func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
- // can't use ? binding for a variable number of arguments, build the IN clause manually
- var inBuf strings.Builder
- inBuf.WriteByte('(')
- for i, id := range ids {
- if i != 0 {
- inBuf.WriteRune(',')
- }
- fmt.Fprintf(&inBuf, "%d", id)
- }
- inBuf.WriteRune(')')
- inClause := inBuf.String()
-
- _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
- if err != nil {
- return
- }
- _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
- if err != nil {
- return
- }
- if mysql.isTrackingAccountMessages() {
- _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
- if err != nil {
- return
- }
- }
- _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
- if err != nil {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
- rows, err := mysql.db.QueryContext(ctx, `
- SELECT history.id, sequence.nanotime, conversations.nanotime
- FROM history
- LEFT JOIN sequence ON history.id = sequence.history_id
- LEFT JOIN conversations on history.id = conversations.history_id
- ORDER BY history.id LIMIT ?;`, cleanupRowLimit)
- if err != nil {
- return
- }
- defer rows.Close()
-
- idset := make(map[uint64]struct{}, cleanupRowLimit)
- threshold := time.Now().Add(-age).UnixNano()
- for rows.Next() {
- var id uint64
- var seqNano, convNano sql.NullInt64
- err = rows.Scan(&id, &seqNano, &convNano)
- if err != nil {
- return
- }
- nanotime := extractNanotime(seqNano, convNano)
- // returns 0 if not found; in that case the data is inconsistent
- // and we should delete the entry
- if nanotime < threshold {
- idset[id] = struct{}{}
- if nanotime > maxNanotime {
- maxNanotime = nanotime
- }
- }
- }
- ids = make([]uint64, len(idset))
- i := 0
- for id := range idset {
- ids[i] = id
- i++
- }
- return
- }
-
- func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
- result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
- if err != nil {
- mysql.logError("error deleting correspondents", err)
- } else {
- count, err := result.RowsAffected()
- if !mysql.logError("error deleting correspondents", err) {
- mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
- }
- }
- }
-
- // wait for forget queue items and process them one by one
- func (mysql *MySQL) forgetLoop() {
- defer func() {
- if r := recover(); r != nil {
- mysql.logger.Error("mysql",
- fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
- time.Sleep(cleanupPauseTime)
- go mysql.forgetLoop()
- }
- }()
-
- for {
- for {
- found, err := mysql.doForget()
- mysql.logError("error processing forget", err)
- if err != nil {
- time.Sleep(cleanupPauseTime)
- }
- if !found {
- break
- }
- }
-
- <-mysql.wakeForgetter
- }
- }
-
- // dequeue an item from the forget queue and process it
- func (mysql *MySQL) doForget() (found bool, err error) {
- id, account, err := func() (id int64, account string, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
- defer cancel()
-
- row := mysql.db.QueryRowContext(ctx,
- `SELECT forget.id, forget.account FROM forget LIMIT 1;`)
- err = row.Scan(&id, &account)
- if err == sql.ErrNoRows {
- return 0, "", nil
- }
- return
- }()
-
- if err != nil || account == "" {
- return false, err
- }
-
- found = true
-
- var count int
- for {
- start := time.Now()
- count, err = mysql.doForgetIteration(account)
- elapsed := time.Since(start)
- if err != nil {
- return true, err
- }
- if count == 0 {
- break
- }
- time.Sleep(elapsed)
- }
-
- mysql.logger.Debug("mysql", "forget complete for account", account)
-
- ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
- defer cancel()
- _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
- return
- }
-
- func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
- defer cancel()
-
- rows, err := mysql.db.QueryContext(ctx, `
- SELECT account_messages.history_id
- FROM account_messages
- WHERE account_messages.account = ?
- LIMIT ?;`, account, cleanupRowLimit)
- if err != nil {
- return
- }
- defer rows.Close()
-
- var ids []uint64
- for rows.Next() {
- var id uint64
- err = rows.Scan(&id)
- if err != nil {
- return
- }
- ids = append(ids, id)
- }
-
- if len(ids) == 0 {
- return
- }
-
- mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
- err = mysql.deleteHistoryIDs(ctx, ids)
- return len(ids), err
- }
-
- func (mysql *MySQL) prepareStatements() (err error) {
- mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
- (data, msgid) VALUES (?, ?);`)
- if err != nil {
- return
- }
- mysql.insertSequence, err = mysql.db.Prepare(`INSERT INTO sequence
- (target, nanotime, history_id) VALUES (?, ?, ?);`)
- if err != nil {
- return
- }
- mysql.insertConversation, err = mysql.db.Prepare(`INSERT INTO conversations
- (target, correspondent, nanotime, history_id) VALUES (?, ?, ?, ?);`)
- if err != nil {
- return
- }
- mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
- (target, correspondent, nanotime) VALUES (?, ?, ?)
- ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
- if err != nil {
- return
- }
- mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
- (history_id, account) VALUES (?, ?);`)
- if err != nil {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) getTimeout() time.Duration {
- return time.Duration(atomic.LoadInt64(&mysql.timeout))
- }
-
- func (mysql *MySQL) isTrackingAccountMessages() bool {
- return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
- }
-
- func (mysql *MySQL) logError(context string, err error) (quit bool) {
- if err != nil {
- mysql.logger.Error("mysql", context, err.Error())
- return true
- }
- return false
- }
-
- func (mysql *MySQL) Forget(account string) {
- if mysql.db == nil || account == "" {
- return
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
- if mysql.logError("can't insert into forget table", err) {
- return
- }
-
- // wake up the forget goroutine if it's blocked:
- select {
- case mysql.wakeForgetter <- e{}:
- default:
- }
- }
-
- func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
- if mysql.db == nil {
- return
- }
-
- if target == "" {
- return utils.ErrInvalidParams
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- id, err := mysql.insertBase(ctx, item)
- if err != nil {
- return
- }
-
- err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
- if err != nil {
- return
- }
-
- err = mysql.insertAccountMessageEntry(ctx, id, account)
- if err != nil {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
- _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
- mysql.logError("could not insert sequence entry", err)
- return
- }
-
- func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
- _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
- mysql.logError("could not insert conversations entry", err)
- return
- }
-
- func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
- _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
- mysql.logError("could not insert conversations entry", err)
- return
- }
-
- func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
- value, err := marshalItem(&item)
- if mysql.logError("could not marshal item", err) {
- return
- }
-
- msgidBytes, err := decodeMsgid(item.Message.Msgid)
- if mysql.logError("could not decode msgid", err) {
- return
- }
-
- result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
- if mysql.logError("could not insert item", err) {
- return
- }
- id, err = result.LastInsertId()
- if mysql.logError("could not insert item", err) {
- return
- }
-
- return
- }
-
- func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
- if account == "" || !mysql.isTrackingAccountMessages() {
- return
- }
- _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
- mysql.logError("could not insert account-message entry", err)
- return
- }
-
- func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
- if mysql.db == nil {
- return
- }
-
- if senderAccount == "" && recipientAccount == "" {
- return
- }
-
- if sender == "" || recipient == "" {
- return utils.ErrInvalidParams
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- id, err := mysql.insertBase(ctx, item)
- if err != nil {
- return
- }
-
- nanotime := item.Message.Time.UnixNano()
-
- if senderAccount != "" {
- err = mysql.insertConversationEntry(ctx, senderAccount, recipient, nanotime, id)
- if err != nil {
- return
- }
- err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
- if err != nil {
- return
- }
- }
-
- if recipientAccount != "" && sender != recipient {
- err = mysql.insertConversationEntry(ctx, recipientAccount, sender, nanotime, id)
- if err != nil {
- return
- }
- err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
- if err != nil {
- return
- }
- }
-
- err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
- if err != nil {
- return
- }
-
- return
- }
-
- // note that accountName is the unfolded name
- func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
- if mysql.db == nil {
- return nil
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
- if err != nil {
- return
- }
-
- if accountName != "*" {
- var item history.Item
- err = unmarshalItem(data, &item)
- // delete if the entry is corrupt
- if err == nil && item.AccountName != accountName {
- return ErrDisallowed
- }
- }
-
- err = mysql.deleteHistoryIDs(ctx, []uint64{id})
- mysql.logError("couldn't delete msgid", err)
- return
- }
-
- func (mysql *MySQL) Export(account string, writer io.Writer) {
- if mysql.db == nil {
- return
- }
-
- var err error
- var lastSeen uint64
- for {
- rows := func() (count int) {
- ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
- defer cancel()
-
- rows, rowsErr := mysql.db.QueryContext(ctx, `
- SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
- INNER JOIN history ON history.id = account_messages.history_id
- INNER JOIN sequence ON account_messages.history_id = sequence.history_id
- WHERE account_messages.account = ? AND account_messages.history_id > ?
- LIMIT ?`, account, lastSeen, cleanupRowLimit)
- if rowsErr != nil {
- err = rowsErr
- return
- }
- defer rows.Close()
- for rows.Next() {
- var id uint64
- var blob, jsonBlob []byte
- var target string
- var item history.Item
- err = rows.Scan(&id, &blob, &target)
- if err != nil {
- return
- }
- err = unmarshalItem(blob, &item)
- if err != nil {
- return
- }
- item.CfCorrespondent = target
- jsonBlob, err = json.Marshal(item)
- if err != nil {
- return
- }
- count++
- if lastSeen < id {
- lastSeen = id
- }
- writer.Write(jsonBlob)
- writer.Write([]byte{'\n'})
- }
- return
- }()
- if rows == 0 || err != nil {
- break
- }
- }
-
- mysql.logError("could not export history", err)
- return
- }
-
- func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
- decoded, err := decodeMsgid(msgid)
- if err != nil {
- return
- }
- cols := `sequence.nanotime, conversations.nanotime`
- if includeData {
- cols = `sequence.nanotime, conversations.nanotime, history.id, history.data`
- }
- row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
- SELECT %s FROM history
- LEFT JOIN sequence ON history.id = sequence.history_id
- LEFT JOIN conversations ON history.id = conversations.history_id
- WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
- var nanoSeq, nanoConv sql.NullInt64
- if !includeData {
- err = row.Scan(&nanoSeq, &nanoConv)
- } else {
- err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
- }
- if err != sql.ErrNoRows {
- mysql.logError("could not resolve msgid to time", err)
- }
- if err != nil {
- return
- }
- nanotime := extractNanotime(nanoSeq, nanoConv)
- if nanotime == 0 {
- err = sql.ErrNoRows
- return
- }
- result = time.Unix(0, nanotime).UTC()
- return
- }
-
- func extractNanotime(seq, conv sql.NullInt64) (result int64) {
- if seq.Valid {
- return seq.Int64
- } else if conv.Valid {
- return conv.Int64
- }
- return
- }
-
- func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
- rows, err := mysql.db.QueryContext(ctx, query, args...)
- if mysql.logError("could not select history items", err) {
- return
- }
-
- defer rows.Close()
-
- for rows.Next() {
- var blob []byte
- var item history.Item
- err = rows.Scan(&blob)
- if mysql.logError("could not scan history item", err) {
- return
- }
- err = unmarshalItem(blob, &item)
- if mysql.logError("could not unmarshal history item", err) {
- return
- }
- results = append(results, item)
- }
- return
- }
-
- func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
- useSequence := correspondent == ""
- table := "sequence"
- if !useSequence {
- table = "conversations"
- }
-
- after, before, ascending := history.MinMaxAsc(after, before, cutoff)
- direction := "ASC"
- if !ascending {
- direction = "DESC"
- }
-
- var queryBuf strings.Builder
-
- args := make([]interface{}, 0, 6)
- fmt.Fprintf(&queryBuf,
- "SELECT history.data from history INNER JOIN %[1]s ON history.id = %[1]s.history_id WHERE", table)
- if useSequence {
- fmt.Fprintf(&queryBuf, " sequence.target = ?")
- args = append(args, target)
- } else {
- fmt.Fprintf(&queryBuf, " conversations.target = ? AND conversations.correspondent = ?")
- args = append(args, target)
- args = append(args, correspondent)
- }
- if !after.IsZero() {
- fmt.Fprintf(&queryBuf, " AND %s.nanotime > ?", table)
- args = append(args, after.UnixNano())
- }
- if !before.IsZero() {
- fmt.Fprintf(&queryBuf, " AND %s.nanotime < ?", table)
- args = append(args, before.UnixNano())
- }
- fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
- args = append(args, limit)
-
- results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
- if err == nil && !ascending {
- history.Reverse(results)
- }
- return
- }
-
- func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) {
- after, before, ascending := history.MinMaxAsc(after, before, cutoff)
- direction := "ASC"
- if !ascending {
- direction = "DESC"
- }
-
- var queryBuf strings.Builder
- args := make([]interface{}, 0, 4)
- queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
- WHERE target = ?`)
- args = append(args, target)
- if !after.IsZero() {
- queryBuf.WriteString(" AND correspondents.nanotime > ?")
- args = append(args, after.UnixNano())
- }
- if !before.IsZero() {
- queryBuf.WriteString(" AND correspondents.nanotime < ?")
- args = append(args, before.UnixNano())
- }
- fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
- args = append(args, limit)
- query := queryBuf.String()
-
- rows, err := mysql.db.QueryContext(ctx, query, args...)
- if err != nil {
- return
- }
- defer rows.Close()
- var correspondent string
- var nanotime int64
- for rows.Next() {
- err = rows.Scan(&correspondent, &nanotime)
- if err != nil {
- return
- }
- results = append(results, history.TargetListing{
- CfName: correspondent,
- Time: time.Unix(0, nanotime),
- })
- }
-
- if !ascending {
- history.ReverseCorrespondents(results)
- }
-
- return
- }
-
- func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
- if mysql.db == nil {
- return
- }
-
- if len(cfchannels) == 0 {
- return
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
- defer cancel()
-
- var queryBuf strings.Builder
- args := make([]interface{}, 0, len(results))
- // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html
- // this should be a "loose index scan"
- queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence
- WHERE sequence.target IN (`)
- for i, chname := range cfchannels {
- if i != 0 {
- queryBuf.WriteString(", ")
- }
- queryBuf.WriteByte('?')
- args = append(args, chname)
- }
- queryBuf.WriteString(") GROUP BY sequence.target;")
-
- rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
- if mysql.logError("could not query channel listings", err) {
- return
- }
- defer rows.Close()
-
- var target string
- var nanotime int64
- for rows.Next() {
- err = rows.Scan(&target, &nanotime)
- if mysql.logError("could not scan channel listings", err) {
- return
- }
- results = append(results, history.TargetListing{
- CfName: target,
- Time: time.Unix(0, nanotime),
- })
- }
- return
- }
-
- func (mysql *MySQL) Close() {
- // closing the database will close our prepared statements as well
- if mysql.db != nil {
- mysql.db.Close()
- }
- mysql.db = nil
- }
-
- // implements history.Sequence, emulating a single history buffer (for a channel,
- // a single user's DMs, or a DM conversation)
- type mySQLHistorySequence struct {
- mysql *MySQL
- target string
- correspondent string
- cutoff time.Time
- }
-
- func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
- defer cancel()
-
- startTime := start.Time
- if start.Msgid != "" {
- startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
- if err != nil {
- return nil, err
- }
- }
- endTime := end.Time
- if end.Msgid != "" {
- endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
- if err != nil {
- return nil, err
- }
- }
-
- results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
- return results, err
- }
-
- func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
- return history.GenericAround(s, start, limit)
- }
-
- func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
- ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
- defer cancel()
-
- // TODO accept msgids here?
- startTime := start.Time
- endTime := end.Time
-
- results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
- seq.mysql.logError("could not read correspondents", err)
- return
- }
-
- func (seq *mySQLHistorySequence) Cutoff() time.Time {
- return seq.cutoff
- }
-
- func (seq *mySQLHistorySequence) Ephemeral() bool {
- return false
- }
-
- func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
- return &mySQLHistorySequence{
- target: target,
- correspondent: correspondent,
- mysql: mysql,
- cutoff: cutoff,
- }
- }
|