瀏覽代碼

introduce "flat ip" representations

tags/v2.5.0-rc1
Shivaram Lingamneni 3 年之前
父節點
當前提交
44cc4c2092
共有 6 個檔案被更改,包括 488 行新增91 行删除
  1. 1
    0
      Makefile
  2. 48
    40
      irc/connection_limits/limiter.go
  3. 19
    6
      irc/connection_limits/limiter_test.go
  4. 29
    45
      irc/dline.go
  5. 217
    0
      irc/flatip/flatip.go
  6. 174
    0
      irc/flatip/flatip_test.go

+ 1
- 0
Makefile 查看文件

25
 	cd irc/cloaks && go test . && go vet .
25
 	cd irc/cloaks && go test . && go vet .
26
 	cd irc/connection_limits && go test . && go vet .
26
 	cd irc/connection_limits && go test . && go vet .
27
 	cd irc/email && go test . && go vet .
27
 	cd irc/email && go test . && go vet .
28
+	cd irc/flatip && go test . && go vet .
28
 	cd irc/history && go test . && go vet .
29
 	cd irc/history && go test . && go vet .
29
 	cd irc/isupport && go test . && go vet .
30
 	cd irc/isupport && go test . && go vet .
30
 	cd irc/migrations && go test . && go vet .
31
 	cd irc/migrations && go test . && go vet .

+ 48
- 40
irc/connection_limits/limiter.go 查看文件

4
 package connection_limits
4
 package connection_limits
5
 
5
 
6
 import (
6
 import (
7
+	"crypto/md5"
7
 	"errors"
8
 	"errors"
8
 	"fmt"
9
 	"fmt"
9
 	"net"
10
 	"net"
10
 	"sync"
11
 	"sync"
11
 	"time"
12
 	"time"
12
 
13
 
14
+	"github.com/oragono/oragono/irc/flatip"
13
 	"github.com/oragono/oragono/irc/utils"
15
 	"github.com/oragono/oragono/irc/utils"
14
 )
16
 )
15
 
17
 
26
 
28
 
27
 // tuples the key-value pair of a CIDR and its custom limit/throttle values
29
 // tuples the key-value pair of a CIDR and its custom limit/throttle values
28
 type customLimit struct {
30
 type customLimit struct {
29
-	name          string
31
+	name          [16]byte
30
 	maxConcurrent int
32
 	maxConcurrent int
31
 	maxPerWindow  int
33
 	maxPerWindow  int
32
-	nets          []net.IPNet
34
+	nets          []flatip.IPNet
35
+}
36
+
37
+type limiterKey struct {
38
+	maskedIP  flatip.IP
39
+	prefixLen uint8 // 0 for the fake nets we generate for custom limits
33
 }
40
 }
34
 
41
 
35
 // LimiterConfig controls the automated connection limits.
42
 // LimiterConfig controls the automated connection limits.
55
 type LimiterConfig struct {
62
 type LimiterConfig struct {
56
 	rawLimiterConfig
63
 	rawLimiterConfig
57
 
64
 
58
-	ipv4Mask     net.IPMask
59
-	ipv6Mask     net.IPMask
60
-	exemptedNets []net.IPNet
65
+	exemptedNets []flatip.IPNet
61
 	customLimits []customLimit
66
 	customLimits []customLimit
62
 }
67
 }
63
 
68
 
69
 }
74
 }
70
 
75
 
71
 func (config *LimiterConfig) postprocess() (err error) {
76
 func (config *LimiterConfig) postprocess() (err error) {
72
-	config.exemptedNets, err = utils.ParseNetList(config.Exempted)
77
+	exemptedNets, err := utils.ParseNetList(config.Exempted)
73
 	if err != nil {
78
 	if err != nil {
74
 		return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error())
79
 		return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error())
75
 	}
80
 	}
81
+	config.exemptedNets = make([]flatip.IPNet, len(exemptedNets))
82
+	for i, exempted := range exemptedNets {
83
+		config.exemptedNets[i] = flatip.FromNetIPNet(exempted)
84
+	}
76
 
85
 
