Coverage Report

Created: 2026-05-30 06:21

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/thrift/lib/rs/src/transport/socket.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements. See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership. The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License. You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied. See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::convert::From;
19
use std::io;
20
use std::io::{ErrorKind, Read, Write};
21
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
22
use std::time::Duration;
23
24
#[cfg(unix)]
25
use std::os::unix::net::UnixStream;
26
27
use super::{ReadHalf, TIoChannel, WriteHalf};
28
use crate::{new_transport_error, TransportErrorKind};
29
30
/// Bidirectional TCP/IP channel.
31
///
32
/// # Examples
33
///
34
/// Create a `TTcpChannel`.
35
///
36
/// ```no_run
37
/// use std::io::{Read, Write};
38
/// use thrift::transport::TTcpChannel;
39
///
40
/// let mut c = TTcpChannel::new();
41
/// c.open("localhost:9090").unwrap();
42
///
43
/// let mut buf = vec![0u8; 4];
44
/// c.read(&mut buf).unwrap();
45
/// c.write(&vec![0, 1, 2]).unwrap();
46
/// ```
47
///
48
/// Create a `TTcpChannel` by wrapping an existing `TcpStream`.
49
///
50
/// ```no_run
51
/// use std::io::{Read, Write};
52
/// use std::net::TcpStream;
53
/// use thrift::transport::TTcpChannel;
54
///
55
/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap();
56
/// stream.set_nodelay(true).unwrap();
57
///
58
/// // no need to call c.open() since we've already connected above
59
/// let mut c = TTcpChannel::with_stream(stream);
60
///
61
/// let mut buf = vec![0u8; 4];
62
/// c.read(&mut buf).unwrap();
63
/// c.write(&vec![0, 1, 2]).unwrap();
64
/// ```
65
#[derive(Debug, Default)]
66
pub struct TTcpChannel {
67
    stream: Option<TcpStream>,
68
    read_timeout: Option<Duration>,
69
    write_timeout: Option<Duration>,
70
}
71
72
impl TTcpChannel {
73
    /// Create an uninitialized `TTcpChannel`.
74
    ///
75
    /// The returned instance must be opened using `TTcpChannel::open(...)`
76
    /// before it can be used.
77
    pub fn new() -> TTcpChannel {
78
        TTcpChannel {
79
            stream: None,
80
            read_timeout: None,
81
            write_timeout: None,
82
        }
83
    }
84
85
    /// Create a `TTcpChannel` that wraps an existing `TcpStream`.
86
    ///
87
    /// The passed-in stream is assumed to have been opened before being wrapped
88
    /// by the created `TTcpChannel` instance.
89
    pub fn with_stream(stream: TcpStream) -> TTcpChannel {
90
        let read_timeout = stream.read_timeout().unwrap_or_default();
91
        let write_timeout = stream.write_timeout().unwrap_or_default();
92
93
        TTcpChannel {
94
            stream: Some(stream),
95
            read_timeout,
96
            write_timeout,
97
        }
98
    }
99
100
    /// Return the read timeout for this channel.
101
    pub fn read_timeout(&self) -> crate::Result<Option<Duration>> {
102
        if let Some(ref stream) = self.stream {
103
            stream.read_timeout().map_err(From::from)
104
        } else {
105
            Ok(self.read_timeout)
106
        }
107
    }
108
109
    /// Return the write timeout for this channel.
110
    pub fn write_timeout(&self) -> crate::Result<Option<Duration>> {
111
        if let Some(ref stream) = self.stream {
112
            stream.write_timeout().map_err(From::from)
113
        } else {
114
            Ok(self.write_timeout)
115
        }
116
    }
117
118
    /// Set the read timeout for this channel.
119
    pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> {
120
        if let Some(ref stream) = self.stream {
121
            stream.set_read_timeout(timeout)?;
122
        }
123
124
        self.read_timeout = timeout;
125
        Ok(())
126
    }
127
128
    /// Set the write timeout for this channel.
129
    pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> {
130
        if let Some(ref stream) = self.stream {
131
            stream.set_write_timeout(timeout)?;
132
        }
133
134
        self.write_timeout = timeout;
135
        Ok(())
136
    }
137
138
    /// Set the read and write timeouts for this channel.
139
    pub fn set_timeouts(
140
        &mut self,
141
        read_timeout: Option<Duration>,
142
        write_timeout: Option<Duration>,
143
    ) -> crate::Result<()> {
144
        self.set_read_timeout(read_timeout)?;
145
        self.set_write_timeout(write_timeout)?;
146
        Ok(())
147
    }
148
149
    /// Connect to `remote_address`, which should implement `ToSocketAddrs` trait.
150
    pub fn open<A: ToSocketAddrs>(&mut self, remote_address: A) -> crate::Result<()> {
151
        if self.stream.is_some() {
152
            Err(new_transport_error(
153
                TransportErrorKind::AlreadyOpen,
154
                "tcp connection previously opened",
155
            ))
156
        } else {
157
            match TcpStream::connect(&remote_address) {
158
                Ok(s) => {
159
                    s.set_nodelay(true)?;
160
                    s.set_read_timeout(self.read_timeout)?;
161
                    s.set_write_timeout(self.write_timeout)?;
162
                    self.stream = Some(s);
163
                    Ok(())
164
                }
165
                Err(e) => Err(From::from(e)),
166
            }
167
        }
168
    }
169
170
    /// Shut down this channel.
171
    ///
172
    /// Both send and receive halves are closed, and this instance can no
173
    /// longer be used to communicate with another endpoint.
174
    pub fn close(&mut self) -> crate::Result<()> {
175
0
        self.if_set(|s| s.shutdown(Shutdown::Both))
176
            .map_err(From::from)
177
    }
178
179
0
    fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T>
180
0
    where
181
0
        F: FnMut(&mut TcpStream) -> io::Result<T>,
182
    {
183
0
        if let Some(ref mut s) = self.stream {
184
0
            stream_operation(s)
185
        } else {
186
0
            Err(io::Error::new(
187
0
                ErrorKind::NotConnected,
188
0
                "tcp endpoint not connected",
189
0
            ))
190
        }
191
0
    }
Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel>::close::{closure#0}, ()>
Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Read>::read::{closure#0}, usize>
Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Write>::flush::{closure#0}, ()>
Unexecuted instantiation: <thrift::transport::socket::TTcpChannel>::if_set::<<thrift::transport::socket::TTcpChannel as std::io::Write>::write::{closure#0}, usize>
192
}
193
194
impl TIoChannel for TTcpChannel {
195
    fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
196
    where
197
        Self: Sized,
198
    {
199
        let mut s = self;
200
201
        s.stream
202
            .as_mut()
203
0
            .and_then(|s| s.try_clone().ok())
204
0
            .map(|cloned| {
205
                // Read from the socket so both halves start with consistent values.
206
0
                let read_timeout = s
207
0
                    .stream
208
0
                    .as_ref()
209
0
                    .and_then(|st| st.read_timeout().ok())
210
0
                    .unwrap_or(s.read_timeout);
211
0
                let write_timeout = s
212
0
                    .stream
213
0
                    .as_ref()
214
0
                    .and_then(|st| st.write_timeout().ok())
215
0
                    .unwrap_or(s.write_timeout);
216
217
0
                let read_half = ReadHalf::new(TTcpChannel {
218
0
                    stream: s.stream.take(),
219
0
                    read_timeout,
220
0
                    write_timeout,
221
0
                });
222
0
                let write_half = WriteHalf::new(TTcpChannel {
223
0
                    stream: Some(cloned),
224
0
                    read_timeout,
225
0
                    write_timeout,
226
0
                });
227
0
                (read_half, write_half)
228
0
            })
229
0
            .ok_or_else(|| {
230
0
                new_transport_error(
231
0
                    TransportErrorKind::Unknown,
232
                    "cannot clone underlying tcp stream",
233
                )
234
0
            })
235
    }
236
}
237
238
impl Read for TTcpChannel {
239
    fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
240
0
        self.if_set(|s| s.read(b))
241
    }
242
}
243
244
impl Write for TTcpChannel {
245
    fn write(&mut self, b: &[u8]) -> io::Result<usize> {
246
0
        self.if_set(|s| s.write(b))
247
    }
248
249
    fn flush(&mut self) -> io::Result<()> {
250
0
        self.if_set(|s| s.flush())
251
    }
252
}
253
254
#[cfg(unix)]
255
impl TIoChannel for UnixStream {
256
    fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
257
    where
258
        Self: Sized,
259
    {
260
        let socket_rx = self.try_clone().unwrap();
261
262
        Ok((ReadHalf::new(self), WriteHalf::new(socket_rx)))
263
    }
264
}
265
266
#[cfg(test)]
267
mod tests {
268
    use std::net::{SocketAddr, TcpListener, TcpStream};
269
    use std::thread;
270
    use std::time::Duration;
271
272
    use super::*;
273
274
    fn listening_address() -> (TcpListener, SocketAddr) {
275
        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
276
        let address = listener.local_addr().unwrap();
277
        (listener, address)
278
    }
279
280
    fn connected_streams() -> (TcpStream, TcpStream) {
281
        let (listener, address) = listening_address();
282
        let accept_handle = thread::spawn(move || listener.accept().unwrap().0);
283
        let client = TcpStream::connect(address).unwrap();
284
        let server = accept_handle.join().unwrap();
285
        (client, server)
286
    }
287
288
    fn wrapped_channel() -> (TTcpChannel, TcpStream) {
289
        let (client, server) = connected_streams();
290
        (TTcpChannel::with_stream(client), server)
291
    }
292
293
    fn assert_channel_timeouts(
294
        channel: &TTcpChannel,
295
        read_timeout: Option<Duration>,
296
        write_timeout: Option<Duration>,
297
    ) {
298
        assert_eq!(channel.read_timeout().unwrap(), read_timeout);
299
        assert_eq!(channel.write_timeout().unwrap(), write_timeout);
300
    }
301
302
    fn assert_stream_timeouts(
303
        channel: &TTcpChannel,
304
        read_timeout: Option<Duration>,
305
        write_timeout: Option<Duration>,
306
    ) {
307
        let stream = channel.stream.as_ref().unwrap();
308
        assert_eq!(stream.read_timeout().unwrap(), read_timeout);
309
        assert_eq!(stream.write_timeout().unwrap(), write_timeout);
310
    }
311
312
    #[test]
313
    fn must_store_read_timeout_before_open() {
314
        let timeout = Some(Duration::from_millis(80));
315
        let mut channel = TTcpChannel::new();
316
317
        channel.set_read_timeout(timeout).unwrap();
318
319
        assert_channel_timeouts(&channel, timeout, None);
320
    }
321
322
    #[test]
323
    fn must_apply_timeouts_when_opening_channel() {
324
        let timeout = Some(Duration::from_millis(80));
325
        let (listener, address) = listening_address();
326
        let accept_handle = thread::spawn(move || listener.accept().unwrap().0);
327
        let mut channel = TTcpChannel::new();
328
329
        channel.set_timeouts(timeout, timeout).unwrap();
330
        channel.open(address).unwrap();
331
332
        assert_channel_timeouts(&channel, timeout, timeout);
333
        assert_stream_timeouts(&channel, timeout, timeout);
334
        let _server = accept_handle.join().unwrap();
335
    }
336
337
    #[test]
338
    fn must_enforce_read_timeout_set_before_open() {
339
        let timeout = Duration::from_millis(80);
340
        let (listener, address) = listening_address();
341
        let server_handle = thread::spawn(move || {
342
            let stream = listener.accept().unwrap().0;
343
            thread::sleep(Duration::from_millis(200));
344
            drop(stream);
345
        });
346
        let mut channel = TTcpChannel::new();
347
        let mut buf = [0; 1];
348
349
        channel.set_read_timeout(Some(timeout)).unwrap();
350
        channel.open(address).unwrap();
351
352
        let err = channel.read(&mut buf).unwrap_err();
353
354
        assert!(matches!(
355
            err.kind(),
356
            ErrorKind::TimedOut | ErrorKind::WouldBlock
357
        ));
358
        server_handle.join().unwrap();
359
    }
360
361
    #[test]
362
    fn must_set_read_timeout_on_wrapped_stream() {
363
        let timeout = Some(Duration::from_millis(80));
364
        let (mut channel, _server) = wrapped_channel();
365
366
        channel.set_read_timeout(timeout).unwrap();
367
368
        assert_channel_timeouts(&channel, timeout, None);
369
        assert_stream_timeouts(&channel, timeout, None);
370
    }
371
372
    #[test]
373
    fn must_set_write_timeout_on_wrapped_stream() {
374
        let timeout = Some(Duration::from_millis(80));
375
        let (mut channel, _server) = wrapped_channel();
376
377
        channel.set_write_timeout(timeout).unwrap();
378
379
        assert_channel_timeouts(&channel, None, timeout);
380
        assert_stream_timeouts(&channel, None, timeout);
381
    }
382
383
    #[test]
384
    fn must_set_both_timeouts_on_wrapped_stream() {
385
        let read_timeout = Some(Duration::from_millis(80));
386
        let write_timeout = Some(Duration::from_millis(120));
387
        let (mut channel, _server) = wrapped_channel();
388
389
        channel.set_timeouts(read_timeout, write_timeout).unwrap();
390
391
        assert_channel_timeouts(&channel, read_timeout, write_timeout);
392
        assert_stream_timeouts(&channel, read_timeout, write_timeout);
393
    }
394
395
    #[test]
396
    fn must_clear_timeouts_on_wrapped_stream() {
397
        let timeout = Some(Duration::from_millis(80));
398
        let (mut channel, _server) = wrapped_channel();
399
400
        channel.set_timeouts(timeout, timeout).unwrap();
401
        channel.set_timeouts(None, None).unwrap();
402
403
        assert_channel_timeouts(&channel, None, None);
404
        assert_stream_timeouts(&channel, None, None);
405
    }
406
407
    #[test]
408
    fn must_store_timeouts_from_wrapped_stream() {
409
        let read_timeout = Some(Duration::from_millis(80));
410
        let write_timeout = Some(Duration::from_millis(120));
411
        let (client, _server) = connected_streams();
412
413
        client.set_read_timeout(read_timeout).unwrap();
414
        client.set_write_timeout(write_timeout).unwrap();
415
        let channel = TTcpChannel::with_stream(client);
416
417
        assert_channel_timeouts(&channel, read_timeout, write_timeout);
418
        assert_stream_timeouts(&channel, read_timeout, write_timeout);
419
    }
420
421
    /// Regression: after split() one half must not clobber the other's timeout.
422
    #[test]
423
    fn split_halves_must_not_clobber_each_others_timeout() {
424
        let initial = Some(Duration::from_millis(80));
425
        let updated = Some(Duration::from_millis(250));
426
        let updated_write = Some(Duration::from_millis(500));
427
        let (mut channel, _server) = wrapped_channel();
428
429
        channel.set_timeouts(initial, None).unwrap();
430
        let (mut read_half, mut write_half) = channel.split().unwrap();
431
432
        read_half.set_read_timeout(updated).unwrap();
433
434
        // Both halves share the same socket, so write_half must see the new value.
435
        let seen = write_half.read_timeout().unwrap();
436
        assert_eq!(seen, updated);
437
438
        // set_timeouts must not write the stale pre-split value back.
439
        write_half.set_timeouts(seen, updated_write).unwrap();
440
441
        assert_eq!(read_half.read_timeout().unwrap(), updated);
442
        assert_eq!(write_half.write_timeout().unwrap(), updated_write);
443
    }
444
}