Browse Source

refactor channel registration

tags/v1.1.0-rc1
Shivaram Lingamneni 5 years ago
parent
commit
63029e2ff5

+ 9
- 17
irc/accounts.go View File

@@ -59,18 +59,15 @@ type AccountManager struct {
59 59
 	accountToMethod   map[string]NickReservationMethod
60 60
 }
61 61
 
62
-func NewAccountManager(server *Server) *AccountManager {
63
-	am := AccountManager{
64
-		accountToClients:  make(map[string][]*Client),
65
-		nickToAccount:     make(map[string]string),
66
-		skeletonToAccount: make(map[string]string),
67
-		accountToMethod:   make(map[string]NickReservationMethod),
68
-		server:            server,
69
-	}
62
+func (am *AccountManager) Initialize(server *Server) {
63
+	am.accountToClients = make(map[string][]*Client)
64
+	am.nickToAccount = make(map[string]string)
65
+	am.skeletonToAccount = make(map[string]string)
66
+	am.accountToMethod = make(map[string]NickReservationMethod)
67
+	am.server = server
70 68
 
71 69
 	am.buildNickToAccountIndex()
72 70
 	am.initVHostRequestQueue()
73
-	return &am
74 71
 }
75 72
 
76 73
 func (am *AccountManager) buildNickToAccountIndex() {
@@ -855,6 +852,7 @@ func (am *AccountManager) Unregister(account string) error {
855 852
 	verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount)
856 853
 	verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount)
857 854
 	nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
855
+	enforcementKey := fmt.Sprintf(keyAccountEnforcement, casefoldedAccount)
858 856
 	vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
859 857
 	vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
860 858
 	channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
@@ -865,14 +863,7 @@ func (am *AccountManager) Unregister(account string) error {
865 863
 	// on our way out, unregister all the account's channels and delete them from the db
866 864
 	defer func() {
867 865
 		for _, channelName := range registeredChannels {
868
-			info := am.server.channelRegistry.LoadChannel(channelName)
869
-			if info != nil && info.Founder == casefoldedAccount {
870
-				am.server.channelRegistry.Delete(channelName, *info)
871
-			}
872
-			channel := am.server.channels.Get(channelName)
873
-			if channel != nil {
874
-				channel.SetUnregistered(casefoldedAccount)
875
-			}
866
+			am.server.channels.SetUnregistered(channelName, casefoldedAccount)
876 867
 		}
877 868
 	}()
878 869
 
@@ -892,6 +883,7 @@ func (am *AccountManager) Unregister(account string) error {
892 883
 		tx.Delete(registeredTimeKey)
893 884
 		tx.Delete(callbackKey)
894 885
 		tx.Delete(verificationCodeKey)
886
+		tx.Delete(enforcementKey)
895 887
 		rawNicks, _ = tx.Get(nicksKey)
896 888
 		tx.Delete(nicksKey)
897 889
 		credText, err = tx.Get(credentialsKey)

+ 0
- 1
irc/caps/set.go View File

@@ -16,7 +16,6 @@ type Set [bitsetLen]uint64
16 16
 // NewSet returns a new Set, with the given capabilities enabled.
17 17
 func NewSet(capabs ...Capability) *Set {
18 18
 	var newSet Set
19
-	utils.BitsetInitialize(newSet[:])
20 19
 	newSet.Enable(capabs...)
21 20
 	return &newSet
22 21
 }

+ 147
- 13
irc/channel.go View File

@@ -22,7 +22,7 @@ import (
22 22
 
23 23
 // Channel represents a channel that clients can join.
24 24
 type Channel struct {
25
-	flags             *modes.ModeSet
25
+	flags             modes.ModeSet
26 26
 	lists             map[modes.Mode]*UserMaskSet
27 27
 	key               string
28 28
 	members           MemberSet
@@ -33,19 +33,22 @@ type Channel struct {
33 33
 	createdTime       time.Time
34 34
 	registeredFounder string
35 35
 	registeredTime    time.Time
36
-	stateMutex        sync.RWMutex // tier 1
37
-	joinPartMutex     sync.Mutex   // tier 3
38 36
 	topic             string
39 37
 	topicSetBy        string
40 38
 	topicSetTime      time.Time
41 39
 	userLimit         int
42 40
 	accountToUMode    map[string]modes.Mode
43 41
 	history           history.Buffer
42
+	stateMutex        sync.RWMutex // tier 1
43
+	writerSemaphore   Semaphore    // tier 1.5
44
+	joinPartMutex     sync.Mutex   // tier 3
45
+	ensureLoaded      utils.Once   // manages loading stored registration info from the database
46
+	dirtyBits         uint
44 47
 }
45 48
 
46 49
 // NewChannel creates a new channel from a `Server` and a `name`
47 50
 // string, which must be unique on the server.
48
-func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
51
+func NewChannel(s *Server, name string, registered bool) *Channel {
49 52
 	casefoldedName, err := CasefoldChannel(name)
50 53
 	if err != nil {
51 54
 		s.logger.Error("internal", "Bad channel name", name, err.Error())
@@ -54,7 +57,6 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
54 57
 
55 58
 	channel := &Channel{
56 59
 		createdTime: time.Now(), // may be overwritten by applyRegInfo
57
-		flags:       modes.NewModeSet(),
58 60
 		lists: map[modes.Mode]*UserMaskSet{
59 61
 			modes.BanMask:    NewUserMaskSet(),
60 62
 			modes.ExceptMask: NewUserMaskSet(),
@@ -69,21 +71,43 @@ func NewChannel(s *Server, name string, regInfo *RegisteredChannel) *Channel {
69 71
 
70 72
 	config := s.Config()
71 73
 
72
-	if regInfo != nil {
73
-		channel.applyRegInfo(regInfo)
74
-	} else {
74
+	channel.writerSemaphore.Initialize(1)
75
+	channel.history.Initialize(config.History.ChannelLength)
76
+
77
+	if !registered {
75 78
 		for _, mode := range config.Channels.defaultModes {
76 79
 			channel.flags.SetMode(mode, true)
77 80
 		}
78
-	}
79
-
80
-	channel.history.Initialize(config.History.ChannelLength)
81
+		// no loading to do, so "mark" the load operation as "done":
82
+		channel.ensureLoaded.Do(func() {})
83
+	} // else: modes will be loaded before first join
81 84
 
82 85
 	return channel
83 86
 }
84 87
 
88
+// EnsureLoaded blocks until the channel's registration info has been loaded
89
+// from the database.
90
+func (channel *Channel) EnsureLoaded() {
91
+	channel.ensureLoaded.Do(func() {
92
+		nmc := channel.NameCasefolded()
93
+		info, err := channel.server.channelRegistry.LoadChannel(nmc)
94
+		if err == nil {
95
+			channel.applyRegInfo(info)
96
+		} else {
97
+			channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error())
98
+		}
99
+	})
100
+}
101
+
102
+func (channel *Channel) IsLoaded() bool {
103
+	return channel.ensureLoaded.Done()
104
+}
105
+
85 106
 // read in channel state that was persisted in the DB
86
-func (channel *Channel) applyRegInfo(chanReg *RegisteredChannel) {
107
+func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
108
+	channel.stateMutex.Lock()
109
+	defer channel.stateMutex.Unlock()
110
+
87 111
 	channel.registeredFounder = chanReg.Founder
88 112
 	channel.registeredTime = chanReg.RegisteredAt
89 113
 	channel.topic = chanReg.Topic
@@ -116,6 +140,7 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
116 140
 	defer channel.stateMutex.RUnlock()
117 141
 
118 142
 	info.Name = channel.name
143
+	info.NameCasefolded = channel.nameCasefolded
119 144
 	info.Founder = channel.registeredFounder
120 145
 	info.RegisteredAt = channel.registeredTime
121 146
 
@@ -149,6 +174,115 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
149 174
 	return
150 175
 }
151 176
 
177
+// begin: asynchronous database writeback implementation, modeled on irc/socket.go
178
+
179
+// MarkDirty marks part (or all) of a channel's data as needing to be written back
180
+// to the database, then starts a writer goroutine if necessary.
181
+// This is the equivalent of Socket.Write().
182
+func (channel *Channel) MarkDirty(dirtyBits uint) {
183
+	channel.stateMutex.Lock()
184
+	isRegistered := channel.registeredFounder != ""
185
+	channel.dirtyBits = channel.dirtyBits | dirtyBits
186
+	channel.stateMutex.Unlock()
187
+	if !isRegistered {
188
+		return
189
+	}
190
+
191
+	channel.wakeWriter()
192
+}
193
+
194
+// IsClean returns whether a channel can be safely removed from the server.
195
+// To avoid the obvious TOCTOU race condition, it must be called while holding
196
+// ChannelManager's lock (that way, no one can join and make the channel dirty again
197
+// between this method exiting and the actual deletion).
198
+func (channel *Channel) IsClean() bool {
199
+	if !channel.writerSemaphore.TryAcquire() {
200
+		// a database write (which may fail) is in progress, the channel cannot be cleaned up
201
+		return false
202
+	}
203
+	defer channel.writerSemaphore.Release()
204
+
205
+	channel.stateMutex.RLock()
206
+	defer channel.stateMutex.RUnlock()
207
+	// the channel must be empty, and either be unregistered or fully written to the DB
208
+	return len(channel.members) == 0 && (channel.registeredFounder == "" || channel.dirtyBits == 0)
209
+}
210
+
211
+func (channel *Channel) wakeWriter() {
212
+	if channel.writerSemaphore.TryAcquire() {
213
+		go channel.writeLoop()
214
+	}
215
+}
216
+
217
+// equivalent of Socket.send()
218
+func (channel *Channel) writeLoop() {
219
+	for {
220
+		// TODO(#357) check the error value of this and implement timed backoff
221
+		channel.performWrite(0)
222
+		channel.writerSemaphore.Release()
223
+
224
+		channel.stateMutex.RLock()
225
+		isDirty := channel.dirtyBits != 0
226
+		isEmpty := len(channel.members) == 0
227
+		channel.stateMutex.RUnlock()
228
+
229
+		if !isDirty {
230
+			if isEmpty {
231
+				channel.server.channels.Cleanup(channel)
232
+			}
233
+			return // nothing to do
234
+		} // else: isDirty, so we need to write again
235
+
236
+		if !channel.writerSemaphore.TryAcquire() {
237
+			return
238
+		}
239
+	}
240
+}
241
+
242
+// Store writes part (or all) of the channel's data back to the database,
243
+// blocking until the write is complete. This is the equivalent of
244
+// Socket.BlockingWrite.
245
+func (channel *Channel) Store(dirtyBits uint) (err error) {
246
+	defer func() {
247
+		channel.stateMutex.Lock()
248
+		isDirty := channel.dirtyBits != 0
249
+		isEmpty := len(channel.members) == 0
250
+		channel.stateMutex.Unlock()
251
+
252
+		if isDirty {
253
+			channel.wakeWriter()
254
+		} else if isEmpty {
255
+			channel.server.channels.Cleanup(channel)
256
+		}
257
+	}()
258
+
259
+	channel.writerSemaphore.Acquire()
260
+	defer channel.writerSemaphore.Release()
261
+	return channel.performWrite(dirtyBits)
262
+}
263
+
264
+// do an individual write; equivalent of Socket.send()
265
+func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) {
266
+	channel.stateMutex.Lock()
267
+	dirtyBits := channel.dirtyBits | additionalDirtyBits
268
+	channel.dirtyBits = 0
269
+	isRegistered := channel.registeredFounder != ""
270
+	channel.stateMutex.Unlock()
271
+
272
+	if !isRegistered || dirtyBits == 0 {
273
+		return
274
+	}
275
+
276
+	info := channel.ExportRegistration(dirtyBits)
277
+	err = channel.server.channelRegistry.StoreChannel(info, dirtyBits)
278
+	if err != nil {
279
+		channel.stateMutex.Lock()
280
+		channel.dirtyBits = channel.dirtyBits | dirtyBits
281
+		channel.stateMutex.Unlock()
282
+	}
283
+	return
284
+}
285
+
152 286
 // SetRegistered registers the channel, returning an error if it was already registered.
153 287
 func (channel *Channel) SetRegistered(founder string) error {
154 288
 	channel.stateMutex.Lock()
@@ -698,7 +832,7 @@ func (channel *Channel) SetTopic(client *Client, topic string, rb *ResponseBuffe
698 832
 		}
699 833
 	}
700 834
 
701
-	go channel.server.channelRegistry.StoreChannel(channel, IncludeTopic)
835
+	channel.MarkDirty(IncludeTopic)
702 836
 }
703 837
 
704 838
 // CanSpeak returns true if the client can speak on this channel.

+ 119
- 43
irc/channelmanager.go View File

@@ -19,25 +19,38 @@ type channelManagerEntry struct {
19 19
 // providing synchronization for creation of new channels on first join,
20 20
 // cleanup of empty channels on last part, and renames.
21 21
 type ChannelManager struct {
22
-	sync.RWMutex // tier 2
23
-	chans        map[string]*channelManagerEntry
22
+	sync.RWMutex       // tier 2
23
+	chans              map[string]*channelManagerEntry
24
+	registeredChannels map[string]bool
25
+	server             *Server
24 26
 }
25 27
 
26 28
 // NewChannelManager returns a new ChannelManager.
27
-func NewChannelManager() *ChannelManager {
28
-	return &ChannelManager{
29
-		chans: make(map[string]*channelManagerEntry),
29
+func (cm *ChannelManager) Initialize(server *Server) {
30
+	cm.chans = make(map[string]*channelManagerEntry)
31
+	cm.server = server
32
+
33
+	if server.Config().Channels.Registration.Enabled {
34
+		cm.loadRegisteredChannels()
30 35
 	}
31 36
 }
32 37
 
38
+func (cm *ChannelManager) loadRegisteredChannels() {
39
+	registeredChannels := cm.server.channelRegistry.AllChannels()
40
+	cm.Lock()
41
+	defer cm.Unlock()
42
+	cm.registeredChannels = registeredChannels
43
+}
44
+
33 45
 // Get returns an existing channel with name equivalent to `name`, or nil
34
-func (cm *ChannelManager) Get(name string) *Channel {
46
+func (cm *ChannelManager) Get(name string) (channel *Channel) {
35 47
 	name, err := CasefoldChannel(name)
36 48
 	if err == nil {
37 49
 		cm.RLock()
38 50
 		defer cm.RUnlock()
39 51
 		entry := cm.chans[name]
40
-		if entry != nil {
52
+		// if the channel is still loading, pretend we don't have it
53
+		if entry != nil && entry.channel.IsLoaded() {
41 54
 			return entry.channel
42 55
 		}
43 56
 	}
@@ -55,28 +68,21 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin
55 68
 	cm.Lock()
56 69
 	entry := cm.chans[casefoldedName]
57 70
 	if entry == nil {
58
-		// XXX give up the lock to check for a registration, then check again
59
-		// to see if we need to create the channel. we could solve this by doing LoadChannel
60
-		// outside the lock initially on every join, so this is best thought of as an
61
-		// optimization to avoid that.
62
-		cm.Unlock()
63
-		info := client.server.channelRegistry.LoadChannel(casefoldedName)
64
-		cm.Lock()
65
-		entry = cm.chans[casefoldedName]
66
-		if entry == nil {
67
-			entry = &channelManagerEntry{
68
-				channel:      NewChannel(server, name, info),
69
-				pendingJoins: 0,
70
-			}
71
-			cm.chans[casefoldedName] = entry
71
+		registered := cm.registeredChannels[casefoldedName]
72
+		entry = &channelManagerEntry{
73
+			channel:      NewChannel(server, name, registered),
74
+			pendingJoins: 0,
72 75
 		}
76
+		cm.chans[casefoldedName] = entry
73 77
 	}
74 78
 	entry.pendingJoins += 1
79
+	channel := entry.channel
75 80
 	cm.Unlock()
76 81
 
77
-	entry.channel.Join(client, key, isSajoin, rb)
82
+	channel.EnsureLoaded()
83
+	channel.Join(client, key, isSajoin, rb)
78 84
 
79
-	cm.maybeCleanup(entry.channel, true)
85
+	cm.maybeCleanup(channel, true)
80 86
 
81 87
 	return nil
82 88
 }
@@ -85,7 +91,8 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
85 91
 	cm.Lock()
86 92
 	defer cm.Unlock()
87 93
 
88
-	entry := cm.chans[channel.NameCasefolded()]
94
+	nameCasefolded := channel.NameCasefolded()
95
+	entry := cm.chans[nameCasefolded]
89 96
 	if entry == nil || entry.channel != channel {
90 97
 		return
91 98
 	}
@@ -93,23 +100,15 @@ func (cm *ChannelManager) maybeCleanup(channel *Channel, afterJoin bool) {
93 100
 	if afterJoin {
94 101
 		entry.pendingJoins -= 1
95 102
 	}
96
-	// TODO(slingamn) right now, registered channels cannot be cleaned up.
97
-	// this is because once ChannelManager becomes the source of truth about a channel,
98
-	// we can't move the source of truth back to the database unless we do an ACID
99
-	// store while holding the ChannelManager's Lock(). This is pending more decisions
100
-	// about where the database transaction lock fits into the overall lock model.
101
-	if !entry.channel.IsRegistered() && entry.channel.IsEmpty() && entry.pendingJoins == 0 {
102
-		// reread the name, handling the case where the channel was renamed
103
-		casefoldedName := entry.channel.NameCasefolded()
104
-		delete(cm.chans, casefoldedName)
105
-		// invalidate the entry (otherwise, a subsequent cleanup attempt could delete
106
-		// a valid, distinct entry under casefoldedName):
107
-		entry.channel = nil
103
+	if entry.pendingJoins == 0 && entry.channel.IsClean() {
104
+		delete(cm.chans, nameCasefolded)
108 105
 	}
109 106
 }
110 107
 
111 108
 // Part parts `client` from the channel named `name`, deleting it if it's empty.
112 109
 func (cm *ChannelManager) Part(client *Client, name string, message string, rb *ResponseBuffer) error {
110
+	var channel *Channel
111
+
113 112
 	casefoldedName, err := CasefoldChannel(name)
114 113
 	if err != nil {
115 114
 		return errNoSuchChannel
@@ -117,12 +116,15 @@ func (cm *ChannelManager) Part(client *Client, name string, message string, rb *
117 116
 
118 117
 	cm.RLock()
119 118
 	entry := cm.chans[casefoldedName]
119
+	if entry != nil {
120
+		channel = entry.channel
121
+	}
120 122
 	cm.RUnlock()
121 123
 
122
-	if entry == nil {
124
+	if channel == nil {
123 125
 		return errNoSuchChannel
124 126
 	}
125
-	entry.channel.Part(client, message, rb)
127
+	channel.Part(client, message, rb)
126 128
 	return nil
127 129
 }
128 130
 
@@ -130,8 +132,68 @@ func (cm *ChannelManager) Cleanup(channel *Channel) {
130 132
 	cm.maybeCleanup(channel, false)
131 133
 }
132 134
 
135
+func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) {
136
+	var channel *Channel
137
+	cfname, err := CasefoldChannel(channelName)
138
+	if err != nil {
139
+		return err
140
+	}
141
+
142
+	var entry *channelManagerEntry
143
+
144
+	defer func() {
145
+		if err == nil && channel != nil {
146
+			// registration was successful: make the database reflect it
147
+			err = channel.Store(IncludeAllChannelAttrs)
148
+		}
149
+	}()
150
+
151
+	cm.Lock()
152
+	defer cm.Unlock()
153
+	entry = cm.chans[cfname]
154
+	if entry == nil {
155
+		return errNoSuchChannel
156
+	}
157
+	channel = entry.channel
158
+	err = channel.SetRegistered(account)
159
+	if err != nil {
160
+		return err
161
+	}
162
+	cm.registeredChannels[cfname] = true
163
+	return nil
164
+}
165
+
166
+func (cm *ChannelManager) SetUnregistered(channelName string, account string) (err error) {
167
+	cfname, err := CasefoldChannel(channelName)
168
+	if err != nil {
169
+		return err
170
+	}
171
+
172
+	var info RegisteredChannel
173
+
174
+	defer func() {
175
+		if err == nil {
176
+			err = cm.server.channelRegistry.Delete(info)
177
+		}
178
+	}()
179
+
180
+	cm.Lock()
181
+	defer cm.Unlock()
182
+	entry := cm.chans[cfname]
183
+	if entry == nil {
184
+		return errNoSuchChannel
185
+	}
186
+	info = entry.channel.ExportRegistration(0)
187
+	if info.Founder != account {
188
+		return errChannelNotOwnedByAccount
189
+	}
190
+	entry.channel.SetUnregistered(account)
191
+	delete(cm.registeredChannels, cfname)
192
+	return nil
193
+}
194
+
133 195
 // Rename renames a channel (but does not notify the members)
134
-func (cm *ChannelManager) Rename(name string, newname string) error {
196
+func (cm *ChannelManager) Rename(name string, newname string) (err error) {
135 197
 	cfname, err := CasefoldChannel(name)
136 198
 	if err != nil {
137 199
 		return errNoSuchChannel
@@ -142,6 +204,17 @@ func (cm *ChannelManager) Rename(name string, newname string) error {
142 204
 		return errInvalidChannelName
143 205
 	}
144 206
 
207
+	var channel *Channel
208
+	var info RegisteredChannel
209
+	defer func() {
210
+		if channel != nil && info.Founder != "" {
211
+			channel.Store(IncludeAllChannelAttrs)
212
+			// we just flushed the channel under its new name, therefore this delete
213
+			// cannot be overwritten by a write to the old name:
214
+			cm.server.channelRegistry.Delete(info)
215
+		}
216
+	}()
217
+
145 218
 	cm.Lock()
146 219
 	defer cm.Unlock()
147 220
 
@@ -152,12 +225,12 @@ func (cm *ChannelManager) Rename(name string, newname string) error {
152 225
 	if entry == nil {
153 226
 		return errNoSuchChannel
154 227
 	}
228
+	channel = entry.channel
229
+	info = channel.ExportRegistration(IncludeInitial)
155 230
 	delete(cm.chans, cfname)
156 231
 	cm.chans[cfnewname] = entry
157
-	entry.channel.setName(newname)
158
-	entry.channel.setNameCasefolded(cfnewname)
232
+	entry.channel.Rename(newname, cfnewname)
159 233
 	return nil
160
-
161 234
 }
162 235
 
163 236
 // Len returns the number of channels
@@ -171,8 +244,11 @@ func (cm *ChannelManager) Len() int {
171 244
 func (cm *ChannelManager) Channels() (result []*Channel) {
172 245
 	cm.RLock()
173 246
 	defer cm.RUnlock()
247
+	result = make([]*Channel, 0, len(cm.chans))
174 248
 	for _, entry := range cm.chans {
175
-		result = append(result, entry.channel)
249
+		if entry.channel.IsLoaded() {
250
+			result = append(result, entry.channel)
251
+		}
176 252
 	}
177 253
 	return
178 254
 }

+ 43
- 59
irc/channelreg.go View File

@@ -7,7 +7,6 @@ import (
7 7
 	"fmt"
8 8
 	"strconv"
9 9
 	"strings"
10
-	"sync"
11 10
 	"time"
12 11
 
13 12
 	"encoding/json"
@@ -71,6 +70,8 @@ const (
71 70
 type RegisteredChannel struct {
72 71
 	// Name of the channel.
73 72
 	Name string
73
+	// Casefolded name of the channel.
74
+	NameCasefolded string
74 75
 	// RegisteredAt represents the time that the channel was registered.
75 76
 	RegisteredAt time.Time
76 77
 	// Founder indicates the founder of the channel.
@@ -97,58 +98,65 @@ type RegisteredChannel struct {
97 98
 
98 99
 // ChannelRegistry manages registered channels.
99 100
 type ChannelRegistry struct {
100
-	// This serializes operations of the form (read channel state, synchronously persist it);
101
-	// this is enough to guarantee eventual consistency of the database with the
102
-	// ChannelManager and Channel objects, which are the source of truth.
103
-	//
104
-	// We could use the buntdb RW transaction lock for this purpose but we share
105
-	// that with all the other modules, so let's not.
106
-	sync.Mutex // tier 2
107
-	server     *Server
101
+	server *Server
108 102
 }
109 103
 
110 104
 // NewChannelRegistry returns a new ChannelRegistry.
111
-func NewChannelRegistry(server *Server) *ChannelRegistry {
112
-	return &ChannelRegistry{
113
-		server: server,
114
-	}
105
+func (reg *ChannelRegistry) Initialize(server *Server) {
106
+	reg.server = server
107
+}
108
+
109
+func (reg *ChannelRegistry) AllChannels() (result map[string]bool) {
110
+	result = make(map[string]bool)
111
+
112
+	prefix := fmt.Sprintf(keyChannelExists, "")
113
+	reg.server.store.View(func(tx *buntdb.Tx) error {
114
+		return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool {
115
+			if !strings.HasPrefix(key, prefix) {
116
+				return false
117
+			}
118
+			channel := strings.TrimPrefix(key, prefix)
119
+			result[channel] = true
120
+			return true
121
+		})
122
+	})
123
+
124
+	return
115 125
 }
116 126
 
117 127
 // StoreChannel obtains a consistent view of a channel, then persists it to the store.
118
-func (reg *ChannelRegistry) StoreChannel(channel *Channel, includeFlags uint) {
128
+func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) {
119 129
 	if !reg.server.ChannelRegistrationEnabled() {
120 130
 		return
121 131
 	}
122 132
 
123
-	reg.Lock()
124
-	defer reg.Unlock()
125
-
126
-	key := channel.NameCasefolded()
127
-	info := channel.ExportRegistration(includeFlags)
128 133
 	if info.Founder == "" {
129 134
 		// sanity check, don't try to store an unregistered channel
130 135
 		return
131 136
 	}
132 137
 
133 138
 	reg.server.store.Update(func(tx *buntdb.Tx) error {
134
-		reg.saveChannel(tx, key, info, includeFlags)
139
+		reg.saveChannel(tx, info, includeFlags)
135 140
 		return nil
136 141
 	})
142
+
143
+	return nil
137 144
 }
138 145
 
139 146
 // LoadChannel loads a channel from the store.
140
-func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *RegisteredChannel) {
147
+func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) {
141 148
 	if !reg.server.ChannelRegistrationEnabled() {
142
-		return nil
149
+		err = errFeatureDisabled
150
+		return
143 151
 	}
144 152
 
145 153
 	channelKey := nameCasefolded
146 154
 	// nice to have: do all JSON (de)serialization outside of the buntdb transaction
147
-	reg.server.store.View(func(tx *buntdb.Tx) error {
148
-		_, err := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
149
-		if err == buntdb.ErrNotFound {
155
+	err = reg.server.store.View(func(tx *buntdb.Tx) error {
156
+		_, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey))
157
+		if dberr == buntdb.ErrNotFound {
150 158
 			// chan does not already exist, return
151
-			return nil
159
+			return errNoSuchChannel
152 160
 		}
153 161
 
154 162
 		// channel exists, load it
@@ -181,7 +189,7 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
181 189
 		accountToUMode := make(map[string]modes.Mode)
182 190
 		_ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode)
183 191
 
184
-		info = &RegisteredChannel{
192
+		info = RegisteredChannel{
185 193
 			Name:           name,
186 194
 			RegisteredAt:   time.Unix(regTimeInt, 0),
187 195
 			Founder:        founder,
@@ -198,46 +206,21 @@ func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info *Registered
198 206
 		return nil
199 207
 	})
200 208
 
201
-	return info
202
-}
203
-
204
-func (reg *ChannelRegistry) Delete(casefoldedName string, info RegisteredChannel) {
205
-	if !reg.server.ChannelRegistrationEnabled() {
206
-		return
207
-	}
208
-
209
-	reg.Lock()
210
-	defer reg.Unlock()
211
-
212
-	reg.server.store.Update(func(tx *buntdb.Tx) error {
213
-		reg.deleteChannel(tx, casefoldedName, info)
214
-		return nil
215
-	})
209
+	return
216 210
 }
217 211
 
218
-// Rename handles the persistence part of a channel rename: the channel is
219
-// persisted under its new name, and the old name is cleaned up if necessary.
220
-func (reg *ChannelRegistry) Rename(channel *Channel, casefoldedOldName string) {
212
+// Delete deletes a channel corresponding to `info`. If no such channel
213
+// is present in the database, no error is returned.
214
+func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) {
221 215
 	if !reg.server.ChannelRegistrationEnabled() {
222 216
 		return
223 217
 	}
224 218
 
225
-	reg.Lock()
226
-	defer reg.Unlock()
227
-
228
-	includeFlags := IncludeAllChannelAttrs
229
-	oldKey := casefoldedOldName
230
-	key := channel.NameCasefolded()
231
-	info := channel.ExportRegistration(includeFlags)
232
-	if info.Founder == "" {
233
-		return
234
-	}
235
-
236 219
 	reg.server.store.Update(func(tx *buntdb.Tx) error {
237
-		reg.deleteChannel(tx, oldKey, info)
238
-		reg.saveChannel(tx, key, info, includeFlags)
220
+		reg.deleteChannel(tx, info.NameCasefolded, info)
239 221
 		return nil
240 222
 	})
223
+	return nil
241 224
 }
242 225
 
243 226
 // delete a channel, unless it was overwritten by another registration of the same channel
@@ -274,7 +257,8 @@ func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info Regist
274 257
 }
275 258
 
276 259
 // saveChannel saves a channel to the store.
277
-func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelKey string, channelInfo RegisteredChannel, includeFlags uint) {
260
+func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) {
261
+	channelKey := channelInfo.NameCasefolded
278 262
 	// maintain the mapping of account -> registered channels
279 263
 	chanExistsKey := fmt.Sprintf(keyChannelExists, channelKey)
280 264
 	_, existsErr := tx.Get(chanExistsKey)

+ 2
- 6
irc/chanserv.go View File

@@ -232,15 +232,12 @@ func csRegisterHandler(server *Server, client *Client, command string, params []
232 232
 	}
233 233
 
234 234
 	// this provides the synchronization that allows exactly one registration of the channel:
235
-	err = channelInfo.SetRegistered(account)
235
+	err = server.channels.SetRegistered(channelKey, account)
236 236
 	if err != nil {
237 237
 		csNotice(rb, err.Error())
238 238
 		return
239 239
 	}
240 240
 
241
-	// registration was successful: make the database reflect it
242
-	go server.channelRegistry.StoreChannel(channelInfo, IncludeAllChannelAttrs)
243
-
244 241
 	csNotice(rb, fmt.Sprintf(client.t("Channel %s successfully registered"), channelName))
245 242
 
246 243
 	server.logger.Info("services", fmt.Sprintf("Client %s registered channel %s", client.nick, channelName))
@@ -297,8 +294,7 @@ func csUnregisterHandler(server *Server, client *Client, command string, params
297 294
 		return
298 295
 	}
299 296
 
300
-	channel.SetUnregistered(founder)
301
-	server.channelRegistry.Delete(channelKey, info)
297
+	server.channels.SetUnregistered(channelKey, founder)
302 298
 	csNotice(rb, fmt.Sprintf(client.t("Channel %s is now unregistered"), channelKey))
303 299
 }
304 300
 

+ 10
- 13
irc/client.go View File

@@ -50,7 +50,7 @@ type Client struct {
50 50
 	accountName        string // display name of the account: uncasefolded, '*' if not logged in
51 51
 	atime              time.Time
52 52
 	awayMessage        string
53
-	capabilities       *caps.Set
53
+	capabilities       caps.Set
54 54
 	capState           caps.State
55 55
 	capVersion         caps.Version
56 56
 	certfp             string
@@ -58,7 +58,7 @@ type Client struct {
58 58
 	ctime              time.Time
59 59
 	exitedSnomaskSent  bool
60 60
 	fakelag            Fakelag
61
-	flags              *modes.ModeSet
61
+	flags              modes.ModeSet
62 62
 	hasQuit            bool
63 63
 	hops               int
64 64
 	hostname           string
@@ -125,15 +125,13 @@ func RunNewClient(server *Server, conn clientConn) {
125 125
 	// give them 1k of grace over the limit:
126 126
 	socket := NewSocket(conn.Conn, fullLineLenLimit+1024, config.Server.MaxSendQBytes)
127 127
 	client := &Client{
128
-		atime:        now,
129
-		capabilities: caps.NewSet(),
130
-		capState:     caps.NoneState,
131
-		capVersion:   caps.Cap301,
132
-		channels:     make(ChannelSet),
133
-		ctime:        now,
134
-		flags:        modes.NewModeSet(),
135
-		isTor:        conn.IsTor,
136
-		languages:    server.Languages().Default(),
128
+		atime:      now,
129
+		capState:   caps.NoneState,
130
+		capVersion: caps.Cap301,
131
+		channels:   make(ChannelSet),
132
+		ctime:      now,
133
+		isTor:      conn.IsTor,
134
+		languages:  server.Languages().Default(),
137 135
 		loginThrottle: connection_limits.GenericThrottle{
138 136
 			Duration: config.Accounts.LoginThrottling.Duration,
139 137
 			Limit:    config.Accounts.LoginThrottling.MaxAttempts,
@@ -546,7 +544,6 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I
546 544
 // copy applicable state from oldClient to client as part of a resume
547 545
 func (client *Client) copyResumeData(oldClient *Client) {
548 546
 	oldClient.stateMutex.RLock()
549
-	flags := oldClient.flags
550 547
 	history := oldClient.history
551 548
 	nick := oldClient.nick
552 549
 	nickCasefolded := oldClient.nickCasefolded
@@ -560,7 +557,7 @@ func (client *Client) copyResumeData(oldClient *Client) {
560 557
 	// resume over plaintext)
561 558
 	hasTLS := client.flags.HasMode(modes.TLS)
562 559
 	temp := modes.NewModeSet()
563
-	temp.Copy(flags)
560
+	temp.Copy(&oldClient.flags)
564 561
 	temp.SetMode(modes.TLS, hasTLS)
565 562
 	client.flags.Copy(temp)
566 563
 

+ 4
- 6
irc/client_lookup_set.go View File

@@ -37,12 +37,10 @@ type ClientManager struct {
37 37
 	bySkeleton   map[string]*Client
38 38
 }
39 39
 
40
-// NewClientManager returns a new ClientManager.
41
-func NewClientManager() *ClientManager {
42
-	return &ClientManager{
43
-		byNick:     make(map[string]*Client),
44
-		bySkeleton: make(map[string]*Client),
45
-	}
40
+// Initialize initializes a ClientManager.
41
+func (clients *ClientManager) Initialize() {
42
+	clients.byNick = make(map[string]*Client)
43
+	clients.bySkeleton = make(map[string]*Client)
46 44
 }
47 45
 
48 46
 // Count returns how many clients are in the manager.

+ 2
- 0
irc/errors.go View File

@@ -27,6 +27,8 @@ var (
27 27
 	errAccountMustHoldNick            = errors.New(`You must hold that nickname in order to register it`)
28 28
 	errCallbackFailed                 = errors.New("Account verification could not be sent")
29 29
 	errCertfpAlreadyExists            = errors.New(`An account already exists for your certificate fingerprint`)
30
+	errChannelNotOwnedByAccount       = errors.New("Channel not owned by the specified account")
31
+	errChannelDoesNotExist            = errors.New("Channel does not exist")
30 32
 	errChannelAlreadyRegistered       = errors.New("Channel is already registered")
31 33
 	errChannelNameInUse               = errors.New(`Channel name in use`)
32 34
 	errInvalidChannelName             = errors.New(`Invalid channel name`)

+ 15
- 8
irc/getters.go View File

@@ -4,6 +4,8 @@
4 4
 package irc
5 5
 
6 6
 import (
7
+	"time"
8
+
7 9
 	"github.com/oragono/oragono/irc/isupport"
8 10
 	"github.com/oragono/oragono/irc/languages"
9 11
 	"github.com/oragono/oragono/irc/modes"
@@ -267,22 +269,20 @@ func (channel *Channel) Name() string {
267 269
 	return channel.name
268 270
 }
269 271
 
270
-func (channel *Channel) setName(name string) {
271
-	channel.stateMutex.Lock()
272
-	defer channel.stateMutex.Unlock()
273
-	channel.name = name
274
-}
275
-
276 272
 func (channel *Channel) NameCasefolded() string {
277 273
 	channel.stateMutex.RLock()
278 274
 	defer channel.stateMutex.RUnlock()
279 275
 	return channel.nameCasefolded
280 276
 }
281 277
 
282
-func (channel *Channel) setNameCasefolded(nameCasefolded string) {
278
+func (channel *Channel) Rename(name, nameCasefolded string) {
283 279
 	channel.stateMutex.Lock()
284
-	defer channel.stateMutex.Unlock()
280
+	channel.name = name
285 281
 	channel.nameCasefolded = nameCasefolded
282
+	if channel.registeredFounder != "" {
283
+		channel.registeredTime = time.Now()
284
+	}
285
+	channel.stateMutex.Unlock()
286 286
 }
287 287
 
288 288
 func (channel *Channel) Members() (result []*Client) {
@@ -314,3 +314,10 @@ func (channel *Channel) Founder() string {
314 314
 	defer channel.stateMutex.RUnlock()
315 315
 	return channel.registeredFounder
316 316
 }
317
+
318
+func (channel *Channel) DirtyBits() (dirtyBits uint) {
319
+	channel.stateMutex.Lock()
320
+	dirtyBits = channel.dirtyBits
321
+	channel.stateMutex.Unlock()
322
+	return
323
+}

+ 2
- 6
irc/handlers.go View File

@@ -1607,8 +1607,8 @@ func cmodeHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res
1607 1607
 		}
1608 1608
 	}
1609 1609
 
1610
-	if channel.IsRegistered() && includeFlags != 0 {
1611
-		go server.channelRegistry.StoreChannel(channel, includeFlags)
1610
+	if includeFlags != 0 {
1611
+		channel.MarkDirty(includeFlags)
1612 1612
 	}
1613 1613
 
1614 1614
 	// send out changes
@@ -2215,7 +2215,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
2215 2215
 		rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), oldName, client.t("No such channel"))
2216 2216
 		return false
2217 2217
 	}
2218
-	casefoldedOldName := channel.NameCasefolded()
2219 2218
 	if !(channel.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("chanreg")) {
2220 2219
 		rb.Add(nil, server.name, ERR_CHANOPRIVSNEEDED, client.Nick(), oldName, client.t("You're not a channel operator"))
2221 2220
 		return false
@@ -2240,9 +2239,6 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Re
2240 2239
 		return false
2241 2240
 	}
2242 2241
 
2243
-	// rename succeeded, persist it
2244
-	go server.channelRegistry.Rename(channel, casefoldedOldName)
2245
-
2246 2242
 	// send RENAME messages
2247 2243
 	clientPrefix := client.NickMaskString()
2248 2244
 	for _, mcl := range channel.Members() {

+ 2
- 2
irc/modes.go View File

@@ -290,14 +290,14 @@ func (channel *Channel) ProcessAccountToUmodeChange(client *Client, change modes
290 290
 	case modes.Add:
291 291
 		if targetModeNow != targetModeAfter {
292 292
 			channel.accountToUMode[change.Arg] = change.Mode
293
-			go client.server.channelRegistry.StoreChannel(channel, IncludeLists)
293
+			channel.MarkDirty(IncludeLists)
294 294
 			return []modes.ModeChange{change}, nil
295 295
 		}
296 296
 		return nil, nil
297 297
 	case modes.Remove:
298 298
 		if targetModeNow == change.Mode {
299 299
 			delete(channel.accountToUMode, change.Arg)
300
-			go client.server.channelRegistry.StoreChannel(channel, IncludeLists)
300
+			channel.MarkDirty(IncludeLists)
301 301
 			return []modes.ModeChange{change}, nil
302 302
 		}
303 303
 		return nil, nil

+ 0
- 1
irc/modes/modes.go View File

@@ -335,7 +335,6 @@ const (
335 335
 // returns a pointer to a new ModeSet
336 336
 func NewModeSet() *ModeSet {
337 337
 	var set ModeSet
338
-	utils.BitsetInitialize(set[:])
339 338
 	return &set
340 339
 }
341 340
 

+ 3
- 4
irc/semaphores.go View File

@@ -32,14 +32,13 @@ type ServerSemaphores struct {
32 32
 	ClientDestroy Semaphore
33 33
 }
34 34
 
35
-// NewServerSemaphores creates a new ServerSemaphores.
36
-func NewServerSemaphores() (result *ServerSemaphores) {
35
+// Initialize initializes a set of server semaphores.
36
+func (serversem *ServerSemaphores) Initialize() {
37 37
 	capacity := runtime.NumCPU()
38 38
 	if capacity > MaxServerSemaphoreCapacity {
39 39
 		capacity = MaxServerSemaphoreCapacity
40 40
 	}
41
-	result = new(ServerSemaphores)
42
-	result.ClientDestroy.Initialize(capacity)
41
+	serversem.ClientDestroy.Initialize(capacity)
43 42
 	return
44 43
 }
45 44
 

+ 19
- 15
irc/server.go View File

@@ -61,10 +61,10 @@ type ListenerWrapper struct {
61 61
 
62 62
 // Server is the main Oragono server.
63 63
 type Server struct {
64
-	accounts               *AccountManager
65
-	channels               *ChannelManager
66
-	channelRegistry        *ChannelRegistry
67
-	clients                *ClientManager
64
+	accounts               AccountManager
65
+	channels               ChannelManager
66
+	channelRegistry        ChannelRegistry
67
+	clients                ClientManager
68 68
 	config                 *Config
69 69
 	configFilename         string
70 70
 	configurableStateMutex sync.RWMutex // tier 1; generic protection for server state modified by rehash()
@@ -89,9 +89,9 @@ type Server struct {
89 89
 	snomasks               *SnoManager
90 90
 	store                  *buntdb.DB
91 91
 	torLimiter             connection_limits.TorLimiter
92
-	whoWas                 *WhoWasList
93
-	stats                  *Stats
94
-	semaphores             *ServerSemaphores
92
+	whoWas                 WhoWasList
93
+	stats                  Stats
94
+	semaphores             ServerSemaphores
95 95
 }
96 96
 
97 97
 var (
@@ -113,8 +113,6 @@ type clientConn struct {
113 113
 func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
114 114
 	// initialize data structures
115 115
 	server := &Server{
116
-		channels:            NewChannelManager(),
117
-		clients:             NewClientManager(),
118 116
 		connectionLimiter:   connection_limits.NewLimiter(),
119 117
 		connectionThrottler: connection_limits.NewThrottler(),
120 118
 		listeners:           make(map[string]*ListenerWrapper),
@@ -123,12 +121,12 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
123 121
 		rehashSignal:        make(chan os.Signal, 1),
124 122
 		signals:             make(chan os.Signal, len(ServerExitSignals)),
125 123
 		snomasks:            NewSnoManager(),
126
-		whoWas:              NewWhoWasList(config.Limits.WhowasEntries),
127
-		stats:               NewStats(),
128
-		semaphores:          NewServerSemaphores(),
129 124
 	}
130 125
 
126
+	server.clients.Initialize()
127
+	server.semaphores.Initialize()
131 128
 	server.resumeManager.Initialize(server)
129
+	server.whoWas.Initialize(config.Limits.WhowasEntries)
132 130
 
133 131
 	if err := server.applyConfig(config, true); err != nil {
134 132
 		return nil, err
@@ -697,6 +695,12 @@ func (server *Server) applyConfig(config *Config, initial bool) (err error) {
697 695
 		server.accounts.initVHostRequestQueue()
698 696
 	}
699 697
 
698
+	chanRegPreviouslyDisabled := oldConfig != nil && !oldConfig.Channels.Registration.Enabled
699
+	chanRegNowEnabled := config.Channels.Registration.Enabled
700
+	if chanRegPreviouslyDisabled && chanRegNowEnabled {
701
+		server.channels.loadRegisteredChannels()
702
+	}
703
+
700 704
 	// MaxLine
701 705
 	if config.Limits.LineLen.Rest != 512 {
702 706
 		SupportedCapabilities.Enable(caps.MaxLine)
@@ -922,9 +926,9 @@ func (server *Server) loadDatastore(config *Config) error {
922 926
 	server.loadDLines()
923 927
 	server.loadKLines()
924 928
 
925
-	server.channelRegistry = NewChannelRegistry(server)
926
-
927
-	server.accounts = NewAccountManager(server)
929
+	server.channelRegistry.Initialize(server)
930
+	server.channels.Initialize(server)
931
+	server.accounts.Initialize(server)
928 932
 
929 933
 	return nil
930 934
 }

+ 0
- 11
irc/stats.go View File

@@ -13,17 +13,6 @@ type Stats struct {
13 13
 	Operators int
14 14
 }
15 15
 
16
-// NewStats creates a new instance of Stats
17
-func NewStats() *Stats {
18
-	serverStats := &Stats{
19
-		Total:     0,
20
-		Invisible: 0,
21
-		Operators: 0,
22
-	}
23
-
24
-	return serverStats
25
-}
26
-
27 16
 // ChangeTotal increments the total user count on server
28 17
 func (s *Stats) ChangeTotal(i int) {
29 18
 	s.Lock()

+ 0
- 11
irc/utils/bitset.go View File

@@ -9,17 +9,6 @@ import "sync/atomic"
9 9
 // For examples of use, see caps.Set and modes.ModeSet; the array has to be converted to a
10 10
 // slice to use these functions.
11 11
 
12
-// BitsetInitialize initializes a bitset.
13
-func BitsetInitialize(set []uint64) {
14
-	// XXX re-zero the bitset using atomic stores. it's unclear whether this is required,
15
-	// however, golang issue #5045 suggests that you shouldn't mix atomic operations
16
-	// with non-atomic operations (such as the runtime's automatic zero-initialization) on
17
-	// the same word
18
-	for i := 0; i < len(set); i++ {
19
-		atomic.StoreUint64(&set[i], 0)
20
-	}
21
-}
22
-
23 12
 // BitsetGet returns whether a given bit of the bitset is set.
24 13
 func BitsetGet(set []uint64, position uint) bool {
25 14
 	idx := position / 64

+ 0
- 2
irc/utils/bitset_test.go View File

@@ -10,7 +10,6 @@ type testBitset [2]uint64
10 10
 func TestSets(t *testing.T) {
11 11
 	var t1 testBitset
12 12
 	t1s := t1[:]
13
-	BitsetInitialize(t1s)
14 13
 
15 14
 	if BitsetGet(t1s, 0) || BitsetGet(t1s, 63) || BitsetGet(t1s, 64) || BitsetGet(t1s, 127) {
16 15
 		t.Error("no bits should be set in a newly initialized bitset")
@@ -47,7 +46,6 @@ func TestSets(t *testing.T) {
47 46
 
48 47
 	var t2 testBitset
49 48
 	t2s := t2[:]
50
-	BitsetInitialize(t2s)
51 49
 
52 50
 	for i = 0; i < 128; i++ {
53 51
 		if i%2 == 1 {

+ 35
- 0
irc/utils/sync.go View File

@@ -0,0 +1,35 @@
1
+// Copyright 2009 The Go Authors. All rights reserved.
2
+// Use of this source code is governed by a BSD-style
3
+// license that can be found in the LICENSE file.
4
+
5
+package utils
6
+
7
+import (
8
+	"sync"
9
+	"sync/atomic"
10
+)
11
+
12
+// Once is a fork of sync.Once to expose a Done() method.
13
+type Once struct {
14
+	done uint32
15
+	m    sync.Mutex
16
+}
17
+
18
+func (o *Once) Do(f func()) {
19
+	if atomic.LoadUint32(&o.done) == 0 {
20
+		o.doSlow(f)
21
+	}
22
+}
23
+
24
+func (o *Once) doSlow(f func()) {
25
+	o.m.Lock()
26
+	defer o.m.Unlock()
27
+	if o.done == 0 {
28
+		defer atomic.StoreUint32(&o.done, 1)
29
+		f()
30
+	}
31
+}
32
+
33
+func (o *Once) Done() bool {
34
+	return atomic.LoadUint32(&o.done) == 1
35
+}

+ 4
- 6
irc/whowas.go View File

@@ -23,12 +23,10 @@ type WhoWasList struct {
23 23
 }
24 24
 
25 25
 // NewWhoWasList returns a new WhoWasList
26
-func NewWhoWasList(size int) *WhoWasList {
27
-	return &WhoWasList{
28
-		buffer: make([]WhoWas, size),
29
-		start:  -1,
30
-		end:    -1,
31
-	}
26
+func (list *WhoWasList) Initialize(size int) {
27
+	list.buffer = make([]WhoWas, size)
28
+	list.start = -1
29
+	list.end = -1
32 30
 }
33 31
 
34 32
 // Append adds an entry to the WhoWasList.

+ 4
- 2
irc/whowas_test.go View File

@@ -23,7 +23,8 @@ func makeTestWhowas(nick string) WhoWas {
23 23
 
24 24
 func TestWhoWas(t *testing.T) {
25 25
 	var results []WhoWas
26
-	wwl := NewWhoWasList(3)
26
+	var wwl WhoWasList
27
+	wwl.Initialize(3)
27 28
 	// test Find on empty list
28 29
 	results = wwl.Find("nobody", 10)
29 30
 	if len(results) != 0 {
@@ -88,7 +89,8 @@ func TestWhoWas(t *testing.T) {
88 89
 
89 90
 func TestEmptyWhoWas(t *testing.T) {
90 91
 	// stupid edge case; setting an empty whowas buffer should not panic
91
-	wwl := NewWhoWasList(0)
92
+	var wwl WhoWasList
93
+	wwl.Initialize(0)
92 94
 	results := wwl.Find("slingamn", 10)
93 95
 	if len(results) != 0 {
94 96
 		t.Fatalf("incorrect whowas results: %v", results)

Loading…
Cancel
Save