/src/ztunnel/src/inpod/protocol.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright Istio Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use super::istio::zds::{self, Ack, Version, WorkloadRequest, WorkloadResponse, ZdsHello}; |
16 | | use super::{WorkloadData, WorkloadMessage}; |
17 | | use crate::drain::DrainWatcher; |
18 | | use nix::sys::socket::{recvmsg, sendmsg, ControlMessageOwned, MsgFlags}; |
19 | | use prost::Message; |
20 | | use std::io::{IoSlice, IoSliceMut}; |
21 | | use std::os::fd::OwnedFd; |
22 | | use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; |
23 | | use tokio::net::UnixStream; |
24 | | use tracing::{debug, info, warn}; |
25 | | use zds::workload_request::Payload; |
26 | | |
27 | | // Not dead code, but automock confuses Rust otherwise when built with certain targets |
28 | | #[allow(dead_code)] |
29 | | pub struct WorkloadStreamProcessor { |
30 | | stream: UnixStream, |
31 | | drain: DrainWatcher, |
32 | | } |
33 | | |
34 | | #[allow(dead_code)] |
35 | | impl WorkloadStreamProcessor { |
36 | 0 | pub fn new(stream: UnixStream, drain: DrainWatcher) -> Self { |
37 | 0 | WorkloadStreamProcessor { stream, drain } |
38 | 0 | } |
39 | | |
40 | 0 | pub async fn send_hello(&mut self) -> std::io::Result<()> { |
41 | 0 | let r = ZdsHello { |
42 | 0 | version: Version::V1 as i32, |
43 | 0 | }; |
44 | 0 | self.send_msg(r).await |
45 | 0 | } |
46 | | |
47 | 0 | pub async fn send_ack(&mut self) -> std::io::Result<()> { |
48 | 0 | let r = WorkloadResponse { |
49 | 0 | payload: Some(zds::workload_response::Payload::Ack(Ack { |
50 | 0 | error: String::new(), |
51 | 0 | })), |
52 | 0 | }; |
53 | 0 | self.send_msg(r).await |
54 | 0 | } |
55 | 0 | pub async fn send_nack(&mut self, e: anyhow::Error) -> std::io::Result<()> { |
56 | 0 | let r = WorkloadResponse { |
57 | 0 | payload: Some(zds::workload_response::Payload::Ack(Ack { |
58 | 0 | error: e.to_string(), |
59 | 0 | })), |
60 | 0 | }; |
61 | 0 | self.send_msg(r).await |
62 | 0 | } |
63 | | |
64 | 0 | async fn send_msg<T: prost::Message + 'static>(&mut self, r: T) -> std::io::Result<()> { Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse> Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello> |
65 | 0 | let mut buf = Vec::new(); |
66 | 0 | r.encode(&mut buf).unwrap(); |
67 | 0 |
|
68 | 0 | let iov = [IoSlice::new(&buf)]; |
69 | 0 | let raw_fd = self.stream.as_raw_fd(); |
70 | 0 |
|
71 | 0 | // async_io takes care of WouldBlock error, so no need for loop here |
72 | 0 | self.stream |
73 | 0 | .async_io(tokio::io::Interest::WRITABLE, || { |
74 | 0 | sendmsg::<()>(raw_fd, &iov[..], &[], MsgFlags::empty(), None) |
75 | 0 | .map_err(|e| std::io::Error::from_raw_os_error(e as i32)) Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#0}::{closure#0} Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#0}::{closure#0} |
76 | 0 | }) Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#0} Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#0} |
77 | 0 | .await |
78 | 0 | .map(|_| ()) Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0}::{closure#1} Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0}::{closure#1} |
79 | 0 | } Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::WorkloadResponse>::{closure#0} Unexecuted instantiation: <ztunnel::inpod::protocol::WorkloadStreamProcessor>::send_msg::<ztunnel::inpod::istio::zds::ZdsHello>::{closure#0} |
80 | 0 | pub async fn read_message(&self) -> anyhow::Result<Option<WorkloadMessage>> { |
81 | 0 | // TODO: support messages for removing workload |
82 | 0 | let mut buffer = vec![0u8; 1024]; |
83 | 0 | let (flags, maybe_our_fd, len) = { |
84 | 0 | let mut cmsgspace = nix::cmsg_space!(RawFd); |
85 | 0 | let raw_fd = self.stream.as_raw_fd(); |
86 | 0 |
|
87 | 0 | // can't use async_io here as the borrow checker doesn't like it. i get it.. |
88 | 0 | let msgspace_ref = cmsgspace.as_mut(); |
89 | 0 | let mut iov = [IoSliceMut::new(&mut buffer)]; |
90 | | |
91 | 0 | let res = loop { |
92 | 0 | tokio::select! { |
93 | | biased; // check drain first, so we don't read from the socket if we are draining. |
94 | 0 | _ = self.drain.clone().wait_for_drain() => { |
95 | 0 | info!("workload proxy manager: drain requested"); |
96 | 0 | return Ok(None); |
97 | | } |
98 | 0 | res = self.stream.readable() => res, |
99 | 0 | }?; |
100 | | |
101 | 0 | let res = self.stream.try_io(tokio::io::Interest::READABLE, || { |
102 | 0 | recvmsg::<()>( |
103 | 0 | raw_fd, |
104 | 0 | &mut iov, |
105 | 0 | Some(msgspace_ref), |
106 | 0 | MsgFlags::MSG_CMSG_CLOEXEC, |
107 | 0 | ) |
108 | 0 | .map_err(|e| std::io::Error::from_raw_os_error(e as i32)) |
109 | 0 | }); |
110 | 0 | let ok_res = match res { |
111 | 0 | Ok(res) => { |
112 | 0 | if res.bytes == 0 { |
113 | 0 | return Ok(None); |
114 | 0 | } |
115 | 0 | res |
116 | | } |
117 | 0 | Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { |
118 | 0 | continue; |
119 | | } |
120 | 0 | Err(e) => { |
121 | 0 | return Err(e.into()); |
122 | | } |
123 | | }; |
124 | 0 | break ok_res; |
125 | | }; |
126 | | |
127 | | // call maybe_get_fd first (and not get_info_from_data), so that if it fails we will close the FDs. |
128 | 0 | let maybe_our_fd = maybe_get_fd(res.cmsgs()?)?; |
129 | 0 | let flags = res.flags; |
130 | 0 | (flags, maybe_our_fd, res.bytes) |
131 | 0 | }; |
132 | 0 |
|
133 | 0 | get_workload_data(&buffer[..len], maybe_our_fd, flags).map(Some) |
134 | 0 | } |
135 | | } |
136 | | |
137 | 0 | fn get_workload_data( |
138 | 0 | data: &[u8], |
139 | 0 | maybe_our_fd: Option<std::os::fd::OwnedFd>, |
140 | 0 | flags: MsgFlags, |
141 | 0 | ) -> anyhow::Result<WorkloadMessage> { |
142 | 0 | // do all other checks after we parsed fds, so no leaks happen. |
143 | 0 | if flags.contains(MsgFlags::MSG_TRUNC) { |
144 | | // TODO: add metrics |
145 | 0 | anyhow::bail!("received truncated message"); |
146 | 0 | } |
147 | 0 |
|
148 | 0 | if flags.contains(MsgFlags::MSG_CTRUNC) { |
149 | | // TODO: add metrics |
150 | 0 | anyhow::bail!("received truncated message"); |
151 | 0 | } |
152 | | |
153 | 0 | let req = get_info_from_data(data)?; |
154 | 0 | let payload = req.payload.ok_or(anyhow::anyhow!("no payload"))?; |
155 | 0 | match (payload, maybe_our_fd) { |
156 | 0 | (Payload::Add(a), Some(our_netns)) => { |
157 | 0 | let uid = a.uid; |
158 | 0 | Ok(WorkloadMessage::AddWorkload(WorkloadData { |
159 | 0 | netns: our_netns, |
160 | 0 | workload_uid: super::WorkloadUid::new(uid), |
161 | 0 | workload_info: a.workload_info, |
162 | 0 | })) |
163 | | } |
164 | 0 | (Payload::Add(_), None) => Err(anyhow::anyhow!("No control message")), |
165 | | // anything other than Add shouldn't have FDs |
166 | 0 | (_, Some(_)) => Err(anyhow::anyhow!("Unexpected control message")), |
167 | 0 | (Payload::Keep(k), None) => Ok(WorkloadMessage::KeepWorkload(super::WorkloadUid::new( |
168 | 0 | k.uid, |
169 | 0 | ))), |
170 | 0 | (Payload::Del(d), None) => Ok(WorkloadMessage::DelWorkload(super::WorkloadUid::new(d.uid))), |
171 | 0 | (Payload::SnapshotSent(_), None) => Ok(WorkloadMessage::WorkloadSnapshotSent), |
172 | | } |
173 | 0 | } |
174 | | |
175 | 0 | fn get_info_from_data<'a>(data: impl bytes::Buf + 'a) -> anyhow::Result<WorkloadRequest> { |
176 | 0 | Ok(WorkloadRequest::decode(data)?) |
177 | 0 | } |
178 | | |
179 | 0 | fn maybe_get_fd( |
180 | 0 | cmsgs: impl Iterator<Item = ControlMessageOwned>, |
181 | 0 | ) -> anyhow::Result<Option<std::os::fd::OwnedFd>> { |
182 | 0 | let mut our_netns = None; |
183 | 0 | let mut total_fds = 0; |
184 | 0 | for cmsg in cmsgs { |
185 | 0 | match cmsg { |
186 | 0 | ControlMessageOwned::ScmRights(fds) => { |
187 | 0 | let len = fds.len(); |
188 | 0 | total_fds += len; |
189 | 0 | if total_fds != 1 { |
190 | 0 | for fd in fds { |
191 | 0 | // fds in the vector are ours, so own them and drop them so they are closed (prevent resource leak). |
192 | 0 | // Safety: ScmRights returns a list of FDs that we own, so we can safely drop them. |
193 | 0 | std::mem::drop(unsafe { std::os::fd::OwnedFd::from_raw_fd(fd) }); |
194 | 0 | } |
195 | | } else { |
196 | | // Safety: ScmRights returns FDs opened by the kernel for us, so we can |
197 | | // safely own it. |
198 | 0 | our_netns = Some(unsafe { std::os::fd::OwnedFd::from_raw_fd(fds[0]) }) |
199 | | } |
200 | | } |
201 | 0 | u => { |
202 | 0 | warn!("Unexpected control message {:?}", u); |
203 | 0 | continue; |
204 | | } |
205 | | } |
206 | | } |
207 | | // only check for errors once we are done parsing all FDs |
208 | 0 | if total_fds > 1 { |
209 | 0 | anyhow::bail!("Expected 1 FD, got {}", total_fds); |
210 | 0 | } |
211 | | |
212 | | // make sure that we got a netns FD. |
213 | 0 | if let Some(our_netns) = &our_netns { |
214 | | // validate that the fd we got is a netns. This should never happen, and is here |
215 | | // to catch potential bugs in the node agent during development. |
216 | 0 | debug!("Validating netns FD: {:?}", validate_ns(our_netns)); |
217 | 0 | } |
218 | | |
219 | 0 | Ok(our_netns) |
220 | 0 | } |
221 | | |
222 | 0 | fn validate_ns(fd: &OwnedFd) -> anyhow::Result<()> { |
223 | | // on newer kernels we can get the ns type! note that this doesn't work on older kernels. |
224 | | // so an error doesn't mean its not a netns. |
225 | | // #define NSIO 0xb7 |
226 | | const NSIO: u8 = 0xb7; |
227 | | // #define NS_GET_NSTYPE _IO(NSIO, 0x3) |
228 | | const NS_GET_NSTYPE: u8 = 0x3; |
229 | | nix::ioctl_none!(get_ns_type, NSIO, NS_GET_NSTYPE); |
230 | 0 | let nstype = unsafe { get_ns_type(fd.as_raw_fd()) }; |
231 | 0 | if let Ok(nstype) = nstype { |
232 | | // ignore errors in case we are in an old kernel |
233 | 0 | if nstype != nix::libc::CLONE_NEWNET { |
234 | 0 | anyhow::bail!("Unexpected ns type: {:?}", nstype); |
235 | | } else { |
236 | 0 | debug!("FD {:?} type is netns", fd); |
237 | | } |
238 | | } else { |
239 | | // can get ns type, do a different check - that the fd came from the nsfs. |
240 | 0 | let data = nix::sys::statfs::fstatfs(fd)?; |
241 | 0 | let f_type = data.filesystem_type(); |
242 | 0 | if f_type != nix::sys::statfs::PROC_SUPER_MAGIC && f_type != nix::sys::statfs::NSFS_MAGIC { |
243 | 0 | anyhow::bail!("Unexpected FD type for netns: {:?}", f_type); |
244 | 0 | } |
245 | | } |
246 | | |
247 | 0 | debug!("FD {:?} looks like a netns", fd); |
248 | 0 | Ok(()) |
249 | 0 | } |
250 | | |
251 | | #[cfg(test)] |
252 | | mod tests { |
253 | | use std::os::fd::OwnedFd; |
254 | | |
255 | | use super::super::istio; |
256 | | use super::*; |
257 | | use crate::inpod::test_helpers::uid; |
258 | | |
259 | | use nix::sys::socket::MsgFlags; |
260 | | // Helpers to test get_workload_data_from_parts |
261 | | |
262 | | fn prep_request(p: zds::workload_request::Payload) -> Vec<u8> { |
263 | | let r = WorkloadRequest { payload: Some(p) }; |
264 | | r.encode_to_vec() |
265 | | } |
266 | | |
267 | | #[test] |
268 | | fn test_parse_add_workload() { |
269 | | let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into(); |
270 | | let flags = MsgFlags::empty(); |
271 | | let data = prep_request(zds::workload_request::Payload::Add( |
272 | | istio::zds::AddWorkload { |
273 | | uid: uid(0).into_string(), |
274 | | ..Default::default() |
275 | | }, |
276 | | )); |
277 | | |
278 | | let m = get_workload_data(&data[..], Some(owned_fd), flags).unwrap(); |
279 | | |
280 | | assert!(matches!(m, WorkloadMessage::AddWorkload(_))); |
281 | | } |
282 | | |
283 | | #[test] |
284 | | fn test_parse_add_workload_with_info() { |
285 | | let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into(); |
286 | | let flags = MsgFlags::empty(); |
287 | | let wi = zds::WorkloadInfo { |
288 | | name: "test".to_string(), |
289 | | namespace: "default".to_string(), |
290 | | service_account: "defaultsvc".to_string(), |
291 | | }; |
292 | | let uid = uid(0); |
293 | | let data = prep_request(zds::workload_request::Payload::Add( |
294 | | istio::zds::AddWorkload { |
295 | | uid: uid.clone().into_string(), |
296 | | workload_info: Some(wi.clone()), |
297 | | }, |
298 | | )); |
299 | | |
300 | | let m = get_workload_data(&data[..], Some(owned_fd), flags).unwrap(); |
301 | | |
302 | | match m { |
303 | | WorkloadMessage::AddWorkload(data) => { |
304 | | assert_eq!(data.workload_info, Some(wi)); |
305 | | assert_eq!(data.workload_uid, uid); |
306 | | } |
307 | | _ => panic!("unexpected message"), |
308 | | } |
309 | | } |
310 | | |
311 | | #[test] |
312 | | fn test_parse_del_workload_with_fds_fails() { |
313 | | let owned_fd: OwnedFd = std::fs::File::open("/dev/null").unwrap().into(); |
314 | | let flags = MsgFlags::empty(); |
315 | | let data = prep_request(zds::workload_request::Payload::Del( |
316 | | istio::zds::DelWorkload { |
317 | | uid: uid(0).into_string(), |
318 | | }, |
319 | | )); |
320 | | |
321 | | let res = get_workload_data(&data[..], Some(owned_fd), flags); |
322 | | assert!(res.is_err()); |
323 | | } |
324 | | |
325 | | #[test] |
326 | | fn test_parse_del_workload() { |
327 | | let flags = MsgFlags::empty(); |
328 | | let data = prep_request(zds::workload_request::Payload::Del( |
329 | | istio::zds::DelWorkload { |
330 | | uid: uid(0).into_string(), |
331 | | }, |
332 | | )); |
333 | | |
334 | | let res = get_workload_data(&data[..], None, flags).unwrap(); |
335 | | assert!(matches!(res, WorkloadMessage::DelWorkload(_))); |
336 | | } |
337 | | } |