1
#include "source/common/network/io_uring_socket_handle_impl.h"
2

            
3
#include "envoy/buffer/buffer.h"
4
#include "envoy/common/exception.h"
5
#include "envoy/event/dispatcher.h"
6

            
7
#include "source/common/api/os_sys_calls_impl.h"
8
#include "source/common/common/assert.h"
9
#include "source/common/common/utility.h"
10
#include "source/common/io/io_uring_worker_impl.h"
11
#include "source/common/network/address_impl.h"
12
#include "source/common/network/io_socket_error_impl.h"
13
#include "source/common/network/io_socket_handle_impl.h"
14
#include "source/common/network/socket_interface_impl.h"
15

            
16
namespace Envoy {
17
namespace Network {
18

            
19
IoUringSocketHandleImpl::IoUringSocketHandleImpl(Io::IoUringWorkerFactory& io_uring_worker_factory,
20
                                                 os_fd_t fd, bool socket_v6only,
21
                                                 absl::optional<int> domain, bool is_server_socket)
22
39
    : IoSocketHandleBaseImpl(fd, socket_v6only, domain),
23
39
      io_uring_worker_factory_(io_uring_worker_factory),
24
39
      io_uring_socket_type_(is_server_socket ? IoUringSocketType::Server
25
39
                                             : IoUringSocketType::Unknown) {
26
39
  ENVOY_LOG(trace, "construct io uring socket handle, fd = {}, type = {}", fd_,
27
39
            ioUringSocketTypeStr());
28
39
}
29

            
30
39
IoUringSocketHandleImpl::~IoUringSocketHandleImpl() {
31
39
  ENVOY_LOG(trace, "~IoUringSocketHandleImpl, type = {}", ioUringSocketTypeStr());
32

            
33
39
  if (SOCKET_INVALID(fd_)) {
34
34
    return;
35
34
  }
36

            
37
  // If the socket is owned by the main thread like a listener, it may outlive the IoUringWorker.
38
  // We have to ensure that the current thread has been registered and the io_uring in the thread
39
  // is still available.
40
  // TODO(zhxie): for current usage of server socket and client socket, the check may be
41
  // redundant.
42
5
  if (io_uring_socket_type_ != IoUringSocketType::Unknown &&
43
5
      io_uring_socket_type_ != IoUringSocketType::Accept &&
44
5
      io_uring_worker_factory_.currentThreadRegistered() && io_uring_socket_.has_value()) {
45
2
    if (io_uring_socket_->getStatus() != Io::IoUringSocketStatus::Closed) {
46
2
      io_uring_socket_.ref().close(false);
47
2
    }
48
3
  } else {
49
    // The TLS slot has been shut down by this moment with io_uring wiped out, thus use the
50
    // POSIX system call instead of IoUringSocketHandleImpl::close().
51
3
    ::close(fd_);
52
3
  }
53
5
}
54

            
55
26
Api::IoCallUint64Result IoUringSocketHandleImpl::close() {
56
26
  ENVOY_LOG(trace, "close, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
57

            
58
26
  ASSERT(SOCKET_VALID(fd_));
59

            
60
26
  if (io_uring_socket_type_ == IoUringSocketType::Unknown ||
61
26
      io_uring_socket_type_ == IoUringSocketType::Accept || !io_uring_socket_.has_value()) {
62
6
    if (file_event_) {
63
2
      file_event_.reset();
64
2
    }
65
6
    ::close(fd_);
66
20
  } else {
67
20
    io_uring_socket_.ref().close(false);
68
20
    io_uring_socket_.reset();
69
20
  }
70
26
  SET_SOCKET_INVALID(fd_);
71
26
  return Api::ioCallUint64ResultNoError();
72
26
}
73

            
74
Api::IoCallUint64Result
75
5
IoUringSocketHandleImpl::readv(uint64_t max_length, Buffer::RawSlice* slices, uint64_t num_slice) {
76
5
  ENVOY_LOG(debug, "readv, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
77

            
78
5
  Api::IoCallUint64Result result = copyOut(max_length, slices, num_slice);
79
5
  if (result.ok()) {
80
    // If the return value is 0, there should be a remote close. Return the value directly.
81
4
    if (result.return_value_ != 0) {
82
4
      io_uring_socket_->getReadParam()->buf_.drain(result.return_value_);
83
4
    }
84
4
  }
85
5
  return result;
86
5
}
87

            
88
Api::IoCallUint64Result IoUringSocketHandleImpl::read(Buffer::Instance& buffer,
89
23
                                                      absl::optional<uint64_t> max_length_opt) {
90
23
  ENVOY_LOG(trace, "read, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
91

            
92
23
  absl::optional<Api::IoCallUint64Result> read_result = checkReadResult();
93
23
  if (read_result.has_value()) {
94
13
    return std::move(*read_result);
95
13
  }
96

            
97
10
  const OptRef<Io::ReadParam>& read_param = io_uring_socket_->getReadParam();
98
10
  uint64_t max_read_length =
99
10
      std::min(max_length_opt.value_or(UINT64_MAX), read_param->buf_.length());
100
10
  buffer.move(read_param->buf_, max_read_length);
101
10
  return {max_read_length, IoSocketError::none()};
102
23
}
103

            
104
Api::IoCallUint64Result IoUringSocketHandleImpl::writev(const Buffer::RawSlice* slices,
105
2
                                                        uint64_t num_slice) {
106
2
  ENVOY_LOG(trace, "writev, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
107

            
108
2
  absl::optional<Api::IoCallUint64Result> write_result = checkWriteResult();
109
2
  if (write_result.has_value()) {
110
1
    return std::move(*write_result);
111
1
  }
112

            
113
1
  uint64_t ret = io_uring_socket_->write(slices, num_slice);
114
1
  return {ret, IoSocketError::none()};
115
2
}
116

            
117
3
Api::IoCallUint64Result IoUringSocketHandleImpl::write(Buffer::Instance& buffer) {
118
3
  ENVOY_LOG(trace, "write {}, fd = {}, type = {}", buffer.length(), fd_, ioUringSocketTypeStr());
119

            
120
3
  absl::optional<Api::IoCallUint64Result> write_result = checkWriteResult();
121
3
  if (write_result.has_value()) {
122
1
    return std::move(*write_result);
123
1
  }
124

            
125
2
  uint64_t buffer_size = buffer.length();
126
2
  io_uring_socket_->write(buffer);
127
2
  return {buffer_size, IoSocketError::none()};
128
3
}
129

            
130
Api::IoCallUint64Result IoUringSocketHandleImpl::sendmsg(const Buffer::RawSlice*, uint64_t, int,
131
                                                         const Address::Ip*,
132
1
                                                         const Address::Instance&) {
133
1
  ENVOY_LOG(trace, "sendmsg, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
134
1
  return Network::IoSocketError::ioResultSocketInvalidAddress();
135
1
}
136

            
137
Api::IoCallUint64Result IoUringSocketHandleImpl::recvmsg(Buffer::RawSlice*, const uint64_t,
138
                                                         uint32_t,
139
                                                         const IoHandle::UdpSaveCmsgConfig&,
140
1
                                                         RecvMsgOutput&) {
141
1
  ENVOY_LOG(trace, "recvmsg, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
142
1
  return Network::IoSocketError::ioResultSocketInvalidAddress();
143
1
}
144

            
145
Api::IoCallUint64Result IoUringSocketHandleImpl::recvmmsg(RawSliceArrays&, uint32_t,
146
                                                          const IoHandle::UdpSaveCmsgConfig&,
147
1
                                                          RecvMsgOutput&) {
148
1
  ENVOY_LOG(trace, "recvmmsg, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
149
1
  return Network::IoSocketError::ioResultSocketInvalidAddress();
150
1
}
151

            
152
2
Api::IoCallUint64Result IoUringSocketHandleImpl::recv(void* buffer, size_t length, int flags) {
153
2
  ASSERT(io_uring_socket_.has_value());
154
2
  ENVOY_LOG(trace, "recv, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
155

            
156
  // The only used flag in Envoy is MSG_PEEK for listener filters, including TLS inspectors.
157
2
  ASSERT(flags == 0 || flags == MSG_PEEK);
158
2
  Buffer::RawSlice slice;
159
2
  slice.mem_ = buffer;
160
2
  slice.len_ = length;
161
2
  if (flags == 0) {
162
1
    return readv(length, &slice, 1);
163
1
  }
164

            
165
1
  return copyOut(length, &slice, 1);
166
2
}
167

            
168
4
Api::SysCallIntResult IoUringSocketHandleImpl::bind(Address::InstanceConstSharedPtr address) {
169
4
  ENVOY_LOG(trace, "bind {}, fd = {}, io_uring_socket_type = {}", address->asString(), fd_,
170
4
            ioUringSocketTypeStr());
171
4
  return Api::OsSysCallsSingleton::get().bind(fd_, address->sockAddr(), address->sockAddrLen());
172
4
}
173

            
174
3
Api::SysCallIntResult IoUringSocketHandleImpl::listen(int backlog) {
175
3
  ENVOY_LOG(trace, "listen, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
176

            
177
3
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Unknown);
178

            
179
3
  io_uring_socket_type_ = IoUringSocketType::Accept;
180
3
  setBlocking(false);
181
3
  return Api::OsSysCallsSingleton::get().listen(fd_, backlog);
182
3
}
183

            
184
4
IoHandlePtr IoUringSocketHandleImpl::accept(struct sockaddr* addr, socklen_t* addrlen) {
185
4
  ENVOY_LOG(trace, "accept, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
186

            
187
4
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Accept);
188

            
189
4
  Envoy::Api::SysCallSocketResult result =
190
4
      Api::OsSysCallsSingleton::get().accept(fd_, addr, addrlen);
191
4
  if (SOCKET_INVALID(result.return_value_)) {
192
2
    return nullptr;
193
2
  }
194
2
  return std::make_unique<IoUringSocketHandleImpl>(io_uring_worker_factory_, result.return_value_,
195
2
                                                   socket_v6only_, domain_, true);
196
4
}
197

            
198
20
Api::SysCallIntResult IoUringSocketHandleImpl::connect(Address::InstanceConstSharedPtr address) {
199
20
  ENVOY_LOG(trace, "connect, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
200

            
201
20
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Client);
202

            
203
20
  io_uring_socket_->connect(address);
204
20
  return Api::SysCallIntResult{-1, EINPROGRESS};
205
20
}
206

            
207
Api::SysCallIntResult IoUringSocketHandleImpl::getOption(int level, int optname, void* optval,
208
21
                                                         socklen_t* optlen) {
209
  // io_uring socket does not populate connect error in getsockopt. Instead, the connect error is
210
  // returned in onConnect() handling. We will imitate the default socket behavior here for client
211
  // socket with optname SO_ERROR, which is only used to check connect error.
212
21
  if (io_uring_socket_type_ == IoUringSocketType::Client && optname == SO_ERROR &&
213
21
      io_uring_socket_.has_value()) {
214
20
    int* intval = static_cast<int*>(optval);
215
20
    *intval = -io_uring_socket_->getWriteParam()->result_;
216
20
    *optlen = sizeof(int);
217
20
    return {0, 0};
218
20
  }
219

            
220
1
  return IoSocketHandleBaseImpl::getOption(level, optname, optval, optlen);
221
21
}
222

            
223
1
IoHandlePtr IoUringSocketHandleImpl::duplicate() {
224
1
  ENVOY_LOG(trace, "duplicate, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
225

            
226
1
  Api::SysCallSocketResult result = Api::OsSysCallsSingleton::get().duplicate(fd_);
227
1
  RELEASE_ASSERT(result.return_value_ != -1,
228
1
                 fmt::format("duplicate failed for '{}': ({}) {}", fd_, result.errno_,
229
1
                             errorDetails(result.errno_)));
230
1
  return SocketInterfaceImpl::makePlatformSpecificSocket(result.return_value_, socket_v6only_,
231
1
                                                         domain_, Network::SocketCreationOptions{},
232
1
                                                         &io_uring_worker_factory_);
233
1
}
234

            
235
void IoUringSocketHandleImpl::initializeFileEvent(Event::Dispatcher& dispatcher,
236
                                                  Event::FileReadyCb cb,
237
45
                                                  Event::FileTriggerType trigger, uint32_t events) {
238
45
  ENVOY_LOG(trace, "initialize file event, fd = {}, type = {}, has socket = {}", fd_,
239
45
            ioUringSocketTypeStr(), io_uring_socket_.has_value());
240

            
241
  // The IoUringSocket has already been created. It usually happened after a resetFileEvents.
242
45
  if (io_uring_socket_.has_value()) {
243
15
    if (&io_uring_socket_->getIoUringWorker().dispatcher() ==
244
15
        &io_uring_worker_factory_.getIoUringWorker()->dispatcher()) {
245
14
      io_uring_socket_->setFileReadyCb(std::move(cb));
246
14
      io_uring_socket_->enableRead();
247
14
      io_uring_socket_->enableCloseEvent(events & Event::FileReadyType::Closed);
248
14
    } else {
249
1
      ENVOY_LOG(trace, "initialize file event from another thread, fd = {}, type = {}", fd_,
250
1
                ioUringSocketTypeStr());
251
1
      Thread::CondVar wait_cv;
252
1
      Thread::MutexBasicLockable mutex;
253
1
      Buffer::OwnedImpl buf;
254
1
      os_fd_t fd = io_uring_socket_->fd();
255

            
256
1
      {
257
1
        Thread::LockGuard lock(mutex);
258
        // Close the original socket in its running thread.
259
1
        io_uring_socket_->getIoUringWorker().dispatcher().post(
260
1
            [&origin_socket = io_uring_socket_, &wait_cv, &mutex, &buf]() {
261
              // Move the data of original socket's read buffer to the temporary buf.
262
1
              origin_socket->close(true, [&wait_cv, &mutex, &buf](Buffer::Instance& buffer) {
263
1
                Thread::LockGuard lock(mutex);
264
1
                buf.move(buffer);
265
1
                wait_cv.notifyOne();
266
1
              });
267
1
            });
268
1
        wait_cv.wait(mutex);
269
1
      }
270

            
271
      // Move the temporary buf to the newly created one.
272
1
      io_uring_socket_ = io_uring_worker_factory_.getIoUringWorker()->addServerSocket(
273
1
          fd, buf, std::move(cb), events & Event::FileReadyType::Closed);
274
1
    }
275
15
    return;
276
15
  }
277

            
278
30
  switch (io_uring_socket_type_) {
279
3
  case IoUringSocketType::Accept:
280
3
    file_event_ = dispatcher.createFileEvent(fd_, cb, trigger, events);
281
3
    break;
282
3
  case IoUringSocketType::Server:
283
3
    io_uring_socket_ = io_uring_worker_factory_.getIoUringWorker()->addServerSocket(
284
3
        fd_, std::move(cb), events & Event::FileReadyType::Closed);
285
3
    break;
286
24
  case IoUringSocketType::Unknown:
287
24
  case IoUringSocketType::Client:
288
24
    io_uring_socket_type_ = IoUringSocketType::Client;
289
24
    io_uring_socket_ = io_uring_worker_factory_.getIoUringWorker()->addClientSocket(
290
24
        fd_, std::move(cb), events & Event::FileReadyType::Closed);
291
24
    break;
292
30
  }
293
30
}
294

            
295
4
void IoUringSocketHandleImpl::activateFileEvents(uint32_t events) {
296
4
  ENVOY_LOG(trace, "activate file events {}, fd = {}, type = {}", events, fd_,
297
4
            ioUringSocketTypeStr());
298

            
299
4
  if (io_uring_socket_type_ == IoUringSocketType::Accept) {
300
1
    ASSERT(file_event_ != nullptr);
301
1
    file_event_->activate(events);
302
1
    return;
303
1
  }
304

            
305
3
  if (events & Event::FileReadyType::Read) {
306
1
    io_uring_socket_->injectCompletion(Io::Request::RequestType::Read);
307
1
  }
308
3
  if (events & Event::FileReadyType::Write) {
309
2
    io_uring_socket_->injectCompletion(Io::Request::RequestType::Write);
310
2
  }
311
3
}
312

            
313
3
void IoUringSocketHandleImpl::enableFileEvents(uint32_t events) {
314
3
  ENVOY_LOG(trace, "enable file events {}, fd = {}, type = {}", events, fd_,
315
3
            ioUringSocketTypeStr());
316

            
317
3
  if (io_uring_socket_type_ == IoUringSocketType::Accept) {
318
1
    ASSERT(file_event_ != nullptr);
319
1
    file_event_->setEnabled(events);
320
1
    return;
321
1
  }
322

            
323
2
  if (events & Event::FileReadyType::Read) {
324
1
    io_uring_socket_->enableRead();
325
1
  } else {
326
1
    io_uring_socket_->disableRead();
327
1
  }
328
2
  io_uring_socket_->enableCloseEvent(events & Event::FileReadyType::Closed);
329
2
}
330

            
331
2
void IoUringSocketHandleImpl::resetFileEvents() {
332
2
  ENVOY_LOG(trace, "reset file events, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
333

            
334
2
  if (io_uring_socket_type_ == IoUringSocketType::Accept) {
335
1
    file_event_.reset();
336
1
    return;
337
1
  }
338

            
339
1
  io_uring_socket_->disableRead();
340
1
  io_uring_socket_->enableCloseEvent(false);
341
1
}
342

            
343
1
Api::SysCallIntResult IoUringSocketHandleImpl::shutdown(int how) {
344
1
  ENVOY_LOG(trace, "shutdown, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
345

            
346
1
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Server ||
347
1
         io_uring_socket_type_ == IoUringSocketType::Client);
348

            
349
1
  io_uring_socket_->shutdown(how);
350
1
  return Api::SysCallIntResult{0, 0};
351
1
}
352

            
353
29
absl::optional<Api::IoCallUint64Result> IoUringSocketHandleImpl::checkReadResult() const {
354
29
  ASSERT(io_uring_socket_.has_value());
355
29
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Server ||
356
29
         io_uring_socket_type_ == IoUringSocketType::Client);
357

            
358
29
  const OptRef<Io::ReadParam>& read_param = io_uring_socket_->getReadParam();
359
  // A absl::nullopt read param means that there is no io_uring request which has been done.
360
29
  if (read_param == absl::nullopt) {
361
3
    if (io_uring_socket_->getStatus() != Io::IoUringSocketStatus::RemoteClosed) {
362
1
      return Api::IoCallUint64Result{0, IoSocketError::getIoSocketEagainError()};
363
2
    } else {
364
2
      ENVOY_LOG(trace, "read, fd = {}, type = {}, remote close", fd_, ioUringSocketTypeStr());
365
2
      return Api::ioCallUint64ResultNoError();
366
2
    }
367
3
  }
368

            
369
26
  if (read_param->result_ == 0) {
370
3
    ENVOY_LOG(trace, "read remote close, fd = {}, type = {}", fd_, ioUringSocketTypeStr());
371
3
    return Api::ioCallUint64ResultNoError();
372
3
  }
373

            
374
23
  if (read_param->result_ < 0) {
375
2
    ASSERT(read_param->buf_.length() == 0);
376
2
    ENVOY_LOG(trace, "read error = {}, fd = {}, type = {}", -read_param->result_, fd_,
377
2
              ioUringSocketTypeStr());
378
2
    if (read_param->result_ == -EAGAIN) {
379
1
      return Api::IoCallUint64Result{0, IoSocketError::getIoSocketEagainError()};
380
1
    }
381
1
    return Api::IoCallUint64Result{0, IoSocketError::create(-read_param->result_)};
382
2
  }
383

            
384
  // The buffer has been read in the previous call, return EAGAIN to tell the caller to wait for
385
  // the next read event.
386
21
  if (read_param->buf_.length() == 0) {
387
6
    return Api::IoCallUint64Result{0, IoSocketError::getIoSocketEagainError()};
388
6
  }
389
15
  return absl::nullopt;
390
21
}
391

            
392
5
absl::optional<Api::IoCallUint64Result> IoUringSocketHandleImpl::checkWriteResult() const {
393
5
  ASSERT(io_uring_socket_.has_value());
394
5
  ASSERT(io_uring_socket_type_ == IoUringSocketType::Server ||
395
5
         io_uring_socket_type_ == IoUringSocketType::Client);
396

            
397
5
  const OptRef<Io::WriteParam>& write_param = io_uring_socket_->getWriteParam();
398
5
  if (write_param != absl::nullopt) {
399
    // EAGAIN indicates an injected write event to trigger IO handle write. Submit the new write to
400
    // the io_uring.
401
3
    if (write_param->result_ < 0 && write_param->result_ != -EAGAIN) {
402
2
      return Api::IoCallUint64Result{0, IoSocketError::create(-write_param->result_)};
403
2
    }
404
3
  }
405
3
  return absl::nullopt;
406
5
}
407

            
408
Api::IoCallUint64Result IoUringSocketHandleImpl::copyOut(uint64_t max_length,
409
                                                         Buffer::RawSlice* slices,
410
6
                                                         uint64_t num_slice) {
411
6
  absl::optional<Api::IoCallUint64Result> read_result = checkReadResult();
412
6
  if (read_result.has_value()) {
413
1
    return std::move(*read_result);
414
1
  }
415

            
416
5
  const OptRef<Io::ReadParam>& read_param = io_uring_socket_->getReadParam();
417
5
  const uint64_t max_read_length = std::min(max_length, static_cast<uint64_t>(read_param->result_));
418
5
  uint64_t num_bytes_to_read = read_param->buf_.copyOutToSlices(max_read_length, slices, num_slice);
419
5
  return {num_bytes_to_read, IoSocketError::none()};
420
6
}
421

            
422
} // namespace Network
423
} // namespace Envoy