Browse Source

refactor the rehash implementation

tags/v0.9.2-alpha
Shivaram Lingamneni 6 years ago
parent
commit
e8b1870067
7 changed files with 323 additions and 251 deletions
  1. 1
    1
      irc/channel.go
  2. 4
    0
      irc/config.go
  3. 26
    0
      irc/database.go
  4. 21
    2
      irc/help.go
  5. 26
    15
      irc/rest_api.go
  6. 244
    232
      irc/server.go
  7. 1
    1
      oragono.go

+ 1
- 1
irc/channel.go View File

@@ -57,7 +57,7 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
57 57
 	}
58 58
 
59 59
 	if addDefaultModes {
60
-		for _, mode := range s.defaultChannelModes {
60
+		for _, mode := range s.GetDefaultChannelModes() {
61 61
 			channel.flags[mode] = true
62 62
 		}
63 63
 	}

+ 4
- 0
irc/config.go View File

@@ -244,6 +244,8 @@ type Config struct {
244 244
 		WhowasEntries  uint          `yaml:"whowas-entries"`
245 245
 		LineLen        LineLenConfig `yaml:"linelen"`
246 246
 	}
247
+
248
+	Filename string
247 249
 }
248 250
 
249 251
 // OperClass defines an assembled operator class.
@@ -390,6 +392,8 @@ func LoadConfig(filename string) (config *Config, err error) {
390 392
 		return nil, err
391 393
 	}
392 394
 
395
+	config.Filename = filename
396
+
393 397
 	// we need this so PasswordBytes returns the correct info
394 398
 	if config.Server.Password != "" {
395 399
 		config.Server.PassConfig.Password = config.Server.Password

+ 26
- 0
irc/database.go View File

@@ -53,6 +53,32 @@ func InitDB(path string) {
53 53
 	}
54 54
 }
55 55
 
56
+// open an existing database, performing a schema version check
57
+func OpenDatabase(path string) (*buntdb.DB, error) {
58
+	// open data store
59
+	db, err := buntdb.Open(path)
60
+	if err != nil {
61
+		return nil, err
62
+	}
63
+
64
+	// check db version
65
+	err = db.View(func(tx *buntdb.Tx) error {
66
+		version, _ := tx.Get(keySchemaVersion)
67
+		if version != latestDbSchema {
68
+			return fmt.Errorf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version)
69
+		}
70
+		return nil
71
+	})
72
+
73
+	if err != nil {
74
+		// close the db
75
+		db.Close()
76
+		return nil, err
77
+	}
78
+
79
+	return db, nil
80
+}
81
+
56 82
 // UpgradeDB upgrades the datastore to the latest schema.
57 83
 func UpgradeDB(path string) {
58 84
 	store, err := buntdb.Open(path)

+ 21
- 2
irc/help.go View File

@@ -530,10 +530,10 @@ Oragono supports the following channel membership prefixes:
530 530
 }
531 531
 
532 532
 // HelpIndex contains the list of all help topics for regular users.
533
-var HelpIndex = "list of all help topics for regular users"
533
+var HelpIndex string
534 534
 
535 535
 // HelpIndexOpers contains the list of all help topics for opers.
536
-var HelpIndexOpers = "list of all help topics for opers"
536
+var HelpIndexOpers string
537 537
 
538 538
 // GenerateHelpIndex is used to generate HelpIndex.
539 539
 func GenerateHelpIndex(forOpers bool) string {
@@ -582,6 +582,25 @@ Information:
582 582
 	return newHelpIndex
583 583
 }
584 584
 
585
+func GenerateHelpIndices() error {
586
+	if HelpIndex != "" && HelpIndexOpers != "" {
587
+		return nil
588
+	}
589
+
590
+	// startup check that we have HELP entries for every command
591
+	for name := range Commands {
592
+		_, exists := Help[strings.ToLower(name)]
593
+		if !exists {
594
+			return fmt.Errorf("Help entry does not exist for command %s", name)
595
+		}
596
+	}
597
+
598
+	// generate help indexes
599
+	HelpIndex = GenerateHelpIndex(false)
600
+	HelpIndexOpers = GenerateHelpIndex(true)
601
+	return nil
602
+}
603
+
585 604
 // sendHelp sends the client help of the given string.
