|
@@ -1,11 +1,16 @@
|
|
1
|
+// Copyright (c) 2020 Shivaram Lingamneni
|
|
2
|
+// released under the MIT license
|
|
3
|
+
|
1
|
4
|
package mysql
|
2
|
5
|
|
3
|
6
|
import (
|
4
|
7
|
"bytes"
|
|
8
|
+ "context"
|
5
|
9
|
"database/sql"
|
6
|
10
|
"fmt"
|
7
|
11
|
"runtime/debug"
|
8
|
12
|
"sync"
|
|
13
|
+ "sync/atomic"
|
9
|
14
|
"time"
|
10
|
15
|
|
11
|
16
|
_ "github.com/go-sql-driver/mysql"
|
|
@@ -27,58 +32,59 @@ const (
|
27
|
32
|
)
|
28
|
33
|
|
29
|
34
|
type MySQL struct {
|
30
|
|
- db *sql.DB
|
31
|
|
- logger *logger.Manager
|
|
35
|
+ timeout int64
|
|
36
|
+ db *sql.DB
|
|
37
|
+ logger *logger.Manager
|
32
|
38
|
|
33
|
39
|
insertHistory *sql.Stmt
|
34
|
40
|
insertSequence *sql.Stmt
|
35
|
41
|
insertConversation *sql.Stmt
|
36
|
42
|
|
37
|
43
|
stateMutex sync.Mutex
|
38
|
|
- expireTime time.Duration
|
|
44
|
+ config Config
|
39
|
45
|
}
|
40
|
46
|
|
41
|
|
-func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
|
|
47
|
+func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
|
42
|
48
|
mysql.logger = logger
|
43
|
|
- mysql.expireTime = expireTime
|
|
49
|
+ mysql.SetConfig(config)
|
44
|
50
|
}
|
45
|
51
|
|
46
|
|
-func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
|
|
52
|
+func (mysql *MySQL) SetConfig(config Config) {
|
|
53
|
+ atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
|
47
|
54
|
mysql.stateMutex.Lock()
|
48
|
|
- mysql.expireTime = expireTime
|
|
55
|
+ mysql.config = config
|
49
|
56
|
mysql.stateMutex.Unlock()
|
50
|
57
|
}
|
51
|
58
|
|
52
|
59
|
func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
|
53
|
60
|
mysql.stateMutex.Lock()
|
54
|
|
- expireTime = mysql.expireTime
|
|
61
|
+ expireTime = mysql.config.ExpireTime
|
55
|
62
|
mysql.stateMutex.Unlock()
|
56
|
63
|
return
|
57
|
64
|
}
|
58
|
65
|
|
59
|
|
-func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
|
60
|
|
- // TODO: timeouts!
|
|
66
|
+func (m *MySQL) Open() (err error) {
|
61
|
67
|
var address string
|
62
|
|
- if port != 0 {
|
63
|
|
- address = fmt.Sprintf("tcp(%s:%d)", host, port)
|
|
68
|
+ if m.config.Port != 0 {
|
|
69
|
+ address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
|
64
|
70
|
}
|
65
|
71
|
|
66
|
|
- mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
|
|
72
|
+ m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
|
67
|
73
|
if err != nil {
|
68
|
74
|
return err
|
69
|
75
|
}
|
70
|
76
|
|
71
|
|
- err = mysql.fixSchemas()
|
|
77
|
+ err = m.fixSchemas()
|
72
|
78
|
if err != nil {
|
73
|
79
|
return err
|
74
|
80
|
}
|
75
|
81
|
|
76
|
|
- err = mysql.prepareStatements()
|
|
82
|
+ err = m.prepareStatements()
|
77
|
83
|
if err != nil {
|
78
|
84
|
return err
|
79
|
85
|
}
|
80
|
86
|
|
81
|
|
- go mysql.cleanupLoop()
|
|
87
|
+ go m.cleanupLoop()
|
82
|
88
|
|
83
|
89
|
return nil
|
84
|
90
|
}
|
|
@@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) {
|
280
|
286
|
return
|
281
|
287
|
}
|
282
|
288
|
|
|
289
|
+func (mysql *MySQL) getTimeout() time.Duration {
|
|
290
|
+ return time.Duration(atomic.LoadInt64(&mysql.timeout))
|
|
291
|
+}
|
|
292
|
+
|
283
|
293
|
func (mysql *MySQL) logError(context string, err error) (quit bool) {
|
284
|
294
|
if err != nil {
|
285
|
295
|
mysql.logger.Error("mysql", context, err.Error())
|
|
@@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
|
297
|
307
|
return utils.ErrInvalidParams
|
298
|
308
|
}
|
299
|
309
|
|
300
|
|
- id, err := mysql.insertBase(item)
|
|
310
|
+ ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
|
311
|
+ defer cancel()
|
|
312
|
+
|
|
313
|
+ id, err := mysql.insertBase(ctx, item)
|
301
|
314
|
if err != nil {
|
302
|
315
|
return
|
303
|
316
|
}
|
304
|
317
|
|
305
|
|
- err = mysql.insertSequenceEntry(target, item.Message.Time, id)
|
|
318
|
+ err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
|
306
|
319
|
return
|
307
|
320
|
}
|
308
|
321
|
|
309
|
|
-func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
|
310
|
|
- _, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
|
|
322
|
+func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
|
|
323
|
+ _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
|
311
|
324
|
mysql.logError("could not insert sequence entry", err)
|
312
|
325
|
return
|
313
|
326
|
}
|
314
|
327
|
|
315
|
|
-func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
|
|
328
|
+func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
|
316
|
329
|
lower, higher := stringMinMax(sender, recipient)
|
317
|
|
- _, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
|
|
330
|
+ _, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
|
318
|
331
|
mysql.logError("could not insert conversations entry", err)
|
319
|
332
|
return
|
320
|
333
|
}
|
321
|
334
|
|
322
|
|
-func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
|
|
335
|
+func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
|
323
|
336
|
value, err := marshalItem(&item)
|
324
|
337
|
if mysql.logError("could not marshal item", err) {
|
325
|
338
|
return
|
|
@@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
|
330
|
343
|
return
|
331
|
344
|
}
|
332
|
345
|
|
333
|
|
- result, err := mysql.insertHistory.Exec(value, msgidBytes)
|
|
346
|
+ result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
|
334
|
347
|
if mysql.logError("could not insert item", err) {
|
335
|
348
|
return
|
336
|
349
|
}
|
|
@@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
|
363
|
376
|
return utils.ErrInvalidParams
|
364
|
377
|
}
|
365
|
378
|
|
366
|
|
- id, err := mysql.insertBase(item)
|
|
379
|
+ ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
|
380
|
+ defer cancel()
|
|
381
|
+
|
|
382
|
+ id, err := mysql.insertBase(ctx, item)
|
367
|
383
|
if err != nil {
|
368
|
384
|
return
|
369
|
385
|
}
|
370
|
386
|
|
371
|
387
|
if senderPersistent {
|
372
|
|
- mysql.insertSequenceEntry(sender, item.Message.Time, id)
|
|
388
|
+ mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
|
373
|
389
|
if err != nil {
|
374
|
390
|
return
|
375
|
391
|
}
|
376
|
392
|
}
|
377
|
393
|
|
378
|
394
|
if recipientPersistent && sender != recipient {
|
379
|
|
- err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
|
|
395
|
+ err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
|
380
|
396
|
if err != nil {
|
381
|
397
|
return
|
382
|
398
|
}
|
383
|
399
|
}
|
384
|
400
|
|
385
|
|
- err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
|
|
401
|
+ err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
|
386
|
402
|
|
387
|
403
|
return
|
388
|
404
|
}
|
389
|
405
|
|
390
|
|
-func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
|
406
|
+func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
|
391
|
407
|
// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
|
392
|
408
|
// sequence.nanotime > (
|
393
|
409
|
// SELECT sequence.nanotime FROM sequence, history
|
|
@@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
400
|
416
|
if err != nil {
|
401
|
417
|
return
|
402
|
418
|
}
|
403
|
|
- row := mysql.db.QueryRow(`
|
|
419
|
+ row := mysql.db.QueryRowContext(ctx, `
|
404
|
420
|
SELECT sequence.nanotime FROM sequence
|
405
|
421
|
INNER JOIN history ON history.id = sequence.history_id
|
406
|
422
|
WHERE history.msgid = ? LIMIT 1;`, decoded)
|
|
@@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
|
413
|
429
|
return
|
414
|
430
|
}
|
415
|
431
|
|
416
|
|
-func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
|
417
|
|
- rows, err := mysql.db.Query(query, args...)
|
|
432
|
+func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
|
|
433
|
+ rows, err := mysql.db.QueryContext(ctx, query, args...)
|
418
|
434
|
if mysql.logError("could not select history items", err) {
|
419
|
435
|
return
|
420
|
436
|
}
|
|
@@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi
|
437
|
453
|
return
|
438
|
454
|
}
|
439
|
455
|
|
440
|
|
-func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
|
|
456
|
+func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
|
441
|
457
|
useSequence := true
|
442
|
458
|
var lowerTarget, upperTarget string
|
443
|
459
|
if sender != "" {
|
|
@@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c
|
480
|
496
|
fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
|
481
|
497
|
args = append(args, limit)
|
482
|
498
|
|
483
|
|
- results, err = mysql.selectItems(queryBuf.String(), args...)
|
|
499
|
+ results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
|
484
|
500
|
if err == nil && !ascending {
|
485
|
501
|
history.Reverse(results)
|
486
|
502
|
}
|
|
@@ -505,22 +521,25 @@ type mySQLHistorySequence struct {
|
505
|
521
|
}
|
506
|
522
|
|
507
|
523
|
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
|
|
524
|
+ ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
|
|
525
|
+ defer cancel()
|
|
526
|
+
|
508
|
527
|
startTime := start.Time
|
509
|
528
|
if start.Msgid != "" {
|
510
|
|
- startTime, err = s.mysql.msgidToTime(start.Msgid)
|
|
529
|
+ startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
|
511
|
530
|
if err != nil {
|
512
|
531
|
return nil, false, err
|
513
|
532
|
}
|
514
|
533
|
}
|
515
|
534
|
endTime := end.Time
|
516
|
535
|
if end.Msgid != "" {
|
517
|
|
- endTime, err = s.mysql.msgidToTime(end.Msgid)
|
|
536
|
+ endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
|
518
|
537
|
if err != nil {
|
519
|
538
|
return nil, false, err
|
520
|
539
|
}
|
521
|
540
|
}
|
522
|
541
|
|
523
|
|
- results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
|
|
542
|
+ results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
|
524
|
543
|
return results, (err == nil), err
|
525
|
544
|
}
|
526
|
545
|
|