Coverage Report

Created: 2025-02-25 06:39

/src/ztunnel/src/inpod/protocol.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use super::istio::zds::{self, Ack, Version, WorkloadRequest, WorkloadResponse, ZdsHello};
16
use super::{WorkloadData, WorkloadMessage};
17
use crate::drain::DrainWatcher;
18
use nix::sys::socket::{recvmsg, sendmsg, ControlMessageOwned, MsgFlags};
19
use prost::Message;
20
use std::io::{IoSlice, IoSliceMut};
21
use std::os::fd::OwnedFd;
22
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
23
use tokio::net::UnixStream;
24
use tracing::{debug, info, warn};
25
use zds::workload_request::Payload;
26
27
// Not dead code, but automock confuses Rust otherwise when built with certain targets
28
#[allow(dead_code)]
29
pub struct WorkloadStreamProcessor {
30
    stream: UnixStream,
31
    drain: DrainWatcher,
32
}
33
34
#[allow(dead_code)]
35
impl WorkloadStreamProcessor {
36
0
    pub fn new(stream: UnixStream, drain: DrainWatcher) -> Self {
37
0
        WorkloadStreamProcessor { stream, drain }
38
0
    }
39
40
0
    pub async fn send_hello(&mut self) -> std::io::Result<()> {
41
0
        let r = ZdsHello {
42
0
            version: Version::V1 as i32,
43
0
        };
44
0
        self.send_msg(r).await
45
0
    }
46
47
0
    pub async fn send_ack(&mut self) -> std::io::Result<()> {
48
0
        let r = WorkloadResponse {
49
0
            payload: Some(zds::workload_response::Payload::Ack(Ack {
50
0
                error: String::new(),
51
0
            })),
52
0
        };
53
0
        self.send_msg(r).await
54
0
    }
55
0
    pub async fn send_nack(&mut self, e: anyhow::Error) -> std::io::Result<()> {
56
0
        let r = WorkloadResponse {
57
0
            payload: Some(zds::workload_response::Payload::Ack(Ack {
58
0
                error: e.to_string(),
59
0
            })),
60
0
        };
61
0
        self.send_msg(r).await
62
0
    }
63
64
0
    async fn send_msg<T: prost::Message + 'static>(&mut self, r: T) -> std::io::Result<()> {
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>
65
0
        let mut buf = Vec::new();
66
0
        r.encode(&mut buf).unwrap();
67
0
68
0
        let iov = [IoSlice::new(&buf)];
69
0
        let raw_fd = self.stream.as_raw_fd();
70
0
71
0
        // async_io takes care of WouldBlock error, so no need for loop here
72
0
        self.stream
73
0
            .async_io(tokio::io::Interest::WRITABLE, || {
74
0
                sendmsg::<()>(raw_fd, &iov[..], &[], MsgFlags::empty(), None)
75
0
                    .map_err(|e| std::io::Error::from_raw_os_error(e as i32))
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#0}::{closure#0}
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#0}::{closure#0}
76
0
            })
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#0}
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#0}
77
0
            .await
78
0
            .map(|_| ())
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#1}
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#1}
79
0
    }
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}
Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}
80
0
    pub async fn read_message(&self) -> anyhow::Result<Option<WorkloadMessage>> {
81
0
        // TODO: support messages for removing workload
82
0
        let mut buffer = vec![0u8; 1024];
83
0
        let (flags, maybe_our_fd, len) = {
84
0
            let mut cmsgspace = nix::cmsg_space!(RawFd);
85
0
            let raw_fd = self.stream.as_raw_fd();
86
0
87
0
            // can't use async_io here as the borrow checker doesn't like it. i get it..
88
0
            let msgspace_ref = cmsgspace.as_mut();
89
0
            let mut iov = [IoSliceMut::new(&mut buffer)];
90
91
0
            let res = loop {
92
0
                tokio::select! {
93
                    biased; // check drain first, so we don't read from the socket if we are draining.
94
0
                    _ =   self.drain.clone().wait_for_drain() => {
95
0
                        info!("workload proxy manager: drain requested");
96
0
                        return Ok(None);
97
                    }
98
0
                    res =  self.stream.readable() => res,
99
0
                }?;
100
101
0
                let res = self.stream.try_io(tokio::io::Interest::READABLE, || {
102
0
                    recvmsg::<()>(
103
0
                        raw_fd,
104
0
                        &mut iov,
105
0
                        Some(msgspace_ref),
106
0
                        MsgFlags::MSG_CMSG_CLOEXEC,
107
0
                    )
108
0
                    .map_err(|e| std::io::Error::from_raw_os_error(e as i32))
109
0
                });
110
0
                let ok_res = match res {
111
0
                    Ok(res) => {
112
0
                        if res.bytes == 0 {
113
0
                            return Ok(None);
114
0
                        }
115
0
                        res
116
                    }
117
0
                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
118
0
                        continue;
119
                    }
120
0
                    Err(e) => {
121
0
                        return Err(e.into());
122
                    }
123
                };
124
0
                break ok_res;
125
            };
126
127
            // call maybe_get_fd first (and not get_info_from_data), so that if it fails we will close the FDs.
128
0
            let maybe_our_fd = maybe_get_fd(res.cmsgs()?)?;
129
0
            let flags = res.flags;
130
0
            (flags, maybe_our_fd, res.bytes)
131
0
        };
