123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- // Copyright (c) 2018 Shivaram Lingamneni
- // released under the MIT license
-
- package connection_limits
-
- import (
- "net"
- "reflect"
- "testing"
- "time"
- )
-
- func assertEqual(supplied, expected interface{}, t *testing.T) {
- if !reflect.DeepEqual(supplied, expected) {
- t.Errorf("expected %v but got %v", expected, supplied)
- }
- }
-
- func TestGenericThrottle(t *testing.T) {
- minute, _ := time.ParseDuration("1m")
- second, _ := time.ParseDuration("1s")
- zero, _ := time.ParseDuration("0s")
-
- throttler := GenericThrottle{
- Duration: minute,
- Limit: 2,
- }
-
- now := time.Now()
- throttled, remaining := throttler.touch(now)
- assertEqual(throttled, false, t)
- assertEqual(remaining, zero, t)
-
- now = now.Add(second)
- throttled, remaining = throttler.touch(now)
- assertEqual(throttled, false, t)
- assertEqual(remaining, zero, t)
-
- now = now.Add(second)
- throttled, remaining = throttler.touch(now)
- assertEqual(throttled, true, t)
- assertEqual(remaining, 58*second, t)
-
- now = now.Add(minute)
- throttled, remaining = throttler.touch(now)
- assertEqual(throttled, false, t)
- assertEqual(remaining, zero, t)
- }
-
- func TestGenericThrottleDisabled(t *testing.T) {
- minute, _ := time.ParseDuration("1m")
- throttler := GenericThrottle{
- Duration: minute,
- Limit: 0,
- }
-
- for i := 0; i < 1024; i += 1 {
- throttled, _ := throttler.Touch()
- if throttled {
- t.Error("disabled throttler should not throttle")
- }
- }
- }
-
- func makeTestThrottler(v4len, v6len int) *Throttler {
- minute, _ := time.ParseDuration("1m")
- maxConnections := 3
- config := ThrottlerConfig{
- Enabled: true,
- CidrLenIPv4: v4len,
- CidrLenIPv6: v6len,
- ConnectionsPerCidr: maxConnections,
- Duration: minute,
- }
- var throttler Throttler
- throttler.ApplyConfig(config)
- return &throttler
- }
-
- func TestConnectionThrottle(t *testing.T) {
- throttler := makeTestThrottler(32, 64)
- addr := net.ParseIP("8.8.8.8")
-
- for i := 0; i < 3; i += 1 {
- err := throttler.AddClient(addr)
- assertEqual(err, nil, t)
- }
- err := throttler.AddClient(addr)
- assertEqual(err, errTooManyClients, t)
- }
-
- func TestConnectionThrottleIPv6(t *testing.T) {
- throttler := makeTestThrottler(32, 64)
-
- var err error
- err = throttler.AddClient(net.ParseIP("2001:0db8::1"))
- assertEqual(err, nil, t)
- err = throttler.AddClient(net.ParseIP("2001:0db8::2"))
- assertEqual(err, nil, t)
- err = throttler.AddClient(net.ParseIP("2001:0db8::3"))
- assertEqual(err, nil, t)
-
- err = throttler.AddClient(net.ParseIP("2001:0db8::4"))
- assertEqual(err, errTooManyClients, t)
- }
-
- func TestConnectionThrottleIPv4(t *testing.T) {
- throttler := makeTestThrottler(24, 64)
-
- var err error
- err = throttler.AddClient(net.ParseIP("192.168.1.101"))
- assertEqual(err, nil, t)
- err = throttler.AddClient(net.ParseIP("192.168.1.102"))
- assertEqual(err, nil, t)
- err = throttler.AddClient(net.ParseIP("192.168.1.103"))
- assertEqual(err, nil, t)
-
- err = throttler.AddClient(net.ParseIP("192.168.1.104"))
- assertEqual(err, errTooManyClients, t)
- }
|