Browse Source

persist and load channel mask lists

tags/v0.1.0
Jeremy Latt 10 years ago
parent
commit
cf76d2bd77
5 changed files with 61 additions and 11 deletions
  1. 8
    1
      ergonomadic.go
  2. 5
    3
      irc/channel.go
  3. 10
    0
      irc/client_lookup_set.go
  4. 20
    4
      irc/database.go
  5. 18
    3
      irc/server.go

+ 8
- 1
ergonomadic.go View File

@@ -12,6 +12,7 @@ import (
12 12
 func main() {
13 13
 	conf := flag.String("conf", "ergonomadic.conf", "ergonomadic config file")
14 14
 	initdb := flag.Bool("initdb", false, "initialize database")
15
+	upgradedb := flag.Bool("upgradedb", false, "update database")
15 16
 	passwd := flag.String("genpasswd", "", "bcrypt a password")
16 17
 	flag.Parse()
17 18
 
@@ -35,7 +36,13 @@ func main() {
35 36
 
36 37
 	if *initdb {
37 38
 		irc.InitDB(config.Server.Database)
38
-		log.Println("database initialized: " + config.Server.Database)
39
+		log.Println("database initialized: ", config.Server.Database)
40
+		return
41
+	}
42
+
43
+	if *upgradedb {
44
+		irc.UpgradeDB(config.Server.Database)
45
+		log.Println("database upgraded: ", config.Server.Database)
39 46
 		return
40 47
 	}
41 48
 

+ 5
- 3
irc/channel.go View File

@@ -443,10 +443,12 @@ func (channel *Channel) Persist() (err error) {
443 443
 	if channel.flags[Persistent] {
444 444
 		_, err = channel.server.db.Exec(`
445 445
             INSERT OR REPLACE INTO channel
446
-              (name, flags, key, topic, user_limit)
447
-              VALUES (?, ?, ?, ?, ?)`,
446
+              (name, flags, key, topic, user_limit, ban_list, except_list,
447
+               invite_list)
448
+              VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
448 449
 			channel.name, channel.flags.String(), channel.key, channel.topic,
449
-			channel.userLimit)
450
+			channel.userLimit, channel.lists[BanMask].String(),
451
+			channel.lists[ExceptMask].String(), channel.lists[InviteMask].String())
450 452
 	} else {
451 453
 		_, err = channel.server.db.Exec(`
452 454
             DELETE FROM channel WHERE name = ?`, channel.name)

+ 10
- 0
irc/client_lookup_set.go View File

@@ -219,6 +219,16 @@ func (set *UserMaskSet) Match(userhost string) bool {
219 219
 	return set.regexp.MatchString(userhost)
220 220
 }
221 221
 
222
+func (set *UserMaskSet) String() string {
223
+	masks := make([]string, len(set.masks))
224
+	index := 0
225
+	for mask := range set.masks {
226
+		masks[index] = mask
227
+		index += 1
228
+	}
229
+	return strings.Join(masks, " ")
230
+}
231
+
222 232
 func (set *UserMaskSet) setRegexp() {
223 233
 	if len(set.masks) == 0 {
224 234
 		set.regexp = nil

+ 20
- 4
irc/database.go View File

@@ -2,6 +2,7 @@ package irc
2 2
 
3 3
 import (
4 4
 	"database/sql"
5
+	"fmt"
5 6
 	_ "github.com/mattn/go-sqlite3"
6 7
 	"log"
7 8
 	"os"
@@ -14,15 +15,30 @@ func InitDB(path string) {
14 15
 	_, err := db.Exec(`
15 16
         CREATE TABLE channel (
16 17
           name TEXT NOT NULL UNIQUE,
17
-          flags TEXT NOT NULL,
18
-          key TEXT NOT NULL,
19
-          topic TEXT NOT NULL,
20
-          user_limit INTEGER DEFAULT 0)`)
18
+          flags TEXT DEFAULT '',
19
+          key TEXT DEFAULT '',
20
+          topic TEXT DEFAULT '',
21
+          user_limit INTEGER DEFAULT 0,
22
+          ban_list TEXT DEFAULT '',
23
+          except_list TEXT DEFAULT '',
24
+          invite_list TEXT DEFAULT '')`)
21 25
 	if err != nil {
22 26
 		log.Fatal("initdb error: ", err)
23 27
 	}
24 28
 }
25 29
 
30
+func UpgradeDB(path string) {
31
+	db := OpenDB(path)
32
+	alter := `ALTER TABLE channel ADD COLUMN %s TEXT DEFAULT ''`
33
+	cols := []string{"ban_list", "except_list", "invite_list"}
34
+	for _, col := range cols {
35
+		_, err := db.Exec(fmt.Sprintf(alter, col))
36
+		if err != nil {
37
+			log.Fatal("updatedb error: ", err)
38
+		}
39
+	}
40
+}
41
+
26 42
 func OpenDB(path string) *sql.DB {
27 43
 	db, err := sql.Open("sqlite3", path)
28 44
 	if err != nil {

+ 18
- 3
irc/server.go View File

@@ -64,9 +64,19 @@ func NewServer(config *Config) *Server {
64 64
 	return server
65 65
 }
66 66
 
67
+func loadChannelList(channel *Channel, list string, maskMode ChannelMode) {
68
+	if list == "" {
69
+		return
70
+	}
71
+	for _, mask := range strings.Split(list, " ") {
72
+		channel.lists[maskMode].Add(mask)
73
+	}
74
+}
75
+
67 76
 func (server *Server) loadChannels() {
68 77
 	rows, err := server.db.Query(`
69
-        SELECT name, flags, key, topic, user_limit
78
+        SELECT name, flags, key, topic, user_limit, ban_list, except_list,
79
+               invite_list
70 80
           FROM channel`)
71 81
 	if err != nil {
72 82
 		log.Fatal("error loading channels: ", err)
@@ -74,9 +84,11 @@ func (server *Server) loadChannels() {
74 84
 	for rows.Next() {
75 85
 		var name, flags, key, topic string
76 86
 		var userLimit uint64
77
-		err = rows.Scan(&name, &flags, &key, &topic, &userLimit)
87
+		var banList, exceptList, inviteList string
88
+		err = rows.Scan(&name, &flags, &key, &topic, &userLimit, &banList,
89
+			&exceptList, &inviteList)
78 90
 		if err != nil {
79
-			log.Println(err)
91
+			log.Println("Server.loadChannels:", err)
80 92
 			continue
81 93
 		}
82 94
 
@@ -87,6 +99,9 @@ func (server *Server) loadChannels() {
87 99
 		channel.key = key
88 100
 		channel.topic = topic
89 101
 		channel.userLimit = userLimit
102
+		loadChannelList(channel, banList, BanMask)
103
+		loadChannelList(channel, exceptList, ExceptMask)
104
+		loadChannelList(channel, inviteList, InviteMask)
90 105
 	}
91 106
 }
92 107
 

Loading…
Cancel
Save