132
0
133
0
        get_workload_data(&buffer[..len], maybe_our_fd, flags).map(Some)
134
0
    }
135
}
136
137
0
fn get_workload_data(
138
0
    data: &[u8],
139
0
    maybe_our_fd: Option<std::os::fd::OwnedFd>,
140
0
    flags: MsgFlags,
141
0
) -> anyhow::Result<WorkloadMessage> {
142
0
    // do all other checks after we parsed fds, so no leaks happen.
143
0
    if flags.contains(MsgFlags::MSG_TRUNC) {
144
        // TODO: add metrics
145
0
        anyhow::bail!("received truncated message");
146
0
    }
147
0
148
0
    if flags.contains(MsgFlags::MSG_CTRUNC) {
149
        // TODO: add metrics
150
0
        anyhow::bail!("received truncated message");
151
0
    }
152
153
0
    let req = get_info_from_data(data)?;
154
0
    let payload = req.payload.ok_or(anyhow::anyhow!("no payload"))?;
155
0
    match (payload, maybe_our_fd) {
156
0
        (Payload::Add(a), Some(our_netns)) => {
157
0
            let uid = a.uid;
158
0
            Ok(WorkloadMessage::AddWorkload(WorkloadData {
159
0
                netns: our_netns,
160
0
                workload_uid: super::WorkloadUid::new(uid),
161
0
                workload_info: a.workload_info,
162
0
            }))
163
        }
164
0
        (Payload::Add(_), None) => Err(anyhow::anyhow!("No control message")),
165
        // anything other than Add shouldn't have FDs
166
0
        (_, Some(_)) => Err(anyhow::anyhow!("Unexpected control message")),
167
0
        (Payload::Keep(k), None) => Ok(WorkloadMessage::KeepWorkload(super::WorkloadUid::new(
168
0
            k.uid,
169
0
        ))),
170
0
        (Payload::Del(d), None) => Ok(WorkloadMessage::DelWorkload(super::WorkloadUid::new(d.uid))),
171
0
        (Payload::SnapshotSent(_), None) => Ok(WorkloadMessage::WorkloadSnapshotSent),
172
    }
173
0
}
174
175
0
fn get_info_from_data<'a>(data: impl bytes::Buf + 'a) -> anyhow::Result<WorkloadRequest> {
176
0
    Ok(WorkloadRequest::decode(data)?)
177
0
}
178
179
0
fn maybe_get_fd(
180
0
    cmsgs: impl Iterator<Item = ControlMessageOwned>,
181
0
) -> anyhow::Result<Option<std::os::fd::OwnedFd>> {
182
0
    let mut our_netns = None;
183
0
    let mut total_fds = 0;
184
0
    for cmsg in cmsgs {
185
0
        match cmsg {
186
0
            ControlMessageOwned::ScmRights(fds) => {
187
0
                let len = fds.len();
188
0
                total_fds += len;
189
0
                if total_fds != 1 {
190
0
                    for fd in fds {
191
0
                        // fds in the vector are ours, so own them and drop them so they are closed (prevent resource leak).
192
0
                        // Safety: ScmRights returns a list of FDs that we own, so we can safely drop them.
193
0
                        std::mem::drop(unsafe { std::os::fd::OwnedFd::from_raw_fd(fd) });
194
0
                    }
195
                } else {
196
                    // Safety: ScmRights returns FDs opened by the kernel for us, so we can
197
                    // safely own it.
198
0
                    our_netns = Some(unsafe { std::os::fd::OwnedFd::from_raw_fd(fds[0]) })
199
                }
200
            }
201
0
            u => {
202
0
                warn!("Unexpected control message {:?}", u);
203
0
                continue;
204
            }
205
        }
206
    }
207
    // only check for errors once we are done parsing all FDs
208
0
    if total_fds > 1 {
209
0
        anyhow::bail!("Expected 1 FD, got {}", total_fds);
210
0
    }
211
212
    // make sure that we got a netns FD.
213
0
    if let Some(our_netns) = &our_netns {
214
        // validate that the fd we got is a netns. This should never happen, and is here
215
        // to catch potential bugs in the node agent during development.
216
0
        debug!("Validating netns FD: {:?}", validate_ns(our_netns));
217
0
    }
218
219
0
    Ok(our_netns)
220
0
}
221
222
0
fn validate_ns(fd: &OwnedFd) -> anyhow::Result<()> {
223
    // on newer kernels we can get the ns type! note that this doesn't work on older kernels.
224
    // so an error doesn't mean its not a netns.
225
    // #define NSIO 0xb7
226
    const NSIO: u8 = 0xb7;
227
    // #define NS_GET_NSTYPE    _IO(NSIO, 0x3)
228
    const NS_GET_NSTYPE: u8 = 0x3;
229
    nix::ioctl_none!(get_ns_type, NSIO, NS_GET_NSTYPE);
230
0
    let nstype = unsafe { get_ns_type(fd.as_raw_fd()) };
231
0
    if let Ok(nstype) = nstype {
232
        // ignore errors in case we are in an old kernel
233
0
        if nstype != nix::libc::CLONE_NEWNET {
234
0
            anyhow::bail!("Unexpected ns type: {:?}", nstype);
235
        } else {
236
0
            debug!("FD {:?} type is netns", fd);
237
        }
238
    } else {
239
        // can get ns type, do a different check - that the fd came from the nsfs.
240
0
        let data = nix::sys::statfs::fstatfs(fd)?;
241
0
        let f_type = data.filesystem_type();
242
0
        if f_type != nix::sys::statfs::PROC_SUPER_MAGIC && f_type != nix::sys::statfs::NSFS_MAGIC {
243
0
            anyhow::bail!("Unexpected FD type for netns: {:?}", f_type);
244
0
        }
245
    }
246
247
0
    debug!("FD {:?} looks like a netns", fd);
248
0
    Ok(())
249
0
}
250
251
#[cfg(test)]
252
mod tests {
253
    use std::os::fd::OwnedFd;
254
255
    use super::super::istio;
256
    use super::*;
257
    use crate::inpod::test_helpers::uid;
258
259
    use nix::sys::socket::MsgFlags;
260
    // Helpers to test get_workload_data_from_parts
261
262
    fn prep_request(p: zds::workload_request::Payload) -> Vec<u8> {
263
        let r = WorkloadRequest { payload: Some(p) };
264
        r.encode_to_vec()
265
    }
266
267
    #[test]
268
    fn test_parse_add_workload() {
269
        let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into();
270
        let flags = MsgFlags::empty();
271
        let data = prep_request(zds::workload_request::Payload::Add(
272
            istio::zds::AddWorkload {
273
                uid: uid(0).into_string(),
274
                ..Default::default()
275
            },
276
        ));
277
278
        let m = get_workload_data(&data[..], Some(owned_fd), flags).unwrap();
279
280
        assert!(matches!(m, WorkloadMessage::AddWorkload(_)));
281
    }
282
283
    #[test]
284
    fn test_parse_add_workload_with_info() {
285
        let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into();
286
        let flags = MsgFlags::empty();
287
        let wi = zds::WorkloadInfo {
288
            name: "test".to_string(),
289
            namespace: "default".to_string(),
290
            service_account: "defaultsvc".to_string(),
291
        };
292
        let uid = uid(0);
293
        let data = prep_request(zds::workload_request::Payload::Add(
294
            istio::zds::AddWorkload {
295
                uid: uid.clone().into_string(),
296
                workload_info: Some(wi.clone()),
297
            },
298
        ));
299
300
        let m = get_workload_data(&data[..], Some(owned_fd), flags).unwrap();
301
302
        match m {
303
            WorkloadMessage::AddWorkload(data) => {
304
                assert_eq!(data.workload_info, Some(wi));
305
                assert_eq!(data.workload_uid, uid);
306
            }
307
            _ => panic!("unexpected message"),
308
        }
309
    }
310
311
    #[test]
312
    fn test_parse_del_workload_with_fds_fails() {
313
        let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into();
314
        let flags = MsgFlags::empty();
315
        let data = prep_request(zds::workload_request::Payload::Del(
316
            istio::zds::DelWorkload {
317
                uid: uid(0).into_string(),
318
            },
319
        ));
320
321
        let res = get_workload_data(&data[..], Some(owned_fd), flags);
322
        assert!(res.is_err());
323
    }
324
325
    #[test]
326
    fn test_parse_del_workload() {
327
        let flags = MsgFlags::empty();
328
        let data = prep_request(zds::workload_request::Payload::Del(
329
            istio::zds::DelWorkload {
330
                uid: uid(0).into_string(),
331
            },
332
        ));
333
334
        let res = get_workload_data(&data[..], None, flags).unwrap();
335
        assert!(matches!(res, WorkloadMessage::DelWorkload(_)));
336
    }
337
}