586 605
 func (client *Client) sendHelp(name string, text string) {
587 606
 	splitName := strings.Split(name, " ")

+ 26
- 15
irc/rest_api.go View File

@@ -19,9 +19,9 @@ import (
19 19
 
20 20
 const restErr = "{\"error\":\"An unknown error occurred\"}"
21 21
 
22
-// restAPIServer is used to keep a link to the current running server since this is the best
22
+// ircServer is used to keep a link to the current running server since this is the best
23 23
 // way to do it, given how HTTP handlers dispatch and work.
24
-var restAPIServer *Server
24
+var ircServer *Server
25 25
 
26 26
 type restInfoResp struct {
27 27
 	ServerName  string `json:"server-name"`
@@ -60,8 +60,8 @@ type restRehashResp struct {
60 60
 func restInfo(w http.ResponseWriter, r *http.Request) {
61 61
 	rs := restInfoResp{
62 62
 		Version:     SemVer,
63
-		ServerName:  restAPIServer.name,
64
-		NetworkName: restAPIServer.networkName,
63
+		ServerName:  ircServer.name,
64
+		NetworkName: ircServer.networkName,
65 65
 	}
66 66
 	b, err := json.Marshal(rs)
67 67
 	if err != nil {
@@ -73,9 +73,9 @@ func restInfo(w http.ResponseWriter, r *http.Request) {
73 73
 
74 74
 func restStatus(w http.ResponseWriter, r *http.Request) {
75 75
 	rs := restStatusResp{
76
-		Clients:  restAPIServer.clients.Count(),
77
-		Opers:    len(restAPIServer.operators),
78
-		Channels: restAPIServer.channels.Len(),
76
+		Clients:  ircServer.clients.Count(),
77
+		Opers:    len(ircServer.operators),
78
+		Channels: ircServer.channels.Len(),
79 79
 	}
80 80
 	b, err := json.Marshal(rs)
81 81
 	if err != nil {
@@ -87,8 +87,8 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
87 87
 
88 88
 func restGetXLines(w http.ResponseWriter, r *http.Request) {
89 89
 	rs := restXLinesResp{
90
-		DLines: restAPIServer.dlines.AllBans(),
91
-		KLines: restAPIServer.klines.AllBans(),
90
+		DLines: ircServer.dlines.AllBans(),
91
+		KLines: ircServer.klines.AllBans(),
92 92
 	}
93 93
 	b, err := json.Marshal(rs)
94 94
 	if err != nil {
@@ -104,7 +104,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
104 104
 	}
105 105
 
106 106
 	// get accounts
107
-	err := restAPIServer.store.View(func(tx *buntdb.Tx) error {
107
+	err := ircServer.store.View(func(tx *buntdb.Tx) error {
108 108
 		tx.AscendKeys("account.exists *", func(key, value string) bool {
109 109
 			key = key[len("account.exists "):]
110 110
 			_, err := tx.Get(fmt.Sprintf(keyAccountVerified, key))
@@ -118,7 +118,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
118 118
 			regTime := time.Unix(regTimeInt, 0)
119 119
 
120 120
 			var clients int
121
-			acct := restAPIServer.accounts[key]
121
+			acct := ircServer.accounts[key]
122 122
 			if acct != nil {
123 123
 				clients = len(acct.Clients)
124 124
 			}
@@ -148,7 +148,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
148 148
 }
149 149
 
150 150
 func restRehash(w http.ResponseWriter, r *http.Request) {
151
-	err := restAPIServer.rehash()
151
+	err := ircServer.rehash()
152 152
 
153 153
 	rs := restRehashResp{
154 154
 		Successful: err == nil,
@@ -166,9 +166,9 @@ func restRehash(w http.ResponseWriter, r *http.Request) {
166 166
 	}
167 167
 }
168 168
 
169
-func (s *Server) startRestAPI() {
169
+func StartRestAPI(s *Server, listenAddr string) (*http.Server, error) {
170 170
 	// so handlers can ref it later
171
-	restAPIServer = s
171
+	ircServer = s
172 172
 
173 173
 	// start router
174 174
 	r := mux.NewRouter()
@@ -185,5 +185,16 @@ func (s *Server) startRestAPI() {
185 185
 	rp.HandleFunc("/rehash", restRehash)
186 186
 
187 187
 	// start api
188
-	go http.ListenAndServe(s.restAPI.Listen, r)
188
+	httpserver := http.Server{
189
+		Addr:    listenAddr,
190
+		Handler: r,
191
+	}
192
+
193
+	go func() {
194
+		if err := httpserver.ListenAndServe(); err != nil {
195
+			s.logger.Error("listeners", fmt.Sprintf("Rest API listenAndServe error: %s", err))
196
+		}
197
+	}()
198
+
199
+	return &httpserver, nil
189 200
 }

+ 244
- 232
irc/server.go View File

@@ -7,9 +7,9 @@ package irc
7 7
 
8 8
 import (
9 9
 	"bufio"
10
+	"context"
10 11
 	"crypto/tls"
11 12
 	"encoding/base64"
12
-	"errors"
13 13
 	"fmt"
14 14
 	"log"
15 15
 	"math/rand"
@@ -35,8 +35,14 @@ var (
35 35
 	tooManyClientsMsg, _   = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line()
36 36
 	couldNotParseIPMsg, _  = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
37 37
 	bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line()
38
+)
38 39
 
39
-	errDbOutOfDate = errors.New("Database schema is old")
40
+const (
41
+	// when shutting down the REST and websocket servers, wait this long
42
+	// before killing active non-WS connections. TODO: this might not be
43
+	// necessary at all? but it seems prudent to avoid potential resource
44
+	// leaks
45
+	httpShutdownTimeout = time.Second
40 46
 )
41 47
 
42 48
 // Limits holds the maximum limits for various things such as topic lengths.
@@ -80,6 +86,7 @@ type Server struct {
80 86
 	clients                      *ClientLookupSet
81 87
 	commands                     chan Command
82 88
 	configFilename               string
89
+	configurableStateMutex       sync.RWMutex // generic protection for server state modified by rehash()
83 90
 	connectionLimits             *ConnectionLimits
84 91
 	connectionLimitsMutex        sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack
85 92
 	connectionThrottle           *ConnectionThrottle
@@ -109,13 +116,15 @@ type Server struct {
109 116
 	registeredChannelsMutex      sync.RWMutex
110 117
 	rehashMutex                  sync.Mutex
111 118
 	rehashSignal                 chan os.Signal
112
-	restAPI                      *RestAPIConfig
119
+	restAPI                      RestAPIConfig
120
+	restAPIServer                *http.Server
113 121
 	proxyAllowedFrom             []string
114 122
 	signals                      chan os.Signal
115 123
 	snomasks                     *SnoManager
116 124
 	store                        *buntdb.DB
117 125
 	stsEnabled                   bool
118 126
 	whoWas                       *WhoWasList
127
+	wsServer                     *http.Server
119 128
 }
120 129
 
121 130
 var (
@@ -133,216 +142,38 @@ type clientConn struct {
133 142
 }
134 143
 
135 144
 // NewServer returns a new Oragono server.
136
-func NewServer(configFilename string, config *Config, logger *logger.Manager) (*Server, error) {
137
-	casefoldedName, err := Casefold(config.Server.Name)
138
-	if err != nil {
139
-		return nil, fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
140
-	}
141
-
142
-	// startup check that we have HELP entries for every command
143
-	for name := range Commands {
144
-		_, exists := Help[strings.ToLower(name)]
145
-		if !exists {
146
-			return nil, fmt.Errorf("Help entry does not exist for command %s", name)
147
-		}
148
-	}
149
-	// generate help indexes
150
-	HelpIndex = GenerateHelpIndex(false)
151
-	HelpIndexOpers = GenerateHelpIndex(true)
152
-
153
-	if config.Accounts.AuthenticationEnabled {
154
-		SupportedCapabilities[SASL] = true
155
-	}
156
-
157
-	if config.Server.STS.Enabled {
158
-		SupportedCapabilities[STS] = true
159
-		CapValues[STS] = config.Server.STS.Value()
160
-	}
161
-
162
-	if config.Limits.LineLen.Tags > 512 || config.Limits.LineLen.Rest > 512 {
163
-		SupportedCapabilities[MaxLine] = true
164
-		CapValues[MaxLine] = fmt.Sprintf("%d,%d", config.Limits.LineLen.Tags, config.Limits.LineLen.Rest)
165
-	}
166
-
167
-	operClasses, err := config.OperatorClasses()
168
-	if err != nil {
169
-		return nil, fmt.Errorf("Error loading oper classes: %s", err.Error())
170
-	}
171
-	opers, err := config.Operators(operClasses)
172
-	if err != nil {
173
-		return nil, fmt.Errorf("Error loading operators: %s", err.Error())
174
-	}
175
-
176
-	connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
177
-	if err != nil {
178
-		return nil, fmt.Errorf("Error loading connection limits: %s", err.Error())
179
-	}
180
-	connectionThrottle, err := NewConnectionThrottle(config.Server.ConnectionThrottle)
181
-	if err != nil {
182
-		return nil, fmt.Errorf("Error loading connection throttler: %s", err.Error())
145
+func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
146
+	// TODO move this to main?
147
+	if err := GenerateHelpIndices(); err != nil {
148
+		return nil, err
183 149
 	}
184 150
 
151
+	// initialize data structures
185 152
 	server := &Server{
186
-		accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
187
-		accounts:                     make(map[string]*ClientAccount),
188
-		channelRegistrationEnabled:   config.Channels.Registration.Enabled,
189
-		channels:                     *NewChannelNameMap(),
190
-		checkIdent:                   config.Server.CheckIdent,
191
-		clients:                      NewClientLookupSet(),
192
-		commands:                     make(chan Command),
193
-		configFilename:               configFilename,
194
-		connectionLimits:             connectionLimits,
195
-		connectionThrottle:           connectionThrottle,
196
-		ctime:                        time.Now(),
197
-		currentOpers:                 make(map[*Client]bool),
198
-		defaultChannelModes:          ParseDefaultChannelModes(config),
199
-		limits: Limits{
200
-			AwayLen:        int(config.Limits.AwayLen),
201
-			ChannelLen:     int(config.Limits.ChannelLen),
202
-			KickLen:        int(config.Limits.KickLen),
203
-			MonitorEntries: int(config.Limits.MonitorEntries),
204
-			NickLen:        int(config.Limits.NickLen),
205
-			TopicLen:       int(config.Limits.TopicLen),
206
-			ChanListModes:  int(config.Limits.ChanListModes),
207
-			LineLen: LineLenLimits{
208
-				Tags: config.Limits.LineLen.Tags,
209
-				Rest: config.Limits.LineLen.Rest,
210
-			},
211
-		},
153
+		accounts:           make(map[string]*ClientAccount),
154
+		channels:           *NewChannelNameMap(),
155
+		clients:            NewClientLookupSet(),
156
+		commands:           make(chan Command),
157
+		currentOpers:       make(map[*Client]bool),
212 158
 		listeners:          make(map[string]*ListenerWrapper),
213 159
 		logger:             logger,
214
-		MaxSendQBytes:      config.Server.MaxSendQBytes,
215 160
 		monitoring:         make(map[string][]*Client),
216
-		name:               config.Server.Name,
217
-		nameCasefolded:     casefoldedName,
218
-		networkName:        config.Network.Name,
219 161
 		newConns:           make(chan clientConn),
220
-		operators:          opers,
221
-		operclasses:        *operClasses,
222
-		proxyAllowedFrom:   config.Server.ProxyAllowedFrom,
223 162
 		registeredChannels: make(map[string]*RegisteredChannel),
224 163
 		rehashSignal:       make(chan os.Signal, 1),
225
-		restAPI:            &config.Server.RestAPI,
226 164
 		signals:            make(chan os.Signal, len(ServerExitSignals)),
227 165
 		snomasks:           NewSnoManager(),
228
-		stsEnabled:         config.Server.STS.Enabled,
229 166
 		whoWas:             NewWhoWasList(config.Limits.WhowasEntries),
230 167
 	}
231 168
 
232
-	// open data store
233
-	server.logger.Debug("startup", "Opening datastore")
234
-	db, err := buntdb.Open(config.Datastore.Path)
235
-	if err != nil {
236
-		return nil, fmt.Errorf("Failed to open datastore: %s", err.Error())
237
-	}
238
-	server.store = db
239
-
240
-	// check db version
241
-	err = server.store.View(func(tx *buntdb.Tx) error {
242
-		version, _ := tx.Get(keySchemaVersion)
243
-		if version != latestDbSchema {
244
-			logger.Error("startup", "server", fmt.Sprintf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version))
245
-			return errDbOutOfDate
246
-		}
247
-		return nil
248
-	})
249
-	if err != nil {
250
-		// close the db
251
-		db.Close()
252
-		return nil, errDbOutOfDate
253
-	}
254
-
255
-	// load *lines
256
-	server.logger.Debug("startup", "Loading D/Klines")
257
-	server.loadDLines()
258
-	server.loadKLines()
259
-
260
-	// load password manager
261
-	server.logger.Debug("startup", "Loading passwords")
262
-	err = server.store.View(func(tx *buntdb.Tx) error {
263
-		saltString, err := tx.Get(keySalt)
264
-		if err != nil {
265
-			return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
266
-		}
267
-
268
-		salt, err := base64.StdEncoding.DecodeString(saltString)
269
-		if err != nil {
270
-			return err
271
-		}
272
-
273
-		pwm := NewPasswordManager(salt)
274
-		server.passwords = &pwm
275
-		return nil
276
-	})
277
-	if err != nil {
278
-		return nil, fmt.Errorf("Could not load salt: %s", err.Error())
279
-	}
280
-
281
-	server.logger.Debug("startup", "Loading MOTD")
282
-	if config.Server.MOTD != "" {
283
-		file, err := os.Open(config.Server.MOTD)
284
-		if err == nil {
285
-			defer file.Close()
286
-
287
-			reader := bufio.NewReader(file)
288
-			for {
289
-				line, err := reader.ReadString('\n')
290
-				if err != nil {
291
-					break
292
-				}
293
-				line = strings.TrimRight(line, "\r\n")
294
-				// "- " is the required prefix for MOTD, we just add it here to make
295
-				// bursting it out to clients easier
296
-				line = fmt.Sprintf("- %s", line)
297
-
298
-				server.motdLines = append(server.motdLines, line)
299
-			}
300
-		}
301
-	}
302
-
303
-	if config.Server.Password != "" {
304
-		server.password = config.Server.PasswordBytes()
169
+	if err := server.applyConfig(config, true); err != nil {
170
+		return nil, err
305 171
 	}
306 172
 
307
-	tlsListeners := config.TLSListeners()
308
-	for _, addr := range config.Server.Listen {
309
-		server.listeners[addr] = server.createListener(addr, tlsListeners[addr])
310
-	}
311
-
312
-	if len(tlsListeners) == 0 {
313
-		server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
314
-	}
315
-	var usesStandardTLSPort bool
316
-	for addr := range config.TLSListeners() {
317
-		if strings.Contains(addr, "6697") {
318
-			usesStandardTLSPort = true
319
-			break
320
-		}
321
-	}
322
-	if 0 < len(tlsListeners) && !usesStandardTLSPort {
323
-		server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
324
-	}
325
-
326
-	if config.Server.Wslisten != "" {
327
-		server.wslisten(config.Server.Wslisten, config.Server.TLSListeners)
328
-	}
329
-
330
-	// registration
331
-	accountReg := NewAccountRegistration(config.Accounts.Registration)
332
-	server.accountRegistration = &accountReg
333
-
334 173
 	// Attempt to clean up when receiving these signals.
335 174
 	signal.Notify(server.signals, ServerExitSignals...)
336 175
 	signal.Notify(server.rehashSignal, syscall.SIGHUP)
337 176
 
338
-	server.setISupport()
339
-
340
-	// start API if enabled
341
-	if server.restAPI.Enabled {
342
-		logger.Info("startup", "server", fmt.Sprintf("%s rest API started on %s.", server.name, server.restAPI.Listen))
343
-		server.startRestAPI()
344
-	}
345
-
346 177
 	return server, nil
347 178
 }
348 179
 
@@ -514,13 +345,6 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
514 345
 		stopEvent:  make(chan bool, 1),
515 346
 	}
516 347
 
517
-	// TODO(slingamn) move all logging of listener status to rehash()
518
-	tlsString := "plaintext"
519
-	if tlsConfig != nil {
520
-		tlsString = "TLS"
521
-	}
522
-	server.logger.Info("listeners", fmt.Sprintf("listening on %s using %s.", addr, tlsString))
523
-
524 348
 	var shouldStop bool
525 349
 
526 350
 	// setup accept goroutine
@@ -562,8 +386,29 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
562 386
 // websocket listen goroutine
563 387
 //
564 388
 
565
-func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) {
566
-	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
389
+func (server *Server) setupWSListener(config *Config) {
390
+	// unconditionally shut down the old listener because we can't tell
391
+	// whether we need to reload the TLS certificate
392
+	if server.wsServer != nil {
393
+		ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
394
+		server.wsServer.Shutdown(ctx)
395
+		server.wsServer.Close()
396
+	}
397
+
398
+	if config.Server.Wslisten == "" {
399
+		server.wsServer = nil
400
+		return
401
+	}
402
+
403
+	addr := config.Server.Wslisten
404
+	tlsConfig := config.Server.TLSListeners[addr]
405
+	handler := http.NewServeMux()
406
+	wsServer := http.Server{
407
+		Addr:    addr,
408
+		Handler: handler,
409
+	}
410
+
411
+	handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
567 412
 		if r.Method != "GET" {
568 413
 			server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method))
569 414
 			return
@@ -584,29 +429,26 @@ func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig)
584 429
 
585 430
 		newConn := clientConn{
586 431
 			Conn:  WSContainer{ws},
587
-			IsTLS: false, //TODO(dan): track TLS or not here properly
432
+			IsTLS: tlsConfig != nil,
588 433
 		}
589 434
 		server.newConns <- newConn
590 435
 	})
591
-	go func() {
592
-		config, listenTLS := tlsMap[addr]
593 436
 
594
-		tlsString := "plaintext"
437
+	go func() {
595 438
 		var err error
596
-		if listenTLS {
597
-			tlsString = "TLS"
598
-		}
599
-		server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s using %s.", addr, tlsString))
439
+		server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s, tls=%t.", addr, tlsConfig != nil))
600 440
 
601
-		if listenTLS {
602
-			err = http.ListenAndServeTLS(addr, config.Cert, config.Key, nil)
441
+		if tlsConfig != nil {
442
+			err = wsServer.ListenAndServeTLS(tlsConfig.Cert, tlsConfig.Key)
603 443
 		} else {
604
-			err = http.ListenAndServe(addr, nil)
444
+			err = wsServer.ListenAndServe()
605 445
 		}
606 446
 		if err != nil {
607
-			server.logger.Error("listeners", fmt.Sprintf("listenAndServe error [%s]: %s", tlsString, err))
447
+			server.logger.Error("listeners", fmt.Sprintf("websocket ListenAndServe error: %s", err))
608 448
 		}
609 449
 	}()
450
+
451
+	server.wsServer = &wsServer
610 452
 }
611 453
 
612 454
 // generateMessageID returns a network-unique message ID.
@@ -660,6 +502,9 @@ func (server *Server) tryRegister(c *Client) {
660 502
 
661 503
 // MOTD serves the Message of the Day.
662 504
 func (server *Server) MOTD(client *Client) {
505
+	server.configurableStateMutex.RLock()
506
+	defer server.configurableStateMutex.RUnlock()
507
+
663 508
 	if len(server.motdLines) < 1 {
664 509
 		client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
665 510
 		return
@@ -1415,14 +1260,34 @@ func (server *Server) rehash() error {
1415 1260
 	server.logger.Debug("rehash", "Got rehash lock")
1416 1261
 
1417 1262
 	config, err := LoadConfig(server.configFilename)
1263
+	if err != nil {
1264
+		return fmt.Errorf("Error loading config file config: %s", err.Error())
1265
+	}
1266
+
1267
+	err = server.applyConfig(config, false)
1268
+	if err != nil {
1269
+		return fmt.Errorf("Error applying config changes: %s", err.Error())
1270
+	}
1271
+
1272
+	return nil
1273
+}
1274
+
1275
+func (server *Server) applyConfig(config *Config, initial bool) error {
1276
+	if initial {
1277
+		server.ctime = time.Now()
1278
+		server.configFilename = config.Filename
1279
+	}
1418 1280
 
1281
+	casefoldedName, err := Casefold(config.Server.Name)
1419 1282
 	if err != nil {
1420
-		return fmt.Errorf("Error rehashing config file config: %s", err.Error())
1283
+		return fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
1421 1284
 	}
1422 1285
 
1423
-	// line lengths cannot be changed after launching the server
1424
-	if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
1425
-		return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
1286
+	if !initial {
1287
+		// line lengths cannot be changed after launching the server
1288
+		if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
1289
+			return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
1290
+		}
1426 1291
 	}
1427 1292
 
1428 1293
 	// confirm connectionLimits are fine
@@ -1453,6 +1318,18 @@ func (server *Server) rehash() error {
1453 1318
 		}
1454 1319
 	}
1455 1320
 
1321
+	// sanity checks complete, start modifying server state
1322
+
1323
+	server.name = config.Server.Name
1324
+	server.nameCasefolded = casefoldedName
1325
+	server.networkName = config.Network.Name
1326
+
1327
+	if config.Server.Password != "" {
1328
+		server.password = config.Server.PasswordBytes()
1329
+	} else {
1330
+		server.password = nil
1331
+	}
1332
+
1456 1333
 	// apply new connectionlimits
1457 1334
 	server.connectionLimitsMutex.Lock()
1458 1335
 	server.connectionLimits = connectionLimits
@@ -1578,7 +1455,9 @@ func (server *Server) rehash() error {
1578 1455
 	server.accountRegistration = &accountReg
1579 1456
 	server.channelRegistrationEnabled = config.Channels.Registration.Enabled
1580 1457
 
1458
+	server.configurableStateMutex.Lock()
1581 1459
 	server.defaultChannelModes = ParseDefaultChannelModes(config)
1460
+	server.configurableStateMutex.Unlock()
1582 1461
 
1583 1462
 	// set new sendqueue size
1584 1463
 	if config.Server.MaxSendQBytes != server.MaxSendQBytes {
@@ -1595,18 +1474,108 @@ func (server *Server) rehash() error {
1595 1474
 	// set RPL_ISUPPORT
1596 1475
 	oldISupportList := server.isupport
1597 1476
 	server.setISupport()
1598
-	newISupportReplies := oldISupportList.GetDifference(server.isupport)
1477
+	if oldISupportList != nil {
1478
+		newISupportReplies := oldISupportList.GetDifference(server.isupport)
1479
+		// push new info to all of our clients
1480
+		server.clients.ByNickMutex.RLock()
1481
+		for _, sClient := range server.clients.ByNick {
1482
+			for _, tokenline := range newISupportReplies {
1483
+				// ugly trickery ahead
1484
+				sClient.Send(nil, server.name, RPL_ISUPPORT, append([]string{sClient.nick}, tokenline...)...)
1485
+			}
1486
+		}
1487
+		server.clients.ByNickMutex.RUnlock()
1488
+	}
1599 1489
 
1600
-	// push new info to all of our clients
1601
-	server.clients.ByNickMutex.RLock()
1602
-	for _, sClient := range server.clients.ByNick {
1603
-		for _, tokenline := range newISupportReplies {
1604
-			// ugly trickery ahead
1605
-			sClient.Send(nil, server.name, RPL_ISUPPORT, append([]string{sClient.nick}, tokenline...)...)
1490
+	server.loadMOTD(config.Server.MOTD)
1491
+
1492
+	if initial {
1493
+		if err := server.loadDatastore(config.Datastore.Path); err != nil {
1494
+			return err
1606 1495
 		}
1607 1496
 	}
1608
-	server.clients.ByNickMutex.RUnlock()
1609 1497
 
1498
+	// we are now open for business
1499
+	server.setupListeners(config)
1500
+	server.setupWSListener(config)
1501
+	server.setupRestAPI(config)
1502
+
1503
+	return nil
1504
+}
1505
+
1506
+func (server *Server) loadMOTD(motdPath string) error {
1507
+	server.logger.Debug("rehash", "Loading MOTD")
1508
+	motdLines := make([]string, 0)
1509
+	if motdPath != "" {
1510
+		file, err := os.Open(motdPath)
1511
+		if err == nil {
1512
+			defer file.Close()
1513
+
1514
+			reader := bufio.NewReader(file)
1515
+			for {
1516
+				line, err := reader.ReadString('\n')
1517
+				if err != nil {
1518
+					break
1519
+				}
1520
+				line = strings.TrimRight(line, "\r\n")
1521
+				// "- " is the required prefix for MOTD, we just add it here to make
1522
+				// bursting it out to clients easier
1523
+				line = fmt.Sprintf("- %s", line)
1524
+
1525
+				motdLines = append(motdLines, line)
1526
+			}
1527
+		} else {
1528
+			return err
1529
+		}
1530
+	}
1531
+
1532
+	server.configurableStateMutex.Lock()
1533
+	defer server.configurableStateMutex.Unlock()
1534
+	server.motdLines = motdLines
1535
+	return nil
1536
+}
1537
+
1538
+func (server *Server) loadDatastore(datastorePath string) error {
1539
+	// open the datastore and load server state for which it (rather than config)
1540
+	// is the source of truth
1541
+
1542
+	server.logger.Debug("startup", "Opening datastore")
1543
+	db, err := OpenDatabase(datastorePath)
1544
+	if err == nil {
1545
+		server.store = db
1546
+	} else {
1547
+		return fmt.Errorf("Failed to open datastore: %s", err.Error())
1548
+	}
1549
+
1550
+	// load *lines (from the datastores)
1551
+	server.logger.Debug("startup", "Loading D/Klines")
1552
+	server.loadDLines()
1553
+	server.loadKLines()
1554
+
1555
+	// load password manager
1556
+	server.logger.Debug("startup", "Loading passwords")
1557
+	err = server.store.View(func(tx *buntdb.Tx) error {
1558
+		saltString, err := tx.Get(keySalt)
1559
+		if err != nil {
1560
+			return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
1561
+		}
1562
+
1563
+		salt, err := base64.StdEncoding.DecodeString(saltString)
1564
+		if err != nil {
1565
+			return err
1566
+		}
1567
+
1568
+		pwm := NewPasswordManager(salt)
1569
+		server.passwords = &pwm
1570
+		return nil
1571
+	})
1572
+	if err != nil {
1573
+		return fmt.Errorf("Could not load salt: %s", err.Error())
1574
+	}
1575
+	return nil
1576
+}
1577
+
1578
+func (server *Server) setupListeners(config *Config) {
1610 1579
 	// update or destroy all existing listeners
1611 1580
 	tlsListeners := config.TLSListeners()
1612 1581
 	for addr := range server.listeners {
@@ -1629,18 +1598,16 @@ func (server *Server) rehash() error {
1629 1598
 		currentListener.configMutex.Unlock()
1630 1599
 
1631 1600
 		if stillConfigured {
1632
-			server.logger.Info("rehash",
1601
+			server.logger.Info("listeners",
1633 1602
 				fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)),
1634 1603
 			)
1635 1604
 		} else {
1636 1605
 			// tell the listener it should stop by interrupting its Accept() call:
1637 1606
 			currentListener.listener.Close()
1638
-			// XXX there is no guarantee from the API when the address will actually
1639
-			// free for bind(2) again; this solution "seems to work". See here:
1640
-			// https://github.com/golang/go/issues/21833
1607
+			// TODO(golang1.10) delete stopEvent once issue #21856 is released
1641 1608
 			<-currentListener.stopEvent
1642 1609
 			delete(server.listeners, addr)
1643
-			server.logger.Info("rehash", fmt.Sprintf("stopped listening on %s.", addr))
1610
+			server.logger.Info("listeners", fmt.Sprintf("stopped listening on %s.", addr))
1644 1611
 		}
1645 1612
 	}
1646 1613
 
@@ -1653,7 +1620,52 @@ func (server *Server) rehash() error {
1653 1620
 		}
1654 1621
 	}
1655 1622
 
1656
-	return nil
1623
+	if len(tlsListeners) == 0 {
1624
+		server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
1625
+	}
1626
+
1627
+	var usesStandardTLSPort bool
1628
+	for addr := range config.TLSListeners() {
1629
+		if strings.Contains(addr, "6697") {
1630
+			usesStandardTLSPort = true
1631
+			break
1632
+		}
1633
+	}
1634
+	if 0 < len(tlsListeners) && !usesStandardTLSPort {
1635
+		server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
1636
+	}
1637
+}
1638
+
1639
+func (server *Server) setupRestAPI(config *Config) {
1640
+	restAPIEnabled := config.Server.RestAPI.Enabled
1641
+	restAPIStarted := server.restAPIServer != nil
1642
+	restAPIListenAddrChanged := server.restAPI.Listen != config.Server.RestAPI.Listen
1643
+
1644
+	// stop an existing REST server if it's been disabled or the addr changed
1645
+	if restAPIStarted && (!restAPIEnabled || restAPIListenAddrChanged) {
1646
+		ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
1647
+		server.restAPIServer.Shutdown(ctx)
1648
+		server.restAPIServer.Close()
1649
+		server.logger.Info("rehash", "server", fmt.Sprintf("%s rest API stopped on %s.", server.name, server.restAPI.Listen))
1650
+		server.restAPIServer = nil
1651
+	}
1652
+
1653
+	// start a new one if it's enabled or the addr changed
1654
+	if restAPIEnabled && (!restAPIStarted || restAPIListenAddrChanged) {
1655
+		server.restAPIServer, _ = StartRestAPI(server, config.Server.RestAPI.Listen)
1656
+		server.logger.Info(
1657
+			"rehash", "server",
1658
+			fmt.Sprintf("%s rest API started on %s.", server.name, config.Server.RestAPI.Listen))
1659
+	}
1660
+
1661
+	// save the config information
1662
+	server.restAPI = config.Server.RestAPI
1663
+}
1664
+
1665
+func (server *Server) GetDefaultChannelModes() Modes {
1666
+	server.configurableStateMutex.RLock()
1667
+	defer server.configurableStateMutex.RUnlock()
1668
+	return server.defaultChannelModes
1657 1669
 }
1658 1670
 
1659 1671
 // REHASH

+ 1
- 1
oragono.go View File

@@ -132,7 +132,7 @@ Options:
132 132
 			logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.")
133 133
 		}
134 134
 
135
-		server, err := irc.NewServer(configfile, config, logger)
135
+		server, err := irc.NewServer(config, logger)
136 136
 		if err != nil {
137 137
 			logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error()))
138 138
 			return

Loading…
Cancel
Save