1
#include "source/extensions/transport_sockets/alts/alts_tsi_handshaker.h"
2

            
3
#include <algorithm>
4
#include <cstdint>
5
#include <cstring>
6
#include <memory>
7
#include <string>
8
#include <utility>
9
#include <vector>
10

            
11
#include "absl/memory/memory.h"
12
#include "absl/status/status.h"
13
#include "absl/strings/str_format.h"
14
#include "absl/types/span.h"
15
#include "src/core/lib/iomgr/exec_ctx.h"
16
#include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"
17

            
18
namespace Envoy {
19
namespace Extensions {
20
namespace TransportSockets {
21
namespace Alts {
22

            
23
using ::grpc::gcp::HandshakerResp;
24

            
25
constexpr std::size_t AltsAes128GcmRekeyKeyLength = 44;
26

            
27
std::unique_ptr<AltsTsiHandshaker>
28
37
AltsTsiHandshaker::createForClient(std::shared_ptr<grpc::Channel> handshaker_service_channel) {
29
37
  return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/true, handshaker_service_channel));
30
37
}
31

            
32
std::unique_ptr<AltsTsiHandshaker>
33
34
AltsTsiHandshaker::createForServer(std::shared_ptr<grpc::Channel> handshaker_service_channel) {
34
34
  return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/false, handshaker_service_channel));
35
34
}
36

            
37
AltsTsiHandshaker::AltsTsiHandshaker(bool is_client,
38
                                     std::shared_ptr<grpc::Channel> handshaker_service_channel)
39
71
    : is_client_(is_client), handshaker_service_channel_(handshaker_service_channel) {}
