123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- // Copyright (c) 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
- // released under the MIT license
-
- package utils
-
- import (
- "crypto/tls"
- "encoding/binary"
- "io"
- "net"
- "strings"
- "sync/atomic"
- "time"
- )
-
- const (
- // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
- // "a 108-byte buffer is always enough to store all the line and a trailing zero
- // for string processing."
- maxProxyLineLenV1 = 107
- )
-
- // XXX implement net.Error with a Temporary() method that returns true;
- // otherwise, ErrBadProxyLine will cause (*http.Server).Serve() to exit
- type proxyLineError struct{}
-
- func (p *proxyLineError) Error() string {
- return "invalid PROXY line"
- }
-
- func (p *proxyLineError) Timeout() bool {
- return false
- }
-
- func (p *proxyLineError) Temporary() bool {
- return true
- }
-
- var (
- ErrBadProxyLine error = &proxyLineError{}
- )
-
- // ListenerConfig is all the information about how to process
- // incoming IRC connections on a listener.
- type ListenerConfig struct {
- TLSConfig *tls.Config
- ProxyDeadline time.Duration
- RequireProxy bool
- // these are just metadata for easier tracking,
- // they are not used by ReloadableListener:
- Tor bool
- STSOnly bool
- WebSocket bool
- HideSTS bool
- }
-
- // read a PROXY header (either v1 or v2), ensuring we don't read anything beyond
- // the header into a buffer (this would break the TLS handshake)
- func readRawProxyLine(conn net.Conn, deadline time.Duration) (result []byte, err error) {
- // normally this is covered by ping timeouts, but we're doing this outside
- // of the normal client goroutine:
- conn.SetDeadline(time.Now().Add(deadline))
- defer conn.SetDeadline(time.Time{})
-
- // read the first 16 bytes of the proxy header
- buf := make([]byte, 16, maxProxyLineLenV1)
- _, err = io.ReadFull(conn, buf)
- if err != nil {
- return
- }
-
- switch buf[0] {
- case 'P':
- // PROXY v1: starts with "PROXY"
- return readRawProxyLineV1(conn, buf)
- case '\r':
- // PROXY v2: starts with "\r\n\r\n"
- return readRawProxyLineV2(conn, buf)
- default:
- return nil, ErrBadProxyLine
- }
- }
-
- func readRawProxyLineV1(conn net.Conn, buf []byte) (result []byte, err error) {
- for {
- i := len(buf)
- if i >= maxProxyLineLenV1 {
- return nil, ErrBadProxyLine // did not find \r\n, fail
- }
- // prepare a single byte of free space, then read into it
- buf = buf[0 : i+1]
- _, err = io.ReadFull(conn, buf[i:])
- if err != nil {
- return nil, err
- }
- if buf[i] == '\n' {
- return buf, nil
- }
- }
- }
-
- func readRawProxyLineV2(conn net.Conn, buf []byte) (result []byte, err error) {
- // "The 15th and 16th bytes is the address length in bytes in network endian order."
- addrLen := int(binary.BigEndian.Uint16(buf[14:16]))
- if addrLen == 0 {
- return buf[0:16], nil
- } else if addrLen <= cap(buf)-16 {
- buf = buf[0 : 16+addrLen]
- } else {
- // proxy source is unix domain, we don't really handle this
- buf2 := make([]byte, 16+addrLen)
- copy(buf2[0:16], buf[0:16])
- buf = buf2
- }
- _, err = io.ReadFull(conn, buf[16:16+addrLen])
- if err != nil {
- return
- }
- return buf[0 : 16+addrLen], nil
- }
-
- // ParseProxyLine parses a PROXY protocol (v1 or v2) line and returns the remote IP.
- func ParseProxyLine(line []byte) (ip net.IP, err error) {
- if len(line) == 0 {
- return nil, ErrBadProxyLine
- }
- switch line[0] {
- case 'P':
- return ParseProxyLineV1(string(line))
- case '\r':
- return parseProxyLineV2(line)
- default:
- return nil, ErrBadProxyLine
- }
- }
-
- // ParseProxyLineV1 parses a PROXY protocol (v1) line and returns the remote IP.
- func ParseProxyLineV1(line string) (ip net.IP, err error) {
- params := strings.Fields(line)
- if len(params) != 6 || params[0] != "PROXY" {
- return nil, ErrBadProxyLine
- }
- ip = net.ParseIP(params[2])
- if ip == nil {
- return nil, ErrBadProxyLine
- }
- return ip.To16(), nil
- }
-
- func parseProxyLineV2(line []byte) (ip net.IP, err error) {
- if len(line) < 16 {
- return nil, ErrBadProxyLine
- }
- // this doesn't allocate
- if string(line[:12]) != "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" {
- return nil, ErrBadProxyLine
- }
- // "The next byte (the 13th one) is the protocol version and command."
- versionCmd := line[12]
- // "The highest four bits contains the version [....] it must always be sent as \x2"
- if (versionCmd >> 4) != 2 {
- return nil, ErrBadProxyLine
- }
- // "The lowest four bits represents the command"
- switch versionCmd & 0x0f {
- case 0:
- return nil, nil // LOCAL command
- case 1:
- // PROXY command, continue below
- default:
- // "Receivers must drop connections presenting unexpected values here"
- return nil, ErrBadProxyLine
- }
-
- var addrLen int
- // "The 14th byte contains the transport protocol and address family."
- protoAddr := line[13]
- // "The highest 4 bits contain the address family"
- switch protoAddr >> 4 {
- case 1:
- addrLen = 4 // AF_INET
- case 2:
- addrLen = 16 // AF_INET6
- default:
- return nil, nil // AF_UNSPEC or AF_UNIX, either way there's no IP address
- }
-
- // header, source and destination address, two 16-bit port numbers:
- expectedLen := 16 + 2*addrLen + 4
- if len(line) < expectedLen {
- return nil, ErrBadProxyLine
- }
-
- // "Starting from the 17th byte, addresses are presented in network byte order.
- // The address order is always the same :
- // - source layer 3 address in network byte order [...]"
- if addrLen == 4 {
- ip = net.IP(line[16 : 16+addrLen]).To16()
- } else {
- ip = make(net.IP, addrLen)
- copy(ip, line[16:16+addrLen])
- }
- return ip, nil
- }
-
- // / WrappedConn is a net.Conn with some additional data stapled to it;
- // the proxied IP, if one was read via the PROXY protocol, and the listener
- // configuration.
- type WrappedConn struct {
- net.Conn
- ProxiedIP net.IP
- TLS bool
- Tor bool
- STSOnly bool
- WebSocket bool
- HideSTS bool
- // Secure indicates whether we believe the connection between us and the client
- // was secure against interception and modification (including all proxies):
- Secure bool
- }
-
- // ReloadableListener is a wrapper for net.Listener that allows reloading
- // of config data for postprocessing connections (TLS, PROXY protocol, etc.)
- type ReloadableListener struct {
- realListener net.Listener
- // nil means the listener is closed:
- config atomic.Pointer[ListenerConfig]
- }
-
- func NewReloadableListener(realListener net.Listener, config ListenerConfig) *ReloadableListener {
- result := &ReloadableListener{
- realListener: realListener,
- }
- result.config.Store(&config) // heap escape
- return result
- }
-
- func (rl *ReloadableListener) Reload(config ListenerConfig) {
- rl.config.Store(&config)
- }
-
- func (rl *ReloadableListener) Accept() (conn net.Conn, err error) {
- conn, err = rl.realListener.Accept()
-
- config := rl.config.Load()
-
- if config == nil {
- // Close() was called
- if err == nil {
- conn.Close()
- }
- err = net.ErrClosed
- }
- if err != nil {
- return nil, err
- }
-
- var proxiedIP net.IP
- if config.RequireProxy {
- // this will occur synchronously on the goroutine calling Accept(),
- // but that's OK because this listener *requires* a PROXY line,
- // therefore it must be used with proxies that always send the line
- // and we won't get slowloris'ed waiting for the client response
- proxyLine, err := readRawProxyLine(conn, config.ProxyDeadline)
- if err == nil {
- proxiedIP, err = ParseProxyLine(proxyLine)
- }
- if err != nil {
- conn.Close()
- return nil, err
- }
- }
-
- if config.TLSConfig != nil {
- conn = tls.Server(conn, config.TLSConfig)
- }
-
- return &WrappedConn{
- Conn: conn,
- ProxiedIP: proxiedIP,
- TLS: config.TLSConfig != nil,
- Tor: config.Tor,
- STSOnly: config.STSOnly,
- WebSocket: config.WebSocket,
- HideSTS: config.HideSTS,
- // Secure will be set later by client code
- }, nil
- }
-
- func (rl *ReloadableListener) Close() error {
- rl.config.Store(nil)
-
- return rl.realListener.Close()
- }
-
- func (rl *ReloadableListener) Addr() net.Addr {
- return rl.realListener.Addr()
- }
|