blob: dcb03168651e7bd45c7048bc49c19268a19dbba2 [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prpc
import (
"context"
"io"
"math"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/errors"
"go.chromium.org/luci/grpc/grpcutil"
)
// This file implements decoding of HTTP requests to RPC parameters.
const headerContentType = "Content-Type"
// decompressionLimitErr is returned if readMessage decompressed too much data.
var decompressionLimitErr = errors.Reason("the decompressed request size exceeds the server limit").Err()
// readMessage decodes a protobuf message from an HTTP request.
//
// Uses given headers (together with `codec` callback) to decide how to
// uncompress and deserialize the message. When decompressing, makes sure to
// decompress no more than `maxDecompressedBytes`. Assumes `body` is already a
// limited reader or it doesn't exceed a limit (see Server.maxBytesReader).
//
// `codec` defines what codec exactly to use to deserialize messages in
// particular format. This callback is implemented by the server based on its
// configuration.
func readMessage(body io.Reader, header http.Header, msg proto.Message, maxDecompressedBytes int64, codec func(Format) protoCodec) error {
format, err := FormatFromContentType(header.Get(headerContentType))
if err != nil {
// Spec: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.16
return protocolErr(
codes.InvalidArgument,
http.StatusUnsupportedMediaType,
"bad Content-Type header: %s", err,
)
}
var buf []byte
if header.Get("Content-Encoding") == "gzip" {
reader, err := getGZipReader(body)
if err != nil {
return requestReadErr(err, "decompressing the request")
}
if maxDecompressedBytes == math.MaxInt64 {
// No limit. Just read until the EOF or OOM.
buf, err = io.ReadAll(reader)
} else {
// Read one extra byte. This is how we'll know that we read past the
// limit, because LimitReader uses io.EOF to indicate the limit is
// reached, which is indistinguishable from just truncating the original
// reader.
buf, err = io.ReadAll(io.LimitReader(reader, maxDecompressedBytes+1))
if err == nil {
if int64(len(buf)) > maxDecompressedBytes {
err = decompressionLimitErr
}
}
}
if err == nil {
err = reader.Close() // this just checks the crc32 checksum
} else {
_ = reader.Close() // this will be some bogus error since we abandoned the reader
}
returnGZipReader(reader)
if err != nil {
return requestReadErr(err, "decompressing the request")
}
} else {
buf, err = io.ReadAll(body)
if err != nil {
return requestReadErr(err, "reading the request")
}
}
err = codec(format).Decode(buf, msg)
if err != nil {
return protocolErr(
codes.InvalidArgument,
http.StatusBadRequest,
"could not decode body: %s", err,
)
}
return nil
}
// requestReadErr interprets an error from reading and unzipping of a request.
func requestReadErr(err error, msg string) *protocolError {
var maxBytesErr *http.MaxBytesError
code := codes.InvalidArgument
errMsg := err.Error()
switch {
case errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, context.Canceled):
code = codes.Canceled
case errors.Is(err, context.DeadlineExceeded):
code = codes.DeadlineExceeded
case errors.Is(err, decompressionLimitErr):
code = codes.Unavailable
case errors.As(err, &maxBytesErr):
code = codes.Unavailable
errMsg = "the request size exceeds the server limit"
}
return protocolErr(code, grpcutil.CodeStatus(code), "%s: %s", msg, errMsg)
}
// parseHeader parses HTTP headers and derives a new context.
//
// Recognizes "X-Prpc-Grpc-Timeout" header. Other recognized reserved headers
// are skipped. If there are unrecognized HTTP headers they are added to
// a metadata.MD and a new context is derived. If ctx already has metadata,
// the latter is copied.
//
// If host is not empty, sets "host" and ":authority" metadata.
func parseHeader(ctx context.Context, header http.Header, host string) (context.Context, context.CancelFunc, error) {
// Parse headers into a metadata map. This skips all reserved headers to avoid
// leaking pRPC protocol implementation details to gRPC servers.
parsed, err := headersIntoMetadata(header)
if err != nil {
return nil, nil, err
}
// Modify metadata.MD in the context only if we have something to add.
if parsed != nil || host != "" {
existing, _ := metadata.FromIncomingContext(ctx)
merged := metadata.Join(existing, parsed)
if host != "" {
merged["host"] = []string{host}
merged[":authority"] = []string{host}
}
ctx = metadata.NewIncomingContext(ctx, merged)
}
// Parse the reserved timeout header, if any. Apply to the context.
if timeoutStr := header.Get(HeaderTimeout); timeoutStr != "" {
timeout, err := DecodeTimeout(timeoutStr)
if err != nil {
return nil, nil, errors.Annotate(err, "%q header", HeaderTimeout).Err()
}
ctx, cancel := clock.WithTimeout(ctx, timeout)
return ctx, cancel, nil
}
return ctx, func() {}, nil
}