Browse Source

Merge pull request #224 from slingamn/channelkeys.1

Updates to channel persistence
tags/v0.12.0
Daniel Oaks 6 years ago
parent
commit
c75d2c91c5
No account linked to committer's email address
8 changed files with 290 additions and 109 deletions
  1. 50
    20
      irc/channel.go
  2. 1
    1
      irc/channelmanager.go
  3. 92
    35
      irc/channelreg.go
  4. 1
    1
      irc/chanserv.go
  5. 137
    40
      irc/database.go
  6. 1
    1
      irc/getters.go
  7. 7
    10
      irc/handlers.go
  8. 1
    1
      oragono.go

+ 50
- 20
irc/channel.go View File

@@ -6,6 +6,7 @@
6 6
 package irc
7 7
 
8 8
 import (
9
+	"crypto/subtle"
9 10
 	"fmt"
10 11
 	"strconv"
11 12
 	"time"
@@ -36,11 +37,12 @@ type Channel struct {
36 37
 	topicSetBy        string
37 38
 	topicSetTime      time.Time
38 39
 	userLimit         uint64
40
+	accountToUMode    map[string]modes.Mode
39 41
 }
40 42
 
41 43
 // NewChannel creates a new channel from a `Server` and a `name`
42 44
 // string, which must be unique on the server.
43
-func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *RegisteredChannel) *Channel {
45
+func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
44 46
 	casefoldedName, err := CasefoldChannel(name)
45 47
 	if err != nil {
46 48
 		s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
@@ -59,18 +61,17 @@ func NewChannel(s *Server, name string, addDefaultModes bool, regInfo *Registere
59 61
 		name:           name,
60 62
 		nameCasefolded: casefoldedName,
61 63
 		server:         s,
64
+		accountToUMode: make(map[string]modes.Mode),
62 65
 	}
63 66
 
64
-	if addDefaultModes {
67
+	if regInfo != nil {
68
+		channel.applyRegInfo(regInfo)
69
+	} else {
65 70
 		for _, mode := range s.DefaultChannelModes() {
66 71
 			channel.flags[mode] = true
67 72
 		}
68 73
 	}
69 74
 
70
-	if regInfo != nil {
71
-		channel.applyRegInfo(regInfo)
72
-	}
73
-
74 75
 	return channel
75 76
 }
76 77
 
@@ -83,6 +84,11 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
83 84
 	channel.topicSetTime = chanReg.TopicSetTime
84 85
 	channel.name = chanReg.Name
85 86
 	channel.createdTime = chanReg.RegisteredAt
87
+	channel.key = chanReg.Key
88
+
89
+	for _, mode := range chanReg.Modes {
90
+		channel.flags[mode] = true
91
+	}
86 92
 	for _, mask := range chanReg.Banlist {
87 93
 		channel.lists[modes.BanMask].Add(mask)
88 94
 	}
@@ -92,21 +98,34 @@ func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
92 98
 	for _, mask := range chanReg.Invitelist {
93 99
 		channel.lists[modes.InviteMask].Add(mask)
94 100
 	}
101
+	for account, mode := range chanReg.AccountToUMode {
102
+		channel.accountToUMode[account] = mode
103
+	}
95 104
 }
96 105
 
97 106
 // obtain a consistent snapshot of the channel state that can be persisted to the DB
98
-func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredChannel) {
107
+func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) {
99 108
 	channel.stateMutex.RLock()
100 109
 	defer channel.stateMutex.RUnlock()
101 110
 
102 111
 	info.Name = channel.name
103
-	info.Topic = channel.topic
104
-	info.TopicSetBy = channel.topicSetBy
105
-	info.TopicSetTime = channel.topicSetTime
106 112
 	info.Founder = channel.registeredFounder
107 113
 	info.RegisteredAt = channel.registeredTime
108 114
 
109
-	if includeLists {
115
+	if includeFlags&IncludeTopic != 0 {
116
+		info.Topic = channel.topic
117
+		info.TopicSetBy = channel.topicSetBy
118
+		info.TopicSetTime = channel.topicSetTime
119
+	}
120
+
121
+	if includeFlags&IncludeModes != 0 {
122
+		info.Key = channel.key
123
+		for mode := range channel.flags {
124
+			info.Modes = append(info.Modes, mode)
125
+		}
126
+	}
127
+
128
+	if includeFlags&IncludeLists != 0 {
110 129
 		for mask := range channel.lists[modes.BanMask].masks {
111 130
 			info.Banlist = append(info.Banlist, mask)
112 131
 		}
@@ -116,6 +135,10 @@ func (channel *Channel) ExportRegistration(includeLists bool) (info RegisteredCh
116 135
 		for mask := range channel.lists[modes.InviteMask].masks {
117 136
 			info.Invitelist = append(info.Invitelist, mask)
118 137
 		}
138
+		info.AccountToUMode = make(map[string]modes.Mode)
139
+		for account, mode := range channel.accountToUMode {
140
+			info.AccountToUMode[account] = mode
141
+		}
119 142
 	}
120 143
 
121 144
 	return
@@ -131,6 +154,7 @@ func (channel *Channel) SetRegistered(founder string) error {
131 154
 	}
132 155
 	channel.registeredFounder = founder
133 156
 	channel.registeredTime = time.Now()
157
+	channel.accountToUMode[founder] = modes.ChannelFounder
134 158
 	return nil
135 159
 }
136 160
 
@@ -338,7 +362,12 @@ func (channel *Channel) IsFull() bool {
338 362
 
339 363
 // CheckKey returns true if the key is not set or matches the given key.
340 364
 func (channel *Channel) CheckKey(key string) bool {
341
-	return (channel.key == "") || (channel.key == key)
365
+	chkey := channel.Key()
366
+	if chkey == "" {
367
+		return true
368
+	}
369
+
370
+	return subtle.ConstantTimeCompare([]byte(key), []byte(chkey)) == 1
342 371
 }
343 372
 
344 373
 func (channel *Channel) IsEmpty() bool {
@@ -404,21 +433,22 @@ func (channel *Channel) Join(client *Client, key string, rb *ResponseBuffer) {
404 433
 
405 434
 	client.addChannel(channel)
406 435
 
436
+	account := client.Account()
437
+
407 438
 	// give channel mode if necessary
408
-	newChannel := firstJoin && !channel.IsRegistered()
439
+	channel.stateMutex.Lock()
440
+	newChannel := firstJoin && channel.registeredFounder == ""
441
+	mode, persistentModeExists := channel.accountToUMode[account]
409 442
 	var givenMode *modes.Mode
410
-	account := client.Account()
411
-	cffounder, _ := CasefoldName(channel.registeredFounder)
412
-	if account != "" && account == cffounder {
413
-		givenMode = &modes.ChannelFounder
443
+	if persistentModeExists {
444
+		givenMode = &mode
414 445
 	} else if newChannel {
415 446
 		givenMode = &modes.ChannelOperator
416 447
 	}
417 448
 	if givenMode != nil {
418
-		channel.stateMutex.Lock()
419 449
 		channel.members[client][*givenMode] = true
420
-		channel.stateMutex.Unlock()
421 450
 	}
451
+	channel.stateMutex.Unlock()
422 452
 
423 453
 	if client.capabilities.Has(caps.ExtendedJoin) {
424 454
 		rb.Add(nil, client.nickMaskString, "JOIN", channel.name, client.AccountName(), client.realname)
@@ -513,7 +543,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
513 543
 		}
514 544
 	}
515 545
 
516
-	go channel.server.channelRegistry.StoreChannel(channel, false)
546
+	go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
517 547
 }
518 548
 
519 549
 // CanSpeak returns true if the client can speak on this channel.

+ 1
- 1
irc/channelmanager.go View File

@@ -65,7 +65,7 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, rb *Resp
65 65
 		entry = cm.chans[casefoldedName]
66 66
 		if entry == nil {
67 67
 			entry = &channelManagerEntry{
68
-				channel:      NewChannel(server, name, true, info),
68
+				channel:      NewChannel(server, name, info),
69 69
 				pendingJoins: 0,
70 70
 			}
71 71
 			cm.chans[casefoldedName] = entry

+ 92
- 35
irc/channelreg.go View File

@@ -6,11 +6,13 @@ package irc
6 6
 import (
7 7
 	"fmt"
8 8
 	"strconv"
9
+	"strings"
9 10
 	"sync"
10 11
 	"time"
11 12
 
12 13
 	"encoding/json"
13 14
 
15
+	"github.com/oragono/oragono/irc/modes"
14 16
 	"github.com/tidwall/buntdb"
15 17
 )
16 18
 
@@ -18,16 +20,19 @@ import (
18 20
 // channel creation/tracking/destruction is in channelmanager.go
19 21
 
20 22
 const (
21
-	keyChannelExists       = "channel.exists %s"
22
-	keyChannelName         = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
23
-	keyChannelRegTime      = "channel.registered.time %s"
24
-	keyChannelFounder      = "channel.founder %s"
25
-	keyChannelTopic        = "channel.topic %s"
26
-	keyChannelTopicSetBy   = "channel.topic.setby %s"
27
-	keyChannelTopicSetTime = "channel.topic.settime %s"
28
-	keyChannelBanlist      = "channel.banlist %s"
29
-	keyChannelExceptlist   = "channel.exceptlist %s"
30
-	keyChannelInvitelist   = "channel.invitelist %s"
23
+	keyChannelExists         = "channel.exists %s"
24
+	keyChannelName           = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped
25
+	keyChannelRegTime        = "channel.registered.time %s"
26
+	keyChannelFounder        = "channel.founder %s"
27
+	keyChannelTopic          = "channel.topic %s"
28
+	keyChannelTopicSetBy     = "channel.topic.setby %s"
29
+	keyChannelTopicSetTime   = "channel.topic.settime %s"
30
+	keyChannelBanlist        = "channel.banlist %s"
31
+	keyChannelExceptlist     = "channel.exceptlist %s"
32
+	keyChannelInvitelist     = "channel.invitelist %s"
33
+	keyChannelPassword       = "channel.key %s"
34
+	keyChannelModes          = "channel.modes %s"
35
+	keyChannelAccountToUMode = "channel.accounttoumode %s"
31 36
 )
32 37
 
33 38
 var (
@@ -42,9 +47,26 @@ var (
42 47
 		keyChannelBanlist,
43 48
 		keyChannelExceptlist,
44 49
 		keyChannelInvitelist,
50
+		keyChannelPassword,
51
+		keyChannelModes,
52
+		keyChannelAccountToUMode,
45 53
 	}
46 54
 )
47 55
 
56
+// these are bit flags indicating what part of the channel status is "dirty"
57
+// and needs to be read from memory and written to the db
58
+const (
59
+	IncludeInitial uint = 1 << iota
60
+	IncludeTopic
61
+	IncludeModes
62
+	IncludeLists
63
+)
64
+
65
+// this is an OR of all possible flags
66
+const (
67
+	IncludeAllChannelAttrs = ^uint(0)
68
+)
69
+
48 70
 // RegisteredChannel holds details about a given registered channel.
49 71
 type RegisteredChannel struct {
50 72
 	// Name of the channel.
@@ -59,6 +81,12 @@ type RegisteredChannel struct {
59 81
 	TopicSetBy string
60 82
 	// TopicSetTime represents the time the topic was set.
61 83
 	TopicSetTime time.Time
84
+	// Modes represents the channel modes
85
+	Modes []modes.Mode
86
+	// Key represents the channel key / password
87
+	Key string
88
+	// AccountToUMode maps user accounts to their persistent channel modes (e.g., +q, +h)
89
+	AccountToUMode map[string]modes.Mode
62 90
 	// Banlist represents the bans set on the channel.
63 91
 	Banlist []string
64 92
 	// Exceptlist represents the exceptions set on the channel.
@@ -87,7 +115,7 @@ func NewChannelRegistry(server *Server) *ChannelRegistry {
87 115
 }
88 116
 
89 117
 // StoreChannel obtains a consistent view of a channel, then persists it to the store.
90
-func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
118
+func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
91 119
 	if !reg.server.ChannelRegistrationEnabled() {
92 120
 		return
93 121
 	}
@@ -96,14 +124,14 @@ func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeLists bool) {
96 124
 	defer reg.Unlock()
97 125
 
98 126
 	key := channel.NameCasefolded()
99
-	info := channel.ExportRegistration(includeLists)
127
+	info := channel.ExportRegistration(includeFlags)
100 128
 	if info.Founder == "" {
101 129
 		// sanity check, don't try to store an unregistered channel
102 130
 		return
103 131
 	}
104 132
 
105 133
 	reg.server.store.Update(func(tx *buntdb.Tx) error {
106
-		reg.saveChannel(tx, key, info, includeLists)
134
+		reg.saveChannel(tx, key, info, includeFlags)
107 135
 		return nil
108 136
 	})
109 137
 }
@@ -132,9 +160,17 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
132 160
 		topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey))
133 161
 		topicSetTime, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey))
134 162
 		topicSetTimeInt, _ := strconv.ParseInt(topicSetTime, 10, 64)
163
+		password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey))
164
+		modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey))
135 165
 		banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey))
