/src/ztunnel/src/proxy/connection_manager.rs
Line | Count | Source |
1 | | // Copyright Istio Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use crate::proxy::Error; |
16 | | |
17 | | use crate::state::DemandProxyState; |
18 | | use crate::state::ProxyRbacContext; |
19 | | use serde::{Serialize, Serializer}; |
20 | | use std::collections::hash_map::Entry; |
21 | | use std::collections::{HashMap, HashSet}; |
22 | | use std::fmt::Formatter; |
23 | | use std::net::SocketAddr; |
24 | | |
25 | | use crate::drain; |
26 | | use crate::drain::{DrainTrigger, DrainWatcher}; |
27 | | use crate::state::workload::{InboundProtocol, OutboundProtocol}; |
28 | | use std::sync::Arc; |
29 | | use std::sync::RwLock; |
30 | | use tracing::{debug, error, info, warn}; |
31 | | |
32 | | struct ConnectionDrain { |
33 | | // TODO: this should almost certainly be changed to a type which has counted references exposed. |
34 | | // tokio::sync::watch can be subscribed without taking a write lock and exposes references |
35 | | // and also a receiver_count method |
36 | | tx: DrainTrigger, |
37 | | rx: DrainWatcher, |
38 | | count: usize, |
39 | | } |
40 | | |
41 | | impl ConnectionDrain { |
42 | 0 | fn new() -> Self { |
43 | 0 | let (tx, rx) = drain::new(); |
44 | 0 | ConnectionDrain { tx, rx, count: 1 } |
45 | 0 | } |
46 | | |
47 | | /// drain drops the internal reference to rx and then signals drain on the tx |
48 | | // always inline, this is for convenience so that we don't forget to drop the rx but there's really no reason it needs to grow the stack |
49 | | #[inline(always)] |
50 | 0 | async fn drain(self) { |
51 | 0 | drop(self.rx); // very important, drain cannot complete if there are outstand rx |
52 | 0 | self.tx |
53 | 0 | .start_drain_and_wait(drain::DrainMode::Immediate) |
54 | 0 | .await; |
55 | 0 | } |
56 | | } |
57 | | |
58 | | #[derive(Clone)] |
59 | | pub struct ConnectionManager { |
60 | | drains: Arc<RwLock<HashMap<InboundConnection, ConnectionDrain>>>, |
61 | | outbound_connections: Arc<RwLock<HashSet<OutboundConnection>>>, |
62 | | } |
63 | | |
64 | | impl std::fmt::Debug for ConnectionManager { |
65 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
66 | 0 | f.debug_struct("ConnectionManager").finish() |
67 | 0 | } |
68 | | } |
69 | | |
70 | | impl Default for ConnectionManager { |
71 | 0 | fn default() -> Self { |
72 | 0 | ConnectionManager { |
73 | 0 | drains: Arc::new(RwLock::new(HashMap::new())), |
74 | 0 | outbound_connections: Arc::new(RwLock::new(HashSet::new())), |
75 | 0 | } |
76 | 0 | } |
77 | | } |
78 | | |
79 | | pub struct ConnectionGuard { |
80 | | cm: ConnectionManager, |
81 | | conn: InboundConnection, |
82 | | watch: Option<DrainWatcher>, |
83 | | } |
84 | | |
85 | | // For reasons that I don't fully understand, this uses an obscene amount of stack space when written as a normal function, |
86 | | // amounting to ~1kb overhead per connection. |
87 | | // Inlining it removes this entirely, and the macro ensures we do it consistently across the various areas we use it. |
88 | | #[macro_export] |
89 | | macro_rules! handle_connection { |
90 | | ($connguard:expr, $future:expr) => {{ |
91 | | let watch = $connguard.watcher(); |
92 | | tokio::select! { |
93 | | res = $future => { |
94 | | $connguard.release(); |
95 | | res |
96 | | } |
97 | | _signaled = watch.wait_for_drain() => Err(Error::AuthorizationPolicyLateRejection) |
98 | | } |
99 | | }}; |
100 | | } |
101 | | |
102 | | impl ConnectionGuard { |
103 | 0 | pub fn watcher(&mut self) -> drain::DrainWatcher { |
104 | 0 | self.watch.take().expect("watch cannot be taken twice") |
105 | 0 | } |
106 | 0 | pub fn release(self) { |
107 | 0 | self.cm.release(&self.conn); |
108 | 0 | } |
109 | | } |
110 | | |
111 | | impl Drop for ConnectionGuard { |
112 | 0 | fn drop(&mut self) { |
113 | 0 | if self.watch.is_some() { |
114 | 0 | debug!("rbac context {:?} auto-dropped", &self.conn); |
115 | 0 | self.cm.release(&self.conn) |
116 | 0 | } |
117 | 0 | } |
118 | | } |
119 | | |
120 | | pub struct OutboundConnectionGuard { |
121 | | cm: ConnectionManager, |
122 | | conn: OutboundConnection, |
123 | | } |
124 | | |
125 | | impl Drop for OutboundConnectionGuard { |
126 | 0 | fn drop(&mut self) { |
127 | 0 | self.cm.release_outbound(&self.conn) |
128 | 0 | } |
129 | | } |
130 | | |
131 | | #[derive(Debug, Clone, Eq, Hash, Ord, PartialEq, PartialOrd, serde::Serialize)] |
132 | | #[serde(rename_all = "camelCase")] |
133 | | pub struct OutboundConnection { |
134 | | pub src: SocketAddr, |
135 | | pub original_dst: SocketAddr, |
136 | | pub actual_dst: SocketAddr, |
137 | | pub protocol: OutboundProtocol, |
138 | | } |
139 | | |
140 | | #[derive(Debug, Clone, Eq, Hash, Ord, PartialEq, PartialOrd, serde::Serialize)] |
141 | | #[serde(rename_all = "camelCase")] |
142 | | pub struct InboundConnectionDump { |
143 | | pub src: SocketAddr, |
144 | | pub original_dst: Option<String>, |
145 | | pub actual_dst: SocketAddr, |
146 | | pub protocol: InboundProtocol, |
147 | | } |
148 | | |
149 | | #[derive(Debug, Clone, Eq, PartialEq, Hash, serde::Serialize)] |
150 | | #[serde(rename_all = "camelCase")] |
151 | | pub struct InboundConnection { |
152 | | #[serde(flatten)] |
153 | | pub ctx: ProxyRbacContext, |
154 | | pub dest_service: Option<String>, |
155 | | } |
156 | | |
157 | | impl ConnectionManager { |
158 | 0 | pub fn track_outbound( |
159 | 0 | &self, |
160 | 0 | src: SocketAddr, |
161 | 0 | original_dst: SocketAddr, |
162 | 0 | actual_dst: SocketAddr, |
163 | 0 | protocol: OutboundProtocol, |
164 | 0 | ) -> OutboundConnectionGuard { |
165 | 0 | let c = OutboundConnection { |
166 | 0 | src, |
167 | 0 | original_dst, |
168 | 0 | actual_dst, |
169 | 0 | protocol, |
170 | 0 | }; |
171 | | |
172 | 0 | self.outbound_connections |
173 | 0 | .write() |
174 | 0 | .expect("mutex") |
175 | 0 | .insert(c.clone()); |
176 | | |
177 | 0 | OutboundConnectionGuard { |
178 | 0 | cm: self.clone(), |
179 | 0 | conn: c, |
180 | 0 | } |
181 | 0 | } |
182 | | |
183 | 0 | pub async fn assert_rbac( |
184 | 0 | &self, |
185 | 0 | state: &DemandProxyState, |
186 | 0 | ctx: &ProxyRbacContext, |
187 | 0 | dest_service: Option<String>, |
188 | 0 | ) -> Result<ConnectionGuard, Error> { |
189 | | // Register before our initial assert. This prevents a race if policy changes between assert() and |
190 | | // track() |
191 | 0 | let conn = InboundConnection { |
192 | 0 | ctx: ctx.clone(), |
193 | 0 | dest_service, |
194 | 0 | }; |
195 | 0 | let Some(watch) = self.register(&conn) else { |
196 | 0 | warn!("failed to track {conn:?}"); |
197 | 0 | debug_assert!(false, "failed to track {conn:?}"); |
198 | 0 | return Err(Error::ConnectionTrackingFailed); |
199 | | }; |
200 | 0 | if let Err(err) = state.assert_rbac(ctx).await { |
201 | 0 | self.release(&conn); |
202 | 0 | return Err(Error::AuthorizationPolicyRejection(err)); |
203 | 0 | } |
204 | 0 | Ok(ConnectionGuard { |
205 | 0 | cm: self.clone(), |
206 | 0 | conn, |
207 | 0 | watch: Some(watch), |
208 | 0 | }) |
209 | 0 | } |
210 | | // register a connection with the connection manager |
211 | | // this must be done before a connection can be tracked |
212 | | // allows policy to be asserted against the connection |
213 | | // even no tasks have a receiver channel yet |
214 | 0 | fn register(&self, c: &InboundConnection) -> Option<DrainWatcher> { |
215 | 0 | match self.drains.write().expect("mutex").entry(c.clone()) { |
216 | 0 | Entry::Occupied(mut cd) => { |
217 | 0 | cd.get_mut().count += 1; |
218 | 0 | let rx = cd.get().rx.clone(); |
219 | 0 | Some(rx) |
220 | | } |
221 | 0 | Entry::Vacant(entry) => { |
222 | 0 | let drain = ConnectionDrain::new(); |
223 | 0 | let rx = drain.rx.clone(); |
224 | 0 | entry.insert(drain); |
225 | 0 | Some(rx) |
226 | | } |
227 | | } |
228 | 0 | } |
229 | | |
230 | | // releases tracking on a connection |
231 | | // uses a counter to determine if there are other tracked connections or not so it may retain the tx/rx channels when necessary |
232 | 0 | pub fn release(&self, c: &InboundConnection) { |
233 | 0 | let mut drains = self.drains.write().expect("mutex"); |
234 | 0 | if let Some((k, mut v)) = drains.remove_entry(c) |
235 | 0 | && v.count > 1 |
236 | 0 | { |
237 | 0 | // something else is tracking this connection, decrement count but retain |
238 | 0 | v.count -= 1; |
239 | 0 | drains.insert(k, v); |
240 | 0 | } |
241 | 0 | } |
242 | | |
243 | 0 | fn release_outbound(&self, c: &OutboundConnection) { |
244 | 0 | self.outbound_connections.write().expect("mutex").remove(c); |
245 | 0 | } |
246 | | |
247 | | // signal all connections listening to this channel to take action (typically terminate traffic) |
248 | 0 | async fn close(&self, c: &InboundConnection) { |
249 | 0 | let drain = { self.drains.write().expect("mutex").remove(c) }; |
250 | 0 | if let Some(cd) = drain { |
251 | 0 | cd.drain().await; |
252 | | } else { |
253 | | // this is bad, possibly drain called twice |
254 | 0 | error!("requested drain on a Connection which wasn't initialized"); |
255 | | } |
256 | 0 | } |
257 | | |
258 | | // get a list of all connections being tracked |
259 | 0 | pub fn connections(&self) -> Vec<InboundConnection> { |
260 | | // potentially large copy under read lock, could require optimization |
261 | 0 | self.drains.read().expect("mutex").keys().cloned().collect() |
262 | 0 | } |
263 | | } |
264 | | |
265 | | #[derive(serde::Serialize)] |
266 | | struct ConnectionManagerDump { |
267 | | inbound: Vec<InboundConnectionDump>, |
268 | | outbound: Vec<OutboundConnection>, |
269 | | } |
270 | | |
271 | | impl Serialize for ConnectionManager { |
272 | 0 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
273 | 0 | where |
274 | 0 | S: Serializer, |
275 | | { |
276 | 0 | let inbound: Vec<_> = self |
277 | 0 | .drains |
278 | 0 | .read() |
279 | 0 | .expect("mutex") |
280 | 0 | .keys() |
281 | 0 | .cloned() |
282 | 0 | .map(|c| InboundConnectionDump { |
283 | 0 | src: c.ctx.conn.src, |
284 | 0 | original_dst: c.dest_service, |
285 | 0 | actual_dst: c.ctx.conn.dst, |
286 | 0 | protocol: if c.ctx.conn.src_identity.is_some() { |
287 | 0 | InboundProtocol::HBONE |
288 | | } else { |
289 | 0 | InboundProtocol::TCP |
290 | | }, |
291 | 0 | }) |
292 | 0 | .collect(); |
293 | 0 | let outbound: Vec<_> = self |
294 | 0 | .outbound_connections |
295 | 0 | .read() |
296 | 0 | .expect("mutex") |
297 | 0 | .iter() |
298 | 0 | .cloned() |
299 | 0 | .collect(); |
300 | 0 | let dump = ConnectionManagerDump { inbound, outbound }; |
301 | 0 | dump.serialize(serializer) |
302 | 0 | } |
303 | | } |
304 | | |
305 | | pub struct PolicyWatcher { |
306 | | state: DemandProxyState, |
307 | | stop: DrainWatcher, |
308 | | connection_manager: ConnectionManager, |
309 | | } |
310 | | |
311 | | impl PolicyWatcher { |
312 | 0 | pub fn new( |
313 | 0 | state: DemandProxyState, |
314 | 0 | stop: DrainWatcher, |
315 | 0 | connection_manager: ConnectionManager, |
316 | 0 | ) -> Self { |
317 | 0 | PolicyWatcher { |
318 | 0 | state, |
319 | 0 | stop, |
320 | 0 | connection_manager, |
321 | 0 | } |
322 | 0 | } |
323 | | |
324 | 0 | pub async fn run(self) { |
325 | 0 | let mut policies_changed = self.state.read().policies.subscribe(); |
326 | | loop { |
327 | 0 | tokio::select! { |
328 | 0 | _ = self.stop.clone().wait_for_drain() => { |
329 | 0 | break; |
330 | | } |
331 | 0 | _ = policies_changed.changed() => { |
332 | 0 | let connections = self.connection_manager.connections(); |
333 | 0 | for conn in connections { |
334 | 0 | if self.state.assert_rbac(&conn.ctx).await.is_err() { |
335 | 0 | self.connection_manager.close(&conn).await; |
336 | 0 | info!("connection {} closed because it's no longer allowed after a policy update", conn.ctx); |
337 | 0 | } |
338 | | } |
339 | | } |
340 | | } |
341 | | } |
342 | 0 | } |
343 | | } |
344 | | |
345 | | #[cfg(test)] |
346 | | mod tests { |
347 | | use crate::drain; |
348 | | use crate::drain::DrainWatcher; |
349 | | use hickory_resolver::config::{ResolverConfig, ResolverOpts}; |
350 | | use prometheus_client::registry::Registry; |
351 | | use std::net::{Ipv4Addr, SocketAddrV4}; |
352 | | use std::sync::{Arc, RwLock}; |
353 | | use std::time::Duration; |
354 | | |
355 | | use crate::rbac::Connection; |
356 | | use crate::state::{DemandProxyState, ProxyState}; |
357 | | use crate::test_helpers::test_default_workload; |
358 | | use crate::xds::ProxyStateUpdateMutator; |
359 | | use crate::xds::istio::security::{Action, Authorization, Scope}; |
360 | | |
361 | | use super::{ConnectionGuard, ConnectionManager, InboundConnection, PolicyWatcher}; |
362 | | |
363 | | #[tokio::test] |
364 | | async fn test_connection_manager_close() { |
365 | | // setup a new ConnectionManager |
366 | | let cm = ConnectionManager::default(); |
367 | | // ensure drains is empty |
368 | | assert_eq!(cm.drains.read().unwrap().len(), 0); |
369 | | assert_eq!(cm.connections().len(), 0); |
370 | | |
371 | | let register = |cm: &ConnectionManager, c: &InboundConnection| { |
372 | | let cm = cm.clone(); |
373 | | let c = c.clone(); |
374 | | |
375 | | let watch = cm.register(&c).unwrap(); |
376 | | ConnectionGuard { |
377 | | cm, |
378 | | conn: c, |
379 | | watch: Some(watch), |
380 | | } |
381 | | }; |
382 | | |
383 | | // track a new connection |
384 | | let rbac_ctx1 = InboundConnection { |
385 | | ctx: crate::state::ProxyRbacContext { |
386 | | conn: Connection { |
387 | | src_identity: None, |
388 | | src: std::net::SocketAddr::new( |
389 | | std::net::Ipv4Addr::new(192, 168, 0, 1).into(), |
390 | | 80, |
391 | | ), |
392 | | dst_network: "".into(), |
393 | | dst: std::net::SocketAddr::V4(SocketAddrV4::new( |
394 | | Ipv4Addr::new(192, 168, 0, 2), |
395 | | 8080, |
396 | | )), |
397 | | }, |
398 | | dest_workload: Arc::new(test_default_workload()), |
399 | | }, |
400 | | dest_service: None, |
401 | | }; |
402 | | |
403 | | // ensure drains contains exactly 1 item |
404 | | let mut close1 = register(&cm, &rbac_ctx1); |
405 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
406 | | assert_eq!(cm.connections().len(), 1); |
407 | | assert_eq!(cm.connections(), vec!(rbac_ctx1.clone())); |
408 | | |
409 | | // setup a second track on the same connection |
410 | | let mut another_close1 = register(&cm, &rbac_ctx1); |
411 | | |
412 | | // ensure drains contains exactly 1 item |
413 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
414 | | assert_eq!(cm.connections().len(), 1); |
415 | | assert_eq!(cm.connections(), vec!(rbac_ctx1.clone())); |
416 | | |
417 | | // track a second connection |
418 | | let rbac_ctx2 = InboundConnection { |
419 | | ctx: crate::state::ProxyRbacContext { |
420 | | conn: Connection { |
421 | | src_identity: None, |
422 | | src: std::net::SocketAddr::new( |
423 | | std::net::Ipv4Addr::new(192, 168, 0, 3).into(), |
424 | | 80, |
425 | | ), |
426 | | dst_network: "".into(), |
427 | | dst: std::net::SocketAddr::V4(SocketAddrV4::new( |
428 | | Ipv4Addr::new(192, 168, 0, 2), |
429 | | 8080, |
430 | | )), |
431 | | }, |
432 | | dest_workload: Arc::new(test_default_workload()), |
433 | | }, |
434 | | dest_service: None, |
435 | | }; |
436 | | |
437 | | let mut close2 = register(&cm, &rbac_ctx2); |
438 | | // ensure drains contains exactly 2 items |
439 | | assert_eq!(cm.drains.read().unwrap().len(), 2); |
440 | | assert_eq!(cm.connections().len(), 2); |
441 | | let mut connections = cm.connections(); |
442 | | // ordering cannot be guaranteed without sorting |
443 | | connections.sort_by(|a, b| a.ctx.conn.cmp(&b.ctx.conn)); |
444 | | assert_eq!(connections, vec![rbac_ctx1.clone(), rbac_ctx2.clone()]); |
445 | | |
446 | | // spawn tasks to assert that we close in a timely manner for rbac_ctx1 |
447 | | tokio::spawn(assert_close(close1.watch.take().unwrap())); |
448 | | tokio::spawn(assert_close(another_close1.watch.take().unwrap())); |
449 | | // close rbac_ctx1 |
450 | | cm.close(&rbac_ctx1).await; |
451 | | // ensure drains contains exactly 1 item |
452 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
453 | | assert_eq!(cm.connections().len(), 1); |
454 | | assert_eq!(cm.connections(), vec!(rbac_ctx2.clone())); |
455 | | |
456 | | // spawn a task to assert that we close in a timely manner for rbac_ctx2 |
457 | | tokio::spawn(assert_close(close2.watch.take().unwrap())); |
458 | | // close rbac_ctx2 |
459 | | cm.close(&rbac_ctx2).await; |
460 | | // assert that drains is empty again |
461 | | assert_eq!(cm.drains.read().unwrap().len(), 0); |
462 | | assert_eq!(cm.connections().len(), 0); |
463 | | } |
464 | | |
465 | | #[tokio::test] |
466 | | async fn test_connection_manager_release() { |
467 | | // setup a new ConnectionManager |
468 | | let cm = ConnectionManager::default(); |
469 | | // ensure drains is empty |
470 | | assert_eq!(cm.drains.read().unwrap().len(), 0); |
471 | | assert_eq!(cm.connections().len(), 0); |
472 | | |
473 | | let register = |cm: &ConnectionManager, c: &InboundConnection| { |
474 | | let cm = cm.clone(); |
475 | | let c = c.clone(); |
476 | | |
477 | | let watch = cm.register(&c).unwrap(); |
478 | | ConnectionGuard { |
479 | | cm, |
480 | | conn: c, |
481 | | watch: Some(watch), |
482 | | } |
483 | | }; |
484 | | |
485 | | // create a new connection |
486 | | let conn1 = InboundConnection { |
487 | | ctx: crate::state::ProxyRbacContext { |
488 | | conn: Connection { |
489 | | src_identity: None, |
490 | | src: std::net::SocketAddr::new( |
491 | | std::net::Ipv4Addr::new(192, 168, 0, 1).into(), |
492 | | 80, |
493 | | ), |
494 | | dst_network: "".into(), |
495 | | dst: std::net::SocketAddr::V4(SocketAddrV4::new( |
496 | | Ipv4Addr::new(192, 168, 0, 2), |
497 | | 8080, |
498 | | )), |
499 | | }, |
500 | | dest_workload: Arc::new(test_default_workload()), |
501 | | }, |
502 | | dest_service: None, |
503 | | }; |
504 | | |
505 | | // create a second connection |
506 | | let conn2 = InboundConnection { |
507 | | ctx: crate::state::ProxyRbacContext { |
508 | | conn: Connection { |
509 | | src_identity: None, |
510 | | src: std::net::SocketAddr::new( |
511 | | std::net::Ipv4Addr::new(192, 168, 0, 3).into(), |
512 | | 80, |
513 | | ), |
514 | | dst_network: "".into(), |
515 | | dst: std::net::SocketAddr::V4(SocketAddrV4::new( |
516 | | Ipv4Addr::new(192, 168, 0, 2), |
517 | | 8080, |
518 | | )), |
519 | | }, |
520 | | dest_workload: Arc::new(test_default_workload()), |
521 | | }, |
522 | | dest_service: None, |
523 | | }; |
524 | | let another_conn1 = conn1.clone(); |
525 | | |
526 | | let close1 = register(&cm, &conn1); |
527 | | let another_close1 = register(&cm, &another_conn1); |
528 | | |
529 | | // ensure drains contains exactly 1 item |
530 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
531 | | assert_eq!(cm.connections().len(), 1); |
532 | | assert_eq!(cm.connections(), vec!(conn1.clone())); |
533 | | |
534 | | // release conn1's clone |
535 | | drop(another_close1); |
536 | | // ensure drains still contains exactly 1 item |
537 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
538 | | assert_eq!(cm.connections().len(), 1); |
539 | | assert_eq!(cm.connections(), vec!(conn1.clone())); |
540 | | |
541 | | let close2 = register(&cm, &conn2); |
542 | | // ensure drains contains exactly 2 items |
543 | | assert_eq!(cm.drains.read().unwrap().len(), 2); |
544 | | assert_eq!(cm.connections().len(), 2); |
545 | | let mut connections = cm.connections(); |
546 | | // ordering cannot be guaranteed without sorting |
547 | | connections.sort_by(|a, b| a.ctx.conn.cmp(&b.ctx.conn)); |
548 | | assert_eq!(connections, vec![conn1.clone(), conn2.clone()]); |
549 | | |
550 | | // release conn1 |
551 | | drop(close1); |
552 | | // ensure drains contains exactly 1 item |
553 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
554 | | assert_eq!(cm.connections().len(), 1); |
555 | | assert_eq!(cm.connections(), vec!(conn2.clone())); |
556 | | |
557 | | // clone conn2 and track it |
558 | | let another_conn2 = conn2.clone(); |
559 | | let another_close2 = register(&cm, &another_conn2); |
560 | | // release tracking on conn2 |
561 | | drop(close2); |
562 | | // ensure drains still contains exactly 1 item |
563 | | assert_eq!(cm.drains.read().unwrap().len(), 1); |
564 | | assert_eq!(cm.connections().len(), 1); |
565 | | assert_eq!(cm.connections(), vec!(another_conn2.clone())); |
566 | | |
567 | | // release tracking on conn2's clone |
568 | | drop(another_close2); |
569 | | // ensure drains contains exactly 0 items |
570 | | assert_eq!(cm.drains.read().unwrap().len(), 0); |
571 | | assert_eq!(cm.connections().len(), 0); |
572 | | } |
573 | | |
574 | | #[tokio::test] |
575 | | async fn test_policy_watcher_lifecycle() { |
576 | | // preamble: setup an environment |
577 | | let state = Arc::new(RwLock::new(ProxyState::new(None))); |
578 | | let mut registry = Registry::default(); |
579 | | let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry)); |
580 | | let dstate = DemandProxyState::new( |
581 | | state.clone(), |
582 | | None, |
583 | | ResolverConfig::default(), |
584 | | ResolverOpts::default(), |
585 | | metrics, |
586 | | ); |
587 | | let connection_manager = ConnectionManager::default(); |
588 | | let (tx, stop) = drain::new(); |
589 | | let state_mutator = ProxyStateUpdateMutator::new_no_fetch(); |
590 | | |
591 | | // clones to move into spawned task |
592 | | let ds = dstate.clone(); |
593 | | let cm = connection_manager.clone(); |
594 | | let pw = PolicyWatcher::new(ds, stop, cm); |
595 | | // spawn a task which watches policy and asserts that the policy watcher stop correctly |
596 | | tokio::spawn(async move { |
597 | | let res = tokio::time::timeout(Duration::from_secs(1), pw.run()).await; |
598 | | assert!(res.is_ok()) |
599 | | }); |
600 | | |
601 | | // create a test connection |
602 | | let conn1 = InboundConnection { |
603 | | ctx: crate::state::ProxyRbacContext { |
604 | | conn: Connection { |
605 | | src_identity: None, |
606 | | src: std::net::SocketAddr::new( |
607 | | std::net::Ipv4Addr::new(192, 168, 0, 1).into(), |
608 | | 80, |
609 | | ), |
610 | | dst_network: "".into(), |
611 | | dst: std::net::SocketAddr::V4(SocketAddrV4::new( |
612 | | Ipv4Addr::new(192, 168, 0, 2), |
613 | | 8080, |
614 | | )), |
615 | | }, |
616 | | dest_workload: Arc::new(test_default_workload()), |
617 | | }, |
618 | | dest_service: None, |
619 | | }; |
620 | | // watch the connection |
621 | | let close1 = connection_manager |
622 | | .register(&conn1) |
623 | | .expect("should not be None"); |
624 | | |
625 | | // generate policy which denies everything |
626 | | let auth_name = "allow-nothing"; |
627 | | let auth_namespace = "default"; |
628 | | let auth = Authorization { |
629 | | name: auth_name.into(), |
630 | | action: Action::Deny as i32, |
631 | | scope: Scope::Global as i32, |
632 | | namespace: auth_namespace.into(), |
633 | | rules: vec![], |
634 | | dry_run: false, |
635 | | }; |
636 | | let mut auth_xds_name = String::with_capacity(1 + auth_namespace.len() + auth_name.len()); |
637 | | auth_xds_name.push_str(auth_namespace); |
638 | | auth_xds_name.push('/'); |
639 | | auth_xds_name.push_str(auth_name); |
640 | | |
641 | | // spawn an assertion that our connection close is received |
642 | | tokio::spawn(assert_close(close1)); |
643 | | |
644 | | // this block will scope our guard appropriately |
645 | | { |
646 | | // update our state |
647 | | let mut s = state |
648 | | .write() |
649 | | .expect("test fails if we're unable to get a write lock on state"); |
650 | | let res = |
651 | | state_mutator.insert_authorization(&mut s, auth_xds_name.clone().into(), auth); |
652 | | // assert that the update was OK |
653 | | assert!(res.is_ok()); |
654 | | } // release lock |
655 | | |
656 | | // send the signal which stops policy watcher |
657 | | tx.start_drain_and_wait(drain::DrainMode::Immediate).await; |
658 | | } |
659 | | |
660 | | // small helper to assert that the Watches are working in a timely manner |
661 | | async fn assert_close(c: DrainWatcher) { |
662 | | let result = tokio::time::timeout(Duration::from_secs(1), c.wait_for_drain()).await; |
663 | | assert!(result.is_ok()) |
664 | | } |
665 | | } |