소스 검색

introduce "flat ip" representations

tags/v2.5.0-rc1
Shivaram Lingamneni 3 년 전
부모
커밋
44cc4c2092
6개의 변경된 파일488개의 추가작업 그리고 91개의 파일을 삭제
  1. 1
    0
      Makefile
  2. 48
    40
      irc/connection_limits/limiter.go
  3. 19
    6
      irc/connection_limits/limiter_test.go
  4. 29
    45
      irc/dline.go
  5. 217
    0
      irc/flatip/flatip.go
  6. 174
    0
      irc/flatip/flatip_test.go

+ 1
- 0
Makefile 파일 보기

@@ -25,6 +25,7 @@ test:
25 25
 	cd irc/cloaks && go test . && go vet .
26 26
 	cd irc/connection_limits && go test . && go vet .
27 27
 	cd irc/email && go test . && go vet .
28
+	cd irc/flatip && go test . && go vet .
28 29
 	cd irc/history && go test . && go vet .
29 30
 	cd irc/isupport && go test . && go vet .
30 31
 	cd irc/migrations && go test . && go vet .

+ 48
- 40
irc/connection_limits/limiter.go 파일 보기

@@ -4,12 +4,14 @@
4 4
 package connection_limits
5 5
 
6 6
 import (
7
+	"crypto/md5"
7 8
 	"errors"
8 9
 	"fmt"
9 10
 	"net"
10 11
 	"sync"
11 12
 	"time"
12 13
 
14
+	"github.com/oragono/oragono/irc/flatip"
13 15
 	"github.com/oragono/oragono/irc/utils"
14 16
 )
15 17
 
