blob: 3bedcf9a980b073d03f978085c9b9b61ed123dbc [file] [log] [blame]
// Copyright 2017 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package localsrv provides helpers for running local TCP servers.
//
// It is used by various machine-local authentication protocols to launch
// local listening servers.
package localsrv
import (
"context"
"fmt"
"net"
"sync"
"time"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/errors"
"go.chromium.org/luci/common/logging"
)
// Server runs a local TCP server.
type Server struct {
l sync.Mutex
name string // name passed to Start
listener net.Listener // to know what to stop in killServe, nil after that
wg sync.WaitGroup // +1 for each request being processed now
ctx context.Context // derived from ctx in Start, never resets to nil after that
cancel context.CancelFunc // cancels 'ctx'
stopped chan struct{} // closed when serve() goroutine stops
}
// ServeFunc is called from internal goroutine to run the server loop.
//
// When server stops, the given listener will be closed and the given context
// will be canceled. The wait group is used to wait for pending requests:
// increment it when starting processing a request, and decrement when done.
//
// If ServeFunc returns after the listener is closed, the returned error is
// ignored (it is most likely caused by the closed listener).
type ServeFunc func(c context.Context, l net.Listener, wg *sync.WaitGroup) error
// Start launches background goroutine with the serving loop 'serve'.
//
// Returns the address the listening socket is bound to.
//
// The provided context is used as base context for request handlers and for
// logging. 'name' identifies this server in logs, and 'port' specifies a TCP
// port number to bind to (or 0 to auto-pick one).
//
// The server must be eventually stopped with Stop().
func (s *Server) Start(ctx context.Context, name string, port int, serve ServeFunc) (*net.TCPAddr, error) {
s.l.Lock()
defer s.l.Unlock()
if s.ctx != nil {
return nil, errors.New("already initialized")
}
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return nil, errors.Annotate(err, "failed to create listening socket").Err()
}
s.name = name
s.ctx, s.cancel = context.WithCancel(ctx)
s.listener = ln
// Start serving in background.
s.stopped = make(chan struct{})
go func() {
defer close(s.stopped)
if err := s.serve(serve); err != nil {
logging.WithError(err).Errorf(s.ctx, "Unexpected error in the server loop of %q", s.name)
}
}()
return ln.Addr().(*net.TCPAddr), nil
}
// Stop closes the listening socket, notifies pending requests to abort and
// stops the internal serving goroutine.
//
// Safe to call multiple times. Once stopped, the server cannot be started again
// (make a new instance of Server instead).
//
// Uses the given context for the deadline when waiting for the serving loop
// to stop.
func (s *Server) Stop(ctx context.Context) error {
// Close the socket. It notifies the serving loop to stop.
if err := s.killServe(); err != nil {
return err
}
// Wait for the serving loop to actually stop.
select {
case <-s.stopped:
logging.Debugf(ctx, "The local server %q has stopped", s.name)
case <-clock.After(ctx, 10*time.Second):
logging.Errorf(ctx, "Giving up waiting for the local server %q to stop", s.name)
}
return nil
}
// serve runs the serving loop.
//
// It unblocks once killServe is called and all pending requests are served.
//
// Returns nil if serving was stopped by killServe or non-nil if it failed for
// some other reason.
func (s *Server) serve(cb ServeFunc) error {
s.l.Lock()
if s.listener == nil {
s.l.Unlock()
return errors.New("already closed")
}
listener := s.listener // accessed outside the lock
ctx := s.ctx
s.l.Unlock()
err := cb(ctx, listener, &s.wg) // blocks until killServe() is called
s.wg.Wait() // waits for all pending requests
// If it was a planned shutdown with killServe(), ignore the error. It says
// that the listening socket was closed.
s.l.Lock()
if s.listener == nil {
err = nil
}
s.l.Unlock()
if err != nil {
return errors.Annotate(err, "error in the serving loop").Err()
}
return nil
}
// killServe notifies the serving goroutine to stop (if it is running).
func (s *Server) killServe() error {
s.l.Lock()
defer s.l.Unlock()
if s.ctx == nil {
return errors.New("not initialized")
}
// Stop accepting requests, unblocks serve(). Do it only once.
if s.listener != nil {
logging.Debugf(s.ctx, "Stopping the local server %q...", s.name)
if err := s.listener.Close(); err != nil {
logging.WithError(err).Errorf(s.ctx, "Failed to close the listening socket of %q", s.name)
}
s.listener = nil
}
s.cancel() // notify all running handlers to stop
return nil
}