// Copyright (c) 2016-2017 Daniel Oaks // released under the MIT license package connection_limits import ( "fmt" "net" "sync" "time" ) // ThrottlerConfig controls the automated connection throttling. type ThrottlerConfig struct { Enabled bool CidrLenIPv4 int `yaml:"cidr-len-ipv4"` CidrLenIPv6 int `yaml:"cidr-len-ipv6"` ConnectionsPerCidr int `yaml:"max-connections"` DurationString string `yaml:"duration"` Duration time.Duration `yaml:"duration-time"` BanDurationString string `yaml:"ban-duration"` BanDuration time.Duration BanMessage string `yaml:"ban-message"` Exempted []string } // ThrottleDetails holds the connection-throttling details for a subnet/IP. type ThrottleDetails struct { Start time.Time ClientCount int } // Throttler manages automated client connection throttling. type Throttler struct { sync.RWMutex enabled bool ipv4Mask net.IPMask ipv6Mask net.IPMask subnetLimit int duration time.Duration population map[string]ThrottleDetails // used by the server to ban clients that go over this limit banDuration time.Duration banMessage string // exemptedIPs holds IPs that are exempt from limits exemptedIPs map[string]bool // exemptedNets holds networks that are exempt from limits exemptedNets []net.IPNet } // maskAddr masks the given IPv4/6 address with our cidr limit masks. func (ct *Throttler) maskAddr(addr net.IP) net.IP { if addr.To4() == nil { // IPv6 addr addr = addr.Mask(ct.ipv6Mask) } else { // IPv4 addr addr = addr.Mask(ct.ipv4Mask) } return addr } // ResetFor removes any existing count for the given address. func (ct *Throttler) ResetFor(addr net.IP) { ct.Lock() defer ct.Unlock() if !ct.enabled { return } // remove ct.maskAddr(addr) addrString := addr.String() delete(ct.population, addrString) } // AddClient introduces a new client connection if possible. If we can't, throws an error instead. func (ct *Throttler) AddClient(addr net.IP) error { ct.Lock() defer ct.Unlock() if !ct.enabled { return nil } // check exempted lists if ct.exemptedIPs[addr.String()] { return nil } for _, ex := range ct.exemptedNets { if ex.Contains(addr) { return nil } } // check throttle ct.maskAddr(addr) addrString := addr.String() details, exists := ct.population[addrString] if !exists || details.Start.Add(ct.duration).Before(time.Now()) { details = ThrottleDetails{ Start: time.Now(), } } if details.ClientCount+1 > ct.subnetLimit { return errTooManyClients } details.ClientCount++ ct.population[addrString] = details return nil } func (ct *Throttler) BanDuration() time.Duration { ct.RLock() defer ct.RUnlock() return ct.banDuration } func (ct *Throttler) BanMessage() string { ct.RLock() defer ct.RUnlock() return ct.banMessage } // NewThrottler returns a new client connection throttler. // The throttler is functional, but disabled; it can be enabled via `ApplyConfig`. func NewThrottler() *Throttler { var ct Throttler // initialize empty population; all other state is configurable ct.population = make(map[string]ThrottleDetails) return &ct } // ApplyConfig atomically applies a config update to a throttler func (ct *Throttler) ApplyConfig(config ThrottlerConfig) error { // assemble exempted nets exemptedIPs := make(map[string]bool) var exemptedNets []net.IPNet for _, cidr := range config.Exempted { ipaddr := net.ParseIP(cidr) _, netaddr, err := net.ParseCIDR(cidr) if ipaddr == nil && err != nil { return fmt.Errorf("Could not parse exempted IP/network [%s]", cidr) } if ipaddr != nil { exemptedIPs[ipaddr.String()] = true } else { exemptedNets = append(exemptedNets, *netaddr) } } ct.Lock() defer ct.Unlock() ct.enabled = config.Enabled ct.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32) ct.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128) ct.subnetLimit = config.ConnectionsPerCidr ct.duration = config.Duration ct.banDuration = config.BanDuration ct.banMessage = config.BanMessage ct.exemptedIPs = exemptedIPs ct.exemptedNets = exemptedNets return nil }