| /* |
| * |
| * Copyright 2021 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| // Package s2a provides the S2A transport credentials used by a gRPC |
| // application. |
| package s2a |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "errors" |
| "fmt" |
| "net" |
| "sync" |
| "time" |
| |
| "github.com/golang/protobuf/proto" |
| "github.com/google/s2a-go/fallback" |
| "github.com/google/s2a-go/internal/handshaker" |
| "github.com/google/s2a-go/internal/handshaker/service" |
| "github.com/google/s2a-go/internal/tokenmanager" |
| "github.com/google/s2a-go/internal/v2" |
| "github.com/google/s2a-go/retry" |
| "google.golang.org/grpc/credentials" |
| "google.golang.org/grpc/grpclog" |
| |
| commonpb "github.com/google/s2a-go/internal/proto/common_go_proto" |
| s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto" |
| ) |
| |
| const ( |
| s2aSecurityProtocol = "tls" |
| // defaultTimeout specifies the default server handshake timeout. |
| defaultTimeout = 30.0 * time.Second |
| ) |
| |
| // s2aTransportCreds are the transport credentials required for establishing |
| // a secure connection using the S2A. They implement the |
| // credentials.TransportCredentials interface. |
| type s2aTransportCreds struct { |
| info *credentials.ProtocolInfo |
| minTLSVersion commonpb.TLSVersion |
| maxTLSVersion commonpb.TLSVersion |
| // tlsCiphersuites contains the ciphersuites used in the S2A connection. |
| // Note that these are currently unconfigurable. |
| tlsCiphersuites []commonpb.Ciphersuite |
| // localIdentity should only be used by the client. |
| localIdentity *commonpb.Identity |
| // localIdentities should only be used by the server. |
| localIdentities []*commonpb.Identity |
| // targetIdentities should only be used by the client. |
| targetIdentities []*commonpb.Identity |
| isClient bool |
| s2aAddr string |
| ensureProcessSessionTickets *sync.WaitGroup |
| } |
| |
| // NewClientCreds returns a client-side transport credentials object that uses |
| // the S2A to establish a secure connection with a server. |
| func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) { |
| if opts == nil { |
| return nil, errors.New("nil client options") |
| } |
| var targetIdentities []*commonpb.Identity |
| for _, targetIdentity := range opts.TargetIdentities { |
| protoTargetIdentity, err := toProtoIdentity(targetIdentity) |
| if err != nil { |
| return nil, err |
| } |
| targetIdentities = append(targetIdentities, protoTargetIdentity) |
| } |
| localIdentity, err := toProtoIdentity(opts.LocalIdentity) |
| if err != nil { |
| return nil, err |
| } |
| if opts.EnableLegacyMode { |
| return &s2aTransportCreds{ |
| info: &credentials.ProtocolInfo{ |
| SecurityProtocol: s2aSecurityProtocol, |
| }, |
| minTLSVersion: commonpb.TLSVersion_TLS1_3, |
| maxTLSVersion: commonpb.TLSVersion_TLS1_3, |
| tlsCiphersuites: []commonpb.Ciphersuite{ |
| commonpb.Ciphersuite_AES_128_GCM_SHA256, |
| commonpb.Ciphersuite_AES_256_GCM_SHA384, |
| commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256, |
| }, |
| localIdentity: localIdentity, |
| targetIdentities: targetIdentities, |
| isClient: true, |
| s2aAddr: opts.S2AAddress, |
| ensureProcessSessionTickets: opts.EnsureProcessSessionTickets, |
| }, nil |
| } |
| verificationMode := getVerificationMode(opts.VerificationMode) |
| var fallbackFunc fallback.ClientHandshake |
| if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil { |
| fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc |
| } |
| return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy) |
| } |
| |
| // NewServerCreds returns a server-side transport credentials object that uses |
| // the S2A to establish a secure connection with a client. |
| func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) { |
| if opts == nil { |
| return nil, errors.New("nil server options") |
| } |
| var localIdentities []*commonpb.Identity |
| for _, localIdentity := range opts.LocalIdentities { |
| protoLocalIdentity, err := toProtoIdentity(localIdentity) |
| if err != nil { |
| return nil, err |
| } |
| localIdentities = append(localIdentities, protoLocalIdentity) |
| } |
| if opts.EnableLegacyMode { |
| return &s2aTransportCreds{ |
| info: &credentials.ProtocolInfo{ |
| SecurityProtocol: s2aSecurityProtocol, |
| }, |
| minTLSVersion: commonpb.TLSVersion_TLS1_3, |
| maxTLSVersion: commonpb.TLSVersion_TLS1_3, |
| tlsCiphersuites: []commonpb.Ciphersuite{ |
| commonpb.Ciphersuite_AES_128_GCM_SHA256, |
| commonpb.Ciphersuite_AES_256_GCM_SHA384, |
| commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256, |
| }, |
| localIdentities: localIdentities, |
| isClient: false, |
| s2aAddr: opts.S2AAddress, |
| }, nil |
| } |
| verificationMode := getVerificationMode(opts.VerificationMode) |
| return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, localIdentities, verificationMode, opts.getS2AStream) |
| } |
| |
| // ClientHandshake initiates a client-side TLS handshake using the S2A. |
| func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |
| if !c.isClient { |
| return nil, nil, errors.New("client handshake called using server transport credentials") |
| } |
| |
| var cancel context.CancelFunc |
| ctx, cancel = context.WithCancel(ctx) |
| defer cancel() |
| |
| // Connect to the S2A. |
| hsConn, err := service.Dial(ctx, c.s2aAddr, nil) |
| if err != nil { |
| grpclog.Infof("Failed to connect to S2A: %v", err) |
| return nil, nil, err |
| } |
| |
| opts := &handshaker.ClientHandshakerOptions{ |
| MinTLSVersion: c.minTLSVersion, |
| MaxTLSVersion: c.maxTLSVersion, |
| TLSCiphersuites: c.tlsCiphersuites, |
| TargetIdentities: c.targetIdentities, |
| LocalIdentity: c.localIdentity, |
| TargetName: serverAuthority, |
| EnsureProcessSessionTickets: c.ensureProcessSessionTickets, |
| } |
| chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts) |
| if err != nil { |
| grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err) |
| return nil, nil, err |
| } |
| defer func() { |
| if err != nil { |
| if closeErr := chs.Close(); closeErr != nil { |
| grpclog.Infof("Close failed unexpectedly: %v", err) |
| err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr) |
| } |
| } |
| }() |
| |
| secConn, authInfo, err := chs.ClientHandshake(context.Background()) |
| if err != nil { |
| grpclog.Infof("Handshake failed: %v", err) |
| return nil, nil, err |
| } |
| return secConn, authInfo, nil |
| } |
| |
| // ServerHandshake initiates a server-side TLS handshake using the S2A. |
| func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |
| if c.isClient { |
| return nil, nil, errors.New("server handshake called using client transport credentials") |
| } |
| |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) |
| defer cancel() |
| |
| // Connect to the S2A. |
| hsConn, err := service.Dial(ctx, c.s2aAddr, nil) |
| if err != nil { |
| grpclog.Infof("Failed to connect to S2A: %v", err) |
| return nil, nil, err |
| } |
| |
| opts := &handshaker.ServerHandshakerOptions{ |
| MinTLSVersion: c.minTLSVersion, |
| MaxTLSVersion: c.maxTLSVersion, |
| TLSCiphersuites: c.tlsCiphersuites, |
| LocalIdentities: c.localIdentities, |
| } |
| shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts) |
| if err != nil { |
| grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err) |
| return nil, nil, err |
| } |
| defer func() { |
| if err != nil { |
| if closeErr := shs.Close(); closeErr != nil { |
| grpclog.Infof("Close failed unexpectedly: %v", err) |
| err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr) |
| } |
| } |
| }() |
| |
| secConn, authInfo, err := shs.ServerHandshake(context.Background()) |
| if err != nil { |
| grpclog.Infof("Handshake failed: %v", err) |
| return nil, nil, err |
| } |
| return secConn, authInfo, nil |
| } |
| |
| func (c *s2aTransportCreds) Info() credentials.ProtocolInfo { |
| return *c.info |
| } |
| |
| func (c *s2aTransportCreds) Clone() credentials.TransportCredentials { |
| info := *c.info |
| var localIdentity *commonpb.Identity |
| if c.localIdentity != nil { |
| localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity) |
| } |
| var localIdentities []*commonpb.Identity |
| if c.localIdentities != nil { |
| localIdentities = make([]*commonpb.Identity, len(c.localIdentities)) |
| for i, localIdentity := range c.localIdentities { |
| localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity) |
| } |
| } |
| var targetIdentities []*commonpb.Identity |
| if c.targetIdentities != nil { |
| targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities)) |
| for i, targetIdentity := range c.targetIdentities { |
| targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity) |
| } |
| } |
| return &s2aTransportCreds{ |
| info: &info, |
| minTLSVersion: c.minTLSVersion, |
| maxTLSVersion: c.maxTLSVersion, |
| tlsCiphersuites: c.tlsCiphersuites, |
| localIdentity: localIdentity, |
| localIdentities: localIdentities, |
| targetIdentities: targetIdentities, |
| isClient: c.isClient, |
| s2aAddr: c.s2aAddr, |
| } |
| } |
| |
| func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error { |
| c.info.ServerName = serverNameOverride |
| return nil |
| } |
| |
| // TLSClientConfigOptions specifies parameters for creating client TLS config. |
| type TLSClientConfigOptions struct { |
| // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate. |
| // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{ |
| // ServerName: "example.com", |
| // }) |
| ServerName string |
| } |
| |
| // TLSClientConfigFactory defines the interface for a client TLS config factory. |
| type TLSClientConfigFactory interface { |
| Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) |
| } |
| |
| // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory. |
| func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) { |
| if opts == nil { |
| return nil, fmt.Errorf("opts must be non-nil") |
| } |
| if opts.EnableLegacyMode { |
| return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2") |
| } |
| tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager() |
| if err != nil { |
| // The only possible error is: access token not set in the environment, |
| // which is okay in environments other than serverless. |
| grpclog.Infof("Access token manager not initialized: %v", err) |
| return &s2aTLSClientConfigFactory{ |
| s2av2Address: opts.S2AAddress, |
| transportCreds: opts.TransportCreds, |
| tokenManager: nil, |
| verificationMode: getVerificationMode(opts.VerificationMode), |
| serverAuthorizationPolicy: opts.serverAuthorizationPolicy, |
| }, nil |
| } |
| return &s2aTLSClientConfigFactory{ |
| s2av2Address: opts.S2AAddress, |
| transportCreds: opts.TransportCreds, |
| tokenManager: tokenManager, |
| verificationMode: getVerificationMode(opts.VerificationMode), |
| serverAuthorizationPolicy: opts.serverAuthorizationPolicy, |
| }, nil |
| } |
| |
| type s2aTLSClientConfigFactory struct { |
| s2av2Address string |
| transportCreds credentials.TransportCredentials |
| tokenManager tokenmanager.AccessTokenManager |
| verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode |
| serverAuthorizationPolicy []byte |
| } |
| |
| func (f *s2aTLSClientConfigFactory) Build( |
| ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) { |
| serverName := "" |
| if opts != nil && opts.ServerName != "" { |
| serverName = opts.ServerName |
| } |
| return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy) |
| } |
| |
| func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode { |
| switch verificationMode { |
| case ConnectToGoogle: |
| return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE |
| case Spiffe: |
| return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE |
| default: |
| return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED |
| } |
| } |
| |
| // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A. |
| // Example use with http.RoundTripper: |
| // |
| // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{ |
| // S2AAddress: s2aAddress, // required |
| // }) |
| // transport := http.DefaultTransport |
| // transport.DialTLSContext = dialTLSContext |
| func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) { |
| |
| return func(ctx context.Context, network, addr string) (net.Conn, error) { |
| |
| fallback := func(err error) (net.Conn, error) { |
| if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil && |
| opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" { |
| fbDialer := opts.FallbackOpts.FallbackDialer |
| grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr) |
| fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr) |
| if fbErr != nil { |
| return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err) |
| } |
| return fbConn, nil |
| } |
| return nil, err |
| } |
| |
| factory, err := NewTLSClientConfigFactory(opts) |
| if err != nil { |
| grpclog.Infof("error creating S2A client config factory: %v", err) |
| return fallback(err) |
| } |
| |
| serverName, _, err := net.SplitHostPort(addr) |
| if err != nil { |
| serverName = addr |
| } |
| timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout()) |
| defer cancel() |
| |
| var s2aTLSConfig *tls.Config |
| retry.Run(timeoutCtx, |
| func() error { |
| s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{ |
| ServerName: serverName, |
| }) |
| return err |
| }) |
| if err != nil { |
| grpclog.Infof("error building S2A TLS config: %v", err) |
| return fallback(err) |
| } |
| |
| s2aDialer := &tls.Dialer{ |
| Config: s2aTLSConfig, |
| } |
| var c net.Conn |
| retry.Run(timeoutCtx, |
| func() error { |
| c, err = s2aDialer.DialContext(timeoutCtx, network, addr) |
| return err |
| }) |
| if err != nil { |
| grpclog.Infof("error dialing with S2A to %s: %v", addr, err) |
| return fallback(err) |
| } |
| grpclog.Infof("success dialing MTLS to %s with S2A", addr) |
| return c, nil |
| } |
| } |