You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

throttler.go 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. // Copyright (c) 2016-2017 Daniel Oaks <daniel@danieloaks.net>
  2. // released under the MIT license
  3. package connection_limits
  4. import (
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. )
  10. // ThrottlerConfig controls the automated connection throttling.
  11. type ThrottlerConfig struct {
  12. Enabled bool
  13. CidrLenIPv4 int `yaml:"cidr-len-ipv4"`
  14. CidrLenIPv6 int `yaml:"cidr-len-ipv6"`
  15. ConnectionsPerCidr int `yaml:"max-connections"`
  16. DurationString string `yaml:"duration"`
  17. Duration time.Duration `yaml:"duration-time"`
  18. BanDurationString string `yaml:"ban-duration"`
  19. BanDuration time.Duration
  20. BanMessage string `yaml:"ban-message"`
  21. Exempted []string
  22. }
  23. // ThrottleDetails holds the connection-throttling details for a subnet/IP.
  24. type ThrottleDetails struct {
  25. Start time.Time
  26. Count int
  27. }
  28. // GenericThrottle allows enforcing limits of the form
  29. // "at most X events per time window of duration Y"
  30. type GenericThrottle struct {
  31. ThrottleDetails // variable state: what events have been seen
  32. // these are constant after creation:
  33. Duration time.Duration // window length to consider
  34. Limit int // number of events allowed per window
  35. }
  36. // Touch checks whether an additional event is allowed:
  37. // it either denies it (by returning false) or allows it (by returning true)
  38. // and records it
  39. func (g *GenericThrottle) Touch() (throttled bool, remainingTime time.Duration) {
  40. return g.touch(time.Now())
  41. }
  42. func (g *GenericThrottle) touch(now time.Time) (throttled bool, remainingTime time.Duration) {
  43. if g.Limit == 0 {
  44. return // limit of 0 disables throttling
  45. }
  46. elapsed := now.Sub(g.Start)
  47. if elapsed > g.Duration {
  48. // reset window, record the operation
  49. g.Start = now
  50. g.Count = 1
  51. return false, 0
  52. } else if g.Count >= g.Limit {
  53. // we are throttled
  54. return true, g.Start.Add(g.Duration).Sub(now)
  55. } else {
  56. // we are not throttled, record the operation
  57. g.Count += 1
  58. return false, 0
  59. }
  60. }
  61. // Throttler manages automated client connection throttling.
  62. type Throttler struct {
  63. sync.RWMutex
  64. enabled bool
  65. ipv4Mask net.IPMask
  66. ipv6Mask net.IPMask
  67. subnetLimit int
  68. duration time.Duration
  69. population map[string]ThrottleDetails
  70. // used by the server to ban clients that go over this limit
  71. banDuration time.Duration
  72. banMessage string
  73. // exemptedIPs holds IPs that are exempt from limits
  74. exemptedIPs map[string]bool
  75. // exemptedNets holds networks that are exempt from limits
  76. exemptedNets []net.IPNet
  77. }
  78. // maskAddr masks the given IPv4/6 address with our cidr limit masks.
  79. func (ct *Throttler) maskAddr(addr net.IP) net.IP {
  80. if addr.To4() == nil {
  81. // IPv6 addr
  82. addr = addr.Mask(ct.ipv6Mask)
  83. } else {
  84. // IPv4 addr
  85. addr = addr.Mask(ct.ipv4Mask)
  86. }
  87. return addr
  88. }
  89. // ResetFor removes any existing count for the given address.
  90. func (ct *Throttler) ResetFor(addr net.IP) {
  91. ct.Lock()
  92. defer ct.Unlock()
  93. if !ct.enabled {
  94. return
  95. }
  96. // remove
  97. ct.maskAddr(addr)
  98. addrString := addr.String()
  99. delete(ct.population, addrString)
  100. }
  101. // AddClient introduces a new client connection if possible. If we can't, throws an error instead.
  102. func (ct *Throttler) AddClient(addr net.IP) error {
  103. ct.Lock()
  104. defer ct.Unlock()
  105. if !ct.enabled {
  106. return nil
  107. }
  108. // check exempted lists
  109. if ct.exemptedIPs[addr.String()] {
  110. return nil
  111. }
  112. for _, ex := range ct.exemptedNets {
  113. if ex.Contains(addr) {
  114. return nil
  115. }
  116. }
  117. // check throttle
  118. ct.maskAddr(addr)
  119. addrString := addr.String()
  120. details := ct.population[addrString] // retrieve mutable throttle state from the map
  121. // add in constant state to process the limiting operation
  122. g := GenericThrottle{
  123. ThrottleDetails: details,
  124. Duration: ct.duration,
  125. Limit: ct.subnetLimit,
  126. }
  127. throttled, _ := g.Touch() // actually check the limit
  128. ct.population[addrString] = g.ThrottleDetails // store modified mutable state
  129. if throttled {
  130. return errTooManyClients
  131. } else {
  132. return nil
  133. }
  134. }
  135. func (ct *Throttler) BanDuration() time.Duration {
  136. ct.RLock()
  137. defer ct.RUnlock()
  138. return ct.banDuration
  139. }
  140. func (ct *Throttler) BanMessage() string {
  141. ct.RLock()
  142. defer ct.RUnlock()
  143. return ct.banMessage
  144. }
  145. // NewThrottler returns a new client connection throttler.
  146. // The throttler is functional, but disabled; it can be enabled via `ApplyConfig`.
  147. func NewThrottler() *Throttler {
  148. var ct Throttler
  149. // initialize empty population; all other state is configurable
  150. ct.population = make(map[string]ThrottleDetails)
  151. return &ct
  152. }
  153. // ApplyConfig atomically applies a config update to a throttler
  154. func (ct *Throttler) ApplyConfig(config ThrottlerConfig) error {
  155. // assemble exempted nets
  156. exemptedIPs := make(map[string]bool)
  157. var exemptedNets []net.IPNet
  158. for _, cidr := range config.Exempted {
  159. ipaddr := net.ParseIP(cidr)
  160. _, netaddr, err := net.ParseCIDR(cidr)
  161. if ipaddr == nil && err != nil {
  162. return fmt.Errorf("Could not parse exempted IP/network [%s]", cidr)
  163. }
  164. if ipaddr != nil {
  165. exemptedIPs[ipaddr.String()] = true
  166. } else {
  167. exemptedNets = append(exemptedNets, *netaddr)
  168. }
  169. }
  170. ct.Lock()
  171. defer ct.Unlock()
  172. ct.enabled = config.Enabled
  173. ct.ipv4Mask = net.CIDRMask(config.CidrLenIPv4, 32)
  174. ct.ipv6Mask = net.CIDRMask(config.CidrLenIPv6, 128)
  175. ct.subnetLimit = config.ConnectionsPerCidr
  176. ct.duration = config.Duration
  177. ct.banDuration = config.BanDuration
  178. ct.banMessage = config.BanMessage
  179. ct.exemptedIPs = exemptedIPs
  180. ct.exemptedNets = exemptedNets
  181. return nil
  182. }