Coverage Report

Created: 2026-05-16 06:08

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/crosvm/third_party/vmm_vhost/src/lib.rs
Line
Count
Source
1
// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
3
4
//! Virtio Vhost Backend Drivers
5
//!
6
//! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue
7
//! is a set of three different single-producer, single-consumer ring structures designed to store
8
//! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact
9
//! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and
10
//! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means
11
//! the memory will be both read and written by both host and guest. The new Packed Virtqueue is
12
//! preferred for performance.
13
//!
14
//! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
15
//! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
16
//! completion interruption are piped through the hypervisor.
17
//! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
18
//! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
19
//! hypervisor process with an existing Virtio (PCI) driver.
20
//!
21
//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
22
//! communicate with userspace applications. Dedicated kernel worker threads are created to handle
23
//! IO requests from the guest.
24
//!
25
//! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
26
//! vhost implementation in the Linux kernel. It implements the control plane needed to establish
27
//! virtqueues sharing with a user space process on the same host. It uses communication over a
28
//! Unix domain socket to share file descriptors in the ancillary data of the message. The protocol
29
//! defines 2 sides of the communication, frontend and backend. Frontend is the application that
30
//! shares its virtqueues. Backend is the consumer of the virtqueues. Frontend and backend can be
31
//! either a client (i.e. connecting) or server (listening) in the socket communication.
32
33
use std::fs::File;
34
use std::io::Error as IOError;
35
use std::num::TryFromIntError;
36
37
use remain::sorted;
38
use thiserror::Error as ThisError;
39
40
mod backend;
41
pub use backend::*;
42
43
pub mod message;
44
pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45
46
pub mod connection;
47
48
mod sys;
49
pub use connection::Connection;
50
pub use message::BackendReq;
51
pub use message::FrontendReq;
52
#[cfg(unix)]
53
pub use sys::unix;
54
55
pub(crate) mod backend_client;
56
pub use backend_client::BackendClient;
57
mod frontend_server;
58
pub use self::frontend_server::Frontend;
59
mod backend_server;
60
mod frontend_client;
61
pub use self::backend_server::Backend;
62
pub use self::backend_server::BackendServer;
63
pub use self::frontend_client::FrontendClient;
64
pub use self::frontend_server::FrontendServer;
65
66
/// Errors for vhost-user operations
67
#[sorted]
68
#[derive(Debug, ThisError)]
69
pub enum Error {
70
    /// Failure from the backend side.
71
    #[error("backend internal error")]
72
    BackendInternalError,
73
    /// client exited properly.
74
    #[error("client exited properly")]
75
    ClientExit,
76
    /// client disconnected.
77
    /// If connection is closed properly, use `ClientExit` instead.
78
    #[error("client closed the connection")]
79
    Disconnect,
80
    #[error("Failed to enter suspended state")]
81
    EnterSuspendedState(anyhow::Error),
82
    /// Failure from the frontend side.
83
    #[error("frontend Internal error")]
84
    FrontendInternalError,
85
    /// Fd array in question is too big or too small
86
    #[error("wrong number of attached fds")]
87
    IncorrectFds,
88
    /// Invalid cast to int.
89
    #[error("invalid cast to int: {0}")]
90
    InvalidCastToInt(TryFromIntError),
91
    /// Invalid message format, flag or content.
92
    #[error("invalid message")]
93
    InvalidMessage,
94
    /// Unsupported operations due to that the protocol feature hasn't been negotiated.
95
    #[error("invalid operation")]
96
    InvalidOperation,
97
    /// Invalid parameters.
98
    #[error("invalid parameters: {0}")]
99
    InvalidParam(&'static str),
100
    /// Message is too large
101
    #[error("oversized message")]
102
    OversizedMsg,
103
    /// Only part of a message have been sent or received successfully
104
    #[error("partial message")]
105
    PartialMessage,
106
    /// Provided recv buffer was too small, and data was dropped.
107
    #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
108
    RecvBufferTooSmall {
109
        /// The size of the buffer received.
110
        got: usize,
111
        /// The expected size of the buffer.
112
        want: usize,
113
    },
114
    /// Error from request handler
115
    #[error("handler failed to handle request: {0}")]
116
    ReqHandlerError(#[source] IOError),
117
    /// Failure to restore.
118
    #[error("Failed to restore")]
119
    RestoreError(anyhow::Error),
120
    /// Failure to snapshot.
121
    #[error("Failed to snapshot")]
122
    SnapshotError(anyhow::Error),
123
    /// Generic socket errors.
124
    #[error("socket error: {0}")]
125
    SocketError(std::io::Error),
126
    /// Should retry the socket operation again.
127
    #[error("temporary socket error: {0}")]
128
    SocketRetry(std::io::Error),
129
    /// Error from tx/rx on a Tube.
130
    #[error("failed to read/write on Tube: {0}")]
131
    TubeError(base::TubeError),
132
}
133
134
/// Result of vhost-user operations
135
pub type Result<T> = std::result::Result<T, Error>;
136
137
/// Result of request handler.
138
pub type HandlerResult<T> = std::result::Result<T, IOError>;
139
140
#[derive(Copy, Clone)]
141
pub struct SharedMemoryRegion {
142
    /// The id of the shared memory region. A device may have multiple regions, but each
143
    /// must have a unique id. The meaning of a particular region is device-specific.
144
    pub id: u8,
145
    pub length: u64,
146
}
147
148
/// Utility function to convert a vector of files into a single file.
149
/// Returns `None` if the vector contains no files or more than one file.
150
0
pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
151
0
    if files.len() != 1 {
152
0
        return None;
153
0
    }
154
0
    Some(files.swap_remove(0))
155
0
}
156
157
#[cfg(test)]
158
mod test_backend;
159
160
#[cfg(test)]
161
mod tests {
162
    use std::io::ErrorKind;
163
    use std::sync::Arc;
164
    use std::sync::Barrier;
165
    use std::thread;
166
167
    use base::AsRawDescriptor;
168
    use tempfile::tempfile;
169
170
    use super::*;
171
    use crate::message::*;
172
    use crate::test_backend::TestBackend;
173
    use crate::test_backend::VIRTIO_FEATURES;
174
    use crate::VhostUserMemoryRegionInfo;
175
    use crate::VringConfigData;
176
177
    fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
178
    where
179
        S: Backend,
180
    {
181
        let (client_connection, server_connection) = Connection::pair().unwrap();
182
        let backend_client = BackendClient::new(client_connection);
183
        (
184
            backend_client,
185
            BackendServer::<S>::new(server_connection, backend),
186
        )
187
    }
188
189
    /// Utility function to process a header and a message together.
190
    fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
191
        // We assume that a header comes together with message body in tests so we don't wait before
192
        // calling `process_message()`.
193
        let (hdr, files) = h.recv_header()?;
194
        h.process_message(hdr, files)
195
    }
196
197
    #[test]
198
    fn create_test_backend() {
199
        let mut backend = TestBackend::new();
200
201
        backend.set_owner().unwrap();
202
        assert!(backend.set_owner().is_err());
203
    }
204
205
    #[test]
206
    fn test_set_owner() {
207
        let test_backend = TestBackend::new();
208
        let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
209
210
        assert!(!backend_server.as_ref().owned);
211
        backend_client.set_owner().unwrap();
212
        handle_request(&mut backend_server).unwrap();
213
        assert!(backend_server.as_ref().owned);
214
        backend_client.set_owner().unwrap();
215
        assert!(handle_request(&mut backend_server).is_err());
216
        assert!(backend_server.as_ref().owned);
217
    }
218
219
    #[test]
220
    fn test_set_features() {
221
        let mbar = Arc::new(Barrier::new(2));
222
        let sbar = mbar.clone();
223
        let test_backend = TestBackend::new();
224
        let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
225
226
        thread::spawn(move || {
227
            handle_request(&mut backend_server).unwrap();
228
            assert!(backend_server.as_ref().owned);
229
230
            handle_request(&mut backend_server).unwrap();
231
            handle_request(&mut backend_server).unwrap();
232
            assert_eq!(
233
                backend_server.as_ref().acked_features,
234
                VIRTIO_FEATURES & !0x1
235
            );
236
237
            handle_request(&mut backend_server).unwrap();
238
            handle_request(&mut backend_server).unwrap();
239
            assert_eq!(
240
                backend_server.as_ref().acked_protocol_features,
241
                VhostUserProtocolFeatures::all().bits()
242
            );
243
244
            sbar.wait();
245
        });
246
247
        backend_client.set_owner().unwrap();
248
249
        // set virtio features
250
        let features = backend_client.get_features().unwrap();
251
        assert_eq!(features, VIRTIO_FEATURES);
252
        backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
253
254
        // set vhost protocol features
255
        let features = backend_client.get_protocol_features().unwrap();
256
        assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
257
        backend_client.set_protocol_features(features).unwrap();
258
259
        mbar.wait();
260
    }
261
262
    #[test]
263
    fn test_client_server_process_no_need_reply() {
264
        test_client_server_process(false);
265
    }
266
267
    #[test]
268
    fn test_client_server_process_need_reply() {
269
        test_client_server_process(true);
270
    }
271
272
    fn test_client_server_process(set_need_reply: bool) {
273
        let mbar = Arc::new(Barrier::new(2));
274
        let sbar = mbar.clone();
275
        let test_backend = TestBackend::new();
276
        let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
277
278
        thread::spawn(move || {
279
            // set_own()
280
            handle_request(&mut backend_server).unwrap();
281
            assert!(backend_server.as_ref().owned);
282
283
            // get/set_features()
284
            handle_request(&mut backend_server).unwrap();
285
            handle_request(&mut backend_server).unwrap();
286
            assert_eq!(
287
                backend_server.as_ref().acked_features,
288
                VIRTIO_FEATURES & !0x1
289
            );
290
291
            handle_request(&mut backend_server).unwrap();
292
            handle_request(&mut backend_server).unwrap();
293
            assert_eq!(
294
                backend_server.as_ref().acked_protocol_features,
295
                VhostUserProtocolFeatures::all().bits()
296
            );
297
298
            // get_inflight_fd()
299
            handle_request(&mut backend_server).unwrap();
300
            // set_inflight_fd()
301
            handle_request(&mut backend_server).unwrap();
302
303
            // get_queue_num()
304
            handle_request(&mut backend_server).unwrap();
305
306
            // set_mem_table()
307
            handle_request(&mut backend_server).unwrap();
308
309
            // get/set_config()
310
            handle_request(&mut backend_server).unwrap();
311
            handle_request(&mut backend_server).unwrap();
312
313
            // set_backend_req_fd
314
            handle_request(&mut backend_server).unwrap();
315
316
            // set_vring_enable
317
            handle_request(&mut backend_server).unwrap();
318
319
            // set_vring_xxx
320
            handle_request(&mut backend_server).unwrap();
321
            handle_request(&mut backend_server).unwrap();
322
            handle_request(&mut backend_server).unwrap();
323
            handle_request(&mut backend_server).unwrap();
324
            handle_request(&mut backend_server).unwrap();
325
            handle_request(&mut backend_server).unwrap();
326
327
            // get_max_mem_slots()
328
            handle_request(&mut backend_server).unwrap();
329
330
            // add_mem_region()
331
            handle_request(&mut backend_server).unwrap();
332
333
            // remove_mem_region()
334
            handle_request(&mut backend_server).unwrap();
335
336
            // set_log_base
337
            //
338
            // Results in an error because it isn't implemented. When `set_need_reply` is true, the
339
            // client waits for an ACK that will never come, instead they will get an error only
340
            // when we drop `backend_server` below.
341
            handle_request(&mut backend_server).unwrap_err();
342
343
            std::mem::drop(backend_server);
344
            sbar.wait();
345
        });
346
347
        backend_client.set_owner().unwrap();
348
349
        // set virtio features
350
        let features = backend_client.get_features().unwrap();
351
        assert_eq!(features, VIRTIO_FEATURES);
352
        backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
353
354
        // set vhost protocol features
355
        let features = backend_client.get_protocol_features().unwrap();
356
        assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
357
        backend_client.set_protocol_features(features).unwrap();
358
359
        backend_client.set_need_reply(set_need_reply);
360
361
        // Retrieve inflight I/O tracking information
362
        let (inflight_info, inflight_file) = backend_client
363
            .get_inflight_fd(&VhostUserInflight {
364
                num_queues: 2,
365
                queue_size: 256,
366
                ..Default::default()
367
            })
368
            .unwrap();
369
        // Set the buffer back to the backend
370
        backend_client
371
            .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
372
            .unwrap();
373
374
        let num = backend_client.get_queue_num().unwrap();
375
        assert_eq!(num, 2);
376
377
        let event = base::Event::new().unwrap();
378
        let mem = [VhostUserMemoryRegionInfo {
379
            guest_phys_addr: 0,
380
            memory_size: 0x10_0000,
381
            userspace_addr: 0,
382
            mmap_offset: 0,
383
            mmap_handle: event.as_raw_descriptor(),
384
        }];
385
        backend_client.set_mem_table(&mem).unwrap();
386
387
        backend_client
388
            .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
389
            .unwrap();
390
        let buf = [0x0u8; 4];
391
        let (reply_body, reply_payload) = backend_client
392
            .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
393
            .unwrap();
394
        let offset = reply_body.offset;
395
        assert_eq!(offset, 0x100);
396
        assert_eq!(reply_payload[0], 0xa5);
397
398
        #[cfg(windows)]
399
        let tubes = base::Tube::pair().unwrap();
400
        #[cfg(windows)]
401
        let descriptor =
402
            // SAFETY:
403
            // Safe because we will be importing the Tube in the other thread.
404
            unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
405
406
        #[cfg(unix)]
407
        let descriptor = base::Event::new().unwrap();
408
409
        backend_client.set_backend_req_fd(&descriptor).unwrap();
410
        backend_client.set_vring_enable(0, true).unwrap();
411
412
        backend_client.set_vring_num(0, 256).unwrap();
413
        backend_client.set_vring_base(0, 0).unwrap();
414
        let config = VringConfigData {
415
            queue_size: 128,
416
            flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
417
            desc_table_addr: 0x1000,
418
            used_ring_addr: 0x2000,
419
            avail_ring_addr: 0x3000,
420
            log_addr: Some(0x4000),
421
        };
422
        backend_client.set_vring_addr(0, &config).unwrap();
423
        backend_client.set_vring_call(0, &event).unwrap();
424
        backend_client.set_vring_kick(0, &event).unwrap();
425
        backend_client.set_vring_err(0, &event).unwrap();
426
427
        let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
428
        assert_eq!(max_mem_slots, 32);
429
430
        let region_file = tempfile().unwrap();
431
        let region = VhostUserMemoryRegionInfo {
432
            guest_phys_addr: 0x10_0000,
433
            memory_size: 0x10_0000,
434
            userspace_addr: 0,
435
            mmap_offset: 0,
436
            mmap_handle: region_file.as_raw_descriptor(),
437
        };
438
        backend_client.add_mem_region(&region).unwrap();
439
440
        backend_client.remove_mem_region(&region).unwrap();
441
442
        // set_log_base isn't implemented by the server and so will break the connection.
443
        let result = backend_client.set_log_base(0, Some(event.as_raw_descriptor()));
444
        if set_need_reply {
445
            // When using `set_need_reply`, we'll get an immediate disconnect error.
446
            assert!(
447
                matches!(result, Err(Error::Disconnect)),
448
                "unexpected result: {result:?}"
449
            );
450
        } else {
451
            // When not using `set_need_reply`, it will seem to succeed and then the next request
452
            // will fail.
453
            result.unwrap();
454
            let result = backend_client.get_features();
455
            match &result {
456
                // Windows errors with Disconnect and Unix with a SocketError.
457
                Err(Error::Disconnect) => {}
458
                Err(Error::SocketError(e))
459
                    if e.kind() == ErrorKind::ConnectionReset
460
                        || e.kind() == ErrorKind::BrokenPipe => {}
461
                _ => panic!("unexpected result: {result:?}"),
462
            }
463
        }
464
465
        mbar.wait();
466
    }
467
468
    #[test]
469
    fn test_error_display() {
470
        assert_eq!(
471
            format!("{}", Error::InvalidParam("")),
472
            "invalid parameters: "
473
        );
474
        assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
475
    }
476
}