瀏覽代碼

caps: Move most capability-handling types into the caps package

tags/v0.9.2-alpha
Daniel Oaks 6 年之前
父節點
當前提交
275449e6cc
共有 10 個檔案被更改,包括 259 行新增138 行删除
  1. 12
    60
      irc/capability.go
  2. 2
    1
      irc/caps/constants.go
  3. 115
    0
      irc/caps/set.go
  4. 42
    0
      irc/caps/values.go
  5. 11
    0
      irc/caps/version.go
  6. 10
    10
      irc/channel.go
  7. 16
    16
      irc/client.go
  8. 3
    3
      irc/client_lookup_set.go
  9. 2
    2
      irc/roleplay.go
  10. 46
    46
      irc/server.go

+ 12
- 60
irc/capability.go 查看文件

@@ -13,28 +13,12 @@ import (
13 13
 
14 14
 var (
15 15
 	// SupportedCapabilities are the caps we advertise.
16
-	SupportedCapabilities = CapabilitySet{
17
-		caps.AccountTag:    true,
18
-		caps.AccountNotify: true,
19
-		caps.AwayNotify:    true,
20
-		caps.CapNotify:     true,
21
-		caps.ChgHost:       true,
22
-		caps.EchoMessage:   true,
23
-		caps.ExtendedJoin:  true,
24
-		caps.InviteNotify:  true,
25
-		// MaxLine is set during server startup
26
-		caps.MessageTags: true,
27
-		caps.MultiPrefix: true,
28
-		caps.Rename:      true,
29
-		// SASL is set during server startup
30
-		caps.ServerTime: true,
31
-		// STS is set during server startup
32
-		caps.UserhostInNames: true,
33
-	}
16
+	// MaxLine, SASL and STS are set during server startup.
17
+	SupportedCapabilities = caps.NewSet(caps.AccountTag, caps.AccountNotify, caps.AwayNotify, caps.CapNotify, caps.ChgHost, caps.EchoMessage, caps.ExtendedJoin, caps.InviteNotify, caps.MessageTags, caps.MultiPrefix, caps.Rename, caps.ServerTime, caps.UserhostInNames)
18
+
34 19
 	// CapValues are the actual values we advertise to v3.2 clients.
35
-	CapValues = map[caps.Capability]string{
36
-		caps.SASL: "PLAIN,EXTERNAL",
37
-	}
20
+	// actual values are set during server startup.
21
+	CapValues = caps.NewValues()
38 22
 )
39 23
 
40 24
 // CapState shows whether we're negotiating caps, finished, etc for connection registration.
@@ -49,40 +33,10 @@ const (
49 33
 	CapNegotiated CapState = iota
50 34
 )
51 35
 
52
-// CapVersion is used to select which max version of CAP the client supports.
53
-type CapVersion uint
54
-
55
-const (
56
-	// Cap301 refers to the base CAP spec.
57
-	Cap301 CapVersion = 301
58
-	// Cap302 refers to the IRCv3.2 CAP spec.
59
-	Cap302 CapVersion = 302
60
-)
61
-
62
-// CapabilitySet is used to track supported, enabled, and existing caps.
63
-type CapabilitySet map[caps.Capability]bool
64
-
65
-func (set CapabilitySet) String(version CapVersion) string {
66
-	strs := make([]string, len(set))
67
-	index := 0
68
-	for capability := range set {
69
-		capString := string(capability)
70
-		if version == Cap302 {
71
-			val, exists := CapValues[capability]
72
-			if exists {
73
-				capString += "=" + val
74
-			}
75
-		}
76
-		strs[index] = capString
77
-		index++
78
-	}
79
-	return strings.Join(strs, " ")
80
-}
81
-
82 36
 // CAP <subcmd> [<caps>]
83 37
 func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
84 38
 	subCommand := strings.ToUpper(msg.Params[0])
85
-	capabilities := make(CapabilitySet)
39
+	capabilities := caps.NewSet()
86 40
 	var capString string
87 41
 
88 42
 	if len(msg.Params) > 1 {
@@ -90,7 +44,7 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
90 44
 		strs := strings.Split(capString, " ")
91 45
 		for _, str := range strs {
92 46
 			if len(str) > 0 {
93
-				capabilities[caps.Capability(str)] = true
47
+				capabilities.Enable(caps.Capability(str))
94 48
 			}
95 49
 		}
96 50
 	}
@@ -107,22 +61,20 @@ func capHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
107 61
 		// the server.name source... otherwise it doesn't respond to the CAP message with
108 62
 		// anything and just hangs on connection.
109 63
 		//TODO(dan): limit number of caps and send it multiline in 3.2 style as appropriate.
110
-		client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion))
64
+		client.Send(nil, server.name, "CAP", client.nick, subCommand, SupportedCapabilities.String(client.capVersion, CapValues))
111 65
 
