/rust/registry/src/index.crates.io-1949cf8c6b5b557f/axum-0.8.8/src/response/sse.rs
Line | Count | Source |
1 | | //! Server-Sent Events (SSE) responses. |
2 | | //! |
3 | | //! # Example |
4 | | //! |
5 | | //! ``` |
6 | | //! use axum::{ |
7 | | //! Router, |
8 | | //! routing::get, |
9 | | //! response::sse::{Event, KeepAlive, Sse}, |
10 | | //! }; |
11 | | //! use std::{time::Duration, convert::Infallible}; |
12 | | //! use tokio_stream::StreamExt as _ ; |
13 | | //! use futures_util::stream::{self, Stream}; |
14 | | //! |
15 | | //! let app = Router::new().route("/sse", get(sse_handler)); |
16 | | //! |
17 | | //! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> { |
18 | | //! // A `Stream` that repeats an event every second |
19 | | //! let stream = stream::repeat_with(|| Event::default().data("hi!")) |
20 | | //! .map(Ok) |
21 | | //! .throttle(Duration::from_secs(1)); |
22 | | //! |
23 | | //! Sse::new(stream).keep_alive(KeepAlive::default()) |
24 | | //! } |
25 | | //! # let _: Router = app; |
26 | | //! ``` |
27 | | |
28 | | use crate::{ |
29 | | body::{Bytes, HttpBody}, |
30 | | BoxError, |
31 | | }; |
32 | | use axum_core::{ |
33 | | body::Body, |
34 | | response::{IntoResponse, Response}, |
35 | | }; |
36 | | use bytes::{BufMut, BytesMut}; |
37 | | use futures_util::stream::{Stream, TryStream}; |
38 | | use http_body::Frame; |
39 | | use pin_project_lite::pin_project; |
40 | | use std::{ |
41 | | fmt::{self, Write as _}, |
42 | | io::Write as _, |
43 | | mem, |
44 | | pin::Pin, |
45 | | task::{ready, Context, Poll}, |
46 | | time::Duration, |
47 | | }; |
48 | | use sync_wrapper::SyncWrapper; |
49 | | |
50 | | /// An SSE response |
51 | | #[derive(Clone)] |
52 | | #[must_use] |
53 | | pub struct Sse<S> { |
54 | | stream: S, |
55 | | } |
56 | | |
57 | | impl<S> Sse<S> { |
58 | | /// Create a new [`Sse`] response that will respond with the given stream of |
59 | | /// [`Event`]s. |
60 | | /// |
61 | | /// See the [module docs](self) for more details. |
62 | 0 | pub fn new(stream: S) -> Self |
63 | 0 | where |
64 | 0 | S: TryStream<Ok = Event> + Send + 'static, |
65 | 0 | S::Error: Into<BoxError>, |
66 | | { |
67 | 0 | Sse { stream } |
68 | 0 | } |
69 | | |
70 | | /// Configure the interval between keep-alive messages. |
71 | | /// |
72 | | /// Defaults to no keep-alive messages. |
73 | | #[cfg(feature = "tokio")] |
74 | | pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> { |
75 | | Sse { |
76 | | stream: KeepAliveStream::new(keep_alive, self.stream), |
77 | | } |
78 | | } |
79 | | } |
80 | | |
81 | | impl<S> fmt::Debug for Sse<S> { |
82 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
83 | 0 | f.debug_struct("Sse") |
84 | 0 | .field("stream", &format_args!("{}", std::any::type_name::<S>())) |
85 | 0 | .finish() |
86 | 0 | } |
87 | | } |
88 | | |
89 | | impl<S, E> IntoResponse for Sse<S> |
90 | | where |
91 | | S: Stream<Item = Result<Event, E>> + Send + 'static, |
92 | | E: Into<BoxError>, |
93 | | { |
94 | 0 | fn into_response(self) -> Response { |
95 | 0 | ( |
96 | 0 | [ |
97 | 0 | (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()), |
98 | 0 | (http::header::CACHE_CONTROL, "no-cache"), |
99 | 0 | ], |
100 | 0 | Body::new(SseBody { |
101 | 0 | event_stream: SyncWrapper::new(self.stream), |
102 | 0 | }), |
103 | 0 | ) |
104 | 0 | .into_response() |
105 | 0 | } |
106 | | } |
107 | | |
108 | | pin_project! { |
109 | | struct SseBody<S> { |
110 | | #[pin] |
111 | | event_stream: SyncWrapper<S>, |
112 | | } |
113 | | } |
114 | | |
115 | | impl<S, E> HttpBody for SseBody<S> |
116 | | where |
117 | | S: Stream<Item = Result<Event, E>>, |
118 | | { |
119 | | type Data = Bytes; |
120 | | type Error = E; |
121 | | |
122 | 0 | fn poll_frame( |
123 | 0 | self: Pin<&mut Self>, |
124 | 0 | cx: &mut Context<'_>, |
125 | 0 | ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { |
126 | 0 | let this = self.project(); |
127 | | |
128 | 0 | match ready!(this.event_stream.get_pin_mut().poll_next(cx)) { |
129 | 0 | Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))), |
130 | 0 | Some(Err(error)) => Poll::Ready(Some(Err(error))), |
131 | 0 | None => Poll::Ready(None), |
132 | | } |
133 | 0 | } |
134 | | } |
135 | | |
136 | | /// The state of an event's buffer. |
137 | | /// |
138 | | /// This type allows creating events in a `const` context |
139 | | /// by using a finalized buffer. |
140 | | /// |
141 | | /// While the buffer is active, more bytes can be written to it. |
142 | | /// Once finalized, it's immutable and cheap to clone. |
143 | | /// The buffer is active during the event building, but eventually |
144 | | /// becomes finalized to send http body frames as [`Bytes`]. |
145 | | #[derive(Debug, Clone)] |
146 | | enum Buffer { |
147 | | Active(BytesMut), |
148 | | Finalized(Bytes), |
149 | | } |
150 | | |
151 | | impl Buffer { |
152 | | /// Returns a mutable reference to the internal buffer. |
153 | | /// |
154 | | /// If the buffer was finalized, this method creates |
155 | | /// a new active buffer with the previous contents. |
156 | 0 | fn as_mut(&mut self) -> &mut BytesMut { |
157 | 0 | match self { |
158 | 0 | Buffer::Active(bytes_mut) => bytes_mut, |
159 | 0 | Buffer::Finalized(bytes) => { |
160 | 0 | *self = Buffer::Active(BytesMut::from(mem::take(bytes))); |
161 | 0 | match self { |
162 | 0 | Buffer::Active(bytes_mut) => bytes_mut, |
163 | 0 | Buffer::Finalized(_) => unreachable!(), |
164 | | } |
165 | | } |
166 | | } |
167 | 0 | } |
168 | | } |
169 | | |
170 | | /// Server-sent event |
171 | | #[derive(Debug, Clone)] |
172 | | #[must_use] |
173 | | pub struct Event { |
174 | | buffer: Buffer, |
175 | | flags: EventFlags, |
176 | | } |
177 | | |
178 | | /// Expose [`Event`] as a [`std::fmt::Write`] |
179 | | /// such that any form of data can be written as data safely. |
180 | | /// |
181 | | /// This also ensures that newline characters `\r` and `\n` |
182 | | /// correctly trigger a split with a new `data: ` prefix. |
183 | | /// |
184 | | /// # Panics |
185 | | /// |
186 | | /// Panics if any `data` has already been written prior to the first write |
187 | | /// of this [`EventDataWriter`] instance. |
188 | | #[derive(Debug)] |
189 | | #[must_use] |
190 | | pub struct EventDataWriter { |
191 | | event: Event, |
192 | | |
193 | | // Indicates if _this_ EventDataWriter has written data, |
194 | | // this does not say anything about whether or not `event` contains |
195 | | // data or not. |
196 | | data_written: bool, |
197 | | } |
198 | | |
199 | | impl Event { |
200 | | /// Default keep-alive event |
201 | | pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n")); |
202 | | |
203 | 0 | const fn finalized(bytes: Bytes) -> Self { |
204 | 0 | Self { |
205 | 0 | buffer: Buffer::Finalized(bytes), |
206 | 0 | flags: EventFlags::from_bits(0), |
207 | 0 | } |
208 | 0 | } |
209 | | |
210 | | /// Use this [`Event`] as a [`EventDataWriter`] to write custom data. |
211 | | /// |
212 | | /// - [`Self::data`] can be used as a shortcut to write `str` data |
213 | | /// - [`Self::json_data`] can be used as a shortcut to write `json` data |
214 | | /// |
215 | | /// Turn it into an [`Event`] again using [`EventDataWriter::into_event`]. |
216 | 0 | pub fn into_data_writer(self) -> EventDataWriter { |
217 | 0 | EventDataWriter { |
218 | 0 | event: self, |
219 | 0 | data_written: false, |
220 | 0 | } |
221 | 0 | } |
222 | | |
223 | | /// Set the event's data data field(s) (`data: <content>`) |
224 | | /// |
225 | | /// Newlines in `data` will automatically be broken across `data: ` fields. |
226 | | /// |
227 | | /// This corresponds to [`MessageEvent`'s data field]. |
228 | | /// |
229 | | /// Note that events with an empty data field will be ignored by the browser. |
230 | | /// |
231 | | /// # Panics |
232 | | /// |
233 | | /// Panics if any `data` has already been written before. |
234 | | /// |
235 | | /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data |
236 | 0 | pub fn data<T>(self, data: T) -> Self |
237 | 0 | where |
238 | 0 | T: AsRef<str>, |
239 | | { |
240 | 0 | let mut writer = self.into_data_writer(); |
241 | 0 | let _ = writer.write_str(data.as_ref()); |
242 | 0 | writer.into_event() |
243 | 0 | } |
244 | | |
245 | | /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`). |
246 | | /// |
247 | | /// This corresponds to [`MessageEvent`'s data field]. |
248 | | /// |
249 | | /// # Panics |
250 | | /// |
251 | | /// Panics if any `data` has already been written before. |
252 | | /// |
253 | | /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data |
254 | | #[cfg(feature = "json")] |
255 | | pub fn json_data<T>(self, data: T) -> Result<Self, axum_core::Error> |
256 | | where |
257 | | T: serde_core::Serialize, |
258 | | { |
259 | | struct JsonWriter<'a>(&'a mut EventDataWriter); |
260 | | impl std::io::Write for JsonWriter<'_> { |
261 | | #[inline] |
262 | | fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { |
263 | | Ok(self.0.write_buf(buf)) |
264 | | } |
265 | | fn flush(&mut self) -> std::io::Result<()> { |
266 | | Ok(()) |
267 | | } |
268 | | } |
269 | | |
270 | | let mut writer = self.into_data_writer(); |
271 | | |
272 | | let json_writer = JsonWriter(&mut writer); |
273 | | serde_json::to_writer(json_writer, &data).map_err(axum_core::Error::new)?; |
274 | | |
275 | | Ok(writer.into_event()) |
276 | | } |
277 | | |
278 | | /// Set the event's comment field (`:<comment-text>`). |
279 | | /// |
280 | | /// This field will be ignored by most SSE clients. |
281 | | /// |
282 | | /// Unlike other functions, this function can be called multiple times to add many comments. |
283 | | /// |
284 | | /// # Panics |
285 | | /// |
286 | | /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in |
287 | | /// comments. |
288 | 0 | pub fn comment<T>(mut self, comment: T) -> Event |
289 | 0 | where |
290 | 0 | T: AsRef<str>, |
291 | | { |
292 | 0 | self.field("", comment.as_ref()); |
293 | 0 | self |
294 | 0 | } |
295 | | |
296 | | /// Set the event's name field (`event:<event-name>`). |
297 | | /// |
298 | | /// This corresponds to the `type` parameter given when calling `addEventListener` on an |
299 | | /// [`EventSource`]. For example, `.event("update")` should correspond to |
300 | | /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a |
301 | | /// [`message` event] instead. |
302 | | /// |
303 | | /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource |
304 | | /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event |
305 | | /// |
306 | | /// # Panics |
307 | | /// |
308 | | /// - Panics if `event` contains any newlines or carriage returns. |
309 | | /// - Panics if this function has already been called on this event. |
310 | 0 | pub fn event<T>(mut self, event: T) -> Event |
311 | 0 | where |
312 | 0 | T: AsRef<str>, |
313 | | { |
314 | 0 | if self.flags.contains(EventFlags::HAS_EVENT) { |
315 | 0 | panic!("Called `Event::event` multiple times"); |
316 | 0 | } |
317 | 0 | self.flags.insert(EventFlags::HAS_EVENT); |
318 | | |
319 | 0 | self.field("event", event.as_ref()); |
320 | | |
321 | 0 | self |
322 | 0 | } |
323 | | |
324 | | /// Set the event's retry timeout field (`retry: <timeout>`). |
325 | | /// |
326 | | /// This sets how long clients will wait before reconnecting if they are disconnected from the |
327 | | /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they |
328 | | /// wish, such as if they implement exponential backoff. |
329 | | /// |
330 | | /// # Panics |
331 | | /// |
332 | | /// Panics if this function has already been called on this event. |
333 | 0 | pub fn retry(mut self, duration: Duration) -> Event { |
334 | 0 | if self.flags.contains(EventFlags::HAS_RETRY) { |
335 | 0 | panic!("Called `Event::retry` multiple times"); |
336 | 0 | } |
337 | 0 | self.flags.insert(EventFlags::HAS_RETRY); |
338 | | |
339 | 0 | let buffer = self.buffer.as_mut(); |
340 | 0 | buffer.extend_from_slice(b"retry: "); |
341 | | |
342 | 0 | let secs = duration.as_secs(); |
343 | 0 | let millis = duration.subsec_millis(); |
344 | | |
345 | 0 | if secs > 0 { |
346 | | // format seconds |
347 | 0 | buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes()); |
348 | | |
349 | | // pad milliseconds |
350 | 0 | if millis < 10 { |
351 | 0 | buffer.extend_from_slice(b"00"); |
352 | 0 | } else if millis < 100 { |
353 | 0 | buffer.extend_from_slice(b"0"); |
354 | 0 | } |
355 | 0 | } |
356 | | |
357 | | // format milliseconds |
358 | 0 | buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes()); |
359 | | |
360 | 0 | buffer.put_u8(b'\n'); |
361 | | |
362 | 0 | self |
363 | 0 | } |
364 | | |
365 | | /// Set the event's identifier field (`id:<identifier>`). |
366 | | /// |
367 | | /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself, |
368 | | /// the browser will set that field to the last known message ID, starting with the empty |
369 | | /// string. |
370 | | /// |
371 | | /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId |
372 | | /// |
373 | | /// # Panics |
374 | | /// |
375 | | /// - Panics if `id` contains any newlines, carriage returns or null characters. |
376 | | /// - Panics if this function has already been called on this event. |
377 | 0 | pub fn id<T>(mut self, id: T) -> Event |
378 | 0 | where |
379 | 0 | T: AsRef<str>, |
380 | | { |
381 | 0 | if self.flags.contains(EventFlags::HAS_ID) { |
382 | 0 | panic!("Called `Event::id` multiple times"); |
383 | 0 | } |
384 | 0 | self.flags.insert(EventFlags::HAS_ID); |
385 | | |
386 | 0 | let id = id.as_ref().as_bytes(); |
387 | 0 | assert_eq!( |
388 | 0 | memchr::memchr(b'\0', id), |
389 | | None, |
390 | | "Event ID cannot contain null characters", |
391 | | ); |
392 | | |
393 | 0 | self.field("id", id); |
394 | 0 | self |
395 | 0 | } |
396 | | |
397 | 0 | fn field(&mut self, name: &str, value: impl AsRef<[u8]>) { |
398 | 0 | let value = value.as_ref(); |
399 | 0 | assert_eq!( |
400 | 0 | memchr::memchr2(b'\r', b'\n', value), |
401 | | None, |
402 | | "SSE field value cannot contain newlines or carriage returns", |
403 | | ); |
404 | | |
405 | 0 | let buffer = self.buffer.as_mut(); |
406 | 0 | buffer.extend_from_slice(name.as_bytes()); |
407 | 0 | buffer.put_u8(b':'); |
408 | 0 | buffer.put_u8(b' '); |
409 | 0 | buffer.extend_from_slice(value); |
410 | 0 | buffer.put_u8(b'\n'); |
411 | 0 | } |
412 | | |
413 | 0 | fn finalize(self) -> Bytes { |
414 | 0 | match self.buffer { |
415 | 0 | Buffer::Finalized(bytes) => bytes, |
416 | 0 | Buffer::Active(mut bytes_mut) => { |
417 | 0 | bytes_mut.put_u8(b'\n'); |
418 | 0 | bytes_mut.freeze() |
419 | | } |
420 | | } |
421 | 0 | } |
422 | | } |
423 | | |
424 | | impl EventDataWriter { |
425 | | /// Consume the [`EventDataWriter`] and return the [`Event`] once again. |
426 | | /// |
427 | | /// In case any data was written by this instance |
428 | | /// it will also write the trailing `\n` character. |
429 | 0 | pub fn into_event(self) -> Event { |
430 | 0 | let mut event = self.event; |
431 | 0 | if self.data_written { |
432 | 0 | let _ = event.buffer.as_mut().write_char('\n'); |
433 | 0 | } |
434 | 0 | event |
435 | 0 | } |
436 | | } |
437 | | |
438 | | impl EventDataWriter { |
439 | | // Assumption: underlying writer never returns an error: |
440 | | // <https://docs.rs/bytes/latest/src/bytes/buf/writer.rs.html#79-82> |
441 | 0 | fn write_buf(&mut self, buf: &[u8]) -> usize { |
442 | 0 | if buf.is_empty() { |
443 | 0 | return 0; |
444 | 0 | } |
445 | | |
446 | 0 | let buffer = self.event.buffer.as_mut(); |
447 | | |
448 | 0 | if !std::mem::replace(&mut self.data_written, true) { |
449 | 0 | if self.event.flags.contains(EventFlags::HAS_DATA) { |
450 | 0 | panic!("Called `Event::data*` multiple times"); |
451 | 0 | } |
452 | | |
453 | 0 | let _ = buffer.write_str("data: "); |
454 | 0 | self.event.flags.insert(EventFlags::HAS_DATA); |
455 | 0 | } |
456 | | |
457 | 0 | let mut writer = buffer.writer(); |
458 | | |
459 | 0 | let mut last_split = 0; |
460 | 0 | for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) { |
461 | 0 | let _ = writer.write_all(&buf[last_split..=delimiter]); |
462 | 0 | let _ = writer.write_all(b"data: "); |
463 | 0 | last_split = delimiter + 1; |
464 | 0 | } |
465 | 0 | let _ = writer.write_all(&buf[last_split..]); |
466 | | |
467 | 0 | buf.len() |
468 | 0 | } |
469 | | } |
470 | | |
471 | | impl fmt::Write for EventDataWriter { |
472 | 0 | fn write_str(&mut self, s: &str) -> fmt::Result { |
473 | 0 | let _ = self.write_buf(s.as_bytes()); |
474 | 0 | Ok(()) |
475 | 0 | } |
476 | | } |
477 | | |
478 | | impl Default for Event { |
479 | 0 | fn default() -> Self { |
480 | 0 | Self { |
481 | 0 | buffer: Buffer::Active(BytesMut::new()), |
482 | 0 | flags: EventFlags::from_bits(0), |
483 | 0 | } |
484 | 0 | } |
485 | | } |
486 | | |
487 | | #[derive(Debug, Copy, Clone, PartialEq)] |
488 | | struct EventFlags(u8); |
489 | | |
490 | | impl EventFlags { |
491 | | const HAS_DATA: Self = Self::from_bits(0b0001); |
492 | | const HAS_EVENT: Self = Self::from_bits(0b0010); |
493 | | const HAS_RETRY: Self = Self::from_bits(0b0100); |
494 | | const HAS_ID: Self = Self::from_bits(0b1000); |
495 | | |
496 | 0 | const fn bits(&self) -> u8 { |
497 | 0 | self.0 |
498 | 0 | } |
499 | | |
500 | 0 | const fn from_bits(bits: u8) -> Self { |
501 | 0 | Self(bits) |
502 | 0 | } |
503 | | |
504 | 0 | const fn contains(&self, other: Self) -> bool { |
505 | 0 | self.bits() & other.bits() == other.bits() |
506 | 0 | } |
507 | | |
508 | 0 | fn insert(&mut self, other: Self) { |
509 | 0 | *self = Self::from_bits(self.bits() | other.bits()); |
510 | 0 | } |
511 | | } |
512 | | |
513 | | /// Configure the interval between keep-alive messages, the content |
514 | | /// of each message, and the associated stream. |
515 | | #[derive(Debug, Clone)] |
516 | | #[must_use] |
517 | | pub struct KeepAlive { |
518 | | event: Event, |
519 | | max_interval: Duration, |
520 | | } |
521 | | |
522 | | impl KeepAlive { |
523 | | /// Create a new `KeepAlive`. |
524 | 0 | pub fn new() -> Self { |
525 | 0 | Self { |
526 | 0 | event: Event::DEFAULT_KEEP_ALIVE, |
527 | 0 | max_interval: Duration::from_secs(15), |
528 | 0 | } |
529 | 0 | } |
530 | | |
531 | | /// Customize the interval between keep-alive messages. |
532 | | /// |
533 | | /// Default is 15 seconds. |
534 | 0 | pub fn interval(mut self, time: Duration) -> Self { |
535 | 0 | self.max_interval = time; |
536 | 0 | self |
537 | 0 | } |
538 | | |
539 | | /// Customize the text of the keep-alive message. |
540 | | /// |
541 | | /// Default is an empty comment. |
542 | | /// |
543 | | /// # Panics |
544 | | /// |
545 | | /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE |
546 | | /// comments. |
547 | 0 | pub fn text<I>(self, text: I) -> Self |
548 | 0 | where |
549 | 0 | I: AsRef<str>, |
550 | | { |
551 | 0 | self.event(Event::default().comment(text)) |
552 | 0 | } |
553 | | |
554 | | /// Customize the event of the keep-alive message. |
555 | | /// |
556 | | /// Default is an empty comment. |
557 | | /// |
558 | | /// # Panics |
559 | | /// |
560 | | /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE |
561 | | /// comments. |
562 | 0 | pub fn event(mut self, event: Event) -> Self { |
563 | 0 | self.event = Event::finalized(event.finalize()); |
564 | 0 | self |
565 | 0 | } |
566 | | } |
567 | | |
568 | | impl Default for KeepAlive { |
569 | 0 | fn default() -> Self { |
570 | 0 | Self::new() |
571 | 0 | } |
572 | | } |
573 | | |
574 | | #[cfg(feature = "tokio")] |
575 | | pin_project! { |
576 | | /// A wrapper around a stream that produces keep-alive events |
577 | | #[derive(Debug)] |
578 | | pub struct KeepAliveStream<S> { |
579 | | #[pin] |
580 | | alive_timer: tokio::time::Sleep, |
581 | | #[pin] |
582 | | inner: S, |
583 | | keep_alive: KeepAlive, |
584 | | } |
585 | | } |
586 | | |
587 | | #[cfg(feature = "tokio")] |
588 | | impl<S> KeepAliveStream<S> { |
589 | | fn new(keep_alive: KeepAlive, inner: S) -> Self { |
590 | | Self { |
591 | | alive_timer: tokio::time::sleep(keep_alive.max_interval), |
592 | | inner, |
593 | | keep_alive, |
594 | | } |
595 | | } |
596 | | |
597 | | fn reset(self: Pin<&mut Self>) { |
598 | | let this = self.project(); |
599 | | this.alive_timer |
600 | | .reset(tokio::time::Instant::now() + this.keep_alive.max_interval); |
601 | | } |
602 | | } |
603 | | |
604 | | #[cfg(feature = "tokio")] |
605 | | impl<S, E> Stream for KeepAliveStream<S> |
606 | | where |
607 | | S: Stream<Item = Result<Event, E>>, |
608 | | { |
609 | | type Item = Result<Event, E>; |
610 | | |
611 | | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
612 | | use std::future::Future; |
613 | | |
614 | | let mut this = self.as_mut().project(); |
615 | | |
616 | | match this.inner.as_mut().poll_next(cx) { |
617 | | Poll::Ready(Some(Ok(event))) => { |
618 | | self.reset(); |
619 | | |
620 | | Poll::Ready(Some(Ok(event))) |
621 | | } |
622 | | Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))), |
623 | | Poll::Ready(None) => Poll::Ready(None), |
624 | | Poll::Pending => { |
625 | | ready!(this.alive_timer.poll(cx)); |
626 | | |
627 | | let event = this.keep_alive.event.clone(); |
628 | | |
629 | | self.reset(); |
630 | | |
631 | | Poll::Ready(Some(Ok(event))) |
632 | | } |
633 | | } |
634 | | } |
635 | | } |
636 | | |
637 | | #[cfg(test)] |
638 | | mod tests { |
639 | | use super::*; |
640 | | use crate::{routing::get, test_helpers::*, Router}; |
641 | | use futures_util::stream; |
642 | | use serde_json::value::RawValue; |
643 | | use std::{collections::HashMap, convert::Infallible}; |
644 | | use tokio_stream::StreamExt as _; |
645 | | |
646 | | #[test] |
647 | | fn leading_space_is_not_stripped() { |
648 | | let no_leading_space = Event::default().data("\tfoobar"); |
649 | | assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n"); |
650 | | |
651 | | let leading_space = Event::default().data(" foobar"); |
652 | | assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n"); |
653 | | } |
654 | | |
655 | | #[test] |
656 | | fn write_data_writer_str() { |
657 | | // also confirm that nop writers do nothing :) |
658 | | let mut writer = Event::default() |
659 | | .into_data_writer() |
660 | | .into_event() |
661 | | .into_data_writer(); |
662 | | writer.write_str("").unwrap(); |
663 | | let mut writer = writer.into_event().into_data_writer(); |
664 | | |
665 | | writer.write_str("").unwrap(); |
666 | | writer.write_str("moon ").unwrap(); |
667 | | writer.write_str("star\nsun").unwrap(); |
668 | | writer.write_str("").unwrap(); |
669 | | writer.write_str("set").unwrap(); |
670 | | writer.write_str("").unwrap(); |
671 | | writer.write_str(" bye\r").unwrap(); |
672 | | |
673 | | let event = writer.into_event(); |
674 | | |
675 | | assert_eq!( |
676 | | &*event.finalize(), |
677 | | b"data: moon star\ndata: sunset bye\rdata: \n\n" |
678 | | ); |
679 | | } |
680 | | |
681 | | #[test] |
682 | | fn valid_json_raw_value_chars_handled() { |
683 | | let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}"; |
684 | | let json_raw_value_event = Event::default() |
685 | | .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap()) |
686 | | .unwrap(); |
687 | | assert_eq!( |
688 | | &*json_raw_value_event.finalize(), |
689 | | b"data: {\rdata: \"foo\": \ndata: \rdata: \rdata: \"bar\\n\"\ndata: }\n\n" |
690 | | ); |
691 | | } |
692 | | |
693 | | #[crate::test] |
694 | | async fn basic() { |
695 | | let app = Router::new().route( |
696 | | "/", |
697 | | get(|| async { |
698 | | let stream = stream::iter(vec![ |
699 | | Event::default().data("one").comment("this is a comment"), |
700 | | Event::default() |
701 | | .json_data(serde_json::json!({ "foo": "bar" })) |
702 | | .unwrap(), |
703 | | Event::default() |
704 | | .event("three") |
705 | | .retry(Duration::from_secs(30)) |
706 | | .id("unique-id"), |
707 | | ]) |
708 | | .map(Ok::<_, Infallible>); |
709 | | Sse::new(stream) |
710 | | }), |
711 | | ); |
712 | | |
713 | | let client = TestClient::new(app); |
714 | | let mut stream = client.get("/").await; |
715 | | |
716 | | assert_eq!(stream.headers()["content-type"], "text/event-stream"); |
717 | | assert_eq!(stream.headers()["cache-control"], "no-cache"); |
718 | | |
719 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
720 | | assert_eq!(event_fields.get("data").unwrap(), "one"); |
721 | | assert_eq!(event_fields.get("comment").unwrap(), "this is a comment"); |
722 | | |
723 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
724 | | assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}"); |
725 | | assert!(!event_fields.contains_key("comment")); |
726 | | |
727 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
728 | | assert_eq!(event_fields.get("event").unwrap(), "three"); |
729 | | assert_eq!(event_fields.get("retry").unwrap(), "30000"); |
730 | | assert_eq!(event_fields.get("id").unwrap(), "unique-id"); |
731 | | assert!(!event_fields.contains_key("comment")); |
732 | | |
733 | | assert!(stream.chunk_text().await.is_none()); |
734 | | } |
735 | | |
736 | | #[tokio::test(start_paused = true)] |
737 | | async fn keep_alive() { |
738 | | const DELAY: Duration = Duration::from_secs(5); |
739 | | |
740 | | let app = Router::new().route( |
741 | | "/", |
742 | | get(|| async { |
743 | | let stream = stream::repeat_with(|| Event::default().data("msg")) |
744 | | .map(Ok::<_, Infallible>) |
745 | | .throttle(DELAY); |
746 | | |
747 | | Sse::new(stream).keep_alive( |
748 | | KeepAlive::new() |
749 | | .interval(Duration::from_secs(1)) |
750 | | .text("keep-alive-text"), |
751 | | ) |
752 | | }), |
753 | | ); |
754 | | |
755 | | let client = TestClient::new(app); |
756 | | let mut stream = client.get("/").await; |
757 | | |
758 | | for _ in 0..5 { |
759 | | // first message should be an event |
760 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
761 | | assert_eq!(event_fields.get("data").unwrap(), "msg"); |
762 | | |
763 | | // then 4 seconds of keep-alive messages |
764 | | for _ in 0..4 { |
765 | | tokio::time::sleep(Duration::from_secs(1)).await; |
766 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
767 | | assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); |
768 | | } |
769 | | } |
770 | | } |
771 | | |
772 | | #[tokio::test(start_paused = true)] |
773 | | async fn keep_alive_ends_when_the_stream_ends() { |
774 | | const DELAY: Duration = Duration::from_secs(5); |
775 | | |
776 | | let app = Router::new().route( |
777 | | "/", |
778 | | get(|| async { |
779 | | let stream = stream::repeat_with(|| Event::default().data("msg")) |
780 | | .map(Ok::<_, Infallible>) |
781 | | .throttle(DELAY) |
782 | | .take(2); |
783 | | |
784 | | Sse::new(stream).keep_alive( |
785 | | KeepAlive::new() |
786 | | .interval(Duration::from_secs(1)) |
787 | | .text("keep-alive-text"), |
788 | | ) |
789 | | }), |
790 | | ); |
791 | | |
792 | | let client = TestClient::new(app); |
793 | | let mut stream = client.get("/").await; |
794 | | |
795 | | // first message should be an event |
796 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
797 | | assert_eq!(event_fields.get("data").unwrap(), "msg"); |
798 | | |
799 | | // then 4 seconds of keep-alive messages |
800 | | for _ in 0..4 { |
801 | | tokio::time::sleep(Duration::from_secs(1)).await; |
802 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
803 | | assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); |
804 | | } |
805 | | |
806 | | // then the last event |
807 | | let event_fields = parse_event(&stream.chunk_text().await.unwrap()); |
808 | | assert_eq!(event_fields.get("data").unwrap(), "msg"); |
809 | | |
810 | | // then no more events or keep-alive messages |
811 | | assert!(stream.chunk_text().await.is_none()); |
812 | | } |
813 | | |
814 | | fn parse_event(payload: &str) -> HashMap<String, String> { |
815 | | let mut fields = HashMap::new(); |
816 | | |
817 | | let mut lines = payload.lines().peekable(); |
818 | | while let Some(line) = lines.next() { |
819 | | if line.is_empty() { |
820 | | assert!(lines.next().is_none()); |
821 | | break; |
822 | | } |
823 | | |
824 | | let (mut key, value) = line.split_once(':').unwrap(); |
825 | | let value = value.trim(); |
826 | | if key.is_empty() { |
827 | | key = "comment"; |
828 | | } |
829 | | fields.insert(key.to_owned(), value.to_owned()); |
830 | | } |
831 | | |
832 | | fields |
833 | | } |
834 | | } |