136 166
 		exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey))
137 167
 		invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey))
168
+		accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey))
169
+
170
+		modeSlice := make([]modes.Mode, len(modeString))
171
+		for i, mode := range modeString {
172
+			modeSlice[i] = modes.Mode(mode)
173
+		}
138 174
 
139 175
 		var banlist []string
140 176
 		_ = json.Unmarshal([]byte(banlistString), &banlist)
@@ -142,17 +178,22 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
142 178
 		_ = json.Unmarshal([]byte(exceptlistString), &exceptlist)
143 179
 		var invitelist []string
144 180
 		_ = json.Unmarshal([]byte(invitelistString), &invitelist)
181
+		accountToUMode := make(map[string]modes.Mode)
182
+		_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
145 183
 
146 184
 		info = &RegisteredChannel{
147
-			Name:         name,
148
-			RegisteredAt: time.Unix(regTimeInt, 0),
149
-			Founder:      founder,
150
-			Topic:        topic,
151
-			TopicSetBy:   topicSetBy,
152
-			TopicSetTime: time.Unix(topicSetTimeInt, 0),
153
-			Banlist:      banlist,
154
-			Exceptlist:   exceptlist,
155
-			Invitelist:   invitelist,
185
+			Name:           name,
186
+			RegisteredAt:   time.Unix(regTimeInt, 0),
187
+			Founder:        founder,
188
+			Topic:          topic,
189
+			TopicSetBy:     topicSetBy,
190
+			TopicSetTime:   time.Unix(topicSetTimeInt, 0),
191
+			Key:            password,
192
+			Modes:          modeSlice,
193
+			Banlist:        banlist,
194
+			Exceptlist:     exceptlist,
195
+			Invitelist:     invitelist,
196
+			AccountToUMode: accountToUMode,
156 197
 		}