112 66
 	case "LIST":
113
-		client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(Cap301)) // values not sent on LIST so force 3.1
67
+		client.Send(nil, server.name, "CAP", client.nick, subCommand, client.capabilities.String(caps.Cap301, CapValues)) // values not sent on LIST so force 3.1
114 68
 
115 69
 	case "REQ":
116 70
 		// make sure all capabilities actually exist
117
-		for capability := range capabilities {
118
-			if !SupportedCapabilities[capability] {
71
+		for _, capability := range capabilities.List() {
72
+			if !SupportedCapabilities.Has(capability) {
119 73
 				client.Send(nil, server.name, "CAP", client.nick, "NAK", capString)
120 74
 				return false
121 75
 			}
122 76
 		}
123
-		for capability := range capabilities {
124
-			client.capabilities[capability] = true
125
-		}
77
+		client.capabilities.Enable(capabilities.List()...)
126 78
 		client.Send(nil, server.name, "CAP", client.nick, "ACK", capString)
127 79
 
128 80
 	case "END":

+ 2
- 1
irc/caps/constants.go 查看文件

@@ -45,6 +45,7 @@ const (
45 45
 	UserhostInNames Capability = "userhost-in-names"
46 46
 )
47 47
 
48
-func (capability Capability) String() string {
48
+// Name returns the name of the given capability.
49
+func (capability Capability) Name() string {
49 50
 	return string(capability)
50 51
 }

+ 115
- 0
irc/caps/set.go 查看文件

@@ -0,0 +1,115 @@
1
+// Package caps holds capabilities.
2
+package caps
3
+
4
+import (
5
+	"sort"
6
+	"strings"
7
+	"sync"
8
+)
9
+
10
+// Set holds a set of enabled capabilities.
11
+type Set struct {
12
+	sync.RWMutex
13
+	// capabilities holds the capabilities this manager has.
14
+	capabilities map[Capability]bool
15
+}
16
+
17
+// NewSet returns a new Set, with the given capabilities enabled.
18
+func NewSet(capabs ...Capability) *Set {
19
+	newSet := Set{
20
+		capabilities: make(map[Capability]bool),
21
+	}
22
+	newSet.Enable(capabs...)
23
+
24
+	return &newSet
25
+}
26
+
27
+// Enable enables the given capabilities.
28
+func (s *Set) Enable(capabs ...Capability) {
29
+	s.Lock()
30
+	defer s.Unlock()
31
+
32
+	for _, capab := range capabs {
33
+		s.capabilities[capab] = true
34
+	}
35
+}
36
+
37
+// Disable disables the given capabilities.
38
+func (s *Set) Disable(capabs ...Capability) {
39
+	s.Lock()
40
+	defer s.Unlock()
41
+
42
+	for _, capab := range capabs {
43
+		delete(s.capabilities, capab)
44
+	}
45
+}
46
+
47
+// Add adds the given capabilities to this set.
48
+// this is just a wrapper to allow more clear use.
49
+func (s *Set) Add(capabs ...Capability) {
50
+	s.Enable(capabs...)
51
+}
52
+
53
+// Remove removes the given capabilities from this set.
54
+// this is just a wrapper to allow more clear use.
55
+func (s *Set) Remove(capabs ...Capability) {
56
+	s.Disable(capabs...)
57
+}
58
+
59
+// Has returns true if this set has the given capabilities.
60
+func (s *Set) Has(caps ...Capability) bool {
61
+	s.RLock()
62
+	defer s.RUnlock()
63
+
64
+	for _, cap := range caps {
65
+		if !s.capabilities[cap] {
66
+			return false
67
+		}
68
+	}
69
+	return true
70
+}
71
+
72
+// List return a list of our enabled capabilities.
73
+func (s *Set) List() []Capability {
74
+	s.RLock()
75
+	defer s.RUnlock()
76
+
77
+	var allCaps []Capability
78
+	for capab := range s.capabilities {
79
+		allCaps = append(allCaps, capab)
80
+	}
81
+
82
+	return allCaps
83
+}
84
+
85
+// Count returns how many enabled caps this set has.
86
+func (s *Set) Count() int {
87
+	s.RLock()
88
+	defer s.RUnlock()
89
+
90
+	return len(s.capabilities)
91
+}
92
+
93
+// String returns all of our enabled capabilities as a string.
94
+func (s *Set) String(version Version, values *Values) string {
95
+	s.RLock()
96
+	defer s.RUnlock()
97
+
98
+	var strs sort.StringSlice
99
+
100
+	for capability := range s.capabilities {
101
+		capString := capability.Name()
102
+		if version == Cap302 {
103
+			val, exists := values.Get(capability)
104
+			if exists {
105
+				capString += "=" + val
106
+			}
107
+		}
108
+		strs = append(strs, capString)
109
+	}
110
+
111
+	// sort the cap string before we send it out
112
+	sort.Sort(strs)
113
+
114
+	return strings.Join(strs, " ")
115
+}

+ 42
- 0
irc/caps/values.go 查看文件

@@ -0,0 +1,42 @@
1
+package caps
2
+
3
+import "sync"
4
+
5
+// Values holds capability values.
6
+type Values struct {
7
+	sync.RWMutex
8
+	// values holds our actual capability values.
9
+	values map[Capability]string
10
+}
11
+
12
+// NewValues returns a new Values.
13
+func NewValues() *Values {
14
+	return &Values{
15
+		values: make(map[Capability]string),
16
+	}
17
+}
18
+
19
+// Set sets the value for the given capability.
20
+func (v *Values) Set(capab Capability, value string) {
21
+	v.Lock()
22
+	defer v.Unlock()
23
+
24
+	v.values[capab] = value
25
+}
26
+
27
+// Unset removes the value for the given capability, if it exists.
28
+func (v *Values) Unset(capab Capability) {
29
+	v.Lock()
30
+	defer v.Unlock()
31
+
32
+	delete(v.values, capab)
33
+}
34
+
35
+// Get returns the value of the given capability, and whether one exists.
36
+func (v *Values) Get(capab Capability) (string, bool) {
37
+	v.RLock()
38
+	defer v.RUnlock()
39
+
40
+	value, exists := v.values[capab]
41
+	return value, exists
42
+}

+ 11
- 0
irc/caps/version.go 查看文件

@@ -0,0 +1,11 @@
1
+package caps
2
+
3
+// Version is used to select which max version of CAP the client supports.
4
+type Version uint
5
+
6
+const (
7
+	// Cap301 refers to the base CAP spec.
8
+	Cap301 Version = 301
9
+	// Cap302 refers to the IRCv3.2 CAP spec.
10
+	Cap302 Version = 302
11
+)

+ 10
- 10
irc/channel.go 查看文件

@@ -165,8 +165,8 @@ func (modes ModeSet) Prefixes(isMultiPrefix bool) string {
165 165
 }
166 166
 
167 167
 func (channel *Channel) nicksNoMutex(target *Client) []string {
168
-	isMultiPrefix := (target != nil) && target.capabilities[caps.MultiPrefix]
169
-	isUserhostInNames := (target != nil) && target.capabilities[caps.UserhostInNames]
168
+	isMultiPrefix := (target != nil) && target.capabilities.Has(caps.MultiPrefix)
169
+	isUserhostInNames := (target != nil) && target.capabilities.Has(caps.UserhostInNames)
170 170
 	nicks := make([]string, len(channel.members))
171 171
 	i := 0
172 172
 	for client, modes := range channel.members {
@@ -262,7 +262,7 @@ func (channel *Channel) Join(client *Client, key string) {
262 262
 	client.server.logger.Debug("join", fmt.Sprintf("%s joined channel %s", client.nick, channel.name))
263 263
 
264 264
 	for member := range channel.members {
265
-		if member.capabilities[caps.ExtendedJoin] {
265
+		if member.capabilities.Has(caps.ExtendedJoin) {
266 266
 			member.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname)
267 267
 		} else {
268 268
 			member.Send(nil, client.nickMaskString, "JOIN", channel.name)
@@ -314,7 +314,7 @@ func (channel *Channel) Join(client *Client, key string) {
314 314
 		return nil
315 315
 	})
316 316
 
317
-	if client.capabilities[caps.ExtendedJoin] {
317
+	if client.capabilities.Has(caps.ExtendedJoin) {
318 318
 		client.Send(nil, client.nickMaskString, "JOIN", channel.name, client.account.Name, client.realname)
319 319
 	} else {
320 320
 		client.Send(nil, client.nickMaskString, "JOIN", channel.name)
@@ -465,13 +465,13 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab
465 465
 			// STATUSMSG
466 466
 			continue
467 467
 		}
468
-		if member == client && !client.capabilities[caps.EchoMessage] {
468
+		if member == client && !client.capabilities.Has(caps.EchoMessage) {
469 469
 			continue
470 470
 		}
471 471
 
472 472
 		canReceive := true
473 473
 		for _, capName := range requiredCaps {
474
-			if !member.capabilities[capName] {
474
+			if !member.capabilities.Has(capName) {
475 475
 				canReceive = false
476 476
 			}
477 477
 		}
@@ -480,7 +480,7 @@ func (channel *Channel) sendMessage(msgid, cmd string, requiredCaps []caps.Capab
480 480
 		}
481 481
 
482 482
 		var messageTagsToUse *map[string]ircmsg.TagValue
483
-		if member.capabilities[caps.MessageTags] {
483
+		if member.capabilities.Has(caps.MessageTags) {
484 484
 			messageTagsToUse = clientOnlyTags
485 485
 		}
486 486
 
@@ -521,11 +521,11 @@ func (channel *Channel) sendSplitMessage(msgid, cmd string, minPrefix *Mode, cli
521 521
 			// STATUSMSG
522 522
 			continue
523 523
 		}
524
-		if member == client && !client.capabilities[caps.EchoMessage] {
524
+		if member == client && !client.capabilities.Has(caps.EchoMessage) {
525 525
 			continue
526 526
 		}
527 527
 		var tagsToUse *map[string]ircmsg.TagValue
528
-		if member.capabilities[caps.MessageTags] {
528
+		if member.capabilities.Has(caps.MessageTags) {
529 529
 			tagsToUse = clientOnlyTags
530 530
 		}
531 531
 
@@ -729,7 +729,7 @@ func (channel *Channel) Invite(invitee *Client, inviter *Client) {
729 729
 
730 730
 	// send invite-notify
731 731
 	for member := range channel.members {
732
-		if member.capabilities[caps.InviteNotify] && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) {
732
+		if member.capabilities.Has(caps.InviteNotify) && member != inviter && member != invitee && channel.ClientIsAtLeast(member, Halfop) {
733 733
 			member.Send(nil, inviter.nickMaskString, "INVITE", invitee.nick, channel.name)
734 734
 		}
735 735
 	}

+ 16
- 16
irc/client.go 查看文件

@@ -45,9 +45,9 @@ type Client struct {
45 45
 	atime              time.Time
46 46
 	authorized         bool
47 47
 	awayMessage        string
48
-	capabilities       CapabilitySet
48
+	capabilities       *caps.Set
49 49
 	capState           CapState
50
-	capVersion         CapVersion
50
+	capVersion         caps.Version
51 51
 	certfp             string
52 52
 	channels           ChannelSet
53 53
 	class              *OperClass
@@ -95,9 +95,9 @@ func NewClient(server *Server, conn net.Conn, isTLS bool) *Client {
95 95
 	client := &Client{
96 96
 		atime:          now,
97 97
 		authorized:     server.password == nil,
98
-		capabilities:   make(CapabilitySet),
98
+		capabilities:   caps.NewSet(),
99 99
 		capState:       CapNone,
100
-		capVersion:     Cap301,
100
+		capVersion:     caps.Cap301,
101 101
 		channels:       make(ChannelSet),
102 102
 		ctime:          now,
103 103
 		flags:          make(map[Mode]bool),
@@ -178,10 +178,10 @@ func (client *Client) IPString() string {
178 178
 func (client *Client) maxlens() (int, int) {
179 179
 	maxlenTags := 512
180 180
 	maxlenRest := 512
181
-	if client.capabilities[caps.MessageTags] {
181
+	if client.capabilities.Has(caps.MessageTags) {
182 182
 		maxlenTags = 4096
183 183
 	}
184
-	if client.capabilities[caps.MaxLine] {
184
+	if client.capabilities.Has(caps.MaxLine) {
185 185
 		if client.server.limits.LineLen.Tags > maxlenTags {
186 186
 			maxlenTags = client.server.limits.LineLen.Tags
187 187
 		}
@@ -357,13 +357,13 @@ func (client *Client) ModeString() (str string) {
357 357
 }
358 358
 
359 359
 // Friends refers to clients that share a channel with this client.
360
-func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet {
360
+func (client *Client) Friends(capabs ...caps.Capability) ClientSet {
361 361
 	friends := make(ClientSet)
362 362
 
363 363
 	// make sure that I have the right caps
364 364
 	hasCaps := true
365
-	for _, Cap := range Capabilities {
366
-		if !client.capabilities[Cap] {
365
+	for _, capab := range capabs {
366
+		if !client.capabilities.Has(capab) {
367 367
 			hasCaps = false
368 368
 			break
369 369
 		}
@@ -377,8 +377,8 @@ func (client *Client) Friends(Capabilities ...caps.Capability) ClientSet {
377 377
 		for member := range channel.members {
378 378
 			// make sure they have all the required caps
379 379
 			hasCaps = true
380
-			for _, Cap := range Capabilities {
381
-				if !member.capabilities[Cap] {
380
+			for _, capab := range capabs {
381
+				if !member.capabilities.Has(capab) {
382 382
 					hasCaps = false
383 383
 					break
384 384
 				}
@@ -580,7 +580,7 @@ func (client *Client) destroy() {
580 580
 // SendSplitMsgFromClient sends an IRC PRIVMSG/NOTICE coming from a specific client.
581 581
 // Adds account-tag to the line as well.
582 582
 func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command, target string, message SplitMessage) {
583
-	if client.capabilities[caps.MaxLine] {
583
+	if client.capabilities.Has(caps.MaxLine) {
584 584
 		client.SendFromClient(msgid, from, tags, command, target, message.ForMaxLine)
585 585
 	} else {
586 586
 		for _, str := range message.For512 {
@@ -593,7 +593,7 @@ func (client *Client) SendSplitMsgFromClient(msgid string, from *Client, tags *m
593 593
 // Adds account-tag to the line as well.
594 594
 func (client *Client) SendFromClient(msgid string, from *Client, tags *map[string]ircmsg.TagValue, command string, params ...string) error {
595 595
 	// attach account-tag
596
-	if client.capabilities[caps.AccountTag] && from.account != &NoAccount {
596
+	if client.capabilities.Has(caps.AccountTag) && from.account != &NoAccount {
597 597
 		if tags == nil {
598 598
 			tags = ircmsg.MakeTags("account", from.account.Name)
599 599
 		} else {
@@ -601,7 +601,7 @@ func (client *Client) SendFromClient(msgid string, from *Client, tags *map[strin
601 601
 		}
602 602
 	}
603 603
 	// attach message-id
604
-	if len(msgid) > 0 && client.capabilities[caps.MessageTags] {
604
+	if len(msgid) > 0 && client.capabilities.Has(caps.MessageTags) {
605 605
 		if tags == nil {
606 606
 			tags = ircmsg.MakeTags("draft/msgid", msgid)
607 607
 		} else {
@@ -628,7 +628,7 @@ var (
628 628
 // Send sends an IRC line to the client.
629 629
 func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, command string, params ...string) error {
630 630
 	// attach server-time
631
-	if client.capabilities[caps.ServerTime] {
631
+	if client.capabilities.Has(caps.ServerTime) {
632 632
 		t := time.Now().UTC().Format("2006-01-02T15:04:05.999Z")
633 633
 		if tags == nil {
634 634
 			tags = ircmsg.MakeTags("time", t)
@@ -678,7 +678,7 @@ func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, comm
678 678
 // Notice sends the client a notice from the server.
679 679
 func (client *Client) Notice(text string) {
680 680
 	limit := 400
681
-	if client.capabilities[caps.MaxLine] {
681
+	if client.capabilities.Has(caps.MaxLine) {
682 682
 		limit = client.server.limits.LineLen.Rest - 110
683 683
 	}
684 684
 	lines := wordWrap(text, limit)

+ 3
- 3
irc/client_lookup_set.go 查看文件

@@ -156,7 +156,7 @@ func (clients *ClientLookupSet) Replace(oldNick, newNick string, client *Client)
156 156
 }
157 157
 
158 158
 // AllWithCaps returns all clients with the given capabilities.
159
-func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set ClientSet) {
159
+func (clients *ClientLookupSet) AllWithCaps(capabs ...caps.Capability) (set ClientSet) {
160 160
 	set = make(ClientSet)
161 161
 
162 162
 	clients.ByNickMutex.RLock()
@@ -164,8 +164,8 @@ func (clients *ClientLookupSet) AllWithCaps(caps ...caps.Capability) (set Client
164 164
 	var client *Client
165 165
 	for _, client = range clients.ByNick {
166 166
 		// make sure they have all the required caps
167
-		for _, Cap := range caps {
168
-			if !client.capabilities[Cap] {
167
+		for _, capab := range capabs {
168
+			if !client.capabilities.Has(capab) {
169 169
 				continue
170 170
 			}
171 171
 		}

+ 2
- 2
irc/roleplay.go 查看文件

@@ -90,7 +90,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt
90 90
 
91 91
 		channel.membersMutex.RLock()
92 92
 		for member := range channel.members {
93
-			if member == client && !client.capabilities[caps.EchoMessage] {
93
+			if member == client && !client.capabilities.Has(caps.EchoMessage) {
94 94
 				continue
95 95
 			}
96 96
 			member.Send(nil, source, "PRIVMSG", channel.name, message)
@@ -110,7 +110,7 @@ func sendRoleplayMessage(server *Server, client *Client, source string, targetSt
110 110
 		}
111 111
 
112 112
 		user.Send(nil, source, "PRIVMSG", user.nick, message)
113
-		if client.capabilities[caps.EchoMessage] {
113
+		if client.capabilities.Has(caps.EchoMessage) {
114 114
 			client.Send(nil, source, "PRIVMSG", user.nick, message)
115 115
 		}
116 116
 		if user.flags[Away] {

+ 46
- 46
irc/server.go 查看文件

@@ -642,11 +642,11 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
642 642
 
643 643
 	// send RENAME messages
644 644
 	for mcl := range channel.members {
645
-		if mcl.capabilities[caps.Rename] {
645
+		if mcl.capabilities.Has(caps.Rename) {
646 646
 			mcl.Send(nil, client.nickMaskString, "RENAME", oldName, newName, reason)
647 647
 		} else {
648 648
 			mcl.Send(nil, mcl.nickMaskString, "PART", oldName, fmt.Sprintf("Channel renamed: %s", reason))
649
-			if mcl.capabilities[caps.ExtendedJoin] {
649
+			if mcl.capabilities.Has(caps.ExtendedJoin) {
650 650
 				accountName := "*"
651 651
 				if mcl.account != nil {
652 652
 					accountName = mcl.account.Name
@@ -825,7 +825,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
825 825
 	message := msg.Params[1]
826 826
 
827 827
 	// split privmsg
828
-	splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine])
828
+	splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine))
829 829
 
830 830
 	for i, targetString := range targets {
831 831
 		// max of four targets per privmsg
@@ -869,7 +869,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
869 869
 				}
870 870
 				continue
871 871
 			}
872
-			if !user.capabilities[caps.MessageTags] {
872
+			if !user.capabilities.Has(caps.MessageTags) {
873 873
 				clientOnlyTags = nil
874 874
 			}
875 875
 			msgid := server.generateMessageID()
@@ -878,7 +878,7 @@ func privmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool
878 878
 			if !user.flags[RegisteredOnly] || client.registered {
879 879
 				user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg)
880 880
 			}
881
-			if client.capabilities[caps.EchoMessage] {
881
+			if client.capabilities.Has(caps.EchoMessage) {
882 882
 				client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "PRIVMSG", user.nick, splitMsg)
883 883
 			}
884 884
 			if user.flags[Away] {
@@ -939,11 +939,11 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
939 939
 			msgid := server.generateMessageID()
940 940
 
941 941
 			// end user can't receive tagmsgs
942
-			if !user.capabilities[caps.MessageTags] {
942
+			if !user.capabilities.Has(caps.MessageTags) {
943 943
 				continue
944 944
 			}
945 945
 			user.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick)
946
-			if client.capabilities[caps.EchoMessage] {
946
+			if client.capabilities.Has(caps.EchoMessage) {
947 947
 				client.SendFromClient(msgid, client, clientOnlyTags, "TAGMSG", user.nick)
948 948
 			}
949 949
 			if user.flags[Away] {
@@ -957,7 +957,7 @@ func tagmsgHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
957 957
 
958 958
 // WhoisChannelsNames returns the common channel names between two users.
959 959
 func (client *Client) WhoisChannelsNames(target *Client) []string {
960
-	isMultiPrefix := target.capabilities[caps.MultiPrefix]
960
+	isMultiPrefix := target.capabilities.Has(caps.MultiPrefix)
961 961
 	var chstrs []string
962 962
 	index := 0
963 963
 	for channel := range client.channels {
@@ -1062,7 +1062,7 @@ func (target *Client) RplWhoReplyNoMutex(channel *Channel, client *Client) {
1062 1062
 	}
1063 1063
 
1064 1064
 	if channel != nil {
1065
-		flags += channel.members[client].Prefixes(target.capabilities[caps.MultiPrefix])
1065
+		flags += channel.members[client].Prefixes(target.capabilities.Has(caps.MultiPrefix))
1066 1066
 		channelName = channel.name
1067 1067
 	}
1068 1068
 	target.Send(nil, target.server.name, RPL_WHOREPLY, target.nick, channelName, client.username, client.hostname, client.server.name, client.nick, flags, strconv.Itoa(client.hops)+" "+client.realname)
@@ -1288,66 +1288,66 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
1288 1288
 	server.connectionLimitsMutex.Unlock()
1289 1289
 
1290 1290
 	// setup new and removed caps
1291
-	addedCaps := make(CapabilitySet)
1292
-	removedCaps := make(CapabilitySet)
1293
-	updatedCaps := make(CapabilitySet)
1291
+	addedCaps := caps.NewSet()
1292
+	removedCaps := caps.NewSet()
1293
+	updatedCaps := caps.NewSet()
1294 1294
 
1295 1295
 	// SASL
1296 1296
 	if config.Accounts.AuthenticationEnabled && !server.accountAuthenticationEnabled {
1297 1297
 		// enabling SASL
1298
-		SupportedCapabilities[caps.SASL] = true
1299
-		addedCaps[caps.SASL] = true
1298
+		SupportedCapabilities.Enable(caps.SASL)
1299
+		CapValues.Set(caps.SASL, "PLAIN,EXTERNAL")
1300
+		addedCaps.Add(caps.SASL)
1300 1301
 	}
1301 1302
 	if !config.Accounts.AuthenticationEnabled && server.accountAuthenticationEnabled {
1302 1303
 		// disabling SASL
1303
-		SupportedCapabilities[caps.SASL] = false
1304
-		removedCaps[caps.SASL] = true
1304
+		SupportedCapabilities.Disable(caps.SASL)
1305
+		removedCaps.Add(caps.SASL)
1305 1306
 	}
1306 1307
 	server.accountAuthenticationEnabled = config.Accounts.AuthenticationEnabled
1307 1308
 
1308 1309
 	// STS
1309 1310
 	stsValue := config.Server.STS.Value()
1310 1311
 	var stsDisabled bool
1311
-	server.logger.Debug("rehash", "STS Vals", CapValues[caps.STS], stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled))
1312
+	stsCurrentCapValue, _ := CapValues.Get(caps.STS)
1313
+	server.logger.Debug("rehash", "STS Vals", stsCurrentCapValue, stsValue, fmt.Sprintf("server[%v] config[%v]", server.stsEnabled, config.Server.STS.Enabled))
1312 1314
 	if config.Server.STS.Enabled && !server.stsEnabled {
1313 1315
 		// enabling STS
1314
-		SupportedCapabilities[caps.STS] = true
1315
-		addedCaps[caps.STS] = true
1316
-		CapValues[caps.STS] = stsValue
1316
+		SupportedCapabilities.Enable(caps.STS)
1317
+		addedCaps.Add(caps.STS)
1318
+		CapValues.Set(caps.STS, stsValue)
1317 1319
 	} else if !config.Server.STS.Enabled && server.stsEnabled {
1318 1320
 		// disabling STS
1319
-		SupportedCapabilities[caps.STS] = false
1320
-		removedCaps[caps.STS] = true
1321
+		SupportedCapabilities.Disable(caps.STS)
1322
+		removedCaps.Add(caps.STS)
1321 1323
 		stsDisabled = true
1322
-	} else if config.Server.STS.Enabled && server.stsEnabled && stsValue != CapValues[caps.STS] {
1324
+	} else if config.Server.STS.Enabled && server.stsEnabled && stsValue != stsCurrentCapValue {
1323 1325
 		// STS policy updated
1324
-		CapValues[caps.STS] = stsValue
1325
-		updatedCaps[caps.STS] = true
1326
+		CapValues.Set(caps.STS, stsValue)
1327
+		updatedCaps.Add(caps.STS)
1326 1328
 	}
1327 1329
 	server.stsEnabled = config.Server.STS.Enabled
1328 1330
 
1329 1331
 	// burst new and removed caps
1330 1332
 	var capBurstClients ClientSet
1331
-	added := make(map[CapVersion]string)
1333
+	added := make(map[caps.Version]string)
1332 1334
 	var removed string
1333 1335
 
1334 1336
 	// updated caps get DEL'd and then NEW'd
1335 1337
 	// so, we can just add updated ones to both removed and added lists here and they'll be correctly handled
1336
-	server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(Cap301), strconv.Itoa(len(updatedCaps)))
1337
-	if len(updatedCaps) > 0 {
1338
-		for capab := range updatedCaps {
1339
-			addedCaps[capab] = true
1340
-			removedCaps[capab] = true
1341
-		}
1338
+	server.logger.Debug("rehash", "Updated Caps", updatedCaps.String(caps.Cap301, CapValues), strconv.Itoa(updatedCaps.Count()))
1339
+	for _, capab := range updatedCaps.List() {
1340
+		addedCaps.Enable(capab)
1341
+		removedCaps.Enable(capab)
1342 1342
 	}
1343 1343
 
1344
-	if len(addedCaps) > 0 || len(removedCaps) > 0 {
1344
+	if 0 < addedCaps.Count() || 0 < removedCaps.Count() {
1345 1345
 		capBurstClients = server.clients.AllWithCaps(caps.CapNotify)
1346 1346
 
1347
-		added[Cap301] = addedCaps.String(Cap301)
1348
-		added[Cap302] = addedCaps.String(Cap302)
1349
-		// removed never has values
1350
-		removed = removedCaps.String(Cap301)
1347
+		added[caps.Cap301] = addedCaps.String(caps.Cap301, CapValues)
1348
+		added[caps.Cap302] = addedCaps.String(caps.Cap302, CapValues)
1349
+		// removed never has values, so we leave it as Cap301
1350
+		removed = removedCaps.String(caps.Cap301, CapValues)
1351 1351
 	}
1352 1352
 
1353 1353
 	for sClient := range capBurstClients {
@@ -1355,18 +1355,18 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
1355 1355
 			// remove STS policy
1356 1356
 			//TODO(dan): this is an ugly hack. we can write this better.
1357 1357
 			stsPolicy := "sts=duration=0"
1358
-			if len(addedCaps) > 0 {
1359
-				added[Cap302] = added[Cap302] + " " + stsPolicy
1358
+			if 0 < addedCaps.Count() {
1359
+				added[caps.Cap302] = added[caps.Cap302] + " " + stsPolicy
1360 1360
 			} else {
1361
-				addedCaps[caps.STS] = true
1362
-				added[Cap302] = stsPolicy
1361
+				addedCaps.Enable(caps.STS)
1362
+				added[caps.Cap302] = stsPolicy
1363 1363
 			}
1364 1364
 		}
1365 1365
 		// DEL caps and then send NEW ones so that updated caps get removed/added correctly
1366
-		if len(removedCaps) > 0 {
1366
+		if 0 < removedCaps.Count() {
1367 1367
 			sClient.Send(nil, server.name, "CAP", sClient.nick, "DEL", removed)
1368 1368
 		}
1369
-		if len(addedCaps) > 0 {
1369
+		if 0 < addedCaps.Count() {
1370 1370
 			sClient.Send(nil, server.name, "CAP", sClient.nick, "NEW", added[sClient.capVersion])
1371 1371
 		}
1372 1372
 	}
@@ -1707,7 +1707,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
1707 1707
 	message := msg.Params[1]
1708 1708
 
1709 1709
 	// split privmsg
1710
-	splitMsg := server.splitMessage(message, !client.capabilities[caps.MaxLine])
1710
+	splitMsg := server.splitMessage(message, !client.capabilities.Has(caps.MaxLine))
1711 1711
 
1712 1712
 	for i, targetString := range targets {
1713 1713
 		// max of four targets per privmsg
@@ -1748,7 +1748,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
1748 1748
 				// errors silently ignored with NOTICE as per RFC
1749 1749
 				continue
1750 1750
 			}
1751
-			if !user.capabilities[caps.MessageTags] {
1751
+			if !user.capabilities.Has(caps.MessageTags) {
1752 1752
 				clientOnlyTags = nil
1753 1753
 			}
1754 1754
 			msgid := server.generateMessageID()
@@ -1757,7 +1757,7 @@ func noticeHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
1757 1757
 			if !user.flags[RegisteredOnly] || client.registered {
1758 1758
 				user.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg)
1759 1759
 			}
1760
-			if client.capabilities[caps.EchoMessage] {
1760
+			if client.capabilities.Has(caps.EchoMessage) {
1761 1761
 				client.SendSplitMsgFromClient(msgid, client, clientOnlyTags, "NOTICE", user.nick, splitMsg)
1762 1762
 			}
1763 1763
 		}

Loading…
取消
儲存