1
#include "source/server/hot_restarting_base.h"
2

            
3
#include "source/common/api/os_sys_calls_impl.h"
4
#include "source/common/common/mem_block_builder.h"
5
#include "source/common/common/safe_memcpy.h"
6
#include "source/common/common/utility.h"
7
#include "source/common/network/address_impl.h"
8
#include "source/common/stats/utility.h"
9

            
10
namespace Envoy {
11
namespace Server {
12

            
13
using HotRestartMessage = envoy::HotRestartMessage;
14

            
15
static constexpr uint64_t MaxSendmsgSize = 4096;
16
static constexpr absl::Duration CONNECTION_REFUSED_RETRY_DELAY = absl::Seconds(1);
17
static constexpr int SENDMSG_MAX_RETRIES = 10;
18

            
19
346
RpcStream::~RpcStream() {
20
346
  if (domain_socket_ != -1) {
21
240
    Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get();
22
240
    Api::SysCallIntResult result = os_sys_calls.close(domain_socket_);
23
240
    ASSERT(result.return_value_ == 0);
24
240
  }
25
346
}
26

            
27
594
void RpcStream::initDomainSocketAddress(sockaddr_un* address) {
28
594
  memset(address, 0, sizeof(*address));
29
594
  address->sun_family = AF_UNIX;
30
594
}
31

            
32
sockaddr_un RpcStream::createDomainSocketAddress(uint64_t id, const std::string& role,
33
                                                 const std::string& socket_path,
34
318
                                                 mode_t socket_mode) {
35
  // Right now we only allow a maximum of 3 concurrent envoy processes to be running. When the third
36
  // starts up it will kill the oldest parent.
37
318
  static constexpr uint64_t MaxConcurrentProcesses = 3;
38
318
  id = id % MaxConcurrentProcesses;
39
318
  sockaddr_un address;
40
318
  initDomainSocketAddress(&address);
41
318
  auto addr = THROW_OR_RETURN_VALUE(
42
318
      Network::Address::PipeInstance::create(
43
318
          fmt::format("{}_{}_{}", socket_path, role, base_id_ + id), socket_mode, nullptr),
44
318
      std::unique_ptr<Network::Address::PipeInstance>);
45
318
  safeMemcpy(&address, &(addr->getSockAddr()));
46
318
  fchmod(domain_socket_, socket_mode);
47

            
48
318
  return address;
49
318
}
50

            
51
void RpcStream::bindDomainSocket(uint64_t id, const std::string& role,
52
240
                                 const std::string& socket_path, mode_t socket_mode) {
53
240
  Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get();
54
  // This actually creates the socket and binds it. We use the socket in datagram mode so we can
55
  // easily read single messages.
56
240
  domain_socket_ = socket(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0);
57
240
  sockaddr_un address = createDomainSocketAddress(id, role, socket_path, socket_mode);
58
240
  unlink(address.sun_path);
59
240
  Api::SysCallIntResult result =
60
240
      os_sys_calls.bind(domain_socket_, reinterpret_cast<sockaddr*>(&address), sizeof(address));
61
240
  if (result.return_value_ != 0) {
62
108
    const auto msg = fmt::format(
63
108
        "unable to bind domain socket with base_id={}, id={}, errno={} (see --base-id option)",
64
108
        base_id_, id, result.errno_);
65
108
    if (result.errno_ == SOCKET_ERROR_ADDR_IN_USE) {
66
104
      throw HotRestartDomainSocketInUseException(msg);
67
104
    }
68
4
    throw EnvoyException(msg);
69
108
  }
70
240
}
71

            
72
bool RpcStream::sendHotRestartMessage(sockaddr_un& address, const HotRestartMessage& proto,
73
12
                                      bool allow_failure) {
74
12
  Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get();
75
12
  const uint64_t serialized_size = proto.ByteSizeLong();
76
12
  const uint64_t total_size = sizeof(uint64_t) + serialized_size;
77
  // Fill with uint64_t 'length' followed by the serialized HotRestartMessage.
78
12
  std::vector<uint8_t> send_buf;
79
12
  send_buf.resize(total_size);
80
12
  *reinterpret_cast<uint64_t*>(send_buf.data()) = htobe64(serialized_size);
81
12
  RELEASE_ASSERT(proto.SerializeWithCachedSizesToArray(send_buf.data() + sizeof(uint64_t)),
82
12
                 "failed to serialize a HotRestartMessage");
83

            
84
12
  RELEASE_ASSERT(fcntl(domain_socket_, F_SETFL, 0) != -1,
85
12
                 fmt::format("Set domain socket blocking failed, errno = {}", errno));
86

            
87
12
  uint8_t* next_byte_to_send = send_buf.data();
88
12
  uint64_t sent = 0;
89
23
  while (sent < total_size) {
90
12
    const uint64_t cur_chunk_size = std::min(MaxSendmsgSize, total_size - sent);
91
12
    iovec iov[1];
92
12
    iov[0].iov_base = next_byte_to_send;
93
12
    iov[0].iov_len = cur_chunk_size;
94
12
    next_byte_to_send += cur_chunk_size;
95
12
    sent += cur_chunk_size;
96
12
    msghdr message;
97
12
    memset(&message, 0, sizeof(message));
98
12
    message.msg_name = &address;
99
12
    message.msg_namelen = sizeof(address);
100
12
    message.msg_iov = iov;
101
12
    message.msg_iovlen = 1;
102

            
103
    // Control data stuff, only relevant for the fd passing done with PassListenSocketReply.
104
12
    uint8_t control_buffer[CMSG_SPACE(sizeof(int))];
105
12
    if (replyIsExpectedType(&proto, HotRestartMessage::Reply::kPassListenSocket) &&
106
12
        proto.reply().pass_listen_socket().fd() != -1) {
107
      memset(control_buffer, 0, CMSG_SPACE(sizeof(int)));
108
      message.msg_control = control_buffer;
109
      message.msg_controllen = CMSG_SPACE(sizeof(int));
110
      cmsghdr* control_message = CMSG_FIRSTHDR(&message);
111
      control_message->cmsg_level = SOL_SOCKET;
112
      control_message->cmsg_type = SCM_RIGHTS;
113
      control_message->cmsg_len = CMSG_LEN(sizeof(int));
114
      *reinterpret_cast<int*>(CMSG_DATA(control_message)) = proto.reply().pass_listen_socket().fd();
115
      ASSERT(sent == total_size, "an fd passing message was too long for one sendmsg().");
116
    }
117

            
118
    // A transient connection refused error probably means the old process is not ready.
119
12
    int saved_errno = 0;
120
12
    int rc = 0;
121
12
    bool sent = false;
122
15
    for (int i = 0; i < SENDMSG_MAX_RETRIES; i++) {
123
13
      auto result = os_sys_calls.sendmsg(domain_socket_, &message, 0);
124
13
      rc = result.return_value_;
125
13
      saved_errno = result.errno_;
126

            
127
13
      if (rc == static_cast<int>(cur_chunk_size)) {
128
9
        sent = true;
129
9
        break;
130
9
      }
131

            
132
4
      if (saved_errno == ECONNREFUSED) {
133
2
        if (allow_failure) {
134
1
          return false;
135
1
        }
136
1
        ENVOY_LOG(error, "hot restart sendmsg() connection refused, retrying");
137
1
        absl::SleepFor(CONNECTION_REFUSED_RETRY_DELAY);
138
1
        continue;
139
2
      }
140

            
141
2
      RELEASE_ASSERT(false, fmt::format("hot restart sendmsg() failed: returned {}, errno {}", rc,
142
2
                                        saved_errno));
143
2
    }
144

            
145
11
    if (!sent) {
146
      RELEASE_ASSERT(false, fmt::format("hot restart sendmsg() failed: returned {}, errno {}", rc,
147
                                        saved_errno));
148
    }
149
11
  }
150

            
151
11
  RELEASE_ASSERT(fcntl(domain_socket_, F_SETFL, O_NONBLOCK) != -1,
152
11
                 fmt::format("Set domain socket nonblocking failed, errno = {}", errno));
153
11
  return true;
154
11
}
155

            
156
bool RpcStream::replyIsExpectedType(const HotRestartMessage* proto,
157
12
                                    HotRestartMessage::Reply::ReplyCase oneof_type) const {
158
12
  return proto != nullptr && proto->requestreply_case() == HotRestartMessage::kReply &&
159
12
         proto->reply().reply_case() == oneof_type;
160
12
}
161

            
162
// Pull the cloned fd, if present, out of the control data and write it into the
163
// PassListenSocketReply proto; the higher level code will see a listening fd that Just Works. We
164
// should only get control data in a PassListenSocketReply, it should only be the fd passing type,
165
// and there should only be one at a time. Crash on any other control data.
166
6
void RpcStream::getPassedFdIfPresent(HotRestartMessage* out, msghdr* message) {
167
  // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
168
6
  cmsghdr* cmsg = CMSG_FIRSTHDR(message);
169
6
  if (cmsg != nullptr) {
170
    RELEASE_ASSERT(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS &&
171
                       replyIsExpectedType(out, HotRestartMessage::Reply::kPassListenSocket),
172
                   "recvmsg() came with control data when the message's purpose was not to pass a "
173
                   "file descriptor.");
174

            
175
    out->mutable_reply()->mutable_pass_listen_socket()->set_fd(
176
        *reinterpret_cast<int*>(CMSG_DATA(cmsg)));
177

            
178
    RELEASE_ASSERT(CMSG_NXTHDR(message, cmsg) == nullptr,
179
                   "More than one control data on a single hot restart recvmsg().");
180
  }
181
6
}
182

            
183
// While in use, recv_buf_ is always >= MaxSendmsgSize. In between messages, it is kept empty,
184
// to be grown back to MaxSendmsgSize at the start of the next message.
185
10
void RpcStream::initRecvBufIfNewMessage() {
186
10
  if (recv_buf_.empty()) {
187
10
    ASSERT(cur_msg_recvd_bytes_ == 0);
188
10
    ASSERT(!expected_proto_length_.has_value());
189
10
    recv_buf_.resize(MaxSendmsgSize);
190
10
  }
191
10
}
192

            
193
// Must only be called when recv_buf_ contains a full proto. Returns that proto, and resets all of
194
// our receive-buffering state back to empty, to await a new message.
195
6
std::unique_ptr<HotRestartMessage> RpcStream::parseProtoAndResetState() {
196
6
  auto ret = std::make_unique<HotRestartMessage>();
197
6
  RELEASE_ASSERT(
198
6
      ret->ParseFromArray(recv_buf_.data() + sizeof(uint64_t), expected_proto_length_.value()),
199
6
      "failed to parse a HotRestartMessage.");
200
6
  recv_buf_.resize(0);
201
6
  cur_msg_recvd_bytes_ = 0;
202
6
  expected_proto_length_.reset();
203
6
  return ret;
204
6
}
205

            
206
10
std::unique_ptr<HotRestartMessage> RpcStream::receiveHotRestartMessage(Blocking block) {
207
  // By default the domain socket is non blocking. If we need to block, make it blocking first.
208
10
  if (block == Blocking::Yes) {
209
    RELEASE_ASSERT(fcntl(domain_socket_, F_SETFL, 0) != -1,
210
                   fmt::format("Set domain socket blocking failed, errno = {}", errno));
211
  }
212

            
213
10
  initRecvBufIfNewMessage();
214

            
215
10
  iovec iov[1];
216
10
  msghdr message;
217
10
  uint8_t control_buffer[CMSG_SPACE(sizeof(int))];
218
10
  std::unique_ptr<HotRestartMessage> ret = nullptr;
219
10
  Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get();
220
16
  while (!ret) {
221
10
    iov[0].iov_base = recv_buf_.data() + cur_msg_recvd_bytes_;
222
10
    iov[0].iov_len = MaxSendmsgSize;
223

            
224
    // We always setup to receive an FD even though most messages do not pass one.
225
10
    memset(control_buffer, 0, CMSG_SPACE(sizeof(int)));
226
10
    memset(&message, 0, sizeof(message));
227
10
    message.msg_iov = iov;
228
10
    message.msg_iovlen = 1;
229
10
    message.msg_control = control_buffer;
230
10
    message.msg_controllen = CMSG_SPACE(sizeof(int));
231

            
232
10
    const Api::SysCallSizeResult recv_result = os_sys_calls.recvmsg(domain_socket_, &message, 0);
233
10
    if (block == Blocking::No && recv_result.return_value_ == -1 &&
234
10
        recv_result.errno_ == SOCKET_ERROR_AGAIN) {
235
4
      return nullptr;
236
4
    }
237
6
    RELEASE_ASSERT(recv_result.return_value_ != -1,
238
6
                   fmt::format("recvmsg() returned -1, errno = {}", recv_result.errno_));
239
6
    RELEASE_ASSERT(message.msg_flags == 0,
240
6
                   fmt::format("recvmsg() left msg_flags = {}", message.msg_flags));
241
6
    cur_msg_recvd_bytes_ += recv_result.return_value_;
242

            
243
    // If we don't already know 'length', we're at the start of a new length+protobuf message!
244
6
    if (!expected_proto_length_.has_value()) {
245
      // We are not ok with messages so fragmented that the length doesn't even come in one piece.
246
6
      RELEASE_ASSERT(recv_result.return_value_ >= 8, "received a brokenly tiny message fragment.");
247

            
248
6
      expected_proto_length_ = be64toh(*reinterpret_cast<uint64_t*>(recv_buf_.data()));
249
      // Expand the buffer from its default 4096 if this message is going to be longer.
250
6
      if (expected_proto_length_.value() > MaxSendmsgSize - sizeof(uint64_t)) {
251
        recv_buf_.resize(expected_proto_length_.value() + sizeof(uint64_t));
252
        cur_msg_recvd_bytes_ = recv_result.return_value_;
253
      }
254
6
    }
255
    // If we have received beyond the end of the current in-flight proto, then next is misaligned.
256
6
    RELEASE_ASSERT(cur_msg_recvd_bytes_ <= sizeof(uint64_t) + expected_proto_length_.value(),
257
6
                   "received a length+protobuf message not aligned to start of sendmsg().");
258

            
259
6
    if (cur_msg_recvd_bytes_ == sizeof(uint64_t) + expected_proto_length_.value()) {
260
6
      ret = parseProtoAndResetState();
261
6
    }
262
6
  }
263

            
264
  // Turn non-blocking back on if we made it blocking.
265
6
  if (block == Blocking::Yes) {
266
    RELEASE_ASSERT(fcntl(domain_socket_, F_SETFL, O_NONBLOCK) != -1,
267
                   fmt::format("Set domain socket nonblocking failed, errno = {}", errno));
268
  }
269
6
  getPassedFdIfPresent(ret.get(), &message);
270
6
  return ret;
271
6
}
272

            
273
24
Stats::Gauge& HotRestartingBase::hotRestartGeneration(Stats::Scope& scope) {
274
  // Track the hot-restart generation. Using gauge's accumulate semantics,
275
  // the increments will be combined across hot-restart. This may be useful
276
  // at some point, though the main motivation for this stat is to enable
277
  // an integration test showing that dynamic stat-names can be coalesced
278
  // across hot-restarts. There's no other reason this particular stat-name
279
  // needs to be created dynamically.
280
  //
281
  // Note also, this stat cannot currently be represented as a counter due to
282
  // the way stats get latched on sink update. See the comment in
283
  // InstanceUtil::flushMetricsToSinks.
284
24
  return Stats::Utility::gaugeFromElements(scope,
285
24
                                           {Stats::DynamicName("server.hot_restart_generation")},
286
24
                                           Stats::Gauge::ImportMode::Accumulate);
287
24
}
288

            
289
} // namespace Server
290
} // namespace Envoy