157 198
 		return nil
158 199
 	})
@@ -170,17 +211,17 @@ func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
170 211
 	reg.Lock()
171 212
 	defer reg.Unlock()
172 213
 
173
-	includeLists := true
214
+	includeFlags := IncludeAllChannelAttrs
174 215
 	oldKey := casefoldedOldName
175 216
 	key := channel.NameCasefolded()
176
-	info := channel.ExportRegistration(includeLists)
217
+	info := channel.ExportRegistration(includeFlags)
177 218
 	if info.Founder == "" {
178 219
 		return
179 220
 	}
180 221
 
181 222
 	reg.server.store.Update(func(tx *buntdb.Tx) error {
182 223
 		reg.deleteChannel(tx, oldKey, info)
183
-		reg.saveChannel(tx, key, info, includeLists)
224
+		reg.saveChannel(tx, key, info, includeFlags)
184 225
 		return nil
185 226
 	})
186 227
 }
@@ -204,21 +245,37 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
204 245
 }
205 246
 
206 247
 // saveChannel saves a channel to the store.
207
-func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeLists bool) {
208
-	tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
209
-	tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
210
-	tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
211
-	tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
212
-	tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
213
-	tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
214
-	tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
215
-
216
-	if includeLists {
248
+func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
249
+	if includeFlags&IncludeInitial != 0 {
250
+		tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil)
251
+		tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil)
252
+		tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.Unix(), 10), nil)
253
+		tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil)
254
+	}
255
+
256
+	if includeFlags&IncludeTopic != 0 {
257
+		tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil)
258
+		tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), strconv.FormatInt(channelInfo.TopicSetTime.Unix(), 10), nil)
259
+		tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil)
260
+	}
261
+
262
+	if includeFlags&IncludeModes != 0 {
263
+		tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil)
264
+		modeStrings := make([]string, len(channelInfo.Modes))
265
+		for i, mode := range channelInfo.Modes {
266
+			modeStrings[i] = string(mode)
267
+		}
268
+		tx.Set(fmt.Sprintf(keyChannelModes, channelKey), strings.Join(modeStrings, ""), nil)
269
+	}
270
+
271
+	if includeFlags&IncludeLists != 0 {
217 272
 		banlistString, _ := json.Marshal(channelInfo.Banlist)
218 273
 		tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil)
