| // Copyright 2017 The LUCI Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package localauth |
| |
| import ( |
| "context" |
| "crypto/subtle" |
| "encoding/json" |
| "fmt" |
| "io" |
| "mime" |
| "net" |
| "net/http" |
| "regexp" |
| "sort" |
| "sync" |
| "time" |
| |
| "golang.org/x/oauth2" |
| |
| "go.chromium.org/luci/auth" |
| "go.chromium.org/luci/auth/integration/localauth/rpcs" |
| "go.chromium.org/luci/common/data/rand/cryptorand" |
| "go.chromium.org/luci/common/data/stringset" |
| "go.chromium.org/luci/common/errors" |
| "go.chromium.org/luci/common/logging" |
| "go.chromium.org/luci/common/retry/transient" |
| "go.chromium.org/luci/common/runtime/paniccatcher" |
| "go.chromium.org/luci/lucictx" |
| |
| "go.chromium.org/luci/auth/integration/internal/localsrv" |
| ) |
| |
| // TokenGenerator produces access tokens. |
| type TokenGenerator interface { |
| // GenerateToken returns an access token for a combination of scopes (given as |
| // a sorted list of strings without duplicates). |
| // |
| // It is called for each request to the local auth server. It may be called |
| // concurrently from multiple goroutines and must implement its own caching |
| // and synchronization if necessary. |
| // |
| // It is expected that the returned token lives for at least given 'lifetime' |
| // duration (which is typically on order of minutes), but it may live longer. |
| // Clients may cache the returned token for the duration of its lifetime. |
| // |
| // May return transient errors (in transient.Tag.In(err) returning true |
| // sense). Such errors result in HTTP 500 responses. This is appropriate for |
| // non-fatal errors. Clients may immediately retry requests on such errors. |
| // |
| // Any non-transient error is considered fatal and results in an RPC-level |
| // error response ({"error": ...}). Clients must treat such responses as fatal |
| // and don't retry requests. |
| // |
| // If the error implements ErrorWithCode interface, the error code returned to |
| // clients will be grabbed from the error object, otherwise the error code is |
| // set to -1. |
| GenerateToken(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) |
| |
| // GetEmail returns an email associated with all tokens produced by this |
| // generator or auth.ErrNoEmail if it's not available. |
| // |
| // Any other error will bubble up through Server.Start. |
| GetEmail() (string, error) |
| } |
| |
| // ErrorWithCode is a fatal error that also has a numeric code. |
| // |
| // May be returned by TokenGenerator to trigger a response with some specific |
| // error code. |
| type ErrorWithCode interface { |
| error |
| |
| // Code returns a code to put into RPC response alongside the error message. |
| Code() int |
| } |
| |
| // Server runs a local RPC server that hands out access tokens. |
| // |
| // Processes that need a token can discover location of this server by looking |
| // at "local_auth" section of LUCI_CONTEXT. |
| type Server struct { |
| // TokenGenerators produce access tokens for given account IDs. |
| TokenGenerators map[string]TokenGenerator |
| |
| // DefaultAccountID is account ID subprocesses should pick by default. |
| // |
| // It is put into "local_auth" section of LUCI_CONTEXT. If empty string, |
| // subprocesses won't attempt to use any account by default (they still can |
| // pick some non-default account though). |
| DefaultAccountID string |
| |
| // Port is a local TCP port to bind to or 0 to allow the OS to pick one. |
| Port int |
| |
| srv localsrv.Server |
| |
| testingServeHook func() // called right before serving |
| } |
| |
| // Start launches background goroutine with the serving loop. |
| // |
| // The provided context is used as base context for request handlers and for |
| // logging. |
| // |
| // Returns a copy of lucictx.LocalAuth structure that specifies how to contact |
| // the server. It should be put into "local_auth" section of LUCI_CONTEXT where |
| // clients can discover it. |
| // |
| // The server must be eventually stopped with Stop(). |
| func (s *Server) Start(ctx context.Context) (*lucictx.LocalAuth, error) { |
| la, err := s.initLocalAuth(ctx) |
| if err != nil { |
| return nil, errors.Annotate(err, "failed to initialize LocalAuth").Err() |
| } |
| |
| addr, err := s.srv.Start(ctx, "local_auth", s.Port, func(c context.Context, l net.Listener, wg *sync.WaitGroup) error { |
| return s.serve(c, l, wg, la.Secret) |
| }) |
| if err != nil { |
| return nil, errors.Annotate(err, "failed to start the local server").Err() |
| } |
| |
| la.RpcPort = uint32(addr.Port) |
| return la, nil |
| } |
| |
| // Stop closes the listening socket, notifies pending requests to abort and |
| // stops the internal serving goroutine. |
| // |
| // Safe to call multiple times. Once stopped, the server cannot be started again |
| // (make a new instance of Server instead). |
| // |
| // Uses the given context for the deadline when waiting for the serving loop |
| // to stop. |
| func (s *Server) Stop(ctx context.Context) error { |
| return s.srv.Stop(ctx) |
| } |
| |
| // initLocalAuth generates new LocalAuth struct with RPC port blank. |
| func (s *Server) initLocalAuth(ctx context.Context) (*lucictx.LocalAuth, error) { |
| // Build a sorted list of LocalAuthAccount to put into the context, grab |
| // emails from the generators. |
| ids := make([]string, 0, len(s.TokenGenerators)) |
| for id := range s.TokenGenerators { |
| ids = append(ids, id) |
| } |
| sort.Strings(ids) |
| accounts := make([]*lucictx.LocalAuthAccount, len(ids)) |
| for i, id := range ids { |
| email, err := s.TokenGenerators[id].GetEmail() |
| switch { |
| case err == auth.ErrNoEmail: |
| email = "-" |
| case err != nil: |
| return nil, errors.Annotate(err, "could not grab email of account %q", id).Err() |
| } |
| accounts[i] = &lucictx.LocalAuthAccount{Id: id, Email: email} |
| } |
| |
| secret := make([]byte, 48) |
| if _, err := cryptorand.Read(ctx, secret); err != nil { |
| return nil, err |
| } |
| |
| return &lucictx.LocalAuth{ |
| Secret: secret, |
| Accounts: accounts, |
| DefaultAccountId: s.DefaultAccountID, |
| }, nil |
| } |
| |
| // serve runs the serving loop. |
| func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup, secret []byte) error { |
| if s.testingServeHook != nil { |
| s.testingServeHook() |
| } |
| srv := http.Server{ |
| Handler: &protocolHandler{ |
| ctx: ctx, |
| wg: wg, |
| secret: secret, |
| tokens: s.TokenGenerators, |
| }, |
| } |
| return srv.Serve(l) |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Protocol implementation. |
| |
| // methodRe defines an URL of RPC method handler. |
| var methodRe = regexp.MustCompile(`^/rpc/LuciLocalAuthService\.([a-zA-Z0-9_]+)$`) |
| |
| // minTokenLifetime is a lifetime of tokens requested through TokenGenerator. |
| // |
| // Must be larger than 'minAcceptedLifetime' in the auth package, or weird |
| // things may happen if local_auth server is used as a basis for some |
| // auth.Authenticator. |
| const minTokenLifetime = 3 * time.Minute |
| |
| // handle is called by http.Server in a separate goroutine to handle a request. |
| // |
| // It implements the server side of local_auth RPC protocol: |
| // * Each request is POST to /rpc/LuciLocalAuthService.<Method> |
| // * Request content type is "application/json; ...". |
| // * The sender must set Content-Length header. |
| // * Response content type is also "application/json". |
| // * The server sets Content-Length header in the response. |
| // * Protocol-level errors have non-200 HTTP status code. |
| // * Logic errors have 200 HTTP status code and error is communicated in |
| // the response body. |
| // |
| // The only supported method currently is 'GetOAuthToken': |
| // |
| // Request body: |
| // { |
| // "scopes": [<string scope1>, <string scope2>, ...], |
| // "secret": <string from LUCI_CONTEXT.local_auth.secret>, |
| // "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts> |
| // } |
| // Response body: |
| // { |
| // "error_code": <int, on success not set or 0>, |
| // "error_message": <string, on success not set>, |
| // "access_token": <string with actual token (on success)>, |
| // "expiry": <int with unix timestamp in seconds (on success)> |
| // } |
| // |
| // See also python counterpart of this code: |
| // https://chromium.googlesource.com/infra/luci/luci-py/+/HEAD/client/utils/auth_server.py |
| type protocolHandler struct { |
| ctx context.Context // the parent context |
| wg *sync.WaitGroup // used for graceful shutdown |
| secret []byte // expected "secret" value |
| tokens map[string]TokenGenerator // the actual producer of tokens (per account) |
| } |
| |
| // protocolError triggers an HTTP reply with some non-200 status code. |
| type protocolError struct { |
| Status int // HTTP status to set |
| Message string // the message to put in the body |
| } |
| |
| func (e *protocolError) Error() string { |
| return fmt.Sprintf("%s (HTTP %d)", e.Message, e.Status) |
| } |
| |
| // ServeHTTP implements the protocol marshaling logic. |
| func (h *protocolHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { |
| h.wg.Add(1) |
| defer h.wg.Done() |
| |
| defer paniccatcher.Catch(func(p *paniccatcher.Panic) { |
| logging.Fields{ |
| "panic.error": p.Reason, |
| }.Errorf(h.ctx, "Caught panic during handling of %q: %s\n%s", r.RequestURI, p.Reason, p.Stack) |
| http.Error(rw, "Internal Server Error. See logs.", http.StatusInternalServerError) |
| }) |
| |
| logging.Debugf(h.ctx, "Handling %s %s", r.Method, r.RequestURI) |
| |
| if r.Method != "POST" { |
| http.Error(rw, "Expecting POST", http.StatusMethodNotAllowed) |
| return |
| } |
| |
| // Grab <method> from /rpc/LuciLocalAuthService.<method>. |
| matches := methodRe.FindStringSubmatch(r.RequestURI) |
| if len(matches) != 2 { |
| http.Error(rw, "Expecting /rpc/LuciLocalAuthService.<method>", http.StatusNotFound) |
| return |
| } |
| method := matches[1] |
| |
| // The content type must be JSON, which is also the default. |
| if ct := r.Header.Get("Content-Type"); ct != "" { |
| baseType, _, err := mime.ParseMediaType(ct) |
| if err != nil { |
| http.Error(rw, fmt.Sprintf("Can't parse Content-Type: %s", err), http.StatusBadRequest) |
| return |
| } |
| if baseType != "application/json" { |
| http.Error(rw, "Expecting 'application/json' Content-Type", http.StatusBadRequest) |
| return |
| } |
| } |
| |
| // The content length must be given and be small enough. |
| if r.ContentLength < 0 || r.ContentLength >= 64*1024 { |
| http.Error(rw, "Expecting 'Content-Length' header, <64Kb", http.StatusBadRequest) |
| return |
| } |
| |
| // Slurp the body, it's easier to deal with []byte going forward. The body is |
| // tiny anyway. |
| request := make([]byte, r.ContentLength) |
| if _, err := io.ReadFull(r.Body, request); err != nil { |
| http.Error(rw, "Can't read the request body", http.StatusBadGateway) |
| return |
| } |
| |
| // Route to the appropriate RPC handler. |
| response, err := h.routeToImpl(method, request) |
| |
| // *protocolError are sent as HTTP errors. |
| if pErr, _ := err.(*protocolError); pErr != nil { |
| http.Error(rw, pErr.Message, pErr.Status) |
| return |
| } |
| |
| // Transient errors are returned as HTTP 500 responses. |
| if transient.Tag.In(err) { |
| http.Error(rw, fmt.Sprintf("Transient error - %s", err), http.StatusInternalServerError) |
| return |
| } |
| |
| // Fatal errors are returned as specially structured JSON responses with |
| // HTTP 200 code. Replace 'response' with it. |
| if err != nil { |
| fatalError := rpcs.BaseResponse{ |
| ErrorCode: -1, |
| ErrorMessage: err.Error(), |
| } |
| if withCode, ok := err.(ErrorWithCode); ok && withCode.Code() != 0 { |
| fatalError.ErrorCode = withCode.Code() |
| } |
| response = &fatalError |
| } |
| |
| // Serialize the response to grab its length. |
| blob, err := json.Marshal(response) |
| if err != nil { |
| http.Error(rw, fmt.Sprintf("Failed to serialize the response - %s", err), http.StatusInternalServerError) |
| return |
| } |
| blob = append(blob, '\n') // for curl's sake |
| |
| // Finally write the response. |
| rw.Header().Set("Content-Type", "application/json; charset=utf-8") |
| rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(blob))) |
| rw.WriteHeader(http.StatusOK) |
| if _, err := rw.Write(blob); err != nil { |
| logging.WithError(err).Warningf(h.ctx, "Failed to write the response") |
| } |
| } |
| |
| // routeToImpl calls appropriate RPC method implementation. |
| func (h *protocolHandler) routeToImpl(method string, request []byte) (interface{}, error) { |
| switch method { |
| case "GetOAuthToken": |
| req := &rpcs.GetOAuthTokenRequest{} |
| if err := json.Unmarshal(request, req); err != nil { |
| return nil, &protocolError{ |
| Status: http.StatusBadRequest, |
| Message: fmt.Sprintf("Not JSON body - %s", err), |
| } |
| } |
| return h.handleGetOAuthToken(req) |
| default: |
| return nil, &protocolError{ |
| Status: http.StatusNotFound, |
| Message: fmt.Sprintf("Unknown RPC method %q", method), |
| } |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // RPC implementations. |
| |
| func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*rpcs.GetOAuthTokenResponse, error) { |
| // Validate the request. |
| if err := req.Validate(); err != nil { |
| return nil, &protocolError{ |
| Status: 400, |
| Message: fmt.Sprintf("Bad request: %s.", err.Error()), |
| } |
| } |
| if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 { |
| return nil, &protocolError{ |
| Status: 403, |
| Message: "Invalid secret.", |
| } |
| } |
| generator := h.tokens[req.AccountID] |
| if generator == nil { |
| return nil, &protocolError{ |
| Status: 404, |
| Message: fmt.Sprintf("Unrecognized account ID %q.", req.AccountID), |
| } |
| } |
| |
| // Dedup and sort scopes. |
| scopes := stringset.New(len(req.Scopes)) |
| for _, s := range req.Scopes { |
| scopes.Add(s) |
| } |
| sortedScopes := scopes.ToSlice() |
| sort.Strings(sortedScopes) |
| |
| // Ask the token provider for the token. This may produce ErrorWithCode. |
| tok, err := generator.GenerateToken(h.ctx, sortedScopes, minTokenLifetime) |
| if err != nil { |
| return nil, err |
| } |
| return &rpcs.GetOAuthTokenResponse{ |
| AccessToken: tok.AccessToken, |
| Expiry: tok.Expiry.Unix(), |
| }, nil |
| } |