77
 	for identifier, customLimitConf := range config.CustomLimits {
86
 	for identifier, customLimitConf := range config.CustomLimits {
78
-		nets := make([]net.IPNet, len(customLimitConf.Nets))
87
+		nets := make([]flatip.IPNet, len(customLimitConf.Nets))
79
 		for i, netStr := range customLimitConf.Nets {
88
 		for i, netStr := range customLimitConf.Nets {
80
-			normalizedNet, err := utils.NormalizedNetFromString(netStr)
89
+			normalizedNet, err := flatip.ParseToNormalizedNet(netStr)
81
 			if err != nil {
90
 			if err != nil {
82
 				return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err)
91
 				return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err)
83
 			}
92
 			}
86
 		if len(customLimitConf.Nets) == 0 {
95
 		if len(customLimitConf.Nets) == 0 {
87
 			// see #1421: this is the legacy config format where the
96
 			// see #1421: this is the legacy config format where the
88
 			// dictionary key of the block is a CIDR string
97
 			// dictionary key of the block is a CIDR string
89
-			normalizedNet, err := utils.NormalizedNetFromString(identifier)
98
+			normalizedNet, err := flatip.ParseToNormalizedNet(identifier)
90
 			if err != nil {
99
 			if err != nil {
91
 				return fmt.Errorf("Custom limit block %s has no defined nets", identifier)
100
 				return fmt.Errorf("Custom limit block %s has no defined nets", identifier)
92
 			}
101
 			}
93
-			nets = []net.IPNet{normalizedNet}
102
+			nets = []flatip.IPNet{normalizedNet}
94
 		}
103
 		}
95
 		config.customLimits = append(config.customLimits, customLimit{
104
 		config.customLimits = append(config.customLimits, customLimit{
96
 			maxConcurrent: customLimitConf.MaxConcurrent,
105
 			maxConcurrent: customLimitConf.MaxConcurrent,
97
 			maxPerWindow:  customLimitConf.MaxPerWindow,
106
 			maxPerWindow:  customLimitConf.MaxPerWindow,
98
-			name:          "*" + identifier,
107
+			name:          md5.Sum([]byte(identifier)),
99
 			nets:          nets,
108
 			nets:          nets,
100
 		})
109
 		})
101
 	}
110
 	}
102
 
111
 
103
-	config.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32)
104
-	config.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128)
105
-
106
 	return nil
112
 	return nil
107
 }
113
 }
108
 
114
 
113
 	config *LimiterConfig
119
 	config *LimiterConfig
114
 
120
 
115
 	// IP/CIDR -> count of clients connected from there:
121
 	// IP/CIDR -> count of clients connected from there:
116
-	limiter map[string]int
122
+	limiter map[limiterKey]int
117
 	// IP/CIDR -> throttle state:
123
 	// IP/CIDR -> throttle state:
118
-	throttler map[string]ThrottleDetails
124
+	throttler map[limiterKey]ThrottleDetails
119
 }
125
 }
120
 
126
 
121
 // addrToKey canonicalizes `addr` to a string key, and returns
127
 // addrToKey canonicalizes `addr` to a string key, and returns
122
 // the relevant connection limit and throttle max-per-window values
128
 // the relevant connection limit and throttle max-per-window values
123
-func (cl *Limiter) addrToKey(addr net.IP) (key string, limit int, throttle int) {
124
-	// `key` will be a CIDR string like "8.8.8.8/32" or "2001:0db8::/32"
129
+func (cl *Limiter) addrToKey(flat flatip.IP) (key limiterKey, limit int, throttle int) {
125
 	for _, custom := range cl.config.customLimits {
130
 	for _, custom := range cl.config.customLimits {
126
 		for _, net := range custom.nets {
131
 		for _, net := range custom.nets {
127
-			if net.Contains(addr) {
128
-				return custom.name, custom.maxConcurrent, custom.maxPerWindow
132
+			if net.Contains(flat) {
133
+				return limiterKey{maskedIP: custom.name, prefixLen: 0}, custom.maxConcurrent, custom.maxPerWindow
129
 			}
134
 			}
130
 		}
135
 		}
131
 	}
136
 	}
132
 
137
 
133
-	var ipNet net.IPNet
134
-	addrv4 := addr.To4()
135
-	if addrv4 != nil {
136
-		ipNet = net.IPNet{
137
-			IP:   addrv4.Mask(cl.config.ipv4Mask),
138
-			Mask: cl.config.ipv4Mask,
139
-		}
138
+	var prefixLen int
139
+	if flat.IsIPv4() {
140
+		prefixLen = cl.config.CidrLenIPv4
141
+		flat = flat.Mask(prefixLen, 32)
142
+		prefixLen += 96
140
 	} else {
143
 	} else {
141
-		ipNet = net.IPNet{
142
-			IP:   addr.Mask(cl.config.ipv6Mask),
143
-			Mask: cl.config.ipv6Mask,
144
-		}
144
+		prefixLen = cl.config.CidrLenIPv6
145
+		flat = flat.Mask(prefixLen, 128)
145
 	}
146
 	}
146
-	return ipNet.String(), cl.config.MaxConcurrent, cl.config.MaxPerWindow
147
+
148
+	return limiterKey{maskedIP: flat, prefixLen: uint8(prefixLen)}, cl.config.MaxConcurrent, cl.config.MaxPerWindow
147
 }
149
 }
148
 
150
 
149
 // AddClient adds a client to our population if possible. If we can't, throws an error instead.
151
 // AddClient adds a client to our population if possible. If we can't, throws an error instead.
