blob: 97b8359cd3a477c5e5646c905255869480d7baf0 [file] [log] [blame]
// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"golang.org/x/net/context/ctxhttp"
"go.chromium.org/luci/common/logging"
"go.chromium.org/luci/common/retry/transient"
)
// ClientFactory knows how to produce http.Client that attach proper OAuth
// headers.
//
// If 'scopes' is empty, the factory should return a client that makes anonymous
// requests.
type ClientFactory func(ctx context.Context, scopes []string) (*http.Client, error)
var clientFactory ClientFactory
// RegisterClientFactory allows external module to provide implementation of
// the ClientFactory.
//
// This is needed to resolve module dependency cycle between server/auth and
// server/auth/internal.
//
// See init() in server/auth/client.go.
//
// If client factory is not set, Do(...) uses http.DefaultClient. This happens
// in unit tests for various auth/* subpackages.
func RegisterClientFactory(f ClientFactory) {
if clientFactory != nil {
panic("ClientFactory is already registered")
}
clientFactory = f
}
// Request represents one JSON REST API request.
type Request struct {
Method string // HTTP method to use
URL string // URL to access
Scopes []string // OAuth2 scopes to authenticate with or anonymous call if empty
Headers map[string]string // optional map with request headers
Body any // object to convert to JSON and send as body or []byte with the body
Out any // where to deserialize the response to
}
// Do performs an HTTP request with retries on transient errors.
//
// It can be used to make GET or DELETE requests (if Body is nil) or POST or PUT
// requests (if Body is not nil). In latter case the body will be serialized to
// JSON.
//
// Respects context's deadline and cancellation.
func (r *Request) Do(ctx context.Context) error {
// Grab a client first. Use same client for all retry attempts.
var client *http.Client
if clientFactory != nil {
var err error
if client, err = clientFactory(ctx, r.Scopes); err != nil {
return err
}
} else {
client = http.DefaultClient
if testTransport := ctx.Value(&testTransportKey); testTransport != nil {
client = &http.Client{Transport: testTransport.(http.RoundTripper)}
}
}
// Prepare a blob with the request body. Marshal it once, to avoid
// remarshaling on retries.
isJSON := false
var bodyBlob []byte
if blob, ok := r.Body.([]byte); ok {
bodyBlob = blob
} else if r.Body != nil {
var err error
if bodyBlob, err = json.Marshal(r.Body); err != nil {
return err
}
isJSON = true
}
return fetchJSON(ctx, client, r.Out, func() (*http.Request, error) {
req, err := http.NewRequest(r.Method, r.URL, bytes.NewReader(bodyBlob))
if err != nil {
return nil, err
}
for k, v := range r.Headers {
req.Header.Set(k, v)
}
if isJSON {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
})
}
// TODO(vadimsh): Add retries on HTTP 500.
// fetchJSON fetches JSON document by making a request using given client.
func fetchJSON(ctx context.Context, client *http.Client, val any, f func() (*http.Request, error)) error {
r, err := f()
if err != nil {
logging.Errorf(ctx, "auth: URL fetch failed - %s", err)
return err
}
logging.Infof(ctx, "auth: %s %s", r.Method, r.URL)
resp, err := ctxhttp.Do(ctx, client, r)
if err != nil {
logging.Errorf(ctx, "auth: URL fetch failed, can't connect - %s", err)
return transient.Tag.Apply(err)
}
defer func() {
io.ReadAll(resp.Body)
resp.Body.Close()
}()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
// Opportunistically try to unmarshal the response. Works with JSON APIs.
if val != nil {
json.Unmarshal(body, val)
}
logging.Errorf(ctx, "auth: URL fetch failed - HTTP %d - %s", resp.StatusCode, string(body))
err := fmt.Errorf("auth: HTTP code (%d) when fetching %s", resp.StatusCode, r.URL)
if resp.StatusCode >= 500 {
return transient.Tag.Apply(err)
}
return err
}
if val != nil {
if err = json.NewDecoder(resp.Body).Decode(val); err != nil {
logging.Errorf(ctx, "auth: URL fetch failed, bad JSON - %s", err)
return fmt.Errorf("auth: can't deserialize JSON at %q - %s", r.URL, err)
}
}
return nil
}