/src/thrift/lib/rs/src/transport/socket.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::convert::From; |
19 | | use std::io; |
20 | | use std::io::{ErrorKind, Read, Write}; |
21 | | use std::net::{Shutdown, TcpStream, ToSocketAddrs}; |
22 | | use std::time::Duration; |
23 | | |
24 | | #[cfg(unix)] |
25 | | use std::os::unix::net::UnixStream; |
26 | | |
27 | | use super::{ReadHalf, TIoChannel, WriteHalf}; |
28 | | use crate::{new_transport_error, TransportErrorKind}; |
29 | | |
30 | | /// Bidirectional TCP/IP channel. |
31 | | /// |
32 | | /// # Examples |
33 | | /// |
34 | | /// Create a `TTcpChannel`. |
35 | | /// |
36 | | /// ```no_run |
37 | | /// use std::io::{Read, Write}; |
38 | | /// use thrift::transport::TTcpChannel; |
39 | | /// |
40 | | /// let mut c = TTcpChannel::new(); |
41 | | /// c.open("localhost:9090").unwrap(); |
42 | | /// |
43 | | /// let mut buf = vec![0u8; 4]; |
44 | | /// c.read(&mut buf).unwrap(); |
45 | | /// c.write(&vec![0, 1, 2]).unwrap(); |
46 | | /// ``` |
47 | | /// |
48 | | /// Create a `TTcpChannel` by wrapping an existing `TcpStream`. |
49 | | /// |
50 | | /// ```no_run |
51 | | /// use std::io::{Read, Write}; |
52 | | /// use std::net::TcpStream; |
53 | | /// use thrift::transport::TTcpChannel; |
54 | | /// |
55 | | /// let stream = TcpStream::connect("127.0.0.1:9189").unwrap(); |
56 | | /// stream.set_nodelay(true).unwrap(); |
57 | | /// |
58 | | /// // no need to call c.open() since we've already connected above |
59 | | /// let mut c = TTcpChannel::with_stream(stream); |
60 | | /// |
61 | | /// let mut buf = vec![0u8; 4]; |
62 | | /// c.read(&mut buf).unwrap(); |
63 | | /// c.write(&vec![0, 1, 2]).unwrap(); |
64 | | /// ``` |
65 | | #[derive(Debug, Default)] |
66 | | pub struct TTcpChannel { |
67 | | stream: Option<TcpStream>, |
68 | | read_timeout: Option<Duration>, |
69 | | write_timeout: Option<Duration>, |
70 | | } |
71 | | |
72 | | impl TTcpChannel { |
73 | | /// Create an uninitialized `TTcpChannel`. |
74 | | /// |
75 | | /// The returned instance must be opened using `TTcpChannel::open(...)` |
76 | | /// before it can be used. |
77 | | pub fn new() -> TTcpChannel { |
78 | | TTcpChannel { |
79 | | stream: None, |
80 | | read_timeout: None, |
81 | | write_timeout: None, |
82 | | } |
83 | | } |
84 | | |
85 | | /// Create a `TTcpChannel` that wraps an existing `TcpStream`. |
86 | | /// |
87 | | /// The passed-in stream is assumed to have been opened before being wrapped |
88 | | /// by the created `TTcpChannel` instance. |
89 | | pub fn with_stream(stream: TcpStream) -> TTcpChannel { |
90 | | let read_timeout = stream.read_timeout().unwrap_or_default(); |
91 | | let write_timeout = stream.write_timeout().unwrap_or_default(); |
92 | | |
93 | | TTcpChannel { |
94 | | stream: Some(stream), |
95 | | read_timeout, |
96 | | write_timeout, |
97 | | } |
98 | | } |
99 | | |
100 | | /// Return the read timeout for this channel. |
101 | | pub fn read_timeout(&self) -> crate::Result<Option<Duration>> { |
102 | | if let Some(ref stream) = self.stream { |
103 | | stream.read_timeout().map_err(From::from) |
104 | | } else { |
105 | | Ok(self.read_timeout) |
106 | | } |
107 | | } |
108 | | |
109 | | /// Return the write timeout for this channel. |
110 | | pub fn write_timeout(&self) -> crate::Result<Option<Duration>> { |
111 | | if let Some(ref stream) = self.stream { |
112 | | stream.write_timeout().map_err(From::from) |
113 | | } else { |
114 | | Ok(self.write_timeout) |
115 | | } |
116 | | } |
117 | | |
118 | | /// Set the read timeout for this channel. |
119 | | pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> { |
120 | | if let Some(ref stream) = self.stream { |
121 | | stream.set_read_timeout(timeout)?; |
122 | | } |
123 | | |
124 | | self.read_timeout = timeout; |
125 | | Ok(()) |
126 | | } |
127 | | |
128 | | /// Set the write timeout for this channel. |
129 | | pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> { |
130 | | if let Some(ref stream) = self.stream { |
131 | | stream.set_write_timeout(timeout)?; |
132 | | } |
133 | | |
134 | | self.write_timeout = timeout; |
135 | | Ok(()) |
136 | | } |
137 | | |
138 | | /// Set the read and write timeouts for this channel. |
139 | | pub fn set_timeouts( |
140 | | &mut self, |
141 | | read_timeout: Option<Duration>, |
142 | | write_timeout: Option<Duration>, |
143 | | ) -> crate::Result<()> { |
144 | | self.set_read_timeout(read_timeout)?; |
145 | | self.set_write_timeout(write_timeout)?; |
146 | | Ok(()) |
147 | | } |
148 | | |
149 | | /// Connect to `remote_address`, which should implement `ToSocketAddrs` trait. |
150 | | pub fn open<A: ToSocketAddrs>(&mut self, remote_address: A) -> crate::Result<()> { |
151 | | if self.stream.is_some() { |
152 | | Err(new_transport_error( |
153 | | TransportErrorKind::AlreadyOpen, |
154 | | "tcp connection previously opened", |
155 | | )) |
156 | | } else { |
157 | | match TcpStream::connect(&remote_address) { |
158 | | Ok(s) => { |
159 | | s.set_nodelay(true)?; |
160 | | s.set_read_timeout(self.read_timeout)?; |
161 | | s.set_write_timeout(self.write_timeout)?; |
162 | | self.stream = Some(s); |
163 | | Ok(()) |
164 | | } |
165 | | Err(e) => Err(From::from(e)), |
166 | | } |
167 | | } |
168 | | } |
169 | | |
170 | | /// Shut down this channel. |
171 | | /// |
172 | | /// Both send and receive halves are closed, and this instance can no |
173 | | /// longer be used to communicate with another endpoint. |
174 | | pub fn close(&mut self) -> crate::Result<()> { |
175 | 0 | self.if_set(|s| s.shutdown(Shutdown::Both)) |
176 | | .map_err(From::from) |
177 | | } |
178 | | |
179 | 0 | fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T> |
180 | 0 | where |
181 | 0 | F: FnMut(&mut TcpStream) -> io::Result<T>, |
182 | | { |
183 | 0 | if let Some(ref mut s) = self.stream { |
184 | 0 | stream_operation(s) |
185 | | } else { |
186 | 0 | Err(io::Error::new( |
187 | 0 | ErrorKind::NotConnected, |
188 | 0 | "tcp endpoint not connected", |
189 | 0 | )) |
190 | | } |
191 | 0 | } Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel>::close::{closure#0}, ()>Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Read>::read::{closure#0}, usize>Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Write>::flush::{closure#0}, ()>Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Write>::write::{closure#0}, usize> |
192 | | } |
193 | | |
194 | | impl TIoChannel for TTcpChannel { |
195 | | fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)> |
196 | | where |
197 | | Self: Sized, |
198 | | { |
199 | | let mut s = self; |
200 | | |
201 | | s.stream |
202 | | .as_mut() |
203 | 0 | .and_then(|s| s.try_clone().ok()) |
204 | 0 | .map(|cloned| { |
205 | | // Read from the socket so both halves start with consistent values. |
206 | 0 | let read_timeout = s |
207 | 0 | .stream |
208 | 0 | .as_ref() |
209 | 0 | .and_then(|st| st.read_timeout().ok()) |
210 | 0 | .unwrap_or(s.read_timeout); |
211 | 0 | let write_timeout = s |
212 | 0 | .stream |
213 | 0 | .as_ref() |
214 | 0 | .and_then(|st| st.write_timeout().ok()) |
215 | 0 | .unwrap_or(s.write_timeout); |
216 | | |
217 | 0 | let read_half = ReadHalf::new(TTcpChannel { |
218 | 0 | stream: s.stream.take(), |
219 | 0 | read_timeout, |
220 | 0 | write_timeout, |
221 | 0 | }); |
222 | 0 | let write_half = WriteHalf::new(TTcpChannel { |
223 | 0 | stream: Some(cloned), |
224 | 0 | read_timeout, |
225 | 0 | write_timeout, |
226 | 0 | }); |
227 | 0 | (read_half, write_half) |
228 | 0 | }) |
229 | 0 | .ok_or_else(|| { |
230 | 0 | new_transport_error( |
231 | 0 | TransportErrorKind::Unknown, |
232 | | "cannot clone underlying tcp stream", |
233 | | ) |
234 | 0 | }) |
235 | | } |
236 | | } |
237 | | |
238 | | impl Read for TTcpChannel { |
239 | | fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { |
240 | 0 | self.if_set(|s| s.read(b)) |
241 | | } |
242 | | } |
243 | | |
244 | | impl Write for TTcpChannel { |
245 | | fn write(&mut self, b: &[u8]) -> io::Result<usize> { |
246 | 0 | self.if_set(|s| s.write(b)) |
247 | | } |
248 | | |
249 | | fn flush(&mut self) -> io::Result<()> { |
250 | 0 | self.if_set(|s| s.flush()) |
251 | | } |
252 | | } |
253 | | |
254 | | #[cfg(unix)] |
255 | | impl TIoChannel for UnixStream { |
256 | | fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)> |
257 | | where |
258 | | Self: Sized, |
259 | | { |
260 | | let socket_rx = self.try_clone().unwrap(); |
261 | | |
262 | | Ok((ReadHalf::new(self), WriteHalf::new(socket_rx))) |
263 | | } |
264 | | } |
265 | | |
266 | | #[cfg(test)] |
267 | | mod tests { |
268 | | use std::net::{SocketAddr, TcpListener, TcpStream}; |
269 | | use std::thread; |
270 | | use std::time::Duration; |
271 | | |
272 | | use super::*; |
273 | | |
274 | | fn listening_address() -> (TcpListener, SocketAddr) { |
275 | | let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
276 | | let address = listener.local_addr().unwrap(); |
277 | | (listener, address) |
278 | | } |
279 | | |
280 | | fn connected_streams() -> (TcpStream, TcpStream) { |
281 | | let (listener, address) = listening_address(); |
282 | | let accept_handle = thread::spawn(move || listener.accept().unwrap().0); |
283 | | let client = TcpStream::connect(address).unwrap(); |
284 | | let server = accept_handle.join().unwrap(); |
285 | | (client, server) |
286 | | } |
287 | | |
288 | | fn wrapped_channel() -> (TTcpChannel, TcpStream) { |
289 | | let (client, server) = connected_streams(); |
290 | | (TTcpChannel::with_stream(client), server) |
291 | | } |
292 | | |
293 | | fn assert_channel_timeouts( |
294 | | channel: &TTcpChannel, |
295 | | read_timeout: Option<Duration>, |
296 | | write_timeout: Option<Duration>, |
297 | | ) { |
298 | | assert_eq!(channel.read_timeout().unwrap(), read_timeout); |
299 | | assert_eq!(channel.write_timeout().unwrap(), write_timeout); |
300 | | } |
301 | | |
302 | | fn assert_stream_timeouts( |
303 | | channel: &TTcpChannel, |
304 | | read_timeout: Option<Duration>, |
305 | | write_timeout: Option<Duration>, |
306 | | ) { |
307 | | let stream = channel.stream.as_ref().unwrap(); |
308 | | assert_eq!(stream.read_timeout().unwrap(), read_timeout); |
309 | | assert_eq!(stream.write_timeout().unwrap(), write_timeout); |
310 | | } |
311 | | |
312 | | #[test] |
313 | | fn must_store_read_timeout_before_open() { |
314 | | let timeout = Some(Duration::from_millis(80)); |
315 | | let mut channel = TTcpChannel::new(); |
316 | | |
317 | | channel.set_read_timeout(timeout).unwrap(); |
318 | | |
319 | | assert_channel_timeouts(&channel, timeout, None); |
320 | | } |
321 | | |
322 | | #[test] |
323 | | fn must_apply_timeouts_when_opening_channel() { |
324 | | let timeout = Some(Duration::from_millis(80)); |
325 | | let (listener, address) = listening_address(); |
326 | | let accept_handle = thread::spawn(move || listener.accept().unwrap().0); |
327 | | let mut channel = TTcpChannel::new(); |
328 | | |
329 | | channel.set_timeouts(timeout, timeout).unwrap(); |
330 | | channel.open(address).unwrap(); |
331 | | |
332 | | assert_channel_timeouts(&channel, timeout, timeout); |
333 | | assert_stream_timeouts(&channel, timeout, timeout); |
334 | | let _server = accept_handle.join().unwrap(); |
335 | | } |
336 | | |
337 | | #[test] |
338 | | fn must_enforce_read_timeout_set_before_open() { |
339 | | let timeout = Duration::from_millis(80); |
340 | | let (listener, address) = listening_address(); |
341 | | let server_handle = thread::spawn(move || { |
342 | | let stream = listener.accept().unwrap().0; |
343 | | thread::sleep(Duration::from_millis(200)); |
344 | | drop(stream); |
345 | | }); |
346 | | let mut channel = TTcpChannel::new(); |
347 | | let mut buf = [0; 1]; |
348 | | |
349 | | channel.set_read_timeout(Some(timeout)).unwrap(); |
350 | | channel.open(address).unwrap(); |
351 | | |
352 | | let err = channel.read(&mut buf).unwrap_err(); |
353 | | |
354 | | assert!(matches!( |
355 | | err.kind(), |
356 | | ErrorKind::TimedOut | ErrorKind::WouldBlock |
357 | | )); |
358 | | server_handle.join().unwrap(); |
359 | | } |
360 | | |
361 | | #[test] |
362 | | fn must_set_read_timeout_on_wrapped_stream() { |
363 | | let timeout = Some(Duration::from_millis(80)); |
364 | | let (mut channel, _server) = wrapped_channel(); |
365 | | |
366 | | channel.set_read_timeout(timeout).unwrap(); |
367 | | |
368 | | assert_channel_timeouts(&channel, timeout, None); |
369 | | assert_stream_timeouts(&channel, timeout, None); |
370 | | } |
371 | | |
372 | | #[test] |
373 | | fn must_set_write_timeout_on_wrapped_stream() { |
374 | | let timeout = Some(Duration::from_millis(80)); |
375 | | let (mut channel, _server) = wrapped_channel(); |
376 | | |
377 | | channel.set_write_timeout(timeout).unwrap(); |
378 | | |
379 | | assert_channel_timeouts(&channel, None, timeout); |
380 | | assert_stream_timeouts(&channel, None, timeout); |
381 | | } |
382 | | |
383 | | #[test] |
384 | | fn must_set_both_timeouts_on_wrapped_stream() { |
385 | | let read_timeout = Some(Duration::from_millis(80)); |
386 | | let write_timeout = Some(Duration::from_millis(120)); |
387 | | let (mut channel, _server) = wrapped_channel(); |
388 | | |
389 | | channel.set_timeouts(read_timeout, write_timeout).unwrap(); |
390 | | |
391 | | assert_channel_timeouts(&channel, read_timeout, write_timeout); |
392 | | assert_stream_timeouts(&channel, read_timeout, write_timeout); |
393 | | } |
394 | | |
395 | | #[test] |
396 | | fn must_clear_timeouts_on_wrapped_stream() { |
397 | | let timeout = Some(Duration::from_millis(80)); |
398 | | let (mut channel, _server) = wrapped_channel(); |
399 | | |
400 | | channel.set_timeouts(timeout, timeout).unwrap(); |
401 | | channel.set_timeouts(None, None).unwrap(); |
402 | | |
403 | | assert_channel_timeouts(&channel, None, None); |
404 | | assert_stream_timeouts(&channel, None, None); |
405 | | } |
406 | | |
407 | | #[test] |
408 | | fn must_store_timeouts_from_wrapped_stream() { |
409 | | let read_timeout = Some(Duration::from_millis(80)); |
410 | | let write_timeout = Some(Duration::from_millis(120)); |
411 | | let (client, _server) = connected_streams(); |
412 | | |
413 | | client.set_read_timeout(read_timeout).unwrap(); |
414 | | client.set_write_timeout(write_timeout).unwrap(); |
415 | | let channel = TTcpChannel::with_stream(client); |
416 | | |
417 | | assert_channel_timeouts(&channel, read_timeout, write_timeout); |
418 | | assert_stream_timeouts(&channel, read_timeout, write_timeout); |
419 | | } |
420 | | |
421 | | /// Regression: after split() one half must not clobber the other's timeout. |
422 | | #[test] |
423 | | fn split_halves_must_not_clobber_each_others_timeout() { |
424 | | let initial = Some(Duration::from_millis(80)); |
425 | | let updated = Some(Duration::from_millis(250)); |
426 | | let updated_write = Some(Duration::from_millis(500)); |
427 | | let (mut channel, _server) = wrapped_channel(); |
428 | | |
429 | | channel.set_timeouts(initial, None).unwrap(); |
430 | | let (mut read_half, mut write_half) = channel.split().unwrap(); |
431 | | |
432 | | read_half.set_read_timeout(updated).unwrap(); |
433 | | |
434 | | // Both halves share the same socket, so write_half must see the new value. |
435 | | let seen = write_half.read_timeout().unwrap(); |
436 | | assert_eq!(seen, updated); |
437 | | |
438 | | // set_timeouts must not write the stale pre-split value back. |
439 | | write_half.set_timeouts(seen, updated_write).unwrap(); |
440 | | |
441 | | assert_eq!(read_half.read_timeout().unwrap(), updated); |
442 | | assert_eq!(write_half.write_timeout().unwrap(), updated_write); |
443 | | } |
444 | | } |