/src/ztunnel/src/proxy.rs
Line | Count | Source |
1 | | // Copyright Istio Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::fmt::Debug; |
16 | | use std::fs::File; |
17 | | use std::io::Read; |
18 | | use std::net::{IpAddr, SocketAddr}; |
19 | | use std::sync::Arc; |
20 | | use std::time::Duration; |
21 | | use std::{fmt, io}; |
22 | | |
23 | | use hickory_proto::ProtoError; |
24 | | |
25 | | use crate::strng::Strng; |
26 | | use rand::Rng; |
27 | | use socket2::TcpKeepalive; |
28 | | use tokio::net::{TcpListener, TcpSocket, TcpStream}; |
29 | | use tokio::time::timeout; |
30 | | use tracing::{Instrument, debug, trace, warn}; |
31 | | |
32 | | use inbound::Inbound; |
33 | | pub use metrics::*; |
34 | | |
35 | | use crate::identity::{Identity, SecretManager}; |
36 | | |
37 | | use crate::dns::resolver::Resolver; |
38 | | use crate::drain::DrainWatcher; |
39 | | use crate::proxy::connection_manager::{ConnectionManager, PolicyWatcher}; |
40 | | use crate::proxy::inbound_passthrough::InboundPassthrough; |
41 | | use crate::proxy::outbound::Outbound; |
42 | | use crate::proxy::socks5::Socks5; |
43 | | use crate::rbac::Connection; |
44 | | use crate::state::service::{Service, ServiceDescription}; |
45 | | use crate::state::workload::address::Address; |
46 | | use crate::state::workload::{GatewayAddress, Workload}; |
47 | | use crate::state::{DemandProxyState, WorkloadInfo}; |
48 | | use crate::{config, identity, socket, tls}; |
49 | | |
50 | | pub mod connection_manager; |
51 | | pub mod inbound; |
52 | | |
53 | | mod h2; |
54 | | mod inbound_passthrough; |
55 | | #[allow(non_camel_case_types)] |
56 | | pub mod metrics; |
57 | | mod outbound; |
58 | | pub mod pool; |
59 | | mod socks5; |
60 | | pub mod util; |
61 | | |
62 | | pub trait SocketFactory { |
63 | | fn new_tcp_v4(&self) -> std::io::Result<TcpSocket>; |
64 | | |
65 | | fn new_tcp_v6(&self) -> std::io::Result<TcpSocket>; |
66 | | |
67 | | fn tcp_bind(&self, addr: SocketAddr) -> std::io::Result<socket::Listener>; |
68 | | |
69 | | fn udp_bind(&self, addr: SocketAddr) -> std::io::Result<tokio::net::UdpSocket>; |
70 | | |
71 | | fn ipv6_enabled_localhost(&self) -> std::io::Result<bool>; |
72 | | } |
73 | | |
74 | | #[derive(Clone, Copy, Default)] |
75 | | pub struct DefaultSocketFactory(pub config::SocketConfig); |
76 | | |
77 | | impl SocketFactory for DefaultSocketFactory { |
78 | 0 | fn new_tcp_v4(&self) -> std::io::Result<TcpSocket> { |
79 | 0 | TcpSocket::new_v4().and_then(|s| { |
80 | 0 | self.setup_socket(&s)?; |
81 | 0 | Ok(s) |
82 | 0 | }) |
83 | 0 | } |
84 | | |
85 | 0 | fn new_tcp_v6(&self) -> std::io::Result<TcpSocket> { |
86 | 0 | TcpSocket::new_v6().and_then(|s| { |
87 | 0 | self.setup_socket(&s)?; |
88 | 0 | Ok(s) |
89 | 0 | }) |
90 | 0 | } |
91 | | |
92 | 0 | fn tcp_bind(&self, addr: SocketAddr) -> std::io::Result<socket::Listener> { |
93 | 0 | let std_sock = std::net::TcpListener::bind(addr)?; |
94 | 0 | std_sock.set_nonblocking(true)?; |
95 | 0 | TcpListener::from_std(std_sock).map(socket::Listener::new) |
96 | 0 | } |
97 | | |
98 | 0 | fn udp_bind(&self, addr: SocketAddr) -> std::io::Result<tokio::net::UdpSocket> { |
99 | 0 | let std_sock = std::net::UdpSocket::bind(addr)?; |
100 | 0 | std_sock.set_nonblocking(true)?; |
101 | 0 | tokio::net::UdpSocket::from_std(std_sock) |
102 | 0 | } |
103 | | |
104 | 0 | fn ipv6_enabled_localhost(&self) -> io::Result<bool> { |
105 | 0 | ipv6_enabled_on_localhost() |
106 | 0 | } |
107 | | } |
108 | | |
109 | | impl DefaultSocketFactory { |
110 | 0 | fn setup_socket(&self, s: &TcpSocket) -> io::Result<()> { |
111 | 0 | s.set_nodelay(true)?; |
112 | 0 | let cfg = self.0; |
113 | 0 | if cfg.keepalive_enabled { |
114 | 0 | let ka = TcpKeepalive::new() |
115 | 0 | .with_time(cfg.keepalive_time) |
116 | 0 | .with_retries(cfg.keepalive_retries) |
117 | 0 | .with_interval(cfg.keepalive_interval); |
118 | 0 | let res = socket2::SockRef::from(&s).set_tcp_keepalive(&ka); |
119 | 0 | tracing::trace!("set keepalive: {:?}", res); |
120 | 0 | } |
121 | 0 | if cfg.user_timeout_enabled { |
122 | | // https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/ |
123 | | // TCP_USER_TIMEOUT = TCP_KEEPIDLE + TCP_KEEPINTVL * TCP_KEEPCNT. |
124 | 0 | let ut = cfg.keepalive_time + cfg.keepalive_retries * cfg.keepalive_interval; |
125 | 0 | let res = socket2::SockRef::from(&s).set_tcp_user_timeout(Some(ut)); |
126 | 0 | tracing::trace!("set user timeout: {:?}", res); |
127 | 0 | } |
128 | 0 | Ok(()) |
129 | 0 | } |
130 | | } |
131 | | |
132 | | pub struct MarkSocketFactory { |
133 | | pub inner: DefaultSocketFactory, |
134 | | pub mark: u32, |
135 | | } |
136 | | |
137 | | impl SocketFactory for MarkSocketFactory { |
138 | 0 | fn new_tcp_v4(&self) -> io::Result<TcpSocket> { |
139 | 0 | self.inner.new_tcp_v4().and_then(|s| { |
140 | 0 | socket::set_mark(&s, self.mark)?; |
141 | 0 | Ok(s) |
142 | 0 | }) |
143 | 0 | } |
144 | | |
145 | 0 | fn new_tcp_v6(&self) -> io::Result<TcpSocket> { |
146 | 0 | self.inner.new_tcp_v6().and_then(|s| { |
147 | 0 | socket::set_mark(&s, self.mark)?; |
148 | 0 | Ok(s) |
149 | 0 | }) |
150 | 0 | } |
151 | | |
152 | 0 | fn tcp_bind(&self, addr: SocketAddr) -> io::Result<socket::Listener> { |
153 | 0 | self.inner.tcp_bind(addr) |
154 | 0 | } |
155 | | |
156 | 0 | fn udp_bind(&self, addr: SocketAddr) -> io::Result<tokio::net::UdpSocket> { |
157 | 0 | self.inner.udp_bind(addr) |
158 | 0 | } |
159 | | |
160 | 0 | fn ipv6_enabled_localhost(&self) -> io::Result<bool> { |
161 | 0 | self.inner.ipv6_enabled_localhost() |
162 | 0 | } |
163 | | } |
164 | | |
165 | | pub struct Proxy { |
166 | | inbound: Inbound, |
167 | | inbound_passthrough: InboundPassthrough, |
168 | | outbound: Outbound, |
169 | | socks5: Option<Socks5>, |
170 | | policy_watcher: PolicyWatcher, |
171 | | } |
172 | | |
173 | | pub struct LocalWorkloadInformation { |
174 | | wi: Arc<WorkloadInfo>, |
175 | | state: DemandProxyState, |
176 | | // full_cert_manager gives access to the full SecretManager. This MUST only be given restricted |
177 | | // access to the appropriate certificates |
178 | | full_cert_manager: Arc<SecretManager>, |
179 | | } |
180 | | |
181 | | impl LocalWorkloadInformation { |
182 | 0 | pub fn new( |
183 | 0 | wi: Arc<WorkloadInfo>, |
184 | 0 | state: DemandProxyState, |
185 | 0 | cert_manager: Arc<SecretManager>, |
186 | 0 | ) -> LocalWorkloadInformation { |
187 | 0 | LocalWorkloadInformation { |
188 | 0 | wi, |
189 | 0 | state, |
190 | 0 | full_cert_manager: cert_manager, |
191 | 0 | } |
192 | 0 | } |
193 | | |
194 | 0 | pub async fn get_workload(&self) -> Result<Arc<Workload>, Error> { |
195 | 0 | get_workload(&self.state, self.wi.clone()).await |
196 | 0 | } |
197 | | |
198 | 0 | pub async fn fetch_certificate( |
199 | 0 | &self, |
200 | 0 | ) -> Result<Arc<tls::WorkloadCertificate>, identity::Error> { |
201 | | // We don't know the trust domain until we get the workload from XDS, so fetch that |
202 | 0 | let wl = self |
203 | 0 | .get_workload() |
204 | 0 | .await |
205 | 0 | .map_err(|_| identity::Error::UnknownWorkload(self.workload_info()))?; |
206 | 0 | let id = &Identity::Spiffe { |
207 | 0 | trust_domain: wl.trust_domain.clone(), |
208 | 0 | namespace: (&self.wi.namespace).into(), |
209 | 0 | service_account: (&self.wi.service_account).into(), |
210 | 0 | }; |
211 | 0 | self.full_cert_manager.fetch_certificate(id).await |
212 | 0 | } |
213 | | |
214 | 0 | pub fn workload_info(&self) -> Arc<WorkloadInfo> { |
215 | 0 | self.wi.clone() |
216 | 0 | } |
217 | | |
218 | 0 | pub fn as_fetcher(&self) -> Arc<LocalWorkloadFetcher> { |
219 | 0 | LocalWorkloadFetcher::new(self.wi.clone(), self.state.clone()) |
220 | 0 | } |
221 | | } |
222 | | |
223 | | /// LocalWorkloadFetcher is essentially LocalWorkloadInformation without CA access. |
224 | | /// This is used to down-scope the LocalWorkloadInformation for components who should not have access |
225 | | /// to certificates. |
226 | | pub struct LocalWorkloadFetcher { |
227 | | wi: Arc<WorkloadInfo>, |
228 | | state: DemandProxyState, |
229 | | } |
230 | | |
231 | | impl LocalWorkloadFetcher { |
232 | 0 | pub fn new(wi: Arc<WorkloadInfo>, state: DemandProxyState) -> Arc<Self> { |
233 | 0 | Arc::new(LocalWorkloadFetcher { wi, state }) |
234 | 0 | } |
235 | 0 | pub async fn get_workload(&self) -> Result<Arc<Workload>, Error> { |
236 | 0 | get_workload(&self.state, self.wi.clone()).await |
237 | 0 | } |
238 | | } |
239 | | |
240 | 0 | async fn get_workload( |
241 | 0 | state: &DemandProxyState, |
242 | 0 | wi: Arc<WorkloadInfo>, |
243 | 0 | ) -> Result<Arc<Workload>, Error> { |
244 | 0 | state |
245 | 0 | .wait_for_workload(&wi, Duration::from_secs(5)) |
246 | 0 | .await |
247 | 0 | .ok_or_else(|| Error::UnknownSourceWorkload(wi.clone())) |
248 | 0 | } |
249 | | |
250 | | #[derive(Clone)] |
251 | | pub(super) struct ProxyInputs { |
252 | | cfg: Arc<config::Config>, |
253 | | connection_manager: ConnectionManager, |
254 | | pub state: DemandProxyState, |
255 | | metrics: Arc<Metrics>, |
256 | | socket_factory: Arc<dyn SocketFactory + Send + Sync>, |
257 | | local_workload_information: Arc<LocalWorkloadInformation>, |
258 | | resolver: Option<Arc<dyn Resolver + Send + Sync>>, |
259 | | // If true, inbound connections created with these inputs will not attempt to preserve the original source IP. |
260 | | pub disable_inbound_freebind: bool, |
261 | | } |
262 | | |
263 | | #[allow(clippy::too_many_arguments)] |
264 | | impl ProxyInputs { |
265 | 0 | pub fn new( |
266 | 0 | cfg: Arc<config::Config>, |
267 | 0 | connection_manager: ConnectionManager, |
268 | 0 | state: DemandProxyState, |
269 | 0 | metrics: Arc<Metrics>, |
270 | 0 | socket_factory: Arc<dyn SocketFactory + Send + Sync>, |
271 | 0 | resolver: Option<Arc<dyn Resolver + Send + Sync>>, |
272 | 0 | local_workload_information: Arc<LocalWorkloadInformation>, |
273 | 0 | disable_inbound_freebind: bool, |
274 | 0 | ) -> Arc<Self> { |
275 | 0 | Arc::new(Self { |
276 | 0 | cfg, |
277 | 0 | state, |
278 | 0 | metrics, |
279 | 0 | connection_manager, |
280 | 0 | socket_factory, |
281 | 0 | local_workload_information, |
282 | 0 | resolver, |
283 | 0 | disable_inbound_freebind, |
284 | 0 | }) |
285 | 0 | } |
286 | | } |
287 | | |
288 | | impl Proxy { |
289 | | #[allow(unused_mut)] |
290 | 0 | pub(super) async fn from_inputs( |
291 | 0 | mut pi: Arc<ProxyInputs>, |
292 | 0 | drain: DrainWatcher, |
293 | 0 | ) -> Result<Self, Error> { |
294 | | // We setup all the listeners first so we can capture any errors that should block startup |
295 | 0 | let inbound = Inbound::new(pi.clone(), drain.clone()).await?; |
296 | | |
297 | | // This exists for `direct` integ tests, no other reason |
298 | | #[cfg(any(test, feature = "testing"))] |
299 | | if pi.cfg.fake_self_inbound { |
300 | | warn!("TEST FAKE - overriding inbound address for test"); |
301 | | let mut old_cfg = (*pi.cfg).clone(); |
302 | | old_cfg.inbound_addr = inbound.address(); |
303 | | let mut new_pi = (*pi).clone(); |
304 | | new_pi.cfg = Arc::new(old_cfg); |
305 | | pi = Arc::new(new_pi); |
306 | | warn!("TEST FAKE: new address is {:?}", pi.cfg.inbound_addr); |
307 | | } |
308 | | |
309 | 0 | let inbound_passthrough = InboundPassthrough::new(pi.clone(), drain.clone()).await?; |
310 | 0 | let outbound = Outbound::new(pi.clone(), drain.clone()).await?; |
311 | 0 | let socks5 = if pi.cfg.socks5_addr.is_some() { |
312 | 0 | let socks5 = Socks5::new(pi.clone(), drain.clone()).await?; |
313 | 0 | Some(socks5) |
314 | | } else { |
315 | 0 | None |
316 | | }; |
317 | 0 | let policy_watcher = |
318 | 0 | PolicyWatcher::new(pi.state.clone(), drain, pi.connection_manager.clone()); |
319 | | |
320 | 0 | Ok(Proxy { |
321 | 0 | inbound, |
322 | 0 | inbound_passthrough, |
323 | 0 | outbound, |
324 | 0 | socks5, |
325 | 0 | policy_watcher, |
326 | 0 | }) |
327 | 0 | } |
328 | | |
329 | 0 | pub async fn run(self) { |
330 | 0 | let mut tasks = vec![ |
331 | 0 | tokio::spawn(self.inbound_passthrough.run().in_current_span()), |
332 | 0 | tokio::spawn(self.policy_watcher.run().in_current_span()), |
333 | 0 | tokio::spawn(self.inbound.run().in_current_span()), |
334 | 0 | tokio::spawn(self.outbound.run().in_current_span()), |
335 | | ]; |
336 | | |
337 | 0 | if let Some(socks5) = self.socks5 { |
338 | 0 | tasks.push(tokio::spawn(socks5.run().in_current_span())); |
339 | 0 | }; |
340 | | |
341 | 0 | futures::future::join_all(tasks).await; |
342 | 0 | } |
343 | | |
344 | 0 | pub fn addresses(&self) -> Addresses { |
345 | | Addresses { |
346 | 0 | outbound: self.outbound.address(), |
347 | 0 | inbound: self.inbound.address(), |
348 | 0 | socks5: self.socks5.as_ref().map(|s| s.address()), |
349 | | } |
350 | 0 | } |
351 | | } |
352 | | |
353 | | #[derive(Copy, Clone)] |
354 | | pub struct Addresses { |
355 | | pub outbound: SocketAddr, |
356 | | pub inbound: SocketAddr, |
357 | | pub socks5: Option<SocketAddr>, |
358 | | } |
359 | | |
360 | | #[derive(Debug, PartialEq, Eq)] |
361 | | pub enum AuthorizationRejectionError { |
362 | | NoWorkload, |
363 | | WorkloadMismatch, |
364 | | ExplicitlyDenied(Strng, Strng), |
365 | | NotAllowed, |
366 | | } |
367 | | impl fmt::Display for AuthorizationRejectionError { |
368 | 0 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
369 | 0 | match self { |
370 | 0 | Self::NoWorkload => write!(fmt, "workload not found"), |
371 | 0 | Self::WorkloadMismatch => write!(fmt, "workload mismatch"), |
372 | 0 | Self::ExplicitlyDenied(a, b) => write!(fmt, "explicitly denied by: {a}/{b}"), |
373 | 0 | Self::NotAllowed => write!(fmt, "allow policies exist, but none allowed"), |
374 | | } |
375 | 0 | } |
376 | | } |
377 | | |
378 | | #[derive(thiserror::Error, Debug)] |
379 | | pub enum Error { |
380 | | #[error("failed to bind to address {0}: {1}")] |
381 | | Bind(SocketAddr, io::Error), |
382 | | |
383 | | #[error("io error: {0}")] |
384 | | Io(#[from] io::Error), |
385 | | |
386 | | #[error("while closing connection: {0}")] |
387 | | ShutdownError(Box<Error>), |
388 | | |
389 | | #[error("connection timed out, maybe a NetworkPolicy is blocking HBONE port 15008: {0}")] |
390 | | MaybeHBONENetworkPolicyError(io::Error), |
391 | | |
392 | | #[error("destination disconnected before all data was written")] |
393 | | BackendDisconnected, |
394 | | #[error("receive: {0}")] |
395 | | ReceiveError(Box<Error>), |
396 | | |
397 | | #[error("client disconnected before all data was written")] |
398 | | ClientDisconnected, |
399 | | #[error("send: {0}")] |
400 | | SendError(Box<Error>), |
401 | | |
402 | | #[error("connection failed: {0}")] |
403 | | ConnectionFailed(io::Error), |
404 | | |
405 | | #[error("connection tracking failed")] |
406 | | ConnectionTrackingFailed, |
407 | | |
408 | | #[error("connection closed due to policy change")] |
409 | | AuthorizationPolicyLateRejection, |
410 | | |
411 | | #[error("connection closed due to policy rejection: {0}")] |
412 | | AuthorizationPolicyRejection(AuthorizationRejectionError), |
413 | | |
414 | | #[error("pool draining")] |
415 | | WorkloadHBONEPoolDraining, |
416 | | |
417 | | #[error("{0}")] |
418 | | Generic(Box<dyn std::error::Error + Send + Sync>), |
419 | | |
420 | | #[error("{0}")] |
421 | | Anyhow(anyhow::Error), |
422 | | |
423 | | #[error("http2 handshake failed: {0}")] |
424 | | Http2Handshake(#[source] ::h2::Error), |
425 | | |
426 | | #[error("h2 failed: {0}")] |
427 | | H2(#[from] ::h2::Error), |
428 | | |
429 | | #[error("http status: {0}")] |
430 | | HttpStatus(http::StatusCode), |
431 | | |
432 | | #[error("expected method CONNECT, got {0}")] |
433 | | NonConnectMethod(String), |
434 | | |
435 | | #[error("invalid CONNECT address {0}")] |
436 | | ConnectAddress(String), |
437 | | |
438 | | #[error("tls error: {0}")] |
439 | | Tls(#[from] tls::Error), |
440 | | |
441 | | #[error("identity error: {0}")] |
442 | | Identity(#[from] identity::Error), |
443 | | |
444 | | #[error("failed to fetch information about local workload: {0}")] |
445 | | UnknownSourceWorkload(Arc<WorkloadInfo>), |
446 | | |
447 | | #[error("unknown waypoint: {0}")] |
448 | | UnknownWaypoint(String), |
449 | | |
450 | | #[error("unknown network gateway: {0}")] |
451 | | UnknownNetworkGateway(String), |
452 | | |
453 | | #[error("no service or workload for hostname: {0}")] |
454 | | NoHostname(String), |
455 | | |
456 | | #[error("no endpoints for workload: {0}")] |
457 | | NoWorkloadEndpoints(String), |
458 | | |
459 | | #[error("no valid authority pseudo header: {0}")] |
460 | | NoValidAuthority(String), |
461 | | |
462 | | #[error("no valid service port in authority header: {0}")] |
463 | | NoValidServicePort(String, u16), |
464 | | |
465 | | #[error("no valid target port for workload: {0}")] |
466 | | NoValidTargetPort(String, u16), |
467 | | |
468 | | #[error("no valid routing destination for workload: {0}")] |
469 | | NoValidDestination(Box<Workload>), |
470 | | |
471 | | #[error("no healthy upstream: {0}")] |
472 | | NoHealthyUpstream(SocketAddr), |
473 | | |
474 | | #[error("no ip addresses were resolved for workload: {0}")] |
475 | | NoResolvedAddresses(String), |
476 | | |
477 | | #[error("requested service {0}:{1} found, but cannot resolve port")] |
478 | | NoPortForServices(String, u16), |
479 | | |
480 | | #[error("requested service {0} found, but has no IP addresses")] |
481 | | NoIPForService(String), |
482 | | |
483 | | #[error("no service for target address: {0}")] |
484 | | NoService(SocketAddr), |
485 | | |
486 | | #[error( |
487 | | "ip addresses were resolved for workload {0}, but valid dns response had no A/AAAA records" |
488 | | )] |
489 | | EmptyResolvedAddresses(String), |
490 | | |
491 | | #[error("attempted recursive call to ourselves")] |
492 | | SelfCall, |
493 | | |
494 | | #[error("no gateway address: {0}")] |
495 | | NoGatewayAddress(Box<Workload>), |
496 | | |
497 | | #[error("unsupported feature: {0}")] |
498 | | UnsupportedFeature(String), |
499 | | |
500 | | #[error("ip mismatch: {0} != {1}")] |
501 | | IPMismatch(IpAddr, IpAddr), |
502 | | |
503 | | #[error("connection failed to drain within the timeout")] |
504 | | DrainTimeOut, |
505 | | #[error("connection closed due to connection drain")] |
506 | | ClosedFromDrain, |
507 | | |
508 | | #[error("dns: {0}")] |
509 | | Dns(#[from] ProtoError), |
510 | | #[error("dns lookup: {0}")] |
511 | | DnsLookup(#[from] hickory_server::authority::LookupError), |
512 | | #[error("dns response had no valid IP addresses")] |
513 | | DnsEmpty, |
514 | | } |
515 | | |
516 | | // Custom TLV for proxy protocol for the identity of the source |
517 | | const PROXY_PROTOCOL_AUTHORITY_TLV: u8 = 0xD0; |
518 | | |
519 | 0 | pub async fn write_proxy_protocol<T>( |
520 | 0 | stream: &mut TcpStream, |
521 | 0 | addresses: T, |
522 | 0 | src_id: Option<Identity>, |
523 | 0 | ) -> io::Result<()> |
524 | 0 | where |
525 | 0 | T: Into<ppp::v2::Addresses> + std::fmt::Debug, |
526 | 0 | { |
527 | | use ppp::v2::{Builder, Command, Protocol, Version}; |
528 | | use tokio::io::AsyncWriteExt; |
529 | | |
530 | | // When the hbone_addr populated from the authority header contains a svc hostname, the address included |
531 | | // with respect to the hbone_addr is the SocketAddr <dst svc IP>:<original dst port>. |
532 | | // This is done since addresses doesn't support hostnames. |
533 | | // See ref https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt |
534 | 0 | debug!("writing proxy protocol addresses: {:?}", addresses); |
535 | 0 | let mut builder = |
536 | 0 | Builder::with_addresses(Version::Two | Command::Proxy, Protocol::Stream, addresses); |
537 | | |
538 | 0 | if let Some(id) = src_id { |
539 | 0 | builder = builder.write_tlv(PROXY_PROTOCOL_AUTHORITY_TLV, id.to_string().as_bytes())?; |
540 | 0 | } |
541 | | |
542 | 0 | let header = builder.build()?; |
543 | 0 | stream.write_all(&header).await |
544 | 0 | } |
545 | | |
546 | | /// Represents a traceparent, as defined by https://www.w3.org/TR/trace-context/ |
547 | | #[derive(Eq, PartialEq)] |
548 | | pub struct TraceParent { |
549 | | version: u8, |
550 | | trace_id: u128, |
551 | | parent_id: u64, |
552 | | flags: u8, |
553 | | } |
554 | | |
555 | | pub const BAGGAGE_HEADER: &str = "baggage"; |
556 | | pub const TRACEPARENT_HEADER: &str = "traceparent"; |
557 | | |
558 | | impl TraceParent { |
559 | 0 | pub fn header(&self) -> hyper::header::HeaderValue { |
560 | 0 | hyper::header::HeaderValue::from_bytes(format!("{self:?}").as_bytes()).unwrap() |
561 | 0 | } |
562 | | } |
563 | | impl TraceParent { |
564 | 0 | fn new() -> Self { |
565 | 0 | let mut rng = rand::rng(); |
566 | 0 | Self { |
567 | 0 | version: 0, |
568 | 0 | trace_id: rng.random(), |
569 | 0 | parent_id: rng.random(), |
570 | 0 | flags: 0, |
571 | 0 | } |
572 | 0 | } |
573 | | } |
574 | | |
575 | | impl fmt::Debug for TraceParent { |
576 | 0 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
577 | 0 | write!( |
578 | 0 | f, |
579 | 0 | "{:02x}-{:032x}-{:016x}-{:02x}", |
580 | | self.version, self.trace_id, self.parent_id, self.flags |
581 | | ) |
582 | 0 | } |
583 | | } |
584 | | |
585 | | impl fmt::Display for TraceParent { |
586 | 0 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
587 | 0 | write!(f, "{:032x}", self.trace_id,) |
588 | 0 | } |
589 | | } |
590 | | |
591 | | impl TryFrom<&str> for TraceParent { |
592 | | type Error = anyhow::Error; |
593 | | |
594 | 0 | fn try_from(value: &str) -> Result<Self, Self::Error> { |
595 | 0 | if value.len() != 55 { |
596 | 0 | anyhow::bail!("traceparent malformed length was {}", value.len()) |
597 | 0 | } |
598 | | |
599 | 0 | let segs: Vec<&str> = value.split('-').collect(); |
600 | | |
601 | | Ok(Self { |
602 | 0 | version: u8::from_str_radix(segs[0], 16)?, |
603 | 0 | trace_id: u128::from_str_radix(segs[1], 16)?, |
604 | 0 | parent_id: u64::from_str_radix(segs[2], 16)?, |
605 | 0 | flags: u8::from_str_radix(segs[3], 16)?, |
606 | | }) |
607 | 0 | } |
608 | | } |
609 | | |
610 | 0 | pub(super) fn maybe_set_transparent( |
611 | 0 | pi: &ProxyInputs, |
612 | 0 | listener: &socket::Listener, |
613 | 0 | ) -> Result<bool, Error> { |
614 | 0 | Ok(match pi.cfg.require_original_source { |
615 | | Some(true) => { |
616 | | // Explicitly enabled. Return error if we cannot set it. |
617 | 0 | listener.set_transparent()?; |
618 | 0 | true |
619 | | } |
620 | | Some(false) => { |
621 | | // Explicitly disabled, don't even attempt to set it. |
622 | 0 | false |
623 | | } |
624 | | None => { |
625 | | // Best effort |
626 | 0 | listener.set_transparent().is_ok() |
627 | | } |
628 | | }) |
629 | 0 | } |
630 | | |
631 | 0 | pub fn get_original_src_from_stream(stream: &TcpStream) -> Option<IpAddr> { |
632 | 0 | stream |
633 | 0 | .peer_addr() |
634 | 0 | .map_or(None, |sa| Some(socket::to_canonical(sa).ip())) |
635 | 0 | } |
636 | | |
637 | | const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); |
638 | | |
639 | 0 | pub async fn freebind_connect( |
640 | 0 | local: Option<IpAddr>, |
641 | 0 | addr: SocketAddr, |
642 | 0 | socket_factory: &(dyn SocketFactory + Send + Sync), |
643 | 0 | ) -> io::Result<TcpStream> { |
644 | 0 | async fn connect( |
645 | 0 | local: Option<IpAddr>, |
646 | 0 | addr: SocketAddr, |
647 | 0 | socket_factory: &(dyn SocketFactory + Send + Sync), |
648 | 0 | ) -> io::Result<TcpStream> { |
649 | 0 | let create_socket = |is_ipv4: bool| { |
650 | 0 | if is_ipv4 { |
651 | 0 | socket_factory.new_tcp_v4() |
652 | | } else { |
653 | 0 | socket_factory.new_tcp_v6() |
654 | | } |
655 | 0 | }; |
656 | | |
657 | 0 | match local { |
658 | | None => { |
659 | 0 | let socket = create_socket(addr.is_ipv4())?; |
660 | 0 | trace!(dest=%addr, "no local address, connect directly"); |
661 | 0 | Ok(socket.connect(addr).await?) |
662 | | } |
663 | | // TODO: Need figure out how to handle case of loadbalancing to itself. |
664 | | // We use ztunnel addr instead, otherwise app side will be confused. |
665 | 0 | Some(src) if src == socket::to_canonical(addr).ip() => { |
666 | 0 | let socket = create_socket(addr.is_ipv4())?; |
667 | 0 | trace!(%src, dest=%addr, "dest and source are the same, connect directly"); |
668 | 0 | Ok(socket.connect(addr).await?) |
669 | | } |
670 | 0 | Some(src) => { |
671 | 0 | let socket = create_socket(src.is_ipv4())?; |
672 | 0 | let local_addr = SocketAddr::new(src, 0); |
673 | 0 | match socket::set_freebind_and_transparent(&socket) { |
674 | 0 | Err(err) => warn!("failed to set freebind: {:?}", err), |
675 | | _ => { |
676 | 0 | if let Err(err) = socket.bind(local_addr) { |
677 | 0 | warn!("failed to bind local addr: {:?}", err) |
678 | 0 | } |
679 | | } |
680 | | }; |
681 | 0 | trace!(%src, dest=%addr, "connect with source IP"); |
682 | 0 | Ok(socket.connect(addr).await?) |
683 | | } |
684 | | } |
685 | 0 | } |
686 | | // Wrap the entire connect function in a timeout |
687 | 0 | timeout(CONNECTION_TIMEOUT, connect(local, addr, socket_factory)) |
688 | 0 | .await |
689 | 0 | .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))? |
690 | 0 | } |
691 | | |
692 | | // guess_inbound_service selects an upstream service for inbound metrics. |
693 | | // There may be many services for a single workload. We find the the first one with an applicable port |
694 | | // as a best guess. |
695 | 0 | pub fn guess_inbound_service( |
696 | 0 | conn: &Connection, |
697 | 0 | for_host_header: &Option<String>, |
698 | 0 | upstream_service: Vec<Arc<Service>>, |
699 | 0 | dest: &Workload, |
700 | 0 | ) -> Option<ServiceDescription> { |
701 | | // First, if the client told us what Service they were reaching, look for that |
702 | | // Note: the set of Services we look for is bounded, so we won't blindly trust bogus info. |
703 | 0 | if let Some(found) = upstream_service |
704 | 0 | .iter() |
705 | 0 | .find(|s| for_host_header.as_deref() == Some(s.hostname.as_ref())) |
706 | 0 | .map(|s| ServiceDescription::from(s.as_ref())) |
707 | | { |
708 | 0 | return Some(found); |
709 | 0 | } |
710 | 0 | let dport = conn.dst.port(); |
711 | 0 | upstream_service |
712 | 0 | .iter() |
713 | 0 | .find(|s| { |
714 | 0 | for (sport, tport) in s.ports.iter() { |
715 | 0 | if tport == &dport { |
716 | | // TargetPort directly matches |
717 | 0 | return true; |
718 | 0 | } |
719 | | // The service itself didn't have a explicit TargetPort match, but an endpoint might. |
720 | | // This happens when there is a named port (in Kubernetes, anyways). |
721 | 0 | if s.endpoints.get(&dest.uid).and_then(|e| e.port.get(sport)) == Some(&dport) { |
722 | | // Named port matched |
723 | 0 | return true; |
724 | 0 | } |
725 | | // no match |
726 | | } |
727 | 0 | false |
728 | 0 | }) |
729 | 0 | .map(|s| ServiceDescription::from(s.as_ref())) |
730 | 0 | } |
731 | | |
732 | | // Checks that the source identiy and address match the upstream's waypoint |
733 | 0 | async fn check_from_waypoint( |
734 | 0 | state: &DemandProxyState, |
735 | 0 | upstream: &Workload, |
736 | 0 | src_identity: Option<&Identity>, |
737 | 0 | src_ip: &IpAddr, |
738 | 0 | ) -> bool { |
739 | 0 | let is_waypoint = |wl: &Workload| { |
740 | 0 | Some(wl.identity()).as_ref() == src_identity && wl.workload_ips.contains(src_ip) |
741 | 0 | }; |
742 | 0 | check_gateway_address(state, upstream.waypoint.as_ref(), is_waypoint).await |
743 | 0 | } |
744 | | |
745 | | // Check if the source's identity matches any workloads that make up the given gateway |
746 | | // TODO: This can be made more accurate by also checking addresses. |
747 | 0 | async fn check_gateway_address<F>( |
748 | 0 | state: &DemandProxyState, |
749 | 0 | gateway_address: Option<&GatewayAddress>, |
750 | 0 | predicate: F, |
751 | 0 | ) -> bool |
752 | 0 | where |
753 | 0 | F: Fn(&Workload) -> bool, |
754 | 0 | { |
755 | 0 | let Some(gateway_address) = gateway_address else { |
756 | 0 | return false; |
757 | | }; |
758 | | |
759 | 0 | match state.fetch_destination(&gateway_address.destination).await { |
760 | 0 | Some(Address::Workload(wl)) => return predicate(wl.as_ref()), |
761 | 0 | Some(Address::Service(svc)) => { |
762 | 0 | for ep in svc.endpoints.iter() { |
763 | | // fetch workloads by workload UID since we may not have an IP for an endpoint (e.g., endpoint is just a hostname) |
764 | 0 | let wl = state.fetch_workload_by_uid(&ep.workload_uid).await; |
765 | 0 | if wl.as_ref().is_some_and(|wl| predicate(wl.as_ref())) { |
766 | 0 | return true; |
767 | 0 | } |
768 | | } |
769 | | } |
770 | 0 | None => {} |
771 | | }; |
772 | | |
773 | 0 | false |
774 | 0 | } |
775 | | |
776 | | const IPV6_DISABLED_LO: &str = "/proc/sys/net/ipv6/conf/lo/disable_ipv6"; |
777 | | |
778 | 0 | fn read_sysctl(key: &str) -> io::Result<String> { |
779 | 0 | let mut file = File::open(key)?; |
780 | 0 | let mut data = String::new(); |
781 | 0 | file.read_to_string(&mut data)?; |
782 | 0 | Ok(data.trim().to_string()) |
783 | 0 | } |
784 | | |
785 | 0 | pub fn ipv6_enabled_on_localhost() -> io::Result<bool> { |
786 | 0 | read_sysctl(IPV6_DISABLED_LO).map(|s| s != "1") |
787 | 0 | } |
788 | | |
789 | 1.23k | pub fn parse_forwarded_host(input: &str) -> Option<String> { |
790 | 1.23k | if !input.is_ascii() { |
791 | 17 | return None; |
792 | 1.22k | } |
793 | 1.22k | input |
794 | 1.22k | .split(';') |
795 | 414k | .find(|part| part.trim().starts_with("host=")) |
796 | 1.22k | .and_then(|host_part| { |
797 | 234 | host_part |
798 | 234 | .trim() |
799 | 234 | .strip_prefix("host=") |
800 | 234 | .map(|h| h.strip_prefix("\"").unwrap_or(h)) |
801 | 234 | .map(|h| h.strip_suffix("\"").unwrap_or(h)) |
802 | 234 | .map(|s| s.to_string()) |
803 | 234 | }) |
804 | 1.22k | .filter(|host| !host.is_empty()) |
805 | 1.23k | } |
806 | | |
807 | | #[derive(Debug, Clone, PartialEq)] |
808 | | pub enum HboneAddress { |
809 | | SocketAddr(SocketAddr), |
810 | | SvcHostname(Strng, u16), |
811 | | } |
812 | | |
813 | | impl HboneAddress { |
814 | 0 | pub fn port(&self) -> u16 { |
815 | 0 | match self { |
816 | 0 | HboneAddress::SocketAddr(s) => s.port(), |
817 | 0 | HboneAddress::SvcHostname(_, p) => *p, |
818 | | } |
819 | 0 | } |
820 | | |
821 | 0 | pub fn ip(&self) -> Option<IpAddr> { |
822 | 0 | match self { |
823 | 0 | HboneAddress::SocketAddr(s) => Some(s.ip()), |
824 | 0 | HboneAddress::SvcHostname(_, _) => None, |
825 | | } |
826 | 0 | } |
827 | | |
828 | 0 | pub fn svc_hostname(&self) -> Option<Strng> { |
829 | 0 | match self { |
830 | 0 | HboneAddress::SocketAddr(_) => None, |
831 | 0 | HboneAddress::SvcHostname(s, _) => Some(s.into()), |
832 | | } |
833 | 0 | } |
834 | | |
835 | 0 | pub fn hostname_addr(&self) -> Option<Strng> { |
836 | 0 | match self { |
837 | 0 | HboneAddress::SocketAddr(_) => None, |
838 | 0 | HboneAddress::SvcHostname(_, _) => Some(Strng::from(self.to_string())), |
839 | | } |
840 | 0 | } |
841 | | } |
842 | | |
843 | | impl std::fmt::Display for HboneAddress { |
844 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
845 | 0 | match self { |
846 | 0 | HboneAddress::SocketAddr(addr) => write!(f, "{addr}"), |
847 | 0 | HboneAddress::SvcHostname(host, port) => write!(f, "{host}:{port}"), |
848 | | } |
849 | 0 | } |
850 | | } |
851 | | |
852 | | impl From<SocketAddr> for HboneAddress { |
853 | 0 | fn from(socket_addr: SocketAddr) -> Self { |
854 | 0 | HboneAddress::SocketAddr(socket_addr) |
855 | 0 | } |
856 | | } |
857 | | |
858 | | impl From<(Strng, u16)> for HboneAddress { |
859 | 0 | fn from(svc_hostname: (Strng, u16)) -> Self { |
860 | 0 | HboneAddress::SvcHostname(svc_hostname.0, svc_hostname.1) |
861 | 0 | } |
862 | | } |
863 | | |
864 | | impl TryFrom<&http::Uri> for HboneAddress { |
865 | | type Error = Error; |
866 | | |
867 | 0 | fn try_from(value: &http::Uri) -> Result<Self, Self::Error> { |
868 | 0 | match value.to_string().parse::<SocketAddr>() { |
869 | 0 | Ok(addr) => Ok(HboneAddress::SocketAddr(addr)), |
870 | | Err(_) => { |
871 | 0 | let hbone_host = value |
872 | 0 | .host() |
873 | 0 | .ok_or_else(|| Error::NoValidAuthority(value.to_string()))?; |
874 | 0 | let hbone_port = value |
875 | 0 | .port_u16() |
876 | 0 | .ok_or_else(|| Error::NoValidAuthority(value.to_string()))?; |
877 | 0 | Ok(HboneAddress::SvcHostname(hbone_host.into(), hbone_port)) |
878 | | } |
879 | | } |
880 | 0 | } |
881 | | } |
882 | | |
883 | | #[cfg(test)] |
884 | | mod tests { |
885 | | use super::*; |
886 | | |
887 | | #[test] |
888 | | fn test_parse_forwarded_host() { |
889 | | let header = "by=identifier;for=identifier;host=example.com;proto=https"; |
890 | | assert_eq!( |
891 | | parse_forwarded_host(header), |
892 | | Some("example.com".to_string()) |
893 | | ); |
894 | | let header = "by=identifier;for=identifier;host=\"example.com\";proto=https"; |
895 | | assert_eq!( |
896 | | parse_forwarded_host(header), |
897 | | Some("example.com".to_string()) |
898 | | ); |
899 | | let header = "by=identifier;for=identifier;proto=https"; |
900 | | assert_eq!(parse_forwarded_host(header), None); |
901 | | let header = "by=identifier;for=identifier;host=;proto=https"; |
902 | | assert_eq!(parse_forwarded_host(header), None); |
903 | | let header = r#"for=for;by=by;host=host;proto="pröto""#; |
904 | | assert_eq!(parse_forwarded_host(header), None); |
905 | | } |
906 | | } |