219 274
 		exceptlistString, _ := json.Marshal(channelInfo.Exceptlist)
220 275
 		tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil)
221 276
 		invitelistString, _ := json.Marshal(channelInfo.Invitelist)
222 277
 		tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil)
278
+		accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode)
279
+		tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil)
223 280
 	}
224 281
 }

+ 1
- 1
irc/chanserv.go View File

@@ -252,7 +252,7 @@ func csRegisterHandler(server *Server, client *Client, command, params string, r
252 252
 	}
253 253
 
254 254
 	// registration was successful: make the database reflect it
255
-	go server.channelRegistry.StoreChannel(channelInfo, true)
255
+	go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
256 256
 
257 257
 	csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
258 258
 

+ 137
- 40
irc/database.go View File

@@ -6,11 +6,13 @@ package irc
6 6
 
7 7
 import (
8 8
 	"encoding/base64"
9
+	"encoding/json"
9 10
 	"fmt"
10 11
 	"log"
11 12
 	"os"
12 13
 	"strings"
13 14
 
15
+	"github.com/oragono/oragono/irc/modes"
14 16
 	"github.com/oragono/oragono/irc/passwd"
15 17
 
16 18
 	"github.com/tidwall/buntdb"
@@ -20,11 +22,22 @@ const (
20 22
 	// 'version' of the database schema
21 23
 	keySchemaVersion = "db.version"
22 24
 	// latest schema of the db
23
-	latestDbSchema = "2"
25
+	latestDbSchema = "3"
24 26
 	// key for the primary salt used by the ircd
25 27
 	keySalt = "crypto.salt"
26 28
 )
27 29
 
30
+type SchemaChanger func(*Config, *buntdb.Tx) error
31
+
32
+type SchemaChange struct {
33
+	InitialVersion string // the change will take this version
34
+	TargetVersion  string // and transform it into this version
35
+	Changer        SchemaChanger
36
+}
37
+
38
+// maps an initial version to a schema change capable of upgrading it
39
+var schemaChanges map[string]SchemaChange
40
+
28 41
 // InitDB creates the database.
29 42
 func InitDB(path string) {
30 43
 	// prepare kvstore db
@@ -46,7 +59,7 @@ func InitDB(path string) {
46 59
 		tx.Set(keySalt, encodedSalt, nil)
47 60
 
48 61
 		// set schema version
49
-		tx.Set(keySchemaVersion, "2", nil)
62
+		tx.Set(keySchemaVersion, latestDbSchema, nil)
50 63
 		return nil
51 64
 	})
52 65
 
@@ -82,58 +95,142 @@ func OpenDatabase(path string) (*buntdb.DB, error) {
82 95
 }
83 96
 
84 97
 // UpgradeDB upgrades the datastore to the latest schema.
85
-func UpgradeDB(path string) {
86
-	store, err := buntdb.Open(path)
98
+func UpgradeDB(config *Config) {
99
+	store, err := buntdb.Open(config.Datastore.Path)
87 100
 	if err != nil {
88 101
 		log.Fatal(fmt.Sprintf("Failed to open datastore: %s", err.Error()))
89 102
 	}
90 103
 	defer store.Close()
91 104
 
105
+	var version string
92 106
 	err = store.Update(func(tx *buntdb.Tx) error {
93
-		version, _ := tx.Get(keySchemaVersion)
94
-
95
-		// == version 1 -> 2 ==
96
-		// account key changes and account.verified key bugfix.
97
-		if version == "1" {
98
-			log.Println("Updating store v1 to v2")
99
-
100
-			var keysToRemove []string
101
-			newKeys := make(map[string]string)
102
-
103
-			tx.AscendKeys("account *", func(key, value string) bool {
104
-				keysToRemove = append(keysToRemove, key)
105
-				splitkey := strings.Split(key, " ")
106
-
107
-				// work around bug
108
-				if splitkey[2] == "exists" {
109
-					// manually create new verified key
110
-					newVerifiedKey := fmt.Sprintf("%s.verified %s", splitkey[0], splitkey[1])
111
-					newKeys[newVerifiedKey] = "1"
112
-				} else if splitkey[1] == "%s" {
113
-					return true
114
-				}
115
-
116
-				newKey := fmt.Sprintf("%s.%s %s", splitkey[0], splitkey[2], splitkey[1])
117
-				newKeys[newKey] = value
118
-
119
-				return true
120
-			})
121
-
122
-			for _, key := range keysToRemove {
123
-				tx.Delete(key)
107
+		for {
108
+			version, _ = tx.Get(keySchemaVersion)
109
+			change, schemaNeedsChange := schemaChanges[version]
110
+			if !schemaNeedsChange {
111
+				break
124 112
 			}
125
-			for key, value := range newKeys {
126
-				tx.Set(key, value, nil)
113
+			log.Println("attempting to update store from version " + version)
114
+			err := change.Changer(config, tx)
115
+			if err != nil {
116
+				return err
127 117
 			}
128
-
129
-			tx.Set(keySchemaVersion, "2", nil)
118
+			_, _, err = tx.Set(keySchemaVersion, change.TargetVersion, nil)
119
+			if err != nil {
120
+				return err
121
+			}
122
+			log.Println("successfully updated store to version " + change.TargetVersion)
130 123
 		}
131
-
132 124
 		return nil
133 125
 	})
126
+
134 127
 	if err != nil {
135 128
 		log.Fatal("Could not update datastore:", err.Error())
136 129
 	}
137 130
 
138 131
 	return
139 132
 }
133
+
134
+func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error {
135
+	// == version 1 -> 2 ==
136
+	// account key changes and account.verified key bugfix.
137
+
138
+	var keysToRemove []string
139
+	newKeys := make(map[string]string)
140
+
141
+	tx.AscendKeys("account *", func(key, value string) bool {
142
+		keysToRemove = append(keysToRemove, key)
143
+		splitkey := strings.Split(key, " ")
144
+
145
+		// work around bug
146
+		if splitkey[2] == "exists" {
147
+			// manually create new verified key
148
+			newVerifiedKey := fmt.Sprintf("%s.verified %s", splitkey[0], splitkey[1])
149
+			newKeys[newVerifiedKey] = "1"
150
+		} else if splitkey[1] == "%s" {
151
+			return true
152
+		}
153
+
154
+		newKey := fmt.Sprintf("%s.%s %s", splitkey[0], splitkey[2], splitkey[1])
155
+		newKeys[newKey] = value
156
+
157
+		return true
158
+	})
159
+
160
+	for _, key := range keysToRemove {
161
+		tx.Delete(key)
162
+	}
163
+	for key, value := range newKeys {
164
+		tx.Set(key, value, nil)
165
+	}
166
+
167
+	return nil
168
+}
169
+
170
+// 1. channel founder names should be casefolded
171
+// 2. founder should be explicitly granted the ChannelFounder user mode
172
+// 3. explicitly initialize stored channel modes to the server default values
173
+func schemaChangeV2ToV3(config *Config, tx *buntdb.Tx) error {
174
+	var channels []string
175
+	prefix := "channel.exists "
176
+	tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
177
+		if !strings.HasPrefix(key, prefix) {
178
+			return false
179
+		}
180
+		chname := strings.TrimPrefix(key, prefix)
181
+		channels = append(channels, chname)
182
+		return true
183
+	})
184
+
185
+	// founder names should be casefolded
186
+	// founder should be explicitly granted the ChannelFounder user mode
187
+	for _, channel := range channels {
188
+		founderKey := "channel.founder " + channel
189
+		founder, _ := tx.Get(founderKey)
190
+		if founder != "" {
191
+			founder, err := CasefoldName(founder)
192
+			if err == nil {
193
+				tx.Set(founderKey, founder, nil)
194
+				accountToUmode := map[string]modes.Mode{
195
+					founder: modes.ChannelFounder,
196
+				}
197
+				atustr, _ := json.Marshal(accountToUmode)
198
+				tx.Set("channel.accounttoumode "+channel, string(atustr), nil)
199
+			}
200
+		}
201
+	}
202
+
203
+	// explicitly store the channel modes
204
+	defaultModes := ParseDefaultChannelModes(config)
205
+	modeStrings := make([]string, len(defaultModes))
206
+	for i, mode := range defaultModes {
207
+		modeStrings[i] = string(mode)
208
+	}
209
+	defaultModeString := strings.Join(modeStrings, "")
210
+	for _, channel := range channels {
211
+		tx.Set("channel.modes "+channel, defaultModeString, nil)
212
+	}
213
+
214
+	return nil
215
+}
216
+
217
+func init() {
218
+	allChanges := []SchemaChange{
219
+		SchemaChange{
220
+			InitialVersion: "1",
221
+			TargetVersion:  "2",
222
+			Changer:        schemaChangeV1toV2,
223
+		},
224
+		SchemaChange{
225
+			InitialVersion: "2",
226
+			TargetVersion:  "3",
227
+			Changer:        schemaChangeV2ToV3,
228
+		},
229
+	}
230
+
231
+	// build the index
232
+	schemaChanges = make(map[string]SchemaChange)
233
+	for _, change := range allChanges {
234
+		schemaChanges[change.InitialVersion] = change
235
+	}
236
+}

+ 1
- 1
irc/getters.go View File

@@ -256,8 +256,8 @@ func (channel *Channel) Key() string {
256 256
 
257 257
 func (channel *Channel) setKey(key string) {
258 258
 	channel.stateMutex.Lock()
259
+	defer channel.stateMutex.Unlock()
259 260
 	channel.key = key
260
-	channel.stateMutex.Unlock()
261 261
 }
262 262
 
263 263
 func (channel *Channel) HasMode(mode modes.Mode) bool {

+ 7
- 10
irc/handlers.go View File

@@ -1351,20 +1351,17 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
1351 1351
 		applied = channel.ApplyChannelModeChanges(client, msg.Command == "SAMODE", changes, rb)
1352 1352
 	}
1353 1353
 
1354
-	// save changes to banlist/exceptlist/invexlist
1355
-	var banlistUpdated, exceptlistUpdated, invexlistUpdated bool
1354
+	// save changes
1355
+	var includeFlags uint
1356 1356
 	for _, change := range applied {
1357
-		if change.Mode == modes.BanMask {
1358
-			banlistUpdated = true
1359
-		} else if change.Mode == modes.ExceptMask {
1360
-			exceptlistUpdated = true
1361
-		} else if change.Mode == modes.InviteMask {
1362
-			invexlistUpdated = true
1357
+		includeFlags |= IncludeModes
1358
+		if change.Mode == modes.BanMask || change.Mode == modes.ExceptMask || change.Mode == modes.InviteMask {
1359
+			includeFlags |= IncludeLists
1363 1360
 		}
1364 1361
 	}
1365 1362
 
1366
-	if (banlistUpdated || exceptlistUpdated || invexlistUpdated) && channel.IsRegistered() {
1367
-		go server.channelRegistry.StoreChannel(channel, true)
1363
+	if channel.IsRegistered() && includeFlags != 0 {
1364
+		go server.channelRegistry.StoreChannel(channel, includeFlags)
1368 1365
 	}
1369 1366
 
1370 1367
 	// send out changes

+ 1
- 1
oragono.go View File

@@ -84,7 +84,7 @@ Options:
84 84
 			log.Println("database initialized: ", config.Datastore.Path)
85 85
 		}
86 86
 	} else if arguments["upgradedb"].(bool) {
87
-		irc.UpgradeDB(config.Datastore.Path)
87
+		irc.UpgradeDB(config)
88 88
 		if !arguments["--quiet"].(bool) {
89 89
 			log.Println("database upgraded: ", config.Datastore.Path)
90 90
 		}

Loading…
Cancel
Save