| /* |
| * |
| * Copyright 2018 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| // Package handshaker provides ALTS handshaking functionality for GCP. |
| package handshaker |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "sync" |
| |
| grpc "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/credentials" |
| core "google.golang.org/grpc/credentials/alts/internal" |
| "google.golang.org/grpc/credentials/alts/internal/authinfo" |
| "google.golang.org/grpc/credentials/alts/internal/conn" |
| altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" |
| altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" |
| ) |
| |
| const ( |
| // The maximum byte size of receive frames. |
| frameLimit = 64 * 1024 // 64 KB |
| rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY" |
| // maxPendingHandshakes represents the maximum number of concurrent |
| // handshakes. |
| maxPendingHandshakes = 100 |
| ) |
| |
| var ( |
| hsProtocol = altspb.HandshakeProtocol_ALTS |
| appProtocols = []string{"grpc"} |
| recordProtocols = []string{rekeyRecordProtocolName} |
| keyLength = map[string]int{ |
| rekeyRecordProtocolName: 44, |
| } |
| altsRecordFuncs = map[string]conn.ALTSRecordFunc{ |
| // ALTS handshaker protocols. |
| rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) { |
| return conn.NewAES128GCMRekey(s, keyData) |
| }, |
| } |
| // control number of concurrent created (but not closed) handshakers. |
| mu sync.Mutex |
| concurrentHandshakes = int64(0) |
| // errDropped occurs when maxPendingHandshakes is reached. |
| errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached") |
| ) |
| |
| func init() { |
| for protocol, f := range altsRecordFuncs { |
| if err := conn.RegisterProtocol(protocol, f); err != nil { |
| panic(err) |
| } |
| } |
| } |
| |
| func acquire(n int64) bool { |
| mu.Lock() |
| success := maxPendingHandshakes-concurrentHandshakes >= n |
| if success { |
| concurrentHandshakes += n |
| } |
| mu.Unlock() |
| return success |
| } |
| |
| func release(n int64) { |
| mu.Lock() |
| concurrentHandshakes -= n |
| if concurrentHandshakes < 0 { |
| mu.Unlock() |
| panic("bad release") |
| } |
| mu.Unlock() |
| } |
| |
| // ClientHandshakerOptions contains the client handshaker options that can |
| // provided by the caller. |
| type ClientHandshakerOptions struct { |
| // ClientIdentity is the handshaker client local identity. |
| ClientIdentity *altspb.Identity |
| // TargetName is the server service account name for secure name |
| // checking. |
| TargetName string |
| // TargetServiceAccounts contains a list of expected target service |
| // accounts. One of these accounts should match one of the accounts in |
| // the handshaker results. Otherwise, the handshake fails. |
| TargetServiceAccounts []string |
| // RPCVersions specifies the gRPC versions accepted by the client. |
| RPCVersions *altspb.RpcProtocolVersions |
| } |
| |
| // ServerHandshakerOptions contains the server handshaker options that can |
| // provided by the caller. |
| type ServerHandshakerOptions struct { |
| // RPCVersions specifies the gRPC versions accepted by the server. |
| RPCVersions *altspb.RpcProtocolVersions |
| } |
| |
| // DefaultClientHandshakerOptions returns the default client handshaker options. |
| func DefaultClientHandshakerOptions() *ClientHandshakerOptions { |
| return &ClientHandshakerOptions{} |
| } |
| |
| // DefaultServerHandshakerOptions returns the default client handshaker options. |
| func DefaultServerHandshakerOptions() *ServerHandshakerOptions { |
| return &ServerHandshakerOptions{} |
| } |
| |
| // TODO: add support for future local and remote endpoint in both client options |
| // and server options (server options struct does not exist now. When |
| // caller can provide endpoints, it should be created. |
| |
| // altsHandshaker is used to complete a ALTS handshaking between client and |
| // server. This handshaker talks to the ALTS handshaker service in the metadata |
| // server. |
| type altsHandshaker struct { |
| // RPC stream used to access the ALTS Handshaker service. |
| stream altsgrpc.HandshakerService_DoHandshakeClient |
| // the connection to the peer. |
| conn net.Conn |
| // client handshake options. |
| clientOpts *ClientHandshakerOptions |
| // server handshake options. |
| serverOpts *ServerHandshakerOptions |
| // defines the side doing the handshake, client or server. |
| side core.Side |
| } |
| |
| // NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC |
| // stub created using the passed conn and used to talk to the ALTS Handshaker |
| // service in the metadata server. |
| func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) { |
| stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false)) |
| if err != nil { |
| return nil, err |
| } |
| return &altsHandshaker{ |
| stream: stream, |
| conn: c, |
| clientOpts: opts, |
| side: core.ClientSide, |
| }, nil |
| } |
| |
| // NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC |
| // stub created using the passed conn and used to talk to the ALTS Handshaker |
| // service in the metadata server. |
| func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) { |
| stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false)) |
| if err != nil { |
| return nil, err |
| } |
| return &altsHandshaker{ |
| stream: stream, |
| conn: c, |
| serverOpts: opts, |
| side: core.ServerSide, |
| }, nil |
| } |
| |
| // ClientHandshake starts and completes a client ALTS handshaking for GCP. Once |
| // done, ClientHandshake returns a secure connection. |
| func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { |
| if !acquire(1) { |
| return nil, nil, errDropped |
| } |
| defer release(1) |
| |
| if h.side != core.ClientSide { |
| return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker") |
| } |
| |
| // Create target identities from service account list. |
| targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts)) |
| for _, account := range h.clientOpts.TargetServiceAccounts { |
| targetIdentities = append(targetIdentities, &altspb.Identity{ |
| IdentityOneof: &altspb.Identity_ServiceAccount{ |
| ServiceAccount: account, |
| }, |
| }) |
| } |
| req := &altspb.HandshakerReq{ |
| ReqOneof: &altspb.HandshakerReq_ClientStart{ |
| ClientStart: &altspb.StartClientHandshakeReq{ |
| HandshakeSecurityProtocol: hsProtocol, |
| ApplicationProtocols: appProtocols, |
| RecordProtocols: recordProtocols, |
| TargetIdentities: targetIdentities, |
| LocalIdentity: h.clientOpts.ClientIdentity, |
| TargetName: h.clientOpts.TargetName, |
| RpcVersions: h.clientOpts.RPCVersions, |
| }, |
| }, |
| } |
| |
| conn, result, err := h.doHandshake(req) |
| if err != nil { |
| return nil, nil, err |
| } |
| authInfo := authinfo.New(result) |
| return conn, authInfo, nil |
| } |
| |
| // ServerHandshake starts and completes a server ALTS handshaking for GCP. Once |
| // done, ServerHandshake returns a secure connection. |
| func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { |
| if !acquire(1) { |
| return nil, nil, errDropped |
| } |
| defer release(1) |
| |
| if h.side != core.ServerSide { |
| return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker") |
| } |
| |
| p := make([]byte, frameLimit) |
| n, err := h.conn.Read(p) |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| // Prepare server parameters. |
| // TODO: currently only ALTS parameters are provided. Might need to use |
| // more options in the future. |
| params := make(map[int32]*altspb.ServerHandshakeParameters) |
| params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{ |
| RecordProtocols: recordProtocols, |
| } |
| req := &altspb.HandshakerReq{ |
| ReqOneof: &altspb.HandshakerReq_ServerStart{ |
| ServerStart: &altspb.StartServerHandshakeReq{ |
| ApplicationProtocols: appProtocols, |
| HandshakeParameters: params, |
| InBytes: p[:n], |
| RpcVersions: h.serverOpts.RPCVersions, |
| }, |
| }, |
| } |
| |
| conn, result, err := h.doHandshake(req) |
| if err != nil { |
| return nil, nil, err |
| } |
| authInfo := authinfo.New(result) |
| return conn, authInfo, nil |
| } |
| |
| func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) { |
| resp, err := h.accessHandshakerService(req) |
| if err != nil { |
| return nil, nil, err |
| } |
| // Check of the returned status is an error. |
| if resp.GetStatus() != nil { |
| if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want { |
| return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details) |
| } |
| } |
| |
| var extra []byte |
| if req.GetServerStart() != nil { |
| extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():] |
| } |
| result, extra, err := h.processUntilDone(resp, extra) |
| if err != nil { |
| return nil, nil, err |
| } |
| // The handshaker returns a 128 bytes key. It should be truncated based |
| // on the returned record protocol. |
| keyLen, ok := keyLength[result.RecordProtocol] |
| if !ok { |
| return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol) |
| } |
| sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra) |
| if err != nil { |
| return nil, nil, err |
| } |
| return sc, result, nil |
| } |
| |
| func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) { |
| if err := h.stream.Send(req); err != nil { |
| return nil, err |
| } |
| resp, err := h.stream.Recv() |
| if err != nil { |
| return nil, err |
| } |
| return resp, nil |
| } |
| |
| // processUntilDone processes the handshake until the handshaker service returns |
| // the results. Handshaker service takes care of frame parsing, so we read |
| // whatever received from the network and send it to the handshaker service. |
| func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) { |
| for { |
| if len(resp.OutFrames) > 0 { |
| if _, err := h.conn.Write(resp.OutFrames); err != nil { |
| return nil, nil, err |
| } |
| } |
| if resp.Result != nil { |
| return resp.Result, extra, nil |
| } |
| buf := make([]byte, frameLimit) |
| n, err := h.conn.Read(buf) |
| if err != nil && err != io.EOF { |
| return nil, nil, err |
| } |
| // If there is nothing to send to the handshaker service, and |
| // nothing is received from the peer, then we are stuck. |
| // This covers the case when the peer is not responding. Note |
| // that handshaker service connection issues are caught in |
| // accessHandshakerService before we even get here. |
| if len(resp.OutFrames) == 0 && n == 0 { |
| return nil, nil, core.PeerNotRespondingError |
| } |
| // Append extra bytes from the previous interaction with the |
| // handshaker service with the current buffer read from conn. |
| p := append(extra, buf[:n]...) |
| resp, err = h.accessHandshakerService(&altspb.HandshakerReq{ |
| ReqOneof: &altspb.HandshakerReq_Next{ |
| Next: &altspb.NextHandshakeMessageReq{ |
| InBytes: p, |
| }, |
| }, |
| }) |
| if err != nil { |
| return nil, nil, err |
| } |
| // Set extra based on handshaker service response. |
| if n == 0 { |
| extra = nil |
| } else { |
| extra = buf[resp.GetBytesConsumed():n] |
| } |
| } |
| } |
| |
| // Close terminates the Handshaker. It should be called when the caller obtains |
| // the secure connection. |
| func (h *altsHandshaker) Close() { |
| h.stream.CloseSend() |
| } |