blob: 0c4b06d75ec040f861105ddebe551b2af00f79ef [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prpc
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"reflect"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/errors"
luciproto "go.chromium.org/luci/common/proto"
"go.chromium.org/luci/grpc/grpcutil"
)
// This file implements decoding of HTTP requests to RPC parameters.
const headerContentType = "Content-Type"
// readMessage decodes a protobuf message from an HTTP request.
//
// Does not close the request body.
//
// fixFieldMasksForJSON indicates whether to attempt a workaround for
// https://github.com/golang/protobuf/issues/745 for requests with FormatJSONPB.
// TODO(crbug/1082369): Remove this workaround once field masks can be decoded.
func readMessage(r *http.Request, msg proto.Message, fixFieldMasksForJSON bool) *protocolError {
format, err := FormatFromContentType(r.Header.Get(headerContentType))
if err != nil {
// Spec: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.16
return protocolErr(
codes.InvalidArgument,
http.StatusUnsupportedMediaType,
"bad Content-Type header: %s", err,
)
}
var buf []byte
if r.Header.Get("Content-Encoding") == "gzip" {
reader, err := getGZipReader(r.Body)
if err != nil {
return requestReadErr(err, "failed to start decompressing gzip request body")
}
buf, err = io.ReadAll(reader)
if err == nil {
err = reader.Close() // this just checks the checksum
}
returnGZipReader(reader)
if err != nil {
return requestReadErr(err, "could not read or decompress request body")
}
} else {
buf, err = io.ReadAll(r.Body)
if err != nil {
return requestReadErr(err, "could not read request body")
}
}
switch format {
// Do not redefine "err" below.
case FormatJSONPB:
if fixFieldMasksForJSON {
t := reflect.TypeOf(msg)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
buf, err = luciproto.FixFieldMasksBeforeUnmarshal(buf, t)
}
if err == nil {
err = jsonpb.Unmarshal(bytes.NewBuffer(buf), msg)
}
case FormatText:
err = proto.UnmarshalText(string(buf), msg)
case FormatBinary:
err = proto.Unmarshal(buf, msg)
default:
panic(fmt.Errorf("impossible: invalid format %v", format))
}
if err != nil {
return protocolErr(
codes.InvalidArgument,
http.StatusBadRequest,
"could not decode body: %s", err,
)
}
return nil
}
// requestReadErr interprets an error from reading and unzipping of a request.
func requestReadErr(err error, msg string) *protocolError {
code := codes.InvalidArgument
switch errors.Unwrap(err) {
case io.ErrUnexpectedEOF, context.Canceled:
code = codes.Canceled
case context.DeadlineExceeded:
code = codes.DeadlineExceeded
}
return protocolErr(code, grpcutil.CodeStatus(code), "%s: %s", msg, err)
}
// parseHeader parses HTTP headers and derives a new context.
//
// Recognizes "X-Prpc-Grpc-Timeout" header. Other recognized reserved headers
// are skipped. If there are unrecognized HTTP headers they are added to
// a metadata.MD and a new context is derived. If ctx already has metadata,
// the latter is copied.
//
// If host is not empty, sets "host" and ":authority" metadata.
func parseHeader(ctx context.Context, header http.Header, host string) (context.Context, context.CancelFunc, error) {
// Parse headers into a metadata map. This skips all reserved headers to avoid
// leaking pRPC protocol implementation details to gRPC servers.
parsed, err := headersIntoMetadata(header)
if err != nil {
return nil, nil, err
}
// Modify metadata.MD in the context only if we have something to add.
if parsed != nil || host != "" {
existing, _ := metadata.FromIncomingContext(ctx)
merged := metadata.Join(existing, parsed)
if host != "" {
merged["host"] = []string{host}
merged[":authority"] = []string{host}
}
ctx = metadata.NewIncomingContext(ctx, merged)
}
// Parse the reserved timeout header, if any. Apply to the context.
if timeoutStr := header.Get(HeaderTimeout); timeoutStr != "" {
timeout, err := DecodeTimeout(timeoutStr)
if err != nil {
return nil, nil, errors.Annotate(err, "%q header", HeaderTimeout).Err()
}
ctx, cancel := clock.WithTimeout(ctx, timeout)
return ctx, cancel, nil
}
return ctx, func() {}, nil
}