1
#include "source/extensions/transport_sockets/alts/config.h"
2

            
3
#include "envoy/extensions/transport_sockets/alts/v3/alts.pb.h"
4
#include "envoy/extensions/transport_sockets/alts/v3/alts.pb.validate.h"
5
#include "envoy/registry/registry.h"
6
#include "envoy/server/transport_socket_config.h"
7

            
8
#include "source/common/common/assert.h"
9
#include "source/common/grpc/google_grpc_context.h"
10
#include "source/common/protobuf/protobuf.h"
11
#include "source/common/protobuf/utility.h"
12
#include "source/extensions/transport_sockets/alts/alts_channel_pool.h"
13
#include "source/extensions/transport_sockets/alts/alts_tsi_handshaker.h"
14
#include "source/extensions/transport_sockets/alts/grpc_tsi.h"
15
#include "source/extensions/transport_sockets/alts/tsi_handshaker.h"
16
#include "source/extensions/transport_sockets/alts/tsi_socket.h"
17

            
18
#include "absl/container/node_hash_set.h"
19
#include "absl/strings/str_join.h"
20
#include "absl/strings/string_view.h"
21
#include "grpcpp/channel.h"
22

            
23
namespace Envoy {
24
namespace Extensions {
25
namespace TransportSockets {
26
namespace Alts {
27
namespace {
28

            
29
// Manage ALTS singleton state via SingletonManager
30
class AltsSharedState : public Singleton::Instance {
31
public:
32
  explicit AltsSharedState(absl::string_view handshaker_service_address)
33
18
      : channel_pool_(AltsChannelPool::create(handshaker_service_address)) {}
34

            
35
18
  ~AltsSharedState() override = default;
36

            
37
15
  std::shared_ptr<grpc::Channel> getChannel() const { return channel_pool_->getChannel(); }
38

            
39
private:
40
  // There is blanket google-grpc initialization in MainCommonBase, but that
41
  // doesn't cover unit tests. However, putting blanket coverage in ProcessWide
42
  // causes background threaded memory allocation in all unit tests making it
43
  // hard to measure memory. Thus we also initialize grpc using our idempotent
44
  // wrapper-class in classes that need it. See
45
  // https://github.com/envoyproxy/envoy/issues/8282 for details.
46
#ifdef ENVOY_GOOGLE_GRPC
47
  Grpc::GoogleGrpcContext google_grpc_context_;
48
#endif
49
  std::unique_ptr<AltsChannelPool> channel_pool_;
50
};
51

            
52
SINGLETON_MANAGER_REGISTRATION(alts_shared_state);
53

            
54
// Returns true if the peer's service account is found in peers, otherwise
55
// returns false and fills out err with an error message.
56
bool doValidate(const absl::node_hash_set<std::string>& peers, TsiInfo& tsi_info,
57
4
                std::string& err) {
58
4
  if (peers.find(tsi_info.peer_identity_) != peers.end()) {
59
2
    return true;
60
2
  }
61
2
  err =
62
2
      "Couldn't find peer's service account in peer_service_accounts: " + absl::StrJoin(peers, ",");
63
2
  return false;
64
4
}
65

            
66
HandshakeValidator
67
18
createHandshakeValidator(const envoy::extensions::transport_sockets::alts::v3::Alts& config) {
68
18
  const auto& peer_service_accounts = config.peer_service_accounts();
69
18
  const absl::node_hash_set<std::string> peers(peer_service_accounts.cbegin(),
70
18
                                               peer_service_accounts.cend());
71
18
  HandshakeValidator validator;
72
  // Skip validation if peers is empty.
73
18
  if (!peers.empty()) {
74
6
    validator = [peers](TsiInfo& tsi_info, std::string& err) {
75
4
      return doValidate(peers, tsi_info, err);
76
4
    };
77
6
  }
78
18
  return validator;
79
18
}
80

            
81
template <class TransportSocketFactoryPtr>
82
TransportSocketFactoryPtr createTransportSocketFactoryHelper(
83
    const Protobuf::Message& message, bool is_upstream,
84
18
    Server::Configuration::TransportSocketFactoryContext& factory_ctxt) {
85
18
  auto config =
86
18
      MessageUtil::downcastAndValidate<const envoy::extensions::transport_sockets::alts::v3::Alts&>(
87
18
          message, factory_ctxt.messageValidationVisitor());
88
18
  HandshakeValidator validator = createHandshakeValidator(config);
89
18
  const std::string& handshaker_service_address = config.handshaker_service();
90

            
91
  // A reference to this is held in the factory closure to keep the singleton
92
  // instance alive.
93
18
  auto alts_shared_state =
94
18
      factory_ctxt.serverFactoryContext().singletonManager().getTyped<AltsSharedState>(
95
18
          SINGLETON_MANAGER_REGISTERED_NAME(alts_shared_state), [handshaker_service_address] {
96
18
            return std::make_shared<AltsSharedState>(handshaker_service_address);
97
18
          });
98
18
  HandshakerFactory factory =
99
18
      [handshaker_service_address, is_upstream,
100
18
       alts_shared_state](Event::Dispatcher& dispatcher,
101
18
                          const Network::Address::InstanceConstSharedPtr& local_address,
102
18
                          const Network::Address::InstanceConstSharedPtr&) -> TsiHandshakerPtr {
103
15
    ASSERT(local_address != nullptr);
104
15
    std::unique_ptr<AltsTsiHandshaker> tsi_handshaker;
105
15
    if (is_upstream) {
106
8
      tsi_handshaker = AltsTsiHandshaker::createForClient(alts_shared_state->getChannel());
107
8
    } else {
108
7
      tsi_handshaker = AltsTsiHandshaker::createForServer(alts_shared_state->getChannel());
109
7
    }
110
15
    return std::make_unique<TsiHandshaker>(std::move(tsi_handshaker), dispatcher);
111
15
  };
112

            
113
18
  return std::make_unique<TsiSocketFactory>(factory, validator);
114
18
}
115

            
116
} // namespace
117

            
118
24
ProtobufTypes::MessagePtr AltsTransportSocketConfigFactory::createEmptyConfigProto() {
119
24
  return std::make_unique<envoy::extensions::transport_sockets::alts::v3::Alts>();
120
24
}
121

            
122
absl::StatusOr<Network::UpstreamTransportSocketFactoryPtr>
123
UpstreamAltsTransportSocketConfigFactory::createTransportSocketFactory(
124
    const Protobuf::Message& message,
125
9
    Server::Configuration::TransportSocketFactoryContext& factory_ctxt) {
126
9
  return createTransportSocketFactoryHelper<Network::UpstreamTransportSocketFactoryPtr>(
127
9
      message, /* is_upstream */ true, factory_ctxt);
128
9
}
129

            
130
absl::StatusOr<Network::DownstreamTransportSocketFactoryPtr>
131
DownstreamAltsTransportSocketConfigFactory::createTransportSocketFactory(
132
    const Protobuf::Message& message,
133
    Server::Configuration::TransportSocketFactoryContext& factory_ctxt,
134
9
    const std::vector<std::string>&) {
135
9
  return createTransportSocketFactoryHelper<Network::DownstreamTransportSocketFactoryPtr>(
136
9
      message, /* is_upstream */ false, factory_ctxt);
137
9
}
138

            
139
REGISTER_FACTORY(UpstreamAltsTransportSocketConfigFactory,
140
                 Server::Configuration::UpstreamTransportSocketConfigFactory);
141

            
142
REGISTER_FACTORY(DownstreamAltsTransportSocketConfigFactory,
143
                 Server::Configuration::DownstreamTransportSocketConfigFactory);
144

            
145
} // namespace Alts
146
} // namespace TransportSockets
147
} // namespace Extensions
148
} // namespace Envoy