@@ -26,10 +28,15 @@ type CustomLimitConfig struct {
26 28
 
27 29
 // tuples the key-value pair of a CIDR and its custom limit/throttle values
28 30
 type customLimit struct {
29
-	name          string
31
+	name          [16]byte
30 32
 	maxConcurrent int
31 33
 	maxPerWindow  int
32
-	nets          []net.IPNet
34
+	nets          []flatip.IPNet
35
+}
36
+
37
+type limiterKey struct {
38
+	maskedIP  flatip.IP
39
+	prefixLen uint8 // 0 for the fake nets we generate for custom limits
33 40
 }
34 41
 
35 42
 // LimiterConfig controls the automated connection limits.
@@ -55,9 +62,7 @@ type rawLimiterConfig struct {
55 62
 type LimiterConfig struct {
56 63
 	rawLimiterConfig
57 64
 
58
-	ipv4Mask     net.IPMask
59
-	ipv6Mask     net.IPMask
60
-	exemptedNets []net.IPNet
65
+	exemptedNets []flatip.IPNet
61 66
 	customLimits []customLimit
62 67
 }
63 68
 
@@ -69,15 +74,19 @@ func (config *LimiterConfig) UnmarshalYAML(unmarshal func(interface{}) error) (e
69 74
 }
70 75
 
71 76
 func (config *LimiterConfig) postprocess() (err error) {
72
-	config.exemptedNets, err = utils.ParseNetList(config.Exempted)
77
+	exemptedNets, err := utils.ParseNetList(config.Exempted)
73 78
 	if err != nil {
74 79
 		return fmt.Errorf("Could not parse limiter exemption list: %v", err.Error())
75 80
 	}
81
+	config.exemptedNets = make([]flatip.IPNet, len(exemptedNets))
82
+	for i, exempted := range exemptedNets {
83
+		config.exemptedNets[i] = flatip.FromNetIPNet(exempted)
84
+	}
76 85
 
77 86
 	for identifier, customLimitConf := range config.CustomLimits {
78
-		nets := make([]net.IPNet, len(customLimitConf.Nets))
87
+		nets := make([]flatip.IPNet, len(customLimitConf.Nets))
79 88
 		for i, netStr := range customLimitConf.Nets {
80
-			normalizedNet, err := utils.NormalizedNetFromString(netStr)
89
+			normalizedNet, err := flatip.ParseToNormalizedNet(netStr)
81 90
 			if err != nil {
82 91
 				return fmt.Errorf("Bad net %s in custom-limits block %s: %w", netStr, identifier, err)
83 92
 			}
@@ -86,23 +95,20 @@ func (config *LimiterConfig) postprocess() (err error) {
86 95
 		if len(customLimitConf.Nets) == 0 {
87 96
 			// see #1421: this is the legacy config format where the
88 97
 			// dictionary key of the block is a CIDR string
89
-			normalizedNet, err := utils.NormalizedNetFromString(identifier)
98
+			normalizedNet, err := flatip.ParseToNormalizedNet(identifier)
90 99
 			if err != nil {
91 100
 				return fmt.Errorf("Custom limit block %s has no defined nets", identifier)
92 101
 			}
93
-			nets = []net.IPNet{normalizedNet}
102
+			nets = []flatip.IPNet{normalizedNet}
94 103
 		}
95 104
 		config.customLimits = append(config.customLimits, customLimit{
96 105
 			maxConcurrent: customLimitConf.MaxConcurrent,
97 106
 			maxPerWindow:  customLimitConf.MaxPerWindow,
98
-			name:          "*" + identifier,
107
+			name:          md5.Sum([]byte(identifier)),
99 108
 			nets:          nets,
100 109
 		})
101 110
 	}
102 111
 
103
-	config.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32)
104
-	config.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128)
105
-
106 112
 	return nil
107 113
 }
108 114
 
@@ -113,50 +119,48 @@ type Limiter struct {
113 119
 	config *LimiterConfig
114 120
 
115 121
 	// IP/CIDR -> count of clients connected from there:
116
-	limiter map[string]int
122
+	limiter map[limiterKey]int
117 123
 	// IP/CIDR -> throttle state:
118
-	throttler map[string]ThrottleDetails
124
+	throttler map[limiterKey]ThrottleDetails
119 125
 }
120 126
 
121 127
 // addrToKey canonicalizes `addr` to a string key, and returns
122 128
 // the relevant connection limit and throttle max-per-window values
123
-func (cl *Limiter) addrToKey(addr net.IP) (key string, limit int, throttle int) {
124
-	// `key` will be a CIDR string like "8.8.8.8/32" or "2001:0db8::/32"
129
+func (cl *Limiter) addrToKey(flat flatip.IP) (key limiterKey, limit int, throttle int) {
125 130
 	for _, custom := range cl.config.customLimits {
126 131
 		for _, net := range custom.nets {
127
-			if net.Contains(addr) {
128
-				return custom.name, custom.maxConcurrent, custom.maxPerWindow
132
+			if net.Contains(flat) {
133
+				return limiterKey{maskedIP: custom.name, prefixLen: 0}, custom.maxConcurrent, custom.maxPerWindow
129 134
 			}
130 135
 		}
131 136
 	}
132 137
 
133
-	var ipNet net.IPNet
134
-	addrv4 := addr.To4()
135
-	if addrv4 != nil {
136
-		ipNet = net.IPNet{
137
-			IP:   addrv4.Mask(cl.config.ipv4Mask),
138
-			Mask: cl.config.ipv4Mask,
139
-		}
138
+	var prefixLen int
139
+	if flat.IsIPv4() {
140
+		prefixLen = cl.config.CidrLenIPv4
141
+		flat = flat.Mask(prefixLen, 32)
142
+		prefixLen += 96
140 143
 	} else {
141
-		ipNet = net.IPNet{
142
-			IP:   addr.Mask(cl.config.ipv6Mask),
143
-			Mask: cl.config.ipv6Mask,
144
-		}
144
+		prefixLen = cl.config.CidrLenIPv6
145
+		flat = flat.Mask(prefixLen, 128)
145 146
 	}
146
-	return ipNet.String(), cl.config.MaxConcurrent, cl.config.MaxPerWindow
147
+
148
+	return limiterKey{maskedIP: flat, prefixLen: uint8(prefixLen)}, cl.config.MaxConcurrent, cl.config.MaxPerWindow
147 149
 }
148 150
 
149 151
 // AddClient adds a client to our population if possible. If we can't, throws an error instead.
150 152
 func (cl *Limiter) AddClient(addr net.IP) error {
153
+	flat := flatip.FromNetIP(addr)
154
+
151 155
 	cl.Lock()
152 156
 	defer cl.Unlock()
153 157
 
154 158
 	// we don't track populations for exempted addresses or nets - this is by design
155
-	if utils.IPInNets(addr, cl.config.exemptedNets) {
159
+	if flatip.IPInNets(flat, cl.config.exemptedNets) {
156 160
 		return nil
157 161
 	}
158 162
 
159
-	addrString, maxConcurrent, maxPerWindow := cl.addrToKey(addr)
163
+	addrString, maxConcurrent, maxPerWindow := cl.addrToKey(flat)
160 164
 
161 165
 	// XXX check throttle first; if we checked limit first and then checked throttle,
162 166
 	// we'd have to decrement the limit on an unsuccessful throttle check
@@ -189,14 +193,16 @@ func (cl *Limiter) AddClient(addr net.IP) error {
189 193
 
190 194
 // RemoveClient removes the given address from our population
191 195
 func (cl *Limiter) RemoveClient(addr net.IP) {
196
+	flat := flatip.FromNetIP(addr)
197
+
192 198
 	cl.Lock()
193 199
 	defer cl.Unlock()
194 200
 
195
-	if !cl.config.Count || utils.IPInNets(addr, cl.config.exemptedNets) {
201
+	if !cl.config.Count || flatip.IPInNets(flat, cl.config.exemptedNets) {
196 202
 		return
197 203
 	}
198 204
 
199
-	addrString, _, _ := cl.addrToKey(addr)
205
+	addrString, _, _ := cl.addrToKey(flat)
200 206
 	count := cl.limiter[addrString]
201 207
 	count -= 1
202 208
 	if count < 0 {
@@ -207,14 +213,16 @@ func (cl *Limiter) RemoveClient(addr net.IP) {
207 213
 
208 214
 // ResetThrottle resets the throttle count for an IP
209 215
 func (cl *Limiter) ResetThrottle(addr net.IP) {
216
+	flat := flatip.FromNetIP(addr)
217
+
210 218
 	cl.Lock()
211 219
 	defer cl.Unlock()
212 220
 
213
-	if !cl.config.Throttle || utils.IPInNets(addr, cl.config.exemptedNets) {
221
+	if !cl.config.Throttle || flatip.IPInNets(flat, cl.config.exemptedNets) {
214 222
 		return
215 223
 	}
216 224
 
217
-	addrString, _, _ := cl.addrToKey(addr)
225
+	addrString, _, _ := cl.addrToKey(flat)
218 226
 	delete(cl.throttler, addrString)
219 227
 }
220 228
 
@@ -224,10 +232,10 @@ func (cl *Limiter) ApplyConfig(config *LimiterConfig) {
224 232
 	defer cl.Unlock()
225 233
 
226 234
 	if cl.limiter == nil {
227
-		cl.limiter = make(map[string]int)
235
+		cl.limiter = make(map[limiterKey]int)
228 236
 	}
229 237
 	if cl.throttler == nil {
230
-		cl.throttler = make(map[string]ThrottleDetails)
238
+		cl.throttler = make(map[limiterKey]ThrottleDetails)
231 239
 	}
232 240
 
233 241
 	cl.config = config

+ 19
- 6
irc/connection_limits/limiter_test.go 파일 보기

@@ -4,9 +4,12 @@
4 4
 package connection_limits
5 5
 
6 6
 import (
7
+	"crypto/md5"
7 8
 	"net"
8 9
 	"testing"
9 10
 	"time"
11
+
12
+	"github.com/oragono/oragono/irc/flatip"
10 13
 )
11 14
 
12 15
 func easyParseIP(ipstr string) (result net.IP) {
@@ -17,6 +20,11 @@ func easyParseIP(ipstr string) (result net.IP) {
17 20
 	return
18 21
 }
19 22
 
23
+func easyParseFlat(ipstr string) (result flatip.IP) {
24
+	r1 := easyParseIP(ipstr)
25
+	return flatip.FromNetIP(r1)
26
+}
27
+
20 28
 var baseConfig = LimiterConfig{
21 29
 	rawLimiterConfig: rawLimiterConfig{
22 30
 		Count:         true,
@@ -47,18 +55,23 @@ func TestKeying(t *testing.T) {
47 55
 	var limiter Limiter
48 56
 	limiter.ApplyConfig(&config)
49 57
 
50
-	key, maxConc, maxWin := limiter.addrToKey(easyParseIP("1.1.1.1"))
51
-	assertEqual(key, "1.1.1.1/32", t)
58
+	// an ipv4 /32 looks like a /128 to us after applying the 4-in-6 mapping
59
+	key, maxConc, maxWin := limiter.addrToKey(easyParseFlat("1.1.1.1"))
60
+	assertEqual(key.prefixLen, uint8(128), t)
61
+	assertEqual(key.maskedIP[12:], []byte{1, 1, 1, 1}, t)
52 62
 	assertEqual(maxConc, 4, t)
53 63
 	assertEqual(maxWin, 8, t)
54 64
 
55
-	key, maxConc, maxWin = limiter.addrToKey(easyParseIP("2607:5301:201:3100::7426"))
56
-	assertEqual(key, "2607:5301:201:3100::/64", t)
65
+	testIPv6 := easyParseFlat("2607:5301:201:3100::7426")
66
+	key, maxConc, maxWin = limiter.addrToKey(testIPv6)
67
+	assertEqual(key.prefixLen, uint8(64), t)
68
+	assertEqual(key.maskedIP[:], []byte(easyParseIP("2607:5301:201:3100::")), t)
57 69
 	assertEqual(maxConc, 4, t)
58 70
 	assertEqual(maxWin, 8, t)
59 71
 
60
-	key, maxConc, maxWin = limiter.addrToKey(easyParseIP("8.8.4.4"))
61
-	assertEqual(key, "*google", t)
72
+	key, maxConc, maxWin = limiter.addrToKey(easyParseFlat("8.8.4.4"))
73
+	assertEqual(key.prefixLen, uint8(0), t)
74
+	assertEqual([16]byte(key.maskedIP), md5.Sum([]byte("google")), t)
62 75
 	assertEqual(maxConc, 128, t)
63 76
 	assertEqual(maxWin, 256, t)
64 77
 }

+ 29
- 45
irc/dline.go 파일 보기

@@ -11,6 +11,7 @@ import (
11 11
 	"sync"
12 12
 	"time"
13 13
 
14
+	"github.com/oragono/oragono/irc/flatip"
14 15
 	"github.com/oragono/oragono/irc/utils"
15 16
 	"github.com/tidwall/buntdb"
16 17
 )
@@ -54,34 +55,22 @@ func (info IPBanInfo) BanMessage(message string) string {
54 55
 	return message
55 56
 }
56 57
 
57
-// dLineNet contains the net itself and expiration time for a given network.
58
-type dLineNet struct {
59
-	// Network is the network that is blocked.
60
-	// This is always an IPv6 CIDR; IPv4 CIDRs are translated with the 4-in-6 prefix,
61
-	// individual IPv4 and IPV6 addresses are translated to the relevant /128.
62
-	Network net.IPNet
63
-	// Info contains information on the ban.
64
-	Info IPBanInfo
65
-}
66
-
67 58
 // DLineManager manages and dlines.
68 59
 type DLineManager struct {
69 60
 	sync.RWMutex                // tier 1
70 61
 	persistenceMutex sync.Mutex // tier 2
71 62
 	// networks that are dlined:
72
-	// XXX: the keys of this map (which are also the database persistence keys)
73
-	// are the human-readable representations returned by NetToNormalizedString
74
-	networks map[string]dLineNet
63
+	networks map[flatip.IPNet]IPBanInfo
75 64
 	// this keeps track of expiration timers for temporary bans
76
-	expirationTimers map[string]*time.Timer
65
+	expirationTimers map[flatip.IPNet]*time.Timer
77 66
 	server           *Server
78 67
 }
79 68
 
80 69
 // NewDLineManager returns a new DLineManager.
81 70
 func NewDLineManager(server *Server) *DLineManager {
82 71
 	var dm DLineManager
83
-	dm.networks = make(map[string]dLineNet)
84
-	dm.expirationTimers = make(map[string]*time.Timer)
72
+	dm.networks = make(map[flatip.IPNet]IPBanInfo)
73
+	dm.expirationTimers = make(map[flatip.IPNet]*time.Timer)
85 74
 	dm.server = server
86 75
 
87 76
 	dm.loadFromDatastore()
@@ -96,9 +85,8 @@ func (dm *DLineManager) AllBans() map[string]IPBanInfo {
96 85
 	dm.RLock()
97 86
 	defer dm.RUnlock()
98 87
 
99
-	// map keys are already the human-readable forms, just return a copy of the map
100 88
 	for key, info := range dm.networks {
101
-		allb[key] = info.Info
89
+		allb[key.String()] = info
102 90
 	}
103 91
 
104 92
 	return allb
@@ -122,9 +110,9 @@ func (dm *DLineManager) AddNetwork(network net.IPNet, duration time.Duration, re
122 110
 	return dm.persistDline(id, info)
123 111
 }
124 112
 
125
-func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id string) {
126
-	network = utils.NormalizeNet(network)
127
-	id = utils.NetToNormalizedString(network)
113
+func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id flatip.IPNet) {
114
+	flatnet := flatip.FromNetIPNet(network)
115
+	id = flatnet
128 116
 
129 117
 	var timeLeft time.Duration
130 118
 	if info.Duration != 0 {
@@ -137,12 +125,9 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i
137 125
 	dm.Lock()
138 126
 	defer dm.Unlock()
139 127
 
140
-	dm.networks[id] = dLineNet{
141
-		Network: network,
142
-		Info:    info,
143
-	}
128
+	dm.networks[flatnet] = info
144 129
 
145
-	dm.cancelTimer(id)
130
+	dm.cancelTimer(flatnet)
146 131
 
147 132
 	if info.Duration == 0 {
148 133
 		return
@@ -154,29 +139,29 @@ func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (i
154 139
 		dm.Lock()
155 140
 		defer dm.Unlock()
156 141
 
157
-		netBan, ok := dm.networks[id]
158
-		if ok && netBan.Info.TimeCreated.Equal(timeCreated) {
159
-			delete(dm.networks, id)
142
+		banInfo, ok := dm.networks[flatnet]
143
+		if ok && banInfo.TimeCreated.Equal(timeCreated) {
144
+			delete(dm.networks, flatnet)
160 145
 			// TODO(slingamn) here's where we'd remove it from the radix tree
161
-			delete(dm.expirationTimers, id)
146
+			delete(dm.expirationTimers, flatnet)
162 147
 		}
163 148
 	}
164
-	dm.expirationTimers[id] = time.AfterFunc(timeLeft, processExpiration)
149
+	dm.expirationTimers[flatnet] = time.AfterFunc(timeLeft, processExpiration)
165 150
 
166 151
 	return
167 152
 }
168 153
 
169
-func (dm *DLineManager) cancelTimer(id string) {
170
-	oldTimer := dm.expirationTimers[id]
154
+func (dm *DLineManager) cancelTimer(flatnet flatip.IPNet) {
155
+	oldTimer := dm.expirationTimers[flatnet]
171 156
 	if oldTimer != nil {
172 157
 		oldTimer.Stop()
173
-		delete(dm.expirationTimers, id)
158
+		delete(dm.expirationTimers, flatnet)
174 159
 	}
175 160
 }
176 161
 
177
-func (dm *DLineManager) persistDline(id string, info IPBanInfo) error {
162
+func (dm *DLineManager) persistDline(id flatip.IPNet, info IPBanInfo) error {
178 163
 	// save in datastore
179
-	dlineKey := fmt.Sprintf(keyDlineEntry, id)
164
+	dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
180 165
 	// assemble json from ban info
181 166
 	b, err := json.Marshal(info)
182 167
 	if err != nil {
@@ -199,8 +184,8 @@ func (dm *DLineManager) persistDline(id string, info IPBanInfo) error {
199 184
 	return err
200 185
 }
201 186
 
202
-func (dm *DLineManager) unpersistDline(id string) error {
203
-	dlineKey := fmt.Sprintf(keyDlineEntry, id)
187
+func (dm *DLineManager) unpersistDline(id flatip.IPNet) error {
188
+	dlineKey := fmt.Sprintf(keyDlineEntry, id.String())
204 189
 	return dm.server.store.Update(func(tx *buntdb.Tx) error {
205 190
 		_, err := tx.Delete(dlineKey)
206 191
 		return err
@@ -212,7 +197,7 @@ func (dm *DLineManager) RemoveNetwork(network net.IPNet) error {
212 197
 	dm.persistenceMutex.Lock()
213 198
 	defer dm.persistenceMutex.Unlock()
214 199
 
215
-	id := utils.NetToNormalizedString(utils.NormalizeNet(network))
200
+	id := flatip.FromNetIPNet(network)
216 201
 
217 202
 	present := func() bool {
218 203
 		dm.Lock()
@@ -241,8 +226,8 @@ func (dm *DLineManager) RemoveIP(addr net.IP) error {
241 226
 }
242 227
 
243 228
 // CheckIP returns whether or not an IP address was banned, and how long it is banned for.
244
-func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) {
245
-	addr = addr.To16() // almost certainly unnecessary
229
+func (dm *DLineManager) CheckIP(netAddr net.IP) (isBanned bool, info IPBanInfo) {
230
+	addr := flatip.FromNetIP(netAddr)
246 231
 	if addr.IsLoopback() {
247 232
 		return // #671
248 233
 	}
@@ -252,13 +237,12 @@ func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) {
252 237
 
253 238
 	// check networks
254 239
 	// TODO(slingamn) use a radix tree as the data plane for this
255
-	for _, netBan := range dm.networks {
256
-		if netBan.Network.Contains(addr) {
257
-			return true, netBan.Info
240
+	for flatnet, info := range dm.networks {
241
+		if flatnet.Contains(addr) {
242
+			return true, info
258 243
 		}
259 244
 	}
260 245
 	// no matches!
261
-	isBanned = false
262 246
 	return
263 247
 }
264 248
 

+ 217
- 0
irc/flatip/flatip.go 파일 보기

@@ -0,0 +1,217 @@
1
+// Copyright 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
2
+// Copyright 2009 The Go Authors
3
+
4
+package flatip
5
+
6
+import (
7
+	"bytes"
8
+	"errors"
9
+	"net"
10
+)
11
+
12
+var (
13
+	v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
14
+
15
+	IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
16
+
17
+	ErrInvalidIPString = errors.New("String could not be interpreted as an IP address")
18
+)
19
+
20
+// packed versions of net.IP and net.IPNet; these are pure value types,
21
+// so they can be compared with == and used as map keys.
22
+
23
+// IP is the 128-bit representation of the IPv6 address, using the 4-in-6 mapping
24
+// if necessary:
25
+type IP [16]byte
26
+
27
+// IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes.
28
+type IPNet struct {
29
+	IP
30
+	PrefixLen uint8
31
+}
32
+
33
+// NetIP converts an IP into a net.IP.
34
+func (ip IP) NetIP() (result net.IP) {
35
+	result = make(net.IP, 16)
36
+	copy(result[:], ip[:])
37
+	return
38
+}
39
+
40
+// FromNetIP converts a net.IP into an IP.
41
+func FromNetIP(ip net.IP) (result IP) {
42
+	if len(ip) == 16 {
43
+		copy(result[:], ip[:])
44
+	} else {
45
+		result[10] = 0xff
46
+		result[11] = 0xff
47
+		copy(result[12:], ip[:])
48
+	}
49
+	return
50
+}
51
+
52
+// IPv4 returns the IP address representation of a.b.c.d
53
+func IPv4(a, b, c, d byte) (result IP) {
54
+	copy(result[:12], v4InV6Prefix)
55
+	result[12] = a
56
+	result[13] = b
57
+	result[14] = c
58
+	result[15] = d
59
+	return
60
+}
61
+
62
+// ParseIP parses a string representation of an IP address into an IP.
63
+// Unlike net.ParseIP, it returns an error instead of a zero value on failure,
64
+// since the zero value of `IP` is a representation of a valid IP (::0, the
65
+// IPv6 "unspecified address").
66
+func ParseIP(ipstr string) (ip IP, err error) {
67
+	// TODO reimplement this without net.ParseIP
68
+	netip := net.ParseIP(ipstr)
69
+	if netip == nil {
70
+		err = ErrInvalidIPString
71
+		return
72
+	}
73
+	netip = netip.To16()
74
+	copy(ip[:], netip)
75
+	return
76
+}
77
+
78
+// String returns the string representation of an IP
79
+func (ip IP) String() string {
80
+	// TODO reimplement this without using (net.IP).String()
81
+	return (net.IP)(ip[:]).String()
82
+}
83
+
84
+// IsIPv4 returns whether the IP is an IPv4 address.
85
+func (ip IP) IsIPv4() bool {
86
+	return bytes.Equal(ip[:12], v4InV6Prefix)
87
+}
88
+
89
+// IsLoopback returns whether the IP is a loopback address.
90
+func (ip IP) IsLoopback() bool {
91
+	if ip.IsIPv4() {
92
+		return ip[12] == 127
93
+	} else {
94
+		return ip == IPv6loopback
95
+	}
96
+}
97
+
98
+func rawCidrMask(length int) (m IP) {
99
+	n := uint(length)
100
+	for i := 0; i < 16; i++ {
101
+		if n >= 8 {
102
+			m[i] = 0xff
103
+			n -= 8
104
+			continue
105
+		}
106
+		m[i] = ^byte(0xff >> n)
107
+		return
108
+	}
109
+	return
110
+}
111
+
112
+func (ip IP) applyMask(mask IP) (result IP) {
113
+	for i := 0; i < 16; i += 1 {
114
+		result[i] = ip[i] & mask[i]
115
+	}
116
+	return
117
+}
118
+
119
+func cidrMask(ones, bits int) (result IP) {
120
+	switch bits {
121
+	case 32:
122
+		return rawCidrMask(96 + ones)
123
+	case 128:
124
+		return rawCidrMask(ones)
125
+	default:
126
+		return
127
+	}
128
+}
129
+
130
+// Mask returns the result of masking ip with the CIDR mask of
131
+// length 'ones', out of a total of 'bits' (which must be either
132
+// 32 for an IPv4 subnet or 128 for an IPv6 subnet).
133
+func (ip IP) Mask(ones, bits int) (result IP) {
134
+	return ip.applyMask(cidrMask(ones, bits))
135
+}
136
+
137
+// ToNetIPNet converts an IPNet into a net.IPNet.
138
+func (cidr IPNet) ToNetIPNet() (result net.IPNet) {
139
+	return net.IPNet{
140
+		IP:   cidr.IP.NetIP(),
141
+		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
142
+	}
143
+}
144
+
145
+// Contains retuns whether the network contains `ip`.
146
+func (cidr IPNet) Contains(ip IP) bool {
147
+	maskedIP := ip.Mask(int(cidr.PrefixLen), 128)
148
+	return cidr.IP == maskedIP
149
+}
150
+
151
+// FromNetIPnet converts a net.IPNet into an IPNet.
152
+func FromNetIPNet(network net.IPNet) (result IPNet) {
153
+	ones, _ := network.Mask.Size()
154
+	if len(network.IP) == 16 {
155
+		copy(result.IP[:], network.IP[:])
156
+	} else {
157
+		result.IP[10] = 0xff
158
+		result.IP[11] = 0xff
159
+		copy(result.IP[12:], network.IP[:])
160
+		ones += 96
161
+	}
162
+	// perform masking so that equal CIDRs are ==
163
+	result.IP = result.IP.Mask(ones, 128)
164
+	result.PrefixLen = uint8(ones)
165
+	return
166
+}
167
+
168
+// String returns a string representation of an IPNet.
169
+func (cidr IPNet) String() string {
170
+	ip := make(net.IP, 16)
171
+	copy(ip[:], cidr.IP[:])
172
+	ipnet := net.IPNet{
173
+		IP:   ip,
174
+		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
175
+	}
176
+	return ipnet.String()
177
+}
178
+
179
+// ParseCIDR parses a string representation of an IP network in CIDR notation,
180
+// then returns it as an IPNet (along with the original, unmasked address).
181
+func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) {
182
+	// TODO reimplement this without net.ParseCIDR
183
+	nip, nipnet, err := net.ParseCIDR(netstr)
184
+	if err != nil {
185
+		return
186
+	}
187
+	return FromNetIP(nip), FromNetIPNet(*nipnet), nil
188
+}
189
+
190
+// begin ad-hoc utilities
191
+
192
+// ParseToNormalizedNet attempts to interpret a string either as an IP
193
+// network in CIDR notation, returning an IPNet, or as an IP address,
194
+// returning an IPNet that contains only that address.
195
+func ParseToNormalizedNet(netstr string) (ipnet IPNet, err error) {
196
+	_, ipnet, err = ParseCIDR(netstr)
197
+	if err == nil {
198
+		return
199
+	}
200
+	ip, err := ParseIP(netstr)
201
+	if err == nil {
202
+		ipnet.IP = ip
203
+		ipnet.PrefixLen = 128
204
+	}
205
+	return
206
+}
207
+
208
+// IPInNets is a convenience function for testing whether an IP is contained
209
+// in any member of a slice of IPNet's.
210
+func IPInNets(addr IP, nets []IPNet) bool {
211
+	for _, net := range nets {
212
+		if net.Contains(addr) {
213
+			return true
214
+		}
215
+	}
216
+	return false
217
+}

+ 174
- 0
irc/flatip/flatip_test.go 파일 보기

@@ -0,0 +1,174 @@
1
+package flatip
2
+
3
+import (
4
+	"bytes"
5
+	"math/rand"
6
+	"net"
7
+	"testing"
8
+	"time"
9
+)
10
+
11
+func easyParseIP(ipstr string) (result net.IP) {
12
+	result = net.ParseIP(ipstr)
13
+	if result == nil {
14
+		panic(ipstr)
15
+	}
16
+	return
17
+}
18
+
19
+func easyParseFlat(ipstr string) (result IP) {
20
+	x := easyParseIP(ipstr)
21
+	return FromNetIP(x)
22
+}
23
+
24
+func easyParseIPNet(nipstr string) (result net.IPNet) {
25
+	_, nip, err := net.ParseCIDR(nipstr)
26
+	if err != nil {
27
+		panic(err)
28
+	}
29
+	return *nip
30
+}
31
+
32
+func TestBasic(t *testing.T) {
33
+	nip := easyParseIP("8.8.8.8")
34
+	flatip := FromNetIP(nip)
35
+	if flatip.String() != "8.8.8.8" {
36
+		t.Errorf("conversions don't work")
37
+	}
38
+}
39
+
40
+func TestLoopback(t *testing.T) {
41
+	localhost_v4 := easyParseFlat("127.0.0.1")
42
+	localhost_v4_again := easyParseFlat("127.2.3.4")
43
+	google := easyParseFlat("8.8.8.8")
44
+	loopback_v6 := easyParseFlat("::1")
45
+	google_v6 := easyParseFlat("2607:f8b0:4006:801::2004")
46
+
47
+	if !(localhost_v4.IsLoopback() && localhost_v4_again.IsLoopback() && loopback_v6.IsLoopback()) {
48
+		t.Errorf("can't detect loopbacks")
49
+	}
50
+
51
+	if google_v6.IsLoopback() || google.IsLoopback() {
52
+		t.Errorf("incorrectly detected loopbacks")
53
+	}
54
+}
55
+
56
+func TestContains(t *testing.T) {
57
+	nipnet := easyParseIPNet("8.8.0.0/16")
58
+	flatipnet := FromNetIPNet(nipnet)
59
+	nip := easyParseIP("8.8.8.8")
60
+	flatip_ := FromNetIP(nip)
61
+	if !flatipnet.Contains(flatip_) {
62
+		t.Errorf("contains doesn't work")
63
+	}
64
+}
65
+
66
+var testIPStrs = []string{
67
+	"8.8.8.8",
68
+	"127.0.0.1",
69
+	"1.1.1.1",
70
+	"128.127.65.64",
71
+	"2001:0db8::1",
72
+	"::1",
73
+	"255.255.255.255",
74
+}
75
+
76
+func doMaskingTest(ip net.IP, t *testing.T) {
77
+	flat := FromNetIP(ip)
78
+	netLen := len(ip) * 8
79
+	for i := 0; i < netLen; i++ {
80
+		masked := flat.Mask(i, netLen)
81
+		netMask := net.CIDRMask(i, netLen)
82
+		netMasked := ip.Mask(netMask)
83
+		if !bytes.Equal(masked[:], netMasked.To16()) {
84
+			t.Errorf("Masking %s with %d/%d; expected %s, got %s", ip.String(), i, netLen, netMasked.String(), masked.String())
85
+		}
86
+	}
87
+}
88
+
89
+func TestMasking(t *testing.T) {
90
+	for _, ipstr := range testIPStrs {
91
+		doMaskingTest(easyParseIP(ipstr), t)
92
+	}
93
+}
94
+
95
+func TestMaskingFuzz(t *testing.T) {
96
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
97
+	buf := make([]byte, 4)
98
+	for i := 0; i < 10000; i++ {
99
+		r.Read(buf)
100
+		doMaskingTest(net.IP(buf), t)
101
+	}
102
+
103
+	buf = make([]byte, 16)
104
+	for i := 0; i < 10000; i++ {
105
+		r.Read(buf)
106
+		doMaskingTest(net.IP(buf), t)
107
+	}
108
+}
109
+
110
+func BenchmarkMasking(b *testing.B) {
111
+	ip := easyParseIP("2001:0db8::42")
112
+	flat := FromNetIP(ip)
113
+	b.ResetTimer()
114
+
115
+	for i := 0; i < b.N; i++ {
116
+		flat.Mask(64, 128)
117
+	}
118
+}
119
+
120
+func BenchmarkMaskingLegacy(b *testing.B) {
121
+	ip := easyParseIP("2001:0db8::42")
122
+	mask := net.CIDRMask(64, 128)
123
+	b.ResetTimer()
124
+
125
+	for i := 0; i < b.N; i++ {
126
+		ip.Mask(mask)
127
+	}
128
+}
129
+
130
+func BenchmarkMaskingCached(b *testing.B) {
131
+	i := easyParseIP("2001:0db8::42")
132
+	flat := FromNetIP(i)
133
+	mask := cidrMask(64, 128)
134
+	b.ResetTimer()
135
+
136
+	for i := 0; i < b.N; i++ {
137
+		flat.applyMask(mask)
138
+	}
139
+}
140
+
141
+func BenchmarkMaskingConstruct(b *testing.B) {
142
+	for i := 0; i < b.N; i++ {
143
+		cidrMask(69, 128)
144
+	}
145
+}
146
+
147
+func BenchmarkContains(b *testing.B) {
148
+	ip := easyParseIP("2001:0db8::42")
149
+	flat := FromNetIP(ip)
150
+	_, ipnet, err := net.ParseCIDR("2001:0db8::/64")
151
+	if err != nil {
152
+		panic(err)
153
+	}
154
+	flatnet := FromNetIPNet(*ipnet)
155
+	b.ResetTimer()
156
+
157
+	for i := 0; i < b.N; i++ {
158
+		flatnet.Contains(flat)
159
+	}
160
+}
161
+
162
+func BenchmarkContainsLegacy(b *testing.B) {
163
+	ip := easyParseIP("2001:0db8::42")
164
+	_, ipnetptr, err := net.ParseCIDR("2001:0db8::/64")
165
+	if err != nil {
166
+		panic(err)
167
+	}
168
+	ipnet := *ipnetptr
169
+	b.ResetTimer()
170
+
171
+	for i := 0; i < b.N; i++ {
172
+		ipnet.Contains(ip)
173
+	}
174
+}

Loading…
취소
저장