Coverage Report

Created: 2024-05-21 06:19

/rust/git/checkouts/vfio-user-1ee9f6371fec66a1/a1f6e52/src/lib.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright © 2021 Intel Corporation
2
//
3
// SPDX-License-Identifier: Apache-2.0
4
//
5
6
use bitflags::bitflags;
7
use libc::{c_void, iovec, EINVAL};
8
use libc::{sysconf, _SC_PAGESIZE};
9
use std::ffi::CString;
10
use std::fs::File;
11
use std::io::{IoSlice, Read, Write};
12
use std::mem::size_of;
13
use std::num::Wrapping;
14
use std::os::unix::{
15
    io::{FromRawFd, RawFd},
16
    net::{UnixListener, UnixStream},
17
};
18
use std::path::Path;
19
use thiserror::Error;
20
use vfio_bindings::bindings::vfio::*;
21
use vm_memory::{ByteValued, FileOffset};
22
use vmm_sys_util::sock_ctrl_msg::ScmSocket;
23
24
#[macro_use]
25
extern crate serde_derive;
26
27
#[macro_use]
28
extern crate log;
29
30
#[allow(dead_code)]
31
#[repr(u16)]
32
0
#[derive(Clone, Copy, Debug, Default)]
33
pub enum Command {
34
    #[default]
35
    Unknown = 0,
36
    Version = 1,
37
    DmaMap = 2,
38
    DmaUnmap = 3,
39
    DeviceGetInfo = 4,
40
    DeviceGetRegionInfo = 5,
41
    GetRegionIoFds = 6,
42
    GetIrqInfo = 7,
43
    SetIrqs = 8,
44
    RegionRead = 9,
45
    RegionWrite = 10,
46
    DmaRead = 11,
47
    DmaWrite = 12,
48
    DeviceReset = 13,
49
    UserDirtyPages = 14,
50
}
51
52
#[allow(dead_code)]
53
#[repr(u32)]
54
0
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
55
enum HeaderFlags {
56
    #[default]
57
    Command = 0,
58
    Reply = 1,
59
    NoReply = 1 << 4,
60
    Error = 1 << 5,
61
}
62
63
#[repr(C)]
64
0
#[derive(Default, Clone, Copy, Debug)]
65
struct Header {
66
    message_id: u16,
67
    command: Command,
68
    message_size: u32,
69
    flags: u32,
70
    error: u32,
71
}
72
73
#[repr(C)]
74
0
#[derive(Default, Clone, Copy, Debug)]
75
struct Version {
76
    header: Header,
77
    major: u16,
78
    minor: u16,
79
}
80
81
0
#[derive(Serialize, Deserialize, Debug)]
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_str::<serde_json::error::Error>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_u64::<_>
Unexecuted instantiation: <<vfio_user::MigrationCapabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_bytes::<_>
82
struct MigrationCapabilities {
83
    pgsize: u32,
84
}
85
86
0
const fn default_max_msg_fds() -> u32 {
87
0
    1
88
0
}
89
90
0
const fn default_max_data_xfer_size() -> u32 {
91
0
    1048576
92
0
}
93
94
#[inline(always)]
95
0
fn pagesize() -> u32 {
96
0
    // SAFETY: sysconf
97
0
    unsafe { sysconf(_SC_PAGESIZE) as u32 }
98
0
}
99
100
0
fn default_migration_capabilities() -> MigrationCapabilities {
101
0
    MigrationCapabilities { pgsize: pagesize() }
102
0
}
103
104
bitflags! {
105
    pub struct DmaMapFlags: u32 {
106
        const READ_ONLY = 1 << 0;
107
        const WRITE_ONLY = 1 << 1;
108
        const READ_WRITE = Self::READ_ONLY.bits | Self::WRITE_ONLY.bits;
109
    }
110
111
    pub struct DmaUnmapFlags: u32 {
112
        const GET_DIRTY_PAGE_INFO = 1 << 1;
113
        const UNMAP_ALL = 1 << 2;
114
    }
115
}
116
117
#[repr(C)]
118
0
#[derive(Default, Clone, Copy, Debug)]
119
struct DmaMap {
120
    header: Header,
121
    argsz: u32,
122
    flags: u32,
123
    offset: u64,
124
    address: u64,
125
    size: u64,
126
}
127
128
#[repr(C)]
129
0
#[derive(Default, Clone, Copy, Debug)]
130
struct DmaUnmap {
131
    header: Header,
132
    argsz: u32,
133
    flags: u32,
134
    address: u64,
135
    size: u64,
136
}
137
138
#[repr(C)]
139
0
#[derive(Default, Clone, Copy, Debug)]
140
struct DeviceGetInfo {
141
    header: Header,
142
    argsz: u32,
143
    flags: u32,
144
    num_regions: u32,
145
    num_irqs: u32,
146
}
147
148
#[repr(C)]
149
0
#[derive(Default, Clone, Copy, Debug)]
150
struct DeviceGetRegionInfo {
151
    header: Header,
152
    region_info: vfio_region_info,
153
}
154
155
#[repr(C)]
156
0
#[derive(Default, Clone, Copy, Debug)]
157
struct RegionAccess {
158
    header: Header,
159
    offset: u64,
160
    region: u32,
161
    count: u32,
162
}
163
164
#[repr(C)]
165
0
#[derive(Default, Clone, Copy, Debug)]
166
struct GetIrqInfo {
167
    header: Header,
168
    argsz: u32,
169
    flags: u32,
170
    index: u32,
171
    count: u32,
172
}
173
174
#[repr(C)]
175
0
#[derive(Default, Clone, Copy, Debug)]
176
struct SetIrqs {
177
    header: Header,
178
    argsz: u32,
179
    flags: u32,
180
    index: u32,
181
    start: u32,
182
    count: u32,
183
}
184
185
#[repr(C)]
186
0
#[derive(Default, Clone, Copy, Debug)]
187
struct DeviceReset {
188
    header: Header,
189
}
190
191
// SAFETY: data structure only contain a series of integers
192
unsafe impl ByteValued for Header {}
193
// SAFETY: data structure only contain a series of integers
194
unsafe impl ByteValued for Version {}
195
// SAFETY: data structure only contain a series of integers
196
unsafe impl ByteValued for DmaMap {}
197
// SAFETY: data structure only contain a series of integers
198
unsafe impl ByteValued for DmaUnmap {}
199
// SAFETY: data structure only contain a series of integers
200
unsafe impl ByteValued for DeviceGetInfo {}
201
// SAFETY: data structure only contain a series of integers
202
unsafe impl ByteValued for DeviceGetRegionInfo {}
203
// SAFETY: data structure only contain a series of integers
204
unsafe impl ByteValued for RegionAccess {}
205
// SAFETY: data structure only contain a series of integers
206
unsafe impl ByteValued for GetIrqInfo {}
207
// SAFETY: data structure only contain a series of integers
208
unsafe impl ByteValued for SetIrqs {}
209
// SAFETY: data structure only contain a series of integers
210
unsafe impl ByteValued for DeviceReset {}
211
212
0
#[derive(Serialize, Deserialize, Debug)]
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_str::<serde_json::error::Error>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <vfio_user::Capabilities as serde::de::Deserialize>::deserialize::<serde::__private::de::missing_field::MissingFieldDeserializer<serde_json::error::Error>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <vfio_user::Capabilities as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <vfio_user::Capabilities as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_bytes::<_>
Unexecuted instantiation: <<vfio_user::Capabilities as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_u64::<_>
213
struct Capabilities {
214
    #[serde(default = "default_max_msg_fds")]
215
    max_msg_fds: u32,
216
    #[serde(default = "default_max_data_xfer_size")]
217
    max_data_xfer_size: u32,
218
    #[serde(default = "default_migration_capabilities")]
219
    migration: MigrationCapabilities,
220
}
221
222
0
#[derive(Serialize, Deserialize, Debug, Default)]
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_map::<serde_json::de::MapAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Visitor as serde::de::Visitor>::visit_seq::<serde_json::de::SeqAccess<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_str::<serde_json::error::Error>
Unexecuted instantiation: <vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::<&mut serde_json::de::Deserializer<serde_json::read::SliceRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__Field as serde::de::Deserialize>::deserialize::<serde_json::de::MapKey<serde_json::read::StrRead>>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::expecting
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_bytes::<_>
Unexecuted instantiation: <<vfio_user::CapabilitiesData as serde::de::Deserialize>::deserialize::__FieldVisitor as serde::de::Visitor>::visit_u64::<_>
223
struct CapabilitiesData {
224
    capabilities: Capabilities,
225
}
226
227
impl Default for Capabilities {
228
0
    fn default() -> Self {
229
0
        Self {
230
0
            max_msg_fds: default_max_msg_fds(),
231
0
            max_data_xfer_size: default_max_data_xfer_size(),
232
0
            migration: default_migration_capabilities(),
233
0
        }
234
0
    }
235
}
236
237
pub struct Client {
238
    stream: UnixStream,
239
    next_message_id: Wrapping<u16>,
240
    num_irqs: u32,
241
    resettable: bool,
242
    regions: Vec<Region>,
243
}
244
245
0
#[derive(Debug)]
246
pub struct Region {
247
    pub flags: u32,
248
    pub index: u32,
249
    pub size: u64,
250
    pub file_offset: Option<FileOffset>,
251
    pub sparse_areas: Vec<vfio_region_sparse_mmap_area>,
252
}
253
254
0
#[derive(Debug)]
255
pub struct IrqInfo {
256
    pub index: u32,
257
    pub flags: u32,
258
    pub count: u32,
259
}
260
261
0
#[derive(Error, Debug)]
Unexecuted instantiation: <vfio_user::Error as core::error::Error>::source
Unexecuted instantiation: <vfio_user::Error as core::fmt::Display>::fmt
Unexecuted instantiation: <vfio_user::Error as core::fmt::Debug>::fmt
Unexecuted instantiation: <vfio_user::Error as core::fmt::Debug>::fmt
262
pub enum Error {
263
    #[error("Error connecting: {0}")]
264
    Connect(#[source] std::io::Error),
265
    #[error("Error serializing capabilities: {0}")]
266
    SerializeCapabilites(#[source] serde_json::Error),
267
    #[error("Error deserializing capabilities: {0}")]
268
    DeserializeCapabilites(#[source] serde_json::Error),
269
    #[error("Error writing to stream: {0}")]
270
    StreamWrite(#[source] std::io::Error),
271
    #[error("Error reading from stream: {0}")]
272
    StreamRead(#[source] std::io::Error),
273
    #[error("Error shutting down stream: {0}")]
274
    StreamShutdown(#[source] std::io::Error),
275
    #[error("Error writing with file descriptors: {0}")]
276
    SendWithFd(#[source] vmm_sys_util::errno::Error),
277
    #[error("Error reading with file descriptors: {0}")]
278
    ReceiveWithFd(#[source] vmm_sys_util::errno::Error),
279
    #[error("Not a PCI device")]
280
    NotPciDevice,
281
    #[error("Error binding to socket: {0}")]
282
    SocketBind(#[source] std::io::Error),
283
    #[error("Error accepting connection: {0}")]
284
    SocketAccept(#[source] std::io::Error),
285
    #[error("Unsupported command: {0:?}")]
286
    UnsupportedCommand(Command),
287
    #[error("Unsupported feature")]
288
    UnsupportedFeature,
289
    #[error("Error from backend: {0:?}")]
290
    Backend(#[source] std::io::Error),
291
    #[error("Invalid input")]
292
    InvalidInput,
293
}
294
295
impl Client {
296
0
    pub fn new(path: &Path) -> Result<Client, Error> {
297
0
        let stream = UnixStream::connect(path).map_err(Error::Connect)?;
298
299
0
        let mut client = Client {
300
0
            next_message_id: Wrapping(0),
301
0
            stream,
302
0
            num_irqs: 0,
303
0
            resettable: false,
304
0
            regions: Vec::new(),
305
0
        };
306
0
307
0
        client.negotiate_version()?;
308
309
0
        client.regions = client.get_regions()?;
310
311
0
        Ok(client)
312
0
    }
313
314
0
    fn negotiate_version(&mut self) -> Result<(), Error> {
315
0
        let caps = CapabilitiesData::default();
316
317
0
        let version_data = serde_json::to_string(&caps).map_err(Error::SerializeCapabilites)?;
318
319
0
        let version = Version {
320
0
            header: Header {
321
0
                message_id: self.next_message_id.0,
322
0
                command: Command::Version,
323
0
                flags: HeaderFlags::Command as u32,
324
0
                message_size: (size_of::<Version>() + version_data.len() + 1) as u32,
325
0
                ..Default::default()
326
0
            },
327
0
            major: 0,
328
0
            minor: 1,
329
0
        };
330
0
        debug!("Command: {:?}", version);
331
332
0
        let version_data = CString::new(version_data.as_bytes()).unwrap();
333
0
        let bufs = vec![
334
0
            IoSlice::new(version.as_slice()),
335
0
            IoSlice::new(version_data.as_bytes_with_nul()),
336
0
        ];
337
0
338
0
        // TODO: Use write_all_vectored() when ready
339
0
        let _ = self
340
0
            .stream
341
0
            .write_vectored(&bufs)
342
0
            .map_err(Error::StreamWrite)?;
343
344
0
        debug!(
345
0
            "Sent client version information: major = {} minor = {} capabilities = {:?}",
346
0
            version.major, version.minor, &caps.capabilities
347
        );
348
349
0
        self.next_message_id += Wrapping(1);
350
0
351
0
        let mut server_version: Version = Version::default();
352
0
        self.stream
353
0
            .read_exact(server_version.as_mut_slice())
354
0
            .map_err(Error::StreamRead)?;
355
356
0
        debug!("Reply: {:?}", server_version);
357
358
0
        let mut server_version_data =
359
0
            vec![0; server_version.header.message_size as usize - size_of::<Version>()];
360
0
        self.stream
361
0
            .read_exact(server_version_data.as_mut_slice())
362
0
            .map_err(Error::StreamRead)?;
363
364
0
        let server_caps: CapabilitiesData =
365
0
            serde_json::from_slice(&server_version_data[0..server_version_data.len() - 1])
366
0
                .map_err(Error::DeserializeCapabilites)?;
367
368
0
        debug!(
369
0
            "Received server version information: major = {} minor = {} capabilities = {:?}",
370
0
            server_version.major, server_version.minor, &server_caps.capabilities
371
        );
372
373
0
        Ok(())
374
0
    }
375
376
0
    pub fn dma_map(
377
0
        &mut self,
378
0
        offset: u64,
379
0
        address: u64,
380
0
        size: u64,
381
0
        fd: RawFd,
382
0
    ) -> Result<(), Error> {
383
0
        let dma_map = DmaMap {
384
0
            header: Header {
385
0
                message_id: self.next_message_id.0,
386
0
                command: Command::DmaMap,
387
0
                flags: HeaderFlags::Command as u32,
388
0
                message_size: size_of::<DmaMap>() as u32,
389
0
                ..Default::default()
390
0
            },
391
0
            argsz: (size_of::<DmaMap>() - size_of::<Header>()) as u32,
392
0
            flags: DmaMapFlags::READ_WRITE.bits,
393
0
            offset,
394
0
            address,
395
0
            size,
396
0
        };
397
0
        debug!("Command: {:?}", dma_map);
398
0
        self.next_message_id += Wrapping(1);
399
0
        self.stream
400
0
            .send_with_fd(dma_map.as_slice(), fd)
401
0
            .map_err(Error::SendWithFd)?;
402
403
0
        let mut reply = Header::default();
404
0
        self.stream
405
0
            .read_exact(reply.as_mut_slice())
406
0
            .map_err(Error::StreamRead)?;
407
0
        debug!("Reply: {:?}", reply);
408
409
0
        Ok(())
410
0
    }
411
412
0
    pub fn dma_unmap(&mut self, address: u64, size: u64) -> Result<(), Error> {
413
0
        let dma_unmap = DmaUnmap {
414
0
            header: Header {
415
0
                message_id: self.next_message_id.0,
416
0
                command: Command::DmaUnmap,
417
0
                flags: HeaderFlags::Command as u32,
418
0
                message_size: size_of::<DmaUnmap>() as u32,
419
0
                ..Default::default()
420
0
            },
421
0
            argsz: (size_of::<DmaUnmap>() - size_of::<Header>()) as u32,
422
0
            flags: 0,
423
0
            address,
424
0
            size,
425
0
        };
426
0
        debug!("Command: {:?}", dma_unmap);
427
0
        self.next_message_id += Wrapping(1);
428
0
        self.stream
429
0
            .write_all(dma_unmap.as_slice())
430
0
            .map_err(Error::StreamWrite)?;
431
432
0
        let mut reply = DmaUnmap::default();
433
0
        self.stream
434
0
            .read_exact(reply.as_mut_slice())
435
0
            .map_err(Error::StreamRead)?;
436
0
        debug!("Reply: {:?}", reply);
437
438
0
        Ok(())
439
0
    }
440
441
0
    pub fn reset(&mut self) -> Result<(), Error> {
442
0
        let reset = DeviceReset {
443
0
            header: Header {
444
0
                message_id: self.next_message_id.0,
445
0
                command: Command::DeviceReset,
446
0
                flags: HeaderFlags::Command as u32,
447
0
                message_size: size_of::<DeviceReset>() as u32,
448
0
                ..Default::default()
449
0
            },
450
0
        };
451
0
        debug!("Command: {:?}", reset);
452
0
        self.next_message_id += Wrapping(1);
453
0
        self.stream
454
0
            .write_all(reset.as_slice())
455
0
            .map_err(Error::StreamWrite)?;
456
457
0
        let mut reply = Header::default();
458
0
        self.stream
459
0
            .read_exact(reply.as_mut_slice())
460
0
            .map_err(Error::StreamRead)?;
461
0
        debug!("Reply: {:?}", reply);
462
463
0
        Ok(())
464
0
    }
465
466
0
    fn get_regions(&mut self) -> Result<Vec<Region>, Error> {
467
0
        let get_info = DeviceGetInfo {
468
0
            header: Header {
469
0
                message_id: self.next_message_id.0,
470
0
                command: Command::DeviceGetInfo,
471
0
                flags: HeaderFlags::Command as u32,
472
0
                message_size: size_of::<DeviceGetInfo>() as u32,
473
0
                ..Default::default()
474
0
            },
475
0
            argsz: size_of::<DeviceGetInfo>() as u32,
476
0
            ..Default::default()
477
0
        };
478
0
        debug!("Command: {:?}", get_info);
479
0
        self.next_message_id += Wrapping(1);
480
0
481
0
        self.stream
482
0
            .write_all(get_info.as_slice())
483
0
            .map_err(Error::StreamWrite)?;
484
485
0
        let mut reply = DeviceGetInfo::default();
486
0
        self.stream
487
0
            .read_exact(reply.as_mut_slice())
488
0
            .map_err(Error::StreamRead)?;
489
0
        debug!("Reply: {:?}", reply);
490
0
        self.num_irqs = reply.num_irqs;
491
0
492
0
        if reply.flags & VFIO_DEVICE_FLAGS_PCI != VFIO_DEVICE_FLAGS_PCI {
493
0
            return Err(Error::NotPciDevice);
494
0
        }
495
0
496
0
        self.resettable = reply.flags & VFIO_DEVICE_FLAGS_RESET != VFIO_DEVICE_FLAGS_RESET;
497
0
498
0
        let num_regions = reply.num_regions;
499
0
        let mut regions = Vec::new();
500
0
        for index in 0..num_regions {
501
0
            let (region_info, fd, sparse_areas) = self.get_region_info(index)?;
502
0
            regions.push(Region {
503
0
                flags: region_info.flags,
504
0
                index: region_info.index,
505
0
                size: region_info.size,
506
0
                file_offset: fd.map(|fd| FileOffset::new(fd, region_info.offset)),
507
0
                sparse_areas,
508
0
            });
509
        }
510
511
0
        Ok(regions)
512
0
    }
513
514
0
    fn get_region_info(
515
0
        &mut self,
516
0
        index: u32,
517
0
    ) -> Result<
518
0
        (
519
0
            vfio_region_info,
520
0
            Option<File>,
521
0
            Vec<vfio_region_sparse_mmap_area>,
522
0
        ),
523
0
        Error,
524
0
    > {
525
0
        // Retrieve the region info without capability
526
0
        let mut get_region_info = DeviceGetRegionInfo {
527
0
            header: Header {
528
0
                message_id: self.next_message_id.0,
529
0
                command: Command::DeviceGetRegionInfo,
530
0
                flags: HeaderFlags::Command as u32,
531
0
                message_size: std::mem::size_of::<DeviceGetRegionInfo>() as u32,
532
0
                ..Default::default()
533
0
            },
534
0
            region_info: vfio_region_info {
535
0
                argsz: size_of::<vfio_region_info>() as u32,
536
0
                index,
537
0
                ..Default::default()
538
0
            },
539
0
        };
540
0
        debug!("Command: {:?}", get_region_info);
541
0
        self.next_message_id += Wrapping(1);
542
0
543
0
        self.stream
544
0
            .write_all(get_region_info.as_slice())
545
0
            .map_err(Error::StreamWrite)?;
546
547
0
        let mut reply = DeviceGetRegionInfo::default();
548
0
        let (_, fd) = self
549
0
            .stream
550
0
            .recv_with_fd(reply.as_mut_slice())
551
0
            .map_err(Error::ReceiveWithFd)?;
552
0
        debug!("Reply: {:?}", reply);
553
554
        // Retrieve the region info again with capabilities if needed
555
0
        if reply.region_info.argsz > std::mem::size_of::<vfio_region_info>() as u32 {
556
0
            get_region_info.region_info.argsz = reply.region_info.argsz;
557
0
            debug!("Command: {:?}", get_region_info);
558
0
            self.next_message_id += Wrapping(1);
559
0
560
0
            self.stream
561
0
                .write_all(get_region_info.as_slice())
562
0
                .map_err(Error::StreamWrite)?;
563
564
0
            let mut reply = DeviceGetRegionInfo::default();
565
0
            let (_, fd) = self
566
0
                .stream
567
0
                .recv_with_fd(reply.as_mut_slice())
568
0
                .map_err(Error::ReceiveWithFd)?;
569
0
            debug!("Reply: {:?}", reply);
570
571
0
            let cap_size = reply.region_info.argsz - std::mem::size_of::<vfio_region_info>() as u32;
572
0
            assert_eq!(
573
0
                cap_size,
574
0
                reply.header.message_size - size_of::<DeviceGetRegionInfo>() as u32
575
0
            );
576
0
            let mut cap_data = vec![0; cap_size as usize];
577
0
            self.stream
578
0
                .read_exact(cap_data.as_mut_slice())
579
0
                .map_err(Error::StreamRead)?;
580
581
0
            let sparse_areas = Self::parse_region_caps(&cap_data, &reply.region_info)?;
582
583
0
            Ok((reply.region_info, fd, sparse_areas))
584
        } else {
585
0
            Ok((reply.region_info, fd, Vec::new()))
586
        }
587
0
    }
588
589
0
    fn parse_region_caps(
590
0
        cap_data: &[u8],
591
0
        region_info: &vfio_region_info,
592
0
    ) -> Result<Vec<vfio_region_sparse_mmap_area>, Error> {
593
0
        let mut sparse_areas: Vec<vfio_region_sparse_mmap_area> = Vec::new();
594
0
595
0
        let cap_size = cap_data.len() as u32;
596
0
        let cap_header_size = size_of::<vfio_info_cap_header>() as u32;
597
0
        let mmap_cap_size = size_of::<vfio_region_info_cap_sparse_mmap>() as u32;
598
0
        let mmap_area_size = size_of::<vfio_region_sparse_mmap_area>() as u32;
599
0
600
0
        let cap_data_ptr = cap_data.as_ptr();
601
0
        let mut region_info_offset = region_info.cap_offset;
602
0
        while region_info_offset != 0 {
603
            // calculate the offset from the begining of the cap_data based on the offset
604
            // that is relative to the begining of the VFIO region info structure
605
0
            let cap_offset = region_info_offset - size_of::<vfio_region_info>() as u32;
606
0
            if cap_offset + cap_header_size > cap_size {
607
0
                warn!(
608
0
                    "Unexpected end of cap data: 'cap_offset + cap_header_size > cap_size' \
609
0
                cap_offset = {}, cap_header_size = {}, cap_size = {}",
610
                    cap_offset, cap_header_size, cap_size
611
                );
612
0
                break;
613
0
            }
614
0
615
0
            // SAFETY: `cap_data_ptr` is valid and the `cap_offset` is checked above
616
0
            let cap_ptr = unsafe { cap_data_ptr.offset(cap_offset as isize) };
617
0
            // SAFETY: `cap_ptr` is valid
618
0
            let cap_header = unsafe { &*(cap_ptr as *const vfio_info_cap_header) };
619
0
            match cap_header.id as u32 {
620
                VFIO_REGION_INFO_CAP_SPARSE_MMAP => {
621
0
                    if cap_offset + mmap_cap_size > cap_size {
622
0
                        warn!(
623
0
                            "Unexpected end of cap data: 'cap_offset + mmap_cap_size > cap_size' \
624
0
                        cap_offset = {}, mmap_cap_size = {}, cap_size = {}",
625
                            cap_offset, mmap_cap_size, cap_size
626
                        );
627
0
                        break;
628
0
                    }
629
0
                    // SAFETY: `cap_ptr` is valid and its size is also checked above
630
0
                    let sparse_mmap = unsafe {
631
0
                        &*(cap_ptr as *mut u8 as *const vfio_region_info_cap_sparse_mmap)
632
0
                    };
633
0
634
0
                    let area_num = sparse_mmap.nr_areas;
635
0
                    if cap_offset + mmap_cap_size + area_num * mmap_area_size > cap_size {
636
0
                        warn!("Unexpected end of cap data: 'cap_offset + mmap_cap_size + area_num * mmap_area_size > cap_size' \
637
0
                        cap_offset = {}, mmap_cap_size = {}, area_num = {}, mmap_area_size = {}, cap_size = {}",
638
                        cap_offset, mmap_cap_size, area_num, mmap_area_size, cap_size);
639
0
                        break;
640
0
                    }
641
0
                    let areas =
642
0
                        // SAFETY: `sparse_mmap` is valid and its size is also checked above
643
0
                        unsafe { sparse_mmap.areas.as_slice(sparse_mmap.nr_areas as usize) };
644
0
                    for area in areas.iter() {
645
0
                        sparse_areas.push(*area);
646
0
                    }
647
                }
648
                _ => {
649
0
                    warn!(
650
0
                        "Ignoring unsupported vfio region capability (id = '{}')",
651
                        cap_header.id
652
                    );
653
                }
654
            }
655
0
            region_info_offset = cap_header.next;
656
        }
657
658
0
        Ok(sparse_areas)
659
0
    }
660
661
0
    pub fn region_read(&mut self, region: u32, offset: u64, data: &mut [u8]) -> Result<(), Error> {
662
0
        let region_read = RegionAccess {
663
0
            header: Header {
664
0
                message_id: self.next_message_id.0,
665
0
                command: Command::RegionRead,
666
0
                flags: HeaderFlags::Command as u32,
667
0
                message_size: size_of::<RegionAccess>() as u32,
668
0
                ..Default::default()
669
0
            },
670
0
            offset,
671
0
            count: data.len() as u32,
672
0
            region,
673
0
        };
674
0
        debug!("Command: {:?}", region_read);
675
0
        self.next_message_id += Wrapping(1);
676
0
        self.stream
677
0
            .write_all(region_read.as_slice())
678
0
            .map_err(Error::StreamWrite)?;
679
680
0
        let mut reply = RegionAccess::default();
681
0
        self.stream
682
0
            .read_exact(reply.as_mut_slice())
683
0
            .map_err(Error::StreamRead)?;
684
0
        debug!("Reply: {:?}", reply);
685
0
        self.stream.read_exact(data).map_err(Error::StreamRead)?;
686
0
        Ok(())
687
0
    }
688
689
0
    pub fn region_write(&mut self, region: u32, offset: u64, data: &[u8]) -> Result<(), Error> {
690
0
        let region_write = RegionAccess {
691
0
            header: Header {
692
0
                message_id: self.next_message_id.0,
693
0
                command: Command::RegionWrite,
694
0
                flags: HeaderFlags::Command as u32,
695
0
                message_size: (size_of::<RegionAccess>() + data.len()) as u32,
696
0
                ..Default::default()
697
0
            },
698
0
            offset,
699
0
            count: data.len() as u32,
700
0
            region,
701
0
        };
702
0
        debug!("Command: {:?}", region_write);
703
0
        self.next_message_id += Wrapping(1);
704
0
705
0
        let bufs = vec![IoSlice::new(region_write.as_slice()), IoSlice::new(data)];
706
0
707
0
        // TODO: Use write_all_vectored() when ready
708
0
        let _ = self
709
0
            .stream
710
0
            .write_vectored(&bufs)
711
0
            .map_err(Error::StreamWrite)?;
712
713
0
        let mut reply = RegionAccess::default();
714
0
        self.stream
715
0
            .read_exact(reply.as_mut_slice())
716
0
            .map_err(Error::StreamRead)?;
717
0
        debug!("Reply: {:?}", reply);
718
0
        Ok(())
719
0
    }
720
721
0
    pub fn get_irq_info(&mut self, index: u32) -> Result<IrqInfo, Error> {
722
0
        let get_irq_info = GetIrqInfo {
723
0
            header: Header {
724
0
                message_id: self.next_message_id.0,
725
0
                command: Command::GetIrqInfo,
726
0
                flags: HeaderFlags::Command as u32,
727
0
                message_size: size_of::<GetIrqInfo>() as u32,
728
0
                ..Default::default()
729
0
            },
730
0
            argsz: (size_of::<GetIrqInfo>() - size_of::<Header>()) as u32,
731
0
            flags: 0,
732
0
            index,
733
0
            count: 0,
734
0
        };
735
0
        debug!("Command: {:?}", get_irq_info);
736
0
        self.next_message_id += Wrapping(1);
737
0
738
0
        self.stream
739
0
            .write_all(get_irq_info.as_slice())
740
0
            .map_err(Error::StreamWrite)?;
741
742
0
        let mut reply = GetIrqInfo::default();
743
0
        self.stream
744
0
            .read_exact(reply.as_mut_slice())
745
0
            .map_err(Error::StreamRead)?;
746
0
        debug!("Reply: {:?}", reply);
747
748
0
        Ok(IrqInfo {
749
0
            index: reply.index,
750
0
            flags: reply.flags,
751
0
            count: reply.count,
752
0
        })
753
0
    }
754
755
0
    pub fn set_irqs(
756
0
        &mut self,
757
0
        index: u32,
758
0
        flags: u32,
759
0
        start: u32,
760
0
        count: u32,
761
0
        fds: &[RawFd],
762
0
    ) -> Result<(), Error> {
763
0
        let set_irqs = SetIrqs {
764
0
            header: Header {
765
0
                message_id: self.next_message_id.0,
766
0
                command: Command::SetIrqs,
767
0
                flags: HeaderFlags::Command as u32,
768
0
                message_size: size_of::<SetIrqs>() as u32,
769
0
                ..Default::default()
770
0
            },
771
0
            argsz: (size_of::<SetIrqs>() - size_of::<Header>()) as u32,
772
0
            flags,
773
0
            start,
774
0
            index,
775
0
            count,
776
0
        };
777
0
        debug!("Command: {:?}", set_irqs);
778
0
        self.next_message_id += Wrapping(1);
779
0
780
0
        self.stream
781
0
            .send_with_fds(&[set_irqs.as_slice()], fds)
782
0
            .map_err(Error::SendWithFd)?;
783
784
0
        let mut reply = Header::default();
785
0
        self.stream
786
0
            .read_exact(reply.as_mut_slice())
787
0
            .map_err(Error::StreamRead)?;
788
0
        debug!("Reply: {:?}", reply);
789
790
0
        Ok(())
791
0
    }
792
793
0
    pub fn region(&self, region_index: u32) -> Option<&Region> {
794
0
        self.regions
795
0
            .iter()
796
0
            .find(|&region| region.index == region_index)
797
0
    }
798
799
0
    pub fn resettable(&self) -> bool {
800
0
        self.resettable
801
0
    }
802
803
0
    pub fn shutdown(&self) -> Result<(), Error> {
804
0
        self.stream
805
0
            .shutdown(std::net::Shutdown::Both)
806
0
            .map_err(Error::StreamShutdown)
807
0
    }
808
}
809
810
pub trait ServerBackend {
811
    fn region_read(
812
        &mut self,
813
        _region: u32,
814
        _offset: u64,
815
        _data: &mut [u8],
816
    ) -> Result<(), std::io::Error>;
817
    fn region_write(
818
        &mut self,
819
        _region: u32,
820
        _offset: u64,
821
        _data: &[u8],
822
    ) -> Result<(), std::io::Error>;
823
    fn dma_map(
824
        &mut self,
825
        _flags: DmaMapFlags,
826
        _offset: u64,
827
        _address: u64,
828
        _size: u64,
829
        _fd: Option<&File>,
830
    ) -> Result<(), std::io::Error>;
831
    fn dma_unmap(
832
        &mut self,
833
        _flags: DmaUnmapFlags,
834
        _address: u64,
835
        _size: u64,
836
    ) -> Result<(), std::io::Error>;
837
    fn reset(&mut self) -> Result<(), std::io::Error>;
838
    fn set_irqs(
839
        &mut self,
840
        _index: u32,
841
        _flags: u32,
842
        _start: u32,
843
        _count: u32,
844
        _fds: Vec<File>,
845
    ) -> Result<(), std::io::Error>;
846
}
847
848
pub struct Server {
849
    listener: UnixListener,
850
    resettable: bool,
851
    irqs: Vec<IrqInfo>,
852
    regions: Vec<vfio_region_info>,
853
}
854
855
impl Server {
856
0
    pub fn new(
857
0
        path: &Path,
858
0
        resettable: bool,
859
0
        irqs: Vec<IrqInfo>,
860
0
        regions: Vec<vfio_region_info>,
861
0
    ) -> Result<Server, Error> {
862
0
        let listener = UnixListener::bind(path).map_err(Error::SocketBind)?;
863
864
0
        Ok(Server {
865
0
            listener,
866
0
            resettable,
867
0
            irqs,
868
0
            regions,
869
0
        })
870
0
    }
871
872
0
    fn handle_command(
873
0
        &self,
874
0
        backend: &mut dyn ServerBackend,
875
0
        stream: &mut UnixStream,
876
0
        header: Header,
877
0
        fds: Vec<File>,
878
0
    ) -> Result<(), Error> {
879
0
        match header.command {
880
            Command::Unknown
881
            | Command::GetRegionIoFds
882
            | Command::DmaRead
883
            | Command::DmaWrite
884
            | Command::UserDirtyPages => {
885
0
                return Err(Error::UnsupportedCommand(header.command));
886
            }
887
            Command::Version => {
888
                // TODO: Make version/capabilities configurable
889
0
                let mut client_version = Version {
890
0
                    header,
891
0
                    ..Default::default()
892
0
                };
893
0
                stream
894
0
                    .read_exact(&mut client_version.as_mut_slice()[size_of::<Header>()..])
895
0
                    .map_err(Error::StreamRead)?;
896
897
0
                let mut raw_version_data =
898
0
                    vec![0; header.message_size as usize - size_of::<Version>()];
899
0
                stream
900
0
                    .read_exact(&mut raw_version_data)
901
0
                    .map_err(Error::StreamRead)?;
902
0
                let client_version_data = CString::from_vec_with_nul(raw_version_data)
903
0
                    .unwrap()
904
0
                    .to_string_lossy()
905
0
                    .into_owned();
906
0
                let client_capabilities: CapabilitiesData =
907
0
                    serde_json::from_str(&client_version_data)
908
0
                        .map_err(Error::DeserializeCapabilites)?;
909
910
0
                info!(
911
0
                    "Received client version: major = {} minor = {} capabilities = {:?}",
912
                    client_version.major, client_version.minor, client_capabilities.capabilities,
913
                );
914
915
0
                let server_capabilities = CapabilitiesData::default();
916
0
                let server_version_data = serde_json::to_string(&server_capabilities)
917
0
                    .map_err(Error::SerializeCapabilites)?;
918
0
                let server_version = Version {
919
0
                    header: Header {
920
0
                        message_id: client_version.header.message_id,
921
0
                        command: Command::Version,
922
0
                        flags: HeaderFlags::Reply as u32,
923
0
                        message_size: (size_of::<Version>() + server_version_data.len() + 1) as u32,
924
0
                        ..Default::default()
925
0
                    },
926
0
                    major: 0,
927
0
                    minor: 0,
928
0
                };
929
0
930
0
                let server_version_data = CString::new(server_version_data.as_bytes()).unwrap();
931
0
932
0
                let bufs = vec![
933
0
                    IoSlice::new(server_version.as_slice()),
934
0
                    IoSlice::new(server_version_data.as_bytes_with_nul()),
935
0
                ];
936
0
937
0
                // TODO: Use write_all_vectored() when ready
938
0
                let _ = stream.write_vectored(&bufs).map_err(Error::StreamWrite)?;
939
940
0
                info!(
941
0
                    "Sent server version: major = {} minor = {} capabilities = {:?}",
942
                    server_version.major, server_version.minor, server_capabilities.capabilities
943
                );
944
            }
945
            Command::DmaMap => {
946
0
                let mut cmd = DmaMap {
947
0
                    header,
948
0
                    ..Default::default()
949
0
                };
950
0
                stream
951
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
952
0
                    .map_err(Error::StreamRead)?;
953
954
0
                backend
955
0
                    .dma_map(
956
0
                        DmaMapFlags::from_bits_truncate(cmd.flags),
957
0
                        cmd.offset,
958
0
                        cmd.address,
959
0
                        cmd.size,
960
0
                        if fds.len() > 1 { Some(&fds[0]) } else { None },
961
                    )
962
0
                    .map_err(Error::Backend)?;
963
964
0
                let reply = Header {
965
0
                    message_id: cmd.header.message_id,
966
0
                    command: Command::DmaMap,
967
0
                    flags: HeaderFlags::Reply as u32,
968
0
                    message_size: size_of::<Header>() as u32,
969
0
                    ..Default::default()
970
0
                };
971
0
                stream
972
0
                    .write_all(reply.as_slice())
973
0
                    .map_err(Error::StreamWrite)?;
974
            }
975
            Command::DmaUnmap => {
976
0
                let mut cmd = DmaUnmap {
977
0
                    header,
978
0
                    ..Default::default()
979
0
                };
980
0
                stream
981
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
982
0
                    .map_err(Error::StreamRead)?;
983
984
0
                backend
985
0
                    .dma_unmap(
986
0
                        DmaUnmapFlags::from_bits_truncate(cmd.flags),
987
0
                        cmd.address,
988
0
                        cmd.size,
989
0
                    )
990
0
                    .map_err(Error::Backend)?;
991
992
0
                let reply = DmaUnmap {
993
0
                    header: Header {
994
0
                        message_id: cmd.header.message_id,
995
0
                        command: Command::DmaUnmap,
996
0
                        flags: HeaderFlags::Reply as u32,
997
0
                        message_size: size_of::<DmaUnmap>() as u32,
998
0
                        ..Default::default()
999
0
                    },
1000
0
                    argsz: cmd.argsz,
1001
0
                    flags: cmd.flags,
1002
0
                    address: cmd.address,
1003
0
                    size: cmd.size,
1004
0
                };
1005
0
                stream
1006
0
                    .write_all(reply.as_slice())
1007
0
                    .map_err(Error::StreamWrite)?;
1008
            }
1009
            Command::DeviceGetInfo => {
1010
0
                let mut cmd = DeviceGetInfo {
1011
0
                    header,
1012
0
                    ..Default::default()
1013
0
                };
1014
0
                stream
1015
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1016
0
                    .map_err(Error::StreamRead)?;
1017
1018
0
                let reply = DeviceGetInfo {
1019
0
                    header: Header {
1020
0
                        message_id: cmd.header.message_id,
1021
0
                        command: Command::DeviceGetInfo,
1022
0
                        flags: HeaderFlags::Reply as u32,
1023
0
                        message_size: size_of::<DeviceGetInfo>() as u32,
1024
0
                        ..Default::default()
1025
0
                    },
1026
0
                    argsz: size_of::<DeviceGetInfo>() as u32 - size_of::<Header>() as u32,
1027
0
                    // TODO: Consider non-PCI devices
1028
0
                    flags: VFIO_DEVICE_FLAGS_PCI
1029
0
                        | if self.resettable {
1030
0
                            VFIO_DEVICE_FLAGS_RESET
1031
                        } else {
1032
0
                            0
1033
                        },
1034
0
                    num_regions: self.regions.len() as u32,
1035
0
                    num_irqs: self.irqs.len() as u32,
1036
0
                };
1037
0
                stream
1038
0
                    .write_all(reply.as_slice())
1039
0
                    .map_err(Error::StreamWrite)?;
1040
            }
1041
            Command::DeviceGetRegionInfo => {
1042
0
                let mut cmd = DeviceGetRegionInfo {
1043
0
                    header,
1044
0
                    ..Default::default()
1045
0
                };
1046
0
                stream
1047
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1048
0
                    .map_err(Error::StreamRead)?;
1049
1050
0
                if cmd.region_info.index as usize > self.regions.len() {
1051
0
                    return Err(Error::InvalidInput);
1052
0
                }
1053
0
1054
0
                // TODO: Need to handle region capabilities e.g. sparse regions
1055
0
                let reply = DeviceGetRegionInfo {
1056
0
                    header: Header {
1057
0
                        message_id: cmd.header.message_id,
1058
0
                        command: Command::DeviceGetRegionInfo,
1059
0
                        flags: HeaderFlags::Reply as u32,
1060
0
                        message_size: size_of::<DeviceGetRegionInfo>() as u32,
1061
0
                        ..Default::default()
1062
0
                    },
1063
0
                    region_info: self.regions[cmd.region_info.index as usize],
1064
0
                };
1065
0
                stream
1066
0
                    .write_all(reply.as_slice())
1067
0
                    .map_err(Error::StreamWrite)?;
1068
            }
1069
            Command::GetIrqInfo => {
1070
0
                let mut cmd = GetIrqInfo {
1071
0
                    header,
1072
0
                    ..Default::default()
1073
0
                };
1074
0
                stream
1075
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1076
0
                    .map_err(Error::StreamRead)?;
1077
1078
0
                if cmd.index as usize > self.irqs.len() {
1079
0
                    return Err(Error::InvalidInput);
1080
0
                }
1081
0
1082
0
                let irq = &self.irqs[cmd.index as usize];
1083
0
1084
0
                let reply = GetIrqInfo {
1085
0
                    header: Header {
1086
0
                        message_id: cmd.header.message_id,
1087
0
                        command: Command::GetIrqInfo,
1088
0
                        flags: HeaderFlags::Reply as u32,
1089
0
                        message_size: size_of::<GetIrqInfo>() as u32,
1090
0
                        ..Default::default()
1091
0
                    },
1092
0
                    argsz: (size_of::<GetIrqInfo>() - size_of::<Header>()) as u32,
1093
0
                    index: irq.index,
1094
0
                    flags: irq.flags,
1095
0
                    count: irq.count,
1096
0
                };
1097
0
                stream
1098
0
                    .write_all(reply.as_slice())
1099
0
                    .map_err(Error::StreamWrite)?;
1100
            }
1101
            Command::SetIrqs => {
1102
0
                let mut cmd = SetIrqs {
1103
0
                    header,
1104
0
                    ..Default::default()
1105
0
                };
1106
0
                stream
1107
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1108
0
                    .map_err(Error::StreamRead)?;
1109
1110
0
                if cmd.index as usize > self.irqs.len() {
1111
0
                    return Err(Error::InvalidInput);
1112
0
                }
1113
0
1114
0
                if cmd.flags & VFIO_IRQ_SET_DATA_BOOL > 0 {
1115
0
                    return Err(Error::UnsupportedFeature);
1116
0
                }
1117
0
1118
0
                backend
1119
0
                    .set_irqs(cmd.index, cmd.flags, cmd.start, cmd.count, fds)
1120
0
                    .map_err(Error::Backend)?;
1121
1122
0
                let reply = Header {
1123
0
                    message_id: cmd.header.message_id,
1124
0
                    command: Command::SetIrqs,
1125
0
                    flags: HeaderFlags::Reply as u32,
1126
0
                    message_size: size_of::<Header>() as u32,
1127
0
                    ..Default::default()
1128
0
                };
1129
0
                stream
1130
0
                    .write_all(reply.as_slice())
1131
0
                    .map_err(Error::StreamWrite)?;
1132
            }
1133
            Command::RegionRead => {
1134
0
                let mut cmd = RegionAccess {
1135
0
                    header,
1136
0
                    ..Default::default()
1137
0
                };
1138
0
                stream
1139
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1140
0
                    .map_err(Error::StreamRead)?;
1141
1142
0
                let (region, offset, count) = (cmd.region, cmd.offset, cmd.count);
1143
0
1144
0
                if region as usize > self.regions.len() {
1145
0
                    return Err(Error::InvalidInput);
1146
0
                }
1147
0
1148
0
                let mut data = vec![0u8; count as usize];
1149
0
                backend
1150
0
                    .region_read(region, offset, &mut data)
1151
0
                    .map_err(Error::Backend)?;
1152
1153
0
                let reply = RegionAccess {
1154
0
                    header: Header {
1155
0
                        message_id: cmd.header.message_id,
1156
0
                        command: Command::RegionRead,
1157
0
                        flags: HeaderFlags::Reply as u32,
1158
0
                        message_size: size_of::<RegionAccess>() as u32 + count,
1159
0
                        ..Default::default()
1160
0
                    },
1161
0
                    region,
1162
0
                    offset,
1163
0
                    count,
1164
0
                };
1165
0
                stream
1166
0
                    .write_all(reply.as_slice())
1167
0
                    .map_err(Error::StreamWrite)?;
1168
0
                stream.write_all(&data).map_err(Error::StreamWrite)?;
1169
            }
1170
            Command::RegionWrite => {
1171
0
                let mut cmd = RegionAccess {
1172
0
                    header,
1173
0
                    ..Default::default()
1174
0
                };
1175
0
                stream
1176
0
                    .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
1177
0
                    .map_err(Error::StreamRead)?;
1178
1179
0
                let (region, offset, count) = (cmd.region, cmd.offset, cmd.count);
1180
0
1181
0
                if region as usize > self.regions.len() {
1182
0
                    return Err(Error::InvalidInput);
1183
0
                }
1184
0
1185
0
                let mut data = vec![0u8; count as usize];
1186
0
                stream.read_exact(&mut data).map_err(Error::StreamRead)?;
1187
0
                backend
1188
0
                    .region_write(region, offset, &data)
1189
0
                    .map_err(Error::Backend)?;
1190
1191
0
                let reply = RegionAccess {
1192
0
                    header: Header {
1193
0
                        message_id: cmd.header.message_id,
1194
0
                        command: Command::RegionWrite,
1195
0
                        flags: HeaderFlags::Reply as u32,
1196
0
                        message_size: size_of::<RegionAccess>() as u32,
1197
0
                        ..Default::default()
1198
0
                    },
1199
0
                    region,
1200
0
                    offset,
1201
0
                    count,
1202
0
                };
1203
0
                stream
1204
0
                    .write_all(reply.as_slice())
1205
0
                    .map_err(Error::StreamWrite)?;
1206
            }
1207
            Command::DeviceReset => {
1208
0
                backend.reset().map_err(Error::Backend)?;
1209
0
                let reply = Header {
1210
0
                    message_id: header.message_id,
1211
0
                    command: Command::DeviceReset,
1212
0
                    flags: HeaderFlags::Reply as u32,
1213
0
                    message_size: size_of::<Header>() as u32,
1214
0
                    ..Default::default()
1215
0
                };
1216
0
                stream
1217
0
                    .write_all(reply.as_slice())
1218
0
                    .map_err(Error::StreamWrite)?;
1219
            }
1220
        }
1221
1222
0
        Ok(())
1223
0
    }
1224
1225
0
    pub fn run(&self, backend: &mut dyn ServerBackend) -> Result<(), Error> {
1226
0
        let (mut stream, _) = self.listener.accept().map_err(Error::SocketAccept)?;
1227
1228
0
        loop {
1229
0
            let mut header = Header::default();
1230
0
1231
0
            // The maximum number of FDs that can be sent is 16 so that is
1232
0
            // also the maximum that can be received.
1233
0
            let mut fds = vec![0; 16];
1234
0
            let mut iovecs = vec![iovec {
1235
0
                iov_base: header.as_mut_slice().as_mut_ptr() as *mut c_void,
1236
0
                iov_len: header.as_mut_slice().len(),
1237
0
            }];
1238
            // SAFETY: Safe as the iovect is correctly initialised and fds is big enough
1239
0
            let (bytes, fds_received) = unsafe {
1240
0
                stream
1241
0
                    .recv_with_fds(&mut iovecs, &mut fds)
1242
0
                    .map_err(Error::ReceiveWithFd)?
1243
            };
1244
1245
            // Other end closed connection
1246
0
            if bytes == 0 {
1247
0
                info!("Connection closed");
1248
0
                break;
1249
0
            }
1250
0
1251
0
            fds.resize(fds_received, 0);
1252
0
1253
0
            let fds: Vec<File> = fds
1254
0
                .iter()
1255
0
                // SAFETY: Safe as we have only valid FDs in the vector now
1256
0
                .map(|fd| unsafe { File::from_raw_fd(*fd) })
1257
0
                .collect();
1258
1259
0
            if let Err(e) = self.handle_command(backend, &mut stream, header, fds) {
1260
0
                error!("Error handling command: {:?}: {e}", header.command);
1261
0
                let reply = Header {
1262
0
                    message_id: header.message_id,
1263
0
                    command: header.command,
1264
0
                    flags: HeaderFlags::Error as u32,
1265
0
                    message_size: size_of::<Header>() as u32,
1266
0
                    error: if matches!(e, Error::InvalidInput) {
1267
0
                        EINVAL as u32
1268
                    } else {
1269
0
                        0
1270
                    },
1271
                };
1272
0
                stream
1273
0
                    .write_all(reply.as_slice())
1274
0
                    .map_err(Error::StreamWrite)?;
1275
0
            }
1276
        }
1277
1278
0
        Ok(())
1279
0
    }
1280
}