150
 func (cl *Limiter) AddClient(addr net.IP) error {
152
 func (cl *Limiter) AddClient(addr net.IP) error {
153
+	flat := flatip.FromNetIP(addr)
154
+
151
 	cl.Lock()
155
 	cl.Lock()
152
 	defer cl.Unlock()
156
 	defer cl.Unlock()
153
 
157
 
154
 	// we don't track populations for exempted addresses or nets - this is by design
158
 	// we don't track populations for exempted addresses or nets - this is by design
155
-	if utils.IPInNets(addr, cl.config.exemptedNets) {
159
+	if flatip.IPInNets(flat, cl.config.exemptedNets) {
156
 		return nil
160
 		return nil
157
 	}
161
 	}
158
 
162
 
159
-	addrString, maxConcurrent, maxPerWindow := cl.addrToKey(addr)
163
+	addrString, maxConcurrent, maxPerWindow := cl.addrToKey(flat)
160
 
164
 
161
 	// XXX check throttle first; if we checked limit first and then checked throttle,
165
 	// XXX check throttle first; if we checked limit first and then checked throttle,
162
 	// we'd have to decrement the limit on an unsuccessful throttle check
166
 	// we'd have to decrement the limit on an unsuccessful throttle check
189
 
193
 
190
 // RemoveClient removes the given address from our population
194
 // RemoveClient removes the given address from our population
191
 func (cl *Limiter) RemoveClient(addr net.IP) {
195
 func (cl *Limiter) RemoveClient(addr net.IP) {
196
+	flat := flatip.FromNetIP(addr)
197
+
192
 	cl.Lock()
198
 	cl.Lock()
193
 	defer cl.Unlock()
199
 	defer cl.Unlock()
194
 
200
 
195
-	if !cl.config.Count || utils.IPInNets(addr, cl.config.exemptedNets) {
201
+	if !cl.config.Count || flatip.IPInNets(flat, cl.config.exemptedNets) {
196
 		return
202
 		return
197
 	}
203
 	}
198
 
204
 
199
-	addrString, _, _ := cl.addrToKey(addr)
205
+	addrString, _, _ := cl.addrToKey(flat)
200
 	count := cl.limiter[addrString]
206
 	count := cl.limiter[addrString]
201
 	count -= 1
207
 	count -= 1
202
 	if count < 0 {
208
 	if count < 0 {
207
 
213
 
208
 // ResetThrottle resets the throttle count for an IP
214
 // ResetThrottle resets the throttle count for an IP
209
 func (cl *Limiter) ResetThrottle(addr net.IP) {
215
 func (cl *Limiter) ResetThrottle(addr net.IP) {
216
+	flat := flatip.FromNetIP(addr)
217
+
210
 	cl.Lock()
218
 	cl.Lock()
211
 	defer cl.Unlock()
219
 	defer cl.Unlock()
212
 
220
 
213
-	if !cl.config.Throttle || utils.IPInNets(addr, cl.config.exemptedNets) {
221
+	if !cl.config.Throttle || flatip.IPInNets(flat, cl.config.exemptedNets) {
214
 		return
222
 		return
215
 	}
223
 	}
216
 
224
 
217
-	addrString, _, _ := cl.addrToKey(addr)
225
+	addrString, _, _ := cl.addrToKey(flat)
218
 	delete(cl.throttler, addrString)
226
 	delete(cl.throttler, addrString)
219
 }
227
 }
220
 
228
 
224
 	defer cl.Unlock()
232
 	defer cl.Unlock()
225
 
233
 
226
 	if cl.limiter == nil {
234
 	if cl.limiter == nil {
227
-		cl.limiter = make(map[string]int)
235
+		cl.limiter = make(map[limiterKey]int)
228
 	}
236
 	}
229
 	if cl.throttler == nil {
237
 	if cl.throttler == nil {
230
-		cl.throttler = make(map[string]ThrottleDetails)
238
+		cl.throttler = make(map[limiterKey]ThrottleDetails)
231
 	}
239
 	}
232
 
240
 
233
 	cl.config = config
241
 	cl.config = config

+ 19
- 6
irc/connection_limits/limiter_test.go 查看文件

4
 package connection_limits
4
 package connection_limits
5
 
5
 
6
 import (
6
 import (
7
+	"crypto/md5"
7
 	"net"
8
 	"net"
8
 	"testing"
9
 	"testing"
9
 	"time"
10
 	"time"
11
+
12
+	"github.com/oragono/oragono/irc/flatip"
10
 )
13
 )
11
 
14
 
12
 func easyParseIP(ipstr string) (result net.IP) {
15
 func easyParseIP(ipstr string) (result net.IP) {
17
 	return
20
 	return
18
 }
21
 }
19
 
22
 
23
+func easyParseFlat(ipstr string) (result flatip.IP) {
24
+	r1 := easyParseIP(ipstr)
25
+	return flatip.FromNetIP(r1)
26
+}
27
+
20
 var baseConfig = LimiterConfig{
28
 var baseConfig = LimiterConfig{
21
 	rawLimiterConfig: rawLimiterConfig{
29
 	rawLimiterConfig: rawLimiterConfig{
22
 		Count:         true,
30
 		Count:         true,
47
 	var limiter Limiter
55
 	var limiter Limiter
48
 	limiter.ApplyConfig(&config)
56
 	limiter.ApplyConfig(&config)
49
 
57
 
50
-	key, maxConc, maxWin := limiter.addrToKey(easyParseIP("1.1.1.1"))
51
-	assertEqual(key, "1.1.1.1/32", t)
58
+	// an ipv4 /32 looks like a /128 to us after applying the 4-in-6 mapping
59
+	key, maxConc, maxWin := limiter.addrToKey(easyParseFlat("1.1.1.1"))
60
+	assertEqual(key.prefixLen, uint8(128), t)
61
+	assertEqual(key.maskedIP[12:], []byte{1, 1, 1, 1}, t)
52
 	assertEqual(maxConc, 4, t)
62
 	assertEqual(maxConc, 4, t)
53
 	assertEqual(maxWin, 8, t)
63
 	assertEqual(maxWin, 8, t)
54
 
64
 
55
-	key, maxConc, maxWin = limiter.addrToKey(easyParseIP("2607:5301:201:3100::7426"))
56
-	assertEqual(key, "2607:5301:201:3100::/64", t)
65
+	testIPv6 := easyParseFlat("2607:5301:201:3100::7426")
66
+	key, maxConc, maxWin = limiter.addrToKey(testIPv6)
67
+	assertEqual(key.prefixLen, uint8(64), t)
68
+	assertEqual(key.maskedIP[:], []byte(easyParseIP("2607:5301:201:3100::")), t)
57
 	assertEqual(maxConc, 4, t)
69
 	assertEqual(maxConc, 4, t)
58
 	assertEqual(maxWin, 8, t)
70
 	assertEqual(maxWin, 8, t)
59
 
71
 
60
-	key, maxConc, maxWin = limiter.addrToKey(easyParseIP("8.8.4.4"))
61
-	assertEqual(key, "*google", t)
72
+	key, maxConc, maxWin = limiter.addrToKey(easyParseFlat("8.8.4.4"))
73
+	assertEqual(key.prefixLen, uint8(0), t)
74
+	assertEqual([16]byte(key.maskedIP), md5.Sum([]byte("google")), t)
62
 	assertEqual(maxConc, 128, t)
75
 	assertEqual(maxConc, 128, t)
63
 	assertEqual(maxWin, 256, t)
76
 	assertEqual(maxWin, 256, t)
64
 }
77
 }

+ 29
- 45
irc/dline.go 查看文件

11
 	"sync"
11
 	"sync"
12
 	"time"
12
 	"time"
13
 
13
 
14
+	"github.com/oragono/oragono/irc/flatip"
14
 	"github.com/oragono/oragono/irc/utils"
15
 	"github.com/oragono/oragono/irc/utils"
15
 	"github.com/tidwall/buntdb"
16
 	"github.com/tidwall/buntdb"
16
 )
17
 )
54
 	return message
55
 	return message
55
 }
56
 }
56
 
57
 
57
-// dLineNet contains the net itself and expiration time for a given network.
58
-type dLineNet struct {
59
-	// Network is the network that is blocked.
60
-	// This is always an IPv6 CIDR; IPv4 CIDRs are translated with the 4-in-6 prefix,
61
-	// individual IPv4 and IPV6 addresses are translated to the relevant /128.
62
-	Network net.IPNet
63
-	// Info contains information on the ban.
64
-	Info IPBanInfo
65
-}
66
-
67
 // DLineManager manages and dlines.
58
 // DLineManager manages and dlines.
68
 type DLineManager struct {
59
 type DLineManager struct {
69
 	sync.RWMutex                // tier 1
60
 	sync.RWMutex                // tier 1
70
 	persistenceMutex sync.Mutex // tier 2
61
 	persistenceMutex sync.Mutex // tier 2
71
 	// networks that are dlined:
62
 	// networks that are dlined:
72
-	// XXX: the keys of this map (which are also the database persistence keys)
73
-	// are the human-readable representations returned by NetToNormalizedString
74
-	networks map[string]dLineNet
63
+	networks map[flatip.IPNet]IPBanInfo
75
 	// this keeps track of expiration timers for temporary bans
64
 	// this keeps track of expiration timers for temporary bans
76
-	expirationTimers map[string]*time.Timer
65
+	expirationTimers map[flatip.IPNet]*time.Timer
77
 	server           *Server
66
 	server           *Server
78
 }
67
 }
79
 
68
 
80
 // NewDLineManager returns a new DLineManager.
69
 // NewDLineManager returns a new DLineManager.
81
 func NewDLineManager(server *Server) *DLineManager {
70
 func NewDLineManager(server *Server) *DLineManager {
82
 	var dm DLineManager
71
 	var dm DLineManager
83
-	dm.networks = make(map[string]dLineNet)
84
-	dm.expirationTimers = make(map[string]*time.Timer)
72
+	dm.networks = make(map[flatip.IPNet]IPBanInfo)
73
+	dm.expirationTimers = make(map[flatip.IPNet]*time.Timer)
85
 	dm.server = server
74
 	dm.server = server
86
 
75
 
87
 	dm.loadFromDatastore()
76
 	dm.loadFromDatastore()
96
 	dm.RLock()
85
 	dm.RLock()
97
 	defer dm.RUnlock()
86
 	defer dm.RUnlock()
98
 
87
 
99
-	// map keys are already the human-readable forms, just return a copy of the map
100
 	for key, info := range dm.networks {
88
 	for key, info := range dm.networks {
101
-		allb[key] = info.Info
89
+		allb[key.String()] = info
102
 	}
90
 	}
103
 
91
 
104
 	return allb
92
 	return allb
122
 	return dm.persistDline(id, info)
110
 	return dm.persistDline(id, info)
123
 }
111
 }
124
 
112
 
125
-func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id string) {
126
-	network = utils.NormalizeNet(network)
127
-	id = utils.NetToNormalizedString(network)
113
+func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id flatip.IPNet) {
114
+	flatnet := flatip.FromNetIPNet(network)
115
+	id = flatnet
128
 
116
 
129
 	var timeLeft time.Duration
117
 	var timeLeft time.Duration
130
 	if info.Duration != 0 {
118
 	if info.Duration != 0 {
137
 	dm.Lock()
125
 	dm.Lock()
138
 	defer dm.Unlock()
126
 	defer dm.Unlock()
139
 
127
 
140
-	dm.networks[id] = dLineNet{
141
-		Network: network,
142
-		Info:    info,
143
-	}
128
+	dm.networks[flatnet] = info
144
 
129
 
145
-	dm.cancelTimer(id)
130
+	dm.cancelTimer(flatnet)
146
 
131
 
147
 	if info.Duration == 0 {
132
 	if info.Duration == 0 {
148
 		return
133
 		return
154
 		dm.Lock()
139
 		dm.Lock()
155
 		defer dm.Unlock()
140
 		defer dm.Unlock()
156
 
141
 
157
-		netBan, ok := dm.networks[id]
158
-		if ok && netBan.Info.TimeCreated.Equal(timeCreated) {
159
-			delete(dm.networks, id)
142
+		banInfo, ok := dm.networks[flatnet]
143
+		if ok && banInfo.TimeCreated.Equal(timeCreated) {
144
+			delete(dm.networks, flatnet)
160
 			// TODO(slingamn) here's where we'd remove it from the radix tree
145
 			// TODO(slingamn) here's where we'd remove it from the radix tree
161
-			delete(dm.expirationTimers, id)
146
+			delete(dm.expirationTimers, flatnet)
162
 		}
147
 		}
163
 	}
148
 	}
164
-	dm.expirationTimers[id] = time.AfterFunc(timeLeft, processExpiration)
149
+	dm.expirationTimers[flatnet] = time.AfterFunc(timeLeft, processExpiration)
165
 
150
 
166
 	return
151
 	return
167
 }
152
 }
168
 
153
 
169
-func (dm *DLineManager) cancelTimer(id string) {
170
-	oldTimer := dm.expirationTimers[id]
154
+func (dm *DLineManager) cancelTimer(flatnet flatip.IPNet) {
155
+	oldTimer := dm.expirationTimers[flatnet]
171
 	if oldTimer != nil {
156
 	if oldTimer != nil {
172
 		oldTimer.Stop()
157
 		oldTimer.Stop()
173
-		delete(dm.expirationTimers, id)
158
+		delete(dm.expirationTimers, flatnet)
174
 	}
159
 	}
175
 }
160
 }
176
 
161
 
177
-func (dm *DLineManager) persistDline(id string, info IPBanInfo) error {
162
+func (dm *DLineManager) persistDline(id flatip.IPNet, info IPBanInfo) error {
178
 	// save in datastore
163
 	// save in datastore
179
-	dlineKey := fmt.Sprintf(keyDlineEntry, id)
164
+	dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
180
 	// assemble json from ban info
165
 	// assemble json from ban info
181
 	b, err := json.Marshal(info)
166
 	b, err := json.Marshal(info)
182
 	if err != nil {
167
 	if err != nil {
199
 	return err
184
 	return err
200
 }
185
 }
201
 
186
 
202
-func (dm *DLineManager) unpersistDline(id string) error {
203
-	dlineKey := fmt.Sprintf(keyDlineEntry, id)
187
+func (dm *DLineManager) unpersistDline(id flatip.IPNet) error {
188
+	dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
204
 	return dm.server.store.Update(func(tx *buntdb.Tx) error {
189
 	return dm.server.store.Update(func(tx *buntdb.Tx) error {
205
 		_, err := tx.Delete(dlineKey)
190
 		_, err := tx.Delete(dlineKey)
206
 		return err
191
 		return err
212
 	dm.persistenceMutex.Lock()
197
 	dm.persistenceMutex.Lock()
213
 	defer dm.persistenceMutex.Unlock()
198
 	defer dm.persistenceMutex.Unlock()
214
 
199
 
215
-	id := utils.NetToNormalizedString(utils.NormalizeNet(network))
200
+	id := flatip.FromNetIPNet(network)
216
 
201
 
217
 	present := func() bool {
202
 	present := func() bool {
218
 		dm.Lock()
203
 		dm.Lock()
241
 }
226
 }
242
 
227
 
243
 // CheckIP returns whether or not an IP address was banned, and how long it is banned for.
228
 // CheckIP returns whether or not an IP address was banned, and how long it is banned for.
244
-func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) {
245
-	addr = addr.To16() // almost certainly unnecessary
229
+func (dm *DLineManager) CheckIP(netAddr net.IP) (isBanned bool, info IPBanInfo) {
230
+	addr := flatip.FromNetIP(netAddr)
246
 	if addr.IsLoopback() {
231
 	if addr.IsLoopback() {
247
 		return // #671
232
 		return // #671
248
 	}
233
 	}
252
 
237
 
253
 	// check networks
238
 	// check networks
254
 	// TODO(slingamn) use a radix tree as the data plane for this
239
 	// TODO(slingamn) use a radix tree as the data plane for this
255
-	for _, netBan := range dm.networks {
256
-		if netBan.Network.Contains(addr) {
257
-			return true, netBan.Info
240
+	for flatnet, info := range dm.networks {
241
+		if flatnet.Contains(addr) {
242
+			return true, info
258
 		}
243
 		}
259
 	}
244
 	}
260
 	// no matches!
245
 	// no matches!
261
-	isBanned = false
262
 	return
246
 	return
263
 }
247
 }
264
 
248
 

+ 217
- 0
irc/flatip/flatip.go 查看文件

1
+// Copyright 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
2
+// Copyright 2009 The Go Authors
3
+
4
+package flatip
5
+
6
+import (
7
+	"bytes"
8
+	"errors"
9
+	"net"
10
+)
11
+
12
+var (
13
+	v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
14
+
15
+	IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
16
+
17
+	ErrInvalidIPString = errors.New("String could not be interpreted as an IP address")
18
+)
19
+
20
+// packed versions of net.IP and net.IPNet; these are pure value types,
21
+// so they can be compared with == and used as map keys.
22
+
23
+// IP is the 128-bit representation of the IPv6 address, using the 4-in-6 mapping
24
+// if necessary:
25
+type IP [16]byte
26
+
27
+// IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes.
28
+type IPNet struct {
29
+	IP
30
+	PrefixLen uint8
31
+}
32
+
33
+// NetIP converts an IP into a net.IP.
34
+func (ip IP) NetIP() (result net.IP) {
35
+	result = make(net.IP, 16)
36
+	copy(result[:], ip[:])
37
+	return
38
+}
39
+
40
+// FromNetIP converts a net.IP into an IP.
41
+func FromNetIP(ip net.IP) (result IP) {
42
+	if len(ip) == 16 {
43
+		copy(result[:], ip[:])
44
+	} else {
45
+		result[10] = 0xff
46
+		result[11] = 0xff
47
+		copy(result[12:], ip[:])
48
+	}
49
+	return
50
+}
51
+
52
+// IPv4 returns the IP address representation of a.b.c.d
53
+func IPv4(a, b, c, d byte) (result IP) {
54
+	copy(result[:12], v4InV6Prefix)
55
+	result[12] = a
56
+	result[13] = b
57
+	result[14] = c
58
+	result[15] = d
59
+	return
60
+}
61
+
62
+// ParseIP parses a string representation of an IP address into an IP.
63
+// Unlike net.ParseIP, it returns an error instead of a zero value on failure,
64
+// since the zero value of `IP` is a representation of a valid IP (::0, the
65
+// IPv6 "unspecified address").
66
+func ParseIP(ipstr string) (ip IP, err error) {
67
+	// TODO reimplement this without net.ParseIP
68
+	netip := net.ParseIP(ipstr)
69
+	if netip == nil {
70
+		err = ErrInvalidIPString
71
+		return
72
+	}
73
+	netip = netip.To16()
74
+	copy(ip[:], netip)
75
+	return
76
+}
77
+
78
+// String returns the string representation of an IP
79
+func (ip IP) String() string {
80
+	// TODO reimplement this without using (net.IP).String()
81
+	return (net.IP)(ip[:]).String()
82
+}
83
+
84
+// IsIPv4 returns whether the IP is an IPv4 address.
85
+func (ip IP) IsIPv4() bool {
86
+	return bytes.Equal(ip[:12], v4InV6Prefix)
87
+}
88
+
89
+// IsLoopback returns whether the IP is a loopback address.
90
+func (ip IP) IsLoopback() bool {
91
+	if ip.IsIPv4() {
92
+		return ip[12] == 127
93
+	} else {
94
+		return ip == IPv6loopback
95
+	}
96
+}
97
+
98
+func rawCidrMask(length int) (m IP) {
99
+	n := uint(length)
100
+	for i := 0; i < 16; i++ {
101
+		if n >= 8 {
102
+			m[i] = 0xff
103
+			n -= 8
104
+			continue
105
+		}
106
+		m[i] = ^byte(0xff >> n)
107
+		return
108
+	}
109
+	return
110
+}
111
+
112
+func (ip IP) applyMask(mask IP) (result IP) {
113
+	for i := 0; i < 16; i += 1 {
114
+		result[i] = ip[i] & mask[i]
115
+	}
116
+	return
117
+}
118
+
119
+func cidrMask(ones, bits int) (result IP) {
120
+	switch bits {
121
+	case 32:
122
+		return rawCidrMask(96 + ones)
123
+	case 128:
124
+		return rawCidrMask(ones)
125
+	default:
126
+		return
127
+	}
128
+}
129
+
130
+// Mask returns the result of masking ip with the CIDR mask of
131
+// length 'ones', out of a total of 'bits' (which must be either
132
+// 32 for an IPv4 subnet or 128 for an IPv6 subnet).
133
+func (ip IP) Mask(ones, bits int) (result IP) {
134
+	return ip.applyMask(cidrMask(ones, bits))
135
+}
136
+
137
+// ToNetIPNet converts an IPNet into a net.IPNet.
138
+func (cidr IPNet) ToNetIPNet() (result net.IPNet) {
139
+	return net.IPNet{
140
+		IP:   cidr.IP.NetIP(),
141
+		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
142
+	}
143
+}
144
+
145
+// Contains retuns whether the network contains `ip`.
146
+func (cidr IPNet) Contains(ip IP) bool {
147
+	maskedIP := ip.Mask(int(cidr.PrefixLen), 128)
148
+	return cidr.IP == maskedIP
149
+}
150
+
151
+// FromNetIPnet converts a net.IPNet into an IPNet.
152
+func FromNetIPNet(network net.IPNet) (result IPNet) {
153
+	ones, _ := network.Mask.Size()
154
+	if len(network.IP) == 16 {
155
+		copy(result.IP[:], network.IP[:])
156
+	} else {
157
+		result.IP[10] = 0xff
158
+		result.IP[11] = 0xff
159
+		copy(result.IP[12:], network.IP[:])
160
+		ones += 96
161
+	}
162
+	// perform masking so that equal CIDRs are ==
163
+	result.IP = result.IP.Mask(ones, 128)
164
+	result.PrefixLen = uint8(ones)
165
+	return
166
+}
167
+
168
+// String returns a string representation of an IPNet.
169
+func (cidr IPNet) String() string {
170
+	ip := make(net.IP, 16)
171
+	copy(ip[:], cidr.IP[:])
172
+	ipnet := net.IPNet{
173
+		IP:   ip,
174
+		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
175
+	}
176
+	return ipnet.String()
177
+}
178
+
179
+// ParseCIDR parses a string representation of an IP network in CIDR notation,
180
+// then returns it as an IPNet (along with the original, unmasked address).
181
+func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) {
182
+	// TODO reimplement this without net.ParseCIDR
183
+	nip, nipnet, err := net.ParseCIDR(netstr)
184
+	if err != nil {
185
+		return
186
+	}
187
+	return FromNetIP(nip), FromNetIPNet(*nipnet), nil
188
+}
189
+
190
+// begin ad-hoc utilities
191
+
192
+// ParseToNormalizedNet attempts to interpret a string either as an IP
193
+// network in CIDR notation, returning an IPNet, or as an IP address,
194
+// returning an IPNet that contains only that address.
195
+func ParseToNormalizedNet(netstr string) (ipnet IPNet, err error) {
196
+	_, ipnet, err = ParseCIDR(netstr)
197
+	if err == nil {
198
+		return
199
+	}
200
+	ip, err := ParseIP(netstr)
201
+	if err == nil {
202
+		ipnet.IP = ip
203
+		ipnet.PrefixLen = 128
204
+	}
205
+	return
206
+}
207
+
208
+// IPInNets is a convenience function for testing whether an IP is contained
209
+// in any member of a slice of IPNet's.
210
+func IPInNets(addr IP, nets []IPNet) bool {
211
+	for _, net := range nets {
212
+		if net.Contains(addr) {
213
+			return true
214
+		}
215
+	}
216
+	return false
217
+}

+ 174
- 0
irc/flatip/flatip_test.go 查看文件

1
+package flatip
2
+
3
+import (
4
+	"bytes"
5
+	"math/rand"
6
+	"net"
7
+	"testing"
8
+	"time"
9
+)
10
+
11
+func easyParseIP(ipstr string) (result net.IP) {
12
+	result = net.ParseIP(ipstr)
13
+	if result == nil {
14
+		panic(ipstr)
15
+	}
16
+	return
17
+}
18
+
19
+func easyParseFlat(ipstr string) (result IP) {
20
+	x := easyParseIP(ipstr)
21
+	return FromNetIP(x)
22
+}
23
+
24
+func easyParseIPNet(nipstr string) (result net.IPNet) {
25
+	_, nip, err := net.ParseCIDR(nipstr)
26
+	if err != nil {
27
+		panic(err)
28
+	}
29
+	return *nip
30
+}
31
+
32
+func TestBasic(t *testing.T) {
33
+	nip := easyParseIP("8.8.8.8")
34
+	flatip := FromNetIP(nip)
35
+	if flatip.String() != "8.8.8.8" {
36
+		t.Errorf("conversions don't work")
37
+	}
38
+}
39
+
40
+func TestLoopback(t *testing.T) {
41
+	localhost_v4 := easyParseFlat("127.0.0.1")
42
+	localhost_v4_again := easyParseFlat("127.2.3.4")
43
+	google := easyParseFlat("8.8.8.8")
44
+	loopback_v6 := easyParseFlat("::1")
45
+	google_v6 := easyParseFlat("2607:f8b0:4006:801::2004")
46
+
47
+	if !(localhost_v4.IsLoopback() && localhost_v4_again.IsLoopback() && loopback_v6.IsLoopback()) {
48
+		t.Errorf("can't detect loopbacks")
49
+	}
50
+
51
+	if google_v6.IsLoopback() || google.IsLoopback() {
52
+		t.Errorf("incorrectly detected loopbacks")
53
+	}
54
+}
55
+
56
+func TestContains(t *testing.T) {
57
+	nipnet := easyParseIPNet("8.8.0.0/16")
58
+	flatipnet := FromNetIPNet(nipnet)
59
+	nip := easyParseIP("8.8.8.8")
60
+	flatip_ := FromNetIP(nip)
61
+	if !flatipnet.Contains(flatip_) {
62
+		t.Errorf("contains doesn't work")
63
+	}
64
+}
65
+
66
+var testIPStrs = []string{
67
+	"8.8.8.8",
68
+	"127.0.0.1",
69
+	"1.1.1.1",
70
+	"128.127.65.64",
71
+	"2001:0db8::1",
72
+	"::1",
73
+	"255.255.255.255",
74
+}
75
+
76
+func doMaskingTest(ip net.IP, t *testing.T) {
77
+	flat := FromNetIP(ip)
78
+	netLen := len(ip) * 8
79
+	for i := 0; i < netLen; i++ {
80
+		masked := flat.Mask(i, netLen)
81
+		netMask := net.CIDRMask(i, netLen)
82
+		netMasked := ip.Mask(netMask)
83
+		if !bytes.Equal(masked[:], netMasked.To16()) {
84
+			t.Errorf("Masking %s with %d/%d; expected %s, got %s", ip.String(), i, netLen, netMasked.String(), masked.String())
85
+		}
86
+	}
87
+}
88
+
89
+func TestMasking(t *testing.T) {
90
+	for _, ipstr := range testIPStrs {
91
+		doMaskingTest(easyParseIP(ipstr), t)
92
+	}
93
+}
94
+
95
+func TestMaskingFuzz(t *testing.T) {
96
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
97
+	buf := make([]byte, 4)
98
+	for i := 0; i < 10000; i++ {
99
+		r.Read(buf)
100
+		doMaskingTest(net.IP(buf), t)
101
+	}
102
+
103
+	buf = make([]byte, 16)
104
+	for i := 0; i < 10000; i++ {
105
+		r.Read(buf)
106
+		doMaskingTest(net.IP(buf), t)
107
+	}
108
+}
109
+
110
+func BenchmarkMasking(b *testing.B) {
111
+	ip := easyParseIP("2001:0db8::42")
112
+	flat := FromNetIP(ip)
113
+	b.ResetTimer()
114
+
115
+	for i := 0; i < b.N; i++ {
116
+		flat.Mask(64, 128)
117
+	}
118
+}
119
+
120
+func BenchmarkMaskingLegacy(b *testing.B) {
121
+	ip := easyParseIP("2001:0db8::42")
122
+	mask := net.CIDRMask(64, 128)
123
+	b.ResetTimer()
124
+
125
+	for i := 0; i < b.N; i++ {
126
+		ip.Mask(mask)
127
+	}
128
+}
129
+
130
+func BenchmarkMaskingCached(b *testing.B) {
131
+	i := easyParseIP("2001:0db8::42")
132
+	flat := FromNetIP(i)
133
+	mask := cidrMask(64, 128)
134
+	b.ResetTimer()
135
+
136
+	for i := 0; i < b.N; i++ {
137
+		flat.applyMask(mask)
138
+	}
139
+}
140
+
141
+func BenchmarkMaskingConstruct(b *testing.B) {
142
+	for i := 0; i < b.N; i++ {
143
+		cidrMask(69, 128)
144
+	}
145
+}
146
+
147
+func BenchmarkContains(b *testing.B) {
148
+	ip := easyParseIP("2001:0db8::42")
149
+	flat := FromNetIP(ip)
150
+	_, ipnet, err := net.ParseCIDR("2001:0db8::/64")
151
+	if err != nil {
152
+		panic(err)
153
+	}
154
+	flatnet := FromNetIPNet(*ipnet)
155
+	b.ResetTimer()
156
+
157
+	for i := 0; i < b.N; i++ {
158
+		flatnet.Contains(flat)
159
+	}
160
+}
161
+
162
+func BenchmarkContainsLegacy(b *testing.B) {
163
+	ip := easyParseIP("2001:0db8::42")
164
+	_, ipnetptr, err := net.ParseCIDR("2001:0db8::/64")
165
+	if err != nil {
166
+		panic(err)
167
+	}
168
+	ipnet := *ipnetptr
169
+	b.ResetTimer()
170
+
171
+	for i := 0; i < b.N; i++ {
172
+		ipnet.Contains(ip)
173
+	}
174
+}

Loading…
取消
儲存