Browse Source

add mysql timeouts

tags/v2.0.0-rc1
Shivaram Lingamneni 4 years ago
parent
commit
98a7b45d96
5 changed files with 86 additions and 49 deletions
  1. 3
    8
      irc/config.go
  2. 22
    0
      irc/mysql/config.go
  3. 56
    37
      irc/mysql/history.go
  4. 4
    4
      irc/server.go
  5. 1
    0
      oragono.yaml

+ 3
- 8
irc/config.go View File

@@ -504,14 +504,7 @@ type Config struct {
504 504
 	Datastore struct {
505 505
 		Path        string
506 506
 		AutoUpgrade bool
507
-		MySQL       struct {
508
-			Enabled         bool
509
-			Host            string
510
-			Port            int
511
-			User            string
512
-			Password        string
513
-			HistoryDatabase string `yaml:"history-database"`
514
-		}
507
+		MySQL       mysql.Config
515 508
 	}
516 509
 
517 510
 	Accounts AccountConfig
@@ -1069,6 +1062,8 @@ func LoadConfig(filename string) (config *Config, err error) {
1069 1062
 		config.History.ZNCMax = config.History.ChathistoryMax
1070 1063
 	}
1071 1064
 
1065
+	config.Datastore.MySQL.ExpireTime = time.Duration(config.History.Restrictions.ExpireTime)
1066
+
1072 1067
 	config.Server.Cloaks.Initialize()
1073 1068
 	if config.Server.Cloaks.Enabled {
1074 1069
 		if config.Server.Cloaks.Secret == "" || config.Server.Cloaks.Secret == "siaELnk6Kaeo65K3RCrwJjlWaZ-Bt3WuZ2L8MXLbNb4" {

+ 22
- 0
irc/mysql/config.go View File

@@ -0,0 +1,22 @@
1
+// Copyright (c) 2020 Shivaram Lingamneni
2
+// released under the MIT license
3
+
4
+package mysql
5
+
6
+import (
7
+	"time"
8
+)
9
+
10
+type Config struct {
11
+	// these are intended to be written directly into the config file:
12
+	Enabled         bool
13
+	Host            string
14
+	Port            int
15
+	User            string
16
+	Password        string
17
+	HistoryDatabase string `yaml:"history-database"`
18
+	Timeout         time.Duration
19
+
20
+	// XXX these are copied from elsewhere in the config:
21
+	ExpireTime time.Duration
22
+}

+ 56
- 37
irc/mysql/history.go View File

@@ -1,11 +1,16 @@
1
+// Copyright (c) 2020 Shivaram Lingamneni
2
+// released under the MIT license
3
+
1 4
 package mysql
2 5
 
3 6
 import (
4 7
 	"bytes"
8
+	"context"
5 9
 	"database/sql"
6 10
 	"fmt"
7 11
 	"runtime/debug"
8 12
 	"sync"
13
+	"sync/atomic"
9 14
 	"time"
10 15
 
11 16
 	_ "github.com/go-sql-driver/mysql"
@@ -27,58 +32,59 @@ const (
27 32
 )
28 33
 
29 34
 type MySQL struct {
30
-	db     *sql.DB
31
-	logger *logger.Manager
35
+	timeout int64
36
+	db      *sql.DB
37
+	logger  *logger.Manager
32 38
 
33 39
 	insertHistory      *sql.Stmt
34 40
 	insertSequence     *sql.Stmt
35 41
 	insertConversation *sql.Stmt
36 42
 
37 43
 	stateMutex sync.Mutex
38
-	expireTime time.Duration
44
+	config     Config
39 45
 }
40 46
 
41
-func (mysql *MySQL) Initialize(logger *logger.Manager, expireTime time.Duration) {
47
+func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
42 48
 	mysql.logger = logger
43
-	mysql.expireTime = expireTime
49
+	mysql.SetConfig(config)
44 50
 }
45 51
 
46
-func (mysql *MySQL) SetExpireTime(expireTime time.Duration) {
52
+func (mysql *MySQL) SetConfig(config Config) {
53
+	atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
47 54
 	mysql.stateMutex.Lock()
48
-	mysql.expireTime = expireTime
55
+	mysql.config = config
49 56
 	mysql.stateMutex.Unlock()
50 57
 }
51 58
 
52 59
 func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
53 60
 	mysql.stateMutex.Lock()
54
-	expireTime = mysql.expireTime
61
+	expireTime = mysql.config.ExpireTime
55 62
 	mysql.stateMutex.Unlock()
56 63
 	return
57 64
 }
58 65
 
59
-func (mysql *MySQL) Open(username, password, host string, port int, database string) (err error) {
60
-	// TODO: timeouts!
66
+func (m *MySQL) Open() (err error) {
61 67
 	var address string
62
-	if port != 0 {
63
-		address = fmt.Sprintf("tcp(%s:%d)", host, port)
68
+	if m.config.Port != 0 {
69
+		address = fmt.Sprintf("tcp(%s:%d)", m.config.Host, m.config.Port)
64 70
 	}
65 71
 
66
-	mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", username, password, address, database))
72
+	m.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", m.config.User, m.config.Password, address, m.config.HistoryDatabase))
67 73
 	if err != nil {
68 74
 		return err
69 75
 	}
70 76
 
71
-	err = mysql.fixSchemas()
77
+	err = m.fixSchemas()
72 78
 	if err != nil {
73 79
 		return err
74 80
 	}
75 81
 
76
-	err = mysql.prepareStatements()
82
+	err = m.prepareStatements()
77 83
 	if err != nil {
78 84
 		return err
79 85
 	}
80 86
 
81
-	go mysql.cleanupLoop()
87
+	go m.cleanupLoop()
82 88
 
83 89
 	return nil
84 90
 }
@@ -280,6 +286,10 @@ func (mysql *MySQL) prepareStatements() (err error) {
280 286
 	return
281 287
 }
282 288
 
289
+func (mysql *MySQL) getTimeout() time.Duration {
290
+	return time.Duration(atomic.LoadInt64(&mysql.timeout))
291
+}
292
+
283 293
 func (mysql *MySQL) logError(context string, err error) (quit bool) {
284 294
 	if err != nil {
285 295
 		mysql.logger.Error("mysql", context, err.Error())
@@ -297,29 +307,32 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
297 307
 		return utils.ErrInvalidParams
298 308
 	}
299 309
 
300
-	id, err := mysql.insertBase(item)
310
+	ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
311
+	defer cancel()
312
+
313
+	id, err := mysql.insertBase(ctx, item)
301 314
 	if err != nil {
302 315
 		return
303 316
 	}
304 317
 
305
-	err = mysql.insertSequenceEntry(target, item.Message.Time, id)
318
+	err = mysql.insertSequenceEntry(ctx, target, item.Message.Time, id)
306 319
 	return
307 320
 }
308 321
 
309
-func (mysql *MySQL) insertSequenceEntry(target string, messageTime time.Time, id int64) (err error) {
310
-	_, err = mysql.insertSequence.Exec(target, messageTime.UnixNano(), id)
322
+func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime time.Time, id int64) (err error) {
323
+	_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime.UnixNano(), id)
311 324
 	mysql.logError("could not insert sequence entry", err)
312 325
 	return
313 326
 }
314 327
 
315
-func (mysql *MySQL) insertConversationEntry(sender, recipient string, messageTime time.Time, id int64) (err error) {
328
+func (mysql *MySQL) insertConversationEntry(ctx context.Context, sender, recipient string, messageTime time.Time, id int64) (err error) {
316 329
 	lower, higher := stringMinMax(sender, recipient)
317
-	_, err = mysql.insertConversation.Exec(lower, higher, messageTime.UnixNano(), id)
330
+	_, err = mysql.insertConversation.ExecContext(ctx, lower, higher, messageTime.UnixNano(), id)
318 331
 	mysql.logError("could not insert conversations entry", err)
319 332
 	return
320 333
 }
321 334
 
322
-func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
335
+func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
323 336
 	value, err := marshalItem(&item)
324 337
 	if mysql.logError("could not marshal item", err) {
325 338
 		return
@@ -330,7 +343,7 @@ func (mysql *MySQL) insertBase(item history.Item) (id int64, err error) {
330 343
 		return
331 344
 	}
332 345
 
333
-	result, err := mysql.insertHistory.Exec(value, msgidBytes)
346
+	result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
334 347
 	if mysql.logError("could not insert item", err) {
335 348
 		return
336 349
 	}
@@ -363,31 +376,34 @@ func (mysql *MySQL) AddDirectMessage(sender, recipient string, senderPersistent,
363 376
 		return utils.ErrInvalidParams
364 377
 	}
365 378
 
366
-	id, err := mysql.insertBase(item)
379
+	ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
380
+	defer cancel()
381
+
382
+	id, err := mysql.insertBase(ctx, item)
367 383
 	if err != nil {
368 384
 		return
369 385
 	}
370 386
 
371 387
 	if senderPersistent {
372
-		mysql.insertSequenceEntry(sender, item.Message.Time, id)
388
+		mysql.insertSequenceEntry(ctx, sender, item.Message.Time, id)
373 389
 		if err != nil {
374 390
 			return
375 391
 		}
376 392
 	}
377 393
 
378 394
 	if recipientPersistent && sender != recipient {
379
-		err = mysql.insertSequenceEntry(recipient, item.Message.Time, id)
395
+		err = mysql.insertSequenceEntry(ctx, recipient, item.Message.Time, id)
380 396
 		if err != nil {
381 397
 			return
382 398
 		}
383 399
 	}
384 400
 
385
-	err = mysql.insertConversationEntry(sender, recipient, item.Message.Time, id)
401
+	err = mysql.insertConversationEntry(ctx, sender, recipient, item.Message.Time, id)
386 402
 
387 403
 	return
388 404
 }
389 405
 
390
-func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
406
+func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
391 407
 	// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
392 408
 	// sequence.nanotime > (
393 409
 	//     SELECT sequence.nanotime FROM sequence, history
@@ -400,7 +416,7 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
400 416
 	if err != nil {
401 417
 		return
402 418
 	}
403
-	row := mysql.db.QueryRow(`
419
+	row := mysql.db.QueryRowContext(ctx, `
404 420
 		SELECT sequence.nanotime FROM sequence
405 421
 		INNER JOIN history ON history.id = sequence.history_id
406 422
 		WHERE history.msgid = ? LIMIT 1;`, decoded)
@@ -413,8 +429,8 @@ func (mysql *MySQL) msgidToTime(msgid string) (result time.Time, err error) {
413 429
 	return
414 430
 }
415 431
 
416
-func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []history.Item, err error) {
417
-	rows, err := mysql.db.Query(query, args...)
432
+func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
433
+	rows, err := mysql.db.QueryContext(ctx, query, args...)
418 434
 	if mysql.logError("could not select history items", err) {
419 435
 		return
420 436
 	}
@@ -437,7 +453,7 @@ func (mysql *MySQL) selectItems(query string, args ...interface{}) (results []hi
437 453
 	return
438 454
 }
439 455
 
440
-func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
456
+func (mysql *MySQL) betweenTimestamps(ctx context.Context, sender, recipient string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) {
441 457
 	useSequence := true
442 458
 	var lowerTarget, upperTarget string
443 459
 	if sender != "" {
@@ -480,7 +496,7 @@ func (mysql *MySQL) BetweenTimestamps(sender, recipient string, after, before, c
480 496
 	fmt.Fprintf(&queryBuf, " ORDER BY %[1]s.nanotime %[2]s LIMIT ?;", table, direction)
481 497
 	args = append(args, limit)
482 498
 
483
-	results, err = mysql.selectItems(queryBuf.String(), args...)
499
+	results, err = mysql.selectItems(ctx, queryBuf.String(), args...)
484 500
 	if err == nil && !ascending {
485 501
 		history.Reverse(results)
486 502
 	}
@@ -505,22 +521,25 @@ type mySQLHistorySequence struct {
505 521
 }
506 522
 
507 523
 func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
524
+	ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
525
+	defer cancel()
526
+
508 527
 	startTime := start.Time
509 528
 	if start.Msgid != "" {
510
-		startTime, err = s.mysql.msgidToTime(start.Msgid)
529
+		startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
511 530
 		if err != nil {
512 531
 			return nil, false, err
513 532
 		}
514 533
 	}
515 534
 	endTime := end.Time
516 535
 	if end.Msgid != "" {
517
-		endTime, err = s.mysql.msgidToTime(end.Msgid)
536
+		endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
518 537
 		if err != nil {
519 538
 			return nil, false, err
520 539
 		}
521 540
 	}
522 541
 
523
-	results, err = s.mysql.BetweenTimestamps(s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
542
+	results, err = s.mysql.betweenTimestamps(ctx, s.sender, s.recipient, startTime, endTime, s.cutoff, limit)
524 543
 	return results, (err == nil), err
525 544
 }
526 545
 

+ 4
- 4
irc/server.go View File

@@ -669,8 +669,8 @@ func (server *Server) applyConfig(config *Config) (err error) {
669 669
 			return err
670 670
 		}
671 671
 	} else {
672
-		if config.Datastore.MySQL.Enabled {
673
-			server.historyDB.SetExpireTime(time.Duration(config.History.Restrictions.ExpireTime))
672
+		if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL {
673
+			server.historyDB.SetConfig(config.Datastore.MySQL)
674 674
 		}
675 675
 	}
676 676
 
@@ -793,8 +793,8 @@ func (server *Server) loadDatastore(config *Config) error {
793 793
 	server.accounts.Initialize(server)
794 794
 
795 795
 	if config.Datastore.MySQL.Enabled {
796
-		server.historyDB.Initialize(server.logger, time.Duration(config.History.Restrictions.ExpireTime))
797
-		err = server.historyDB.Open(config.Datastore.MySQL.User, config.Datastore.MySQL.Password, config.Datastore.MySQL.Host, config.Datastore.MySQL.Port, config.Datastore.MySQL.HistoryDatabase)
796
+		server.historyDB.Initialize(server.logger, config.Datastore.MySQL)
797
+		err = server.historyDB.Open()
798 798
 		if err != nil {
799 799
 			server.logger.Error("internal", "could not connect to mysql", err.Error())
800 800
 			return err

+ 1
- 0
oragono.yaml View File

@@ -608,6 +608,7 @@ datastore:
608 608
         user: "oragono"
609 609
         password: "KOHw8WSaRwaoo-avo0qVpQ"
610 610
         history-database: "oragono_history"
611
+        timeout: 3s
611 612
 
612 613
 # languages config
613 614
 languages:

Loading…
Cancel
Save