40

            
41
absl::Status AltsTsiHandshaker::next(void* handshaker, const unsigned char* received_bytes,
42
138
                                     size_t received_bytes_size, OnNextDone on_next_done) {
43
  // Argument and state checks.
44
138
  if (handshaker == nullptr || (received_bytes == nullptr && received_bytes_size > 0) ||
45
138
      on_next_done == nullptr) {
46
3
    return absl::InvalidArgumentError("Invalid nullptr argument to AltsTsiHandshaker::Next.");
47
3
  }
48
135
  if (is_handshake_complete_) {
49
6
    return absl::InternalError("Handshake is already complete.");
50
6
  }
51

            
52
  // Get a handshake message from the handshaker service.
53
129
  absl::Span<const uint8_t> in_bytes = absl::MakeConstSpan(received_bytes, received_bytes_size);
54
129
  HandshakerResp response;
55
129
  if (!has_sent_initial_handshake_message_) {
56
69
    has_sent_initial_handshake_message_ = true;
57
69
    auto alts_proxy = AltsProxy::create(handshaker_service_channel_);
58
69
    if (!alts_proxy.ok()) {
59
      return alts_proxy.status();
60
    }
61
69
    alts_proxy_ = *std::move(alts_proxy);
62
69
    if (is_client_) {
63
35
      auto client_start = alts_proxy_->sendStartClientHandshakeReq();
64
35
      if (!client_start.ok()) {
65
4
        return client_start.status();
66
4
      }
67
31
      response = *std::move(client_start);
68
34
    } else {
69
34
      auto server_start = alts_proxy_->sendStartServerHandshakeReq(in_bytes);
70
34
      if (!server_start.ok()) {
71
3
        return server_start.status();
72
3
      }
73
31
      response = *std::move(server_start);
74
31
    }
75
71
  } else {
76
60
    auto next = alts_proxy_->sendNextHandshakeReq(in_bytes);
77
60
    if (!next.ok()) {
78
      return next.status();
79
    }
80
60
    response = *std::move(next);
81
60
  }
82

            
83
  // Maybe prepare the handshake result.
84
122
  std::unique_ptr<AltsHandshakeResult> handshake_result = nullptr;
85
122
  if (response.has_result()) {
86
58
    is_handshake_complete_ = true;
87
58
    auto result = getHandshakeResult(response.result(), in_bytes, response.bytes_consumed());
88
58
    if (!result.ok()) {
89
      return result.status();
90
    }
91
58
    handshake_result = *std::move(result);
92
58
  }
93

            
94
  // Write the out bytes.
95
122
  const std::string& out_bytes = response.out_frames();
96
122
  on_next_done(absl::OkStatus(), handshaker,
97
122
               reinterpret_cast<const unsigned char*>(out_bytes.c_str()), out_bytes.size(),
98
122
               std::move(handshake_result));
99
122
  return absl::OkStatus();
100
122
}
101

            
102
62
std::size_t AltsTsiHandshaker::computeMaxFrameSize(const grpc::gcp::HandshakerResult& result) {
103
62
  if (result.max_frame_size() > 0) {
104
12
    return std::clamp(static_cast<std::size_t>(result.max_frame_size()), AltsMinFrameSize,
105
12
                      MaxFrameSize);
106
12
  }
107
50
  return AltsMinFrameSize;
108
62
}
109

            
110
absl::StatusOr<std::unique_ptr<AltsHandshakeResult>>
111
AltsTsiHandshaker::getHandshakeResult(const grpc::gcp::HandshakerResult& result,
112
                                      absl::Span<const uint8_t> received_bytes,
113
65
                                      std::size_t bytes_consumed) {
114
  // Validate the HandshakerResult message.
115
65
  if (!result.has_peer_identity()) {
116
1
    return absl::FailedPreconditionError("Handshake result is missing peer identity.");
117
1
  }
118
64
  if (!result.has_local_identity()) {
119
1
    return absl::FailedPreconditionError("Handshake result is missing local identity.");
120
1
  }
121
63
  if (!result.has_peer_rpc_versions()) {
122
1
    return absl::FailedPreconditionError("Handshake result is missing peer rpc versions.");
123
1
  }
124
62
  if (result.application_protocol().empty()) {
125
1
    return absl::FailedPreconditionError("Handshake result has empty application protocol.");
126
1
  }
127
61
  if (result.record_protocol() != RecordProtocol) {
128
1
    return absl::FailedPreconditionError(
129
1
        "Handshake result's record protocol is not ALTSRP_GCM_AES128_REKEY.");
130
1
  }
131
60
  if (result.key_data().size() < AltsAes128GcmRekeyKeyLength) {
132
1
    return absl::FailedPreconditionError("Handshake result's key data is too short.");
133
1
  }
134
59
  if (bytes_consumed > received_bytes.size()) {
135
1
    return absl::FailedPreconditionError(
136
1
        "Handshaker service consumed more bytes than were received from the "
137
1
        "peer.");
138
1
  }
139

            
140
  // Create the frame protector.
141
58
  std::size_t max_frame_size = computeMaxFrameSize(result);
142
58
  tsi_zero_copy_grpc_protector* protector = nullptr;
143
58
  grpc_core::ExecCtx exec_ctx;
144
58
  tsi_result ok = alts_zero_copy_grpc_protector_create(
145
58
      grpc_core::GsecKeyFactory(
146
58
          {reinterpret_cast<const uint8_t*>(result.key_data().data()), AltsAes128GcmRekeyKeyLength},
147
58
          true),
148
58
      is_client_,
149
58
      /*is_integrity_only=*/false, /*enable_extra_copy=*/false, &max_frame_size, &protector);
150
58
  if (ok != TSI_OK) {
151
    return absl::InternalError(absl::StrFormat("Failed to create frame protector: %zu", ok));
152
  }
153

            
154
  // Calculate the unused bytes.
155
58
  std::size_t unused_bytes_size = received_bytes.size() - bytes_consumed;
156
58
  const uint8_t* unused_bytes_ptr = received_bytes.data() + bytes_consumed;
157
58
  std::vector<uint8_t> unused_bytes(unused_bytes_ptr, unused_bytes_ptr + unused_bytes_size);
158

            
159
  // Create and return the AltsHandshakeResult.
160
58
  auto handshake_result = std::make_unique<AltsHandshakeResult>();
161
58
  handshake_result->frame_protector = std::make_unique<TsiFrameProtector>(protector);
162
58
  handshake_result->peer_identity = result.peer_identity().service_account();
163
58
  handshake_result->unused_bytes = unused_bytes;
164
58
  return handshake_result;
165
58
}
166

            
167
} // namespace Alts
168
} // namespace TransportSockets
169
} // namespace Extensions
170
} // namespace Envoy