Coverage Report

Created: 2026-03-31 07:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/axum-0.8.8/src/response/sse.rs
Line
Count
Source
1
//! Server-Sent Events (SSE) responses.
2
//!
3
//! # Example
4
//!
5
//! ```
6
//! use axum::{
7
//!     Router,
8
//!     routing::get,
9
//!     response::sse::{Event, KeepAlive, Sse},
10
//! };
11
//! use std::{time::Duration, convert::Infallible};
12
//! use tokio_stream::StreamExt as _ ;
13
//! use futures_util::stream::{self, Stream};
14
//!
15
//! let app = Router::new().route("/sse", get(sse_handler));
16
//!
17
//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18
//!     // A `Stream` that repeats an event every second
19
//!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20
//!         .map(Ok)
21
//!         .throttle(Duration::from_secs(1));
22
//!
23
//!     Sse::new(stream).keep_alive(KeepAlive::default())
24
//! }
25
//! # let _: Router = app;
26
//! ```
27
28
use crate::{
29
    body::{Bytes, HttpBody},
30
    BoxError,
31
};
32
use axum_core::{
33
    body::Body,
34
    response::{IntoResponse, Response},
35
};
36
use bytes::{BufMut, BytesMut};
37
use futures_util::stream::{Stream, TryStream};
38
use http_body::Frame;
39
use pin_project_lite::pin_project;
40
use std::{
41
    fmt::{self, Write as _},
42
    io::Write as _,
43
    mem,
44
    pin::Pin,
45
    task::{ready, Context, Poll},
46
    time::Duration,
47
};
48
use sync_wrapper::SyncWrapper;
49
50
/// An SSE response
51
#[derive(Clone)]
52
#[must_use]
53
pub struct Sse<S> {
54
    stream: S,
55
}
56
57
impl<S> Sse<S> {
58
    /// Create a new [`Sse`] response that will respond with the given stream of
59
    /// [`Event`]s.
60
    ///
61
    /// See the [module docs](self) for more details.
62
0
    pub fn new(stream: S) -> Self
63
0
    where
64
0
        S: TryStream<Ok = Event> + Send + 'static,
65
0
        S::Error: Into<BoxError>,
66
    {
67
0
        Sse { stream }
68
0
    }
69
70
    /// Configure the interval between keep-alive messages.
71
    ///
72
    /// Defaults to no keep-alive messages.
73
    #[cfg(feature = "tokio")]
74
    pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
75
        Sse {
76
            stream: KeepAliveStream::new(keep_alive, self.stream),
77
        }
78
    }
79
}
80
81
impl<S> fmt::Debug for Sse<S> {
82
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83
0
        f.debug_struct("Sse")
84
0
            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
85
0
            .finish()
86
0
    }
87
}
88
89
impl<S, E> IntoResponse for Sse<S>
90
where
91
    S: Stream<Item = Result<Event, E>> + Send + 'static,
92
    E: Into<BoxError>,
93
{
94
0
    fn into_response(self) -> Response {
95
0
        (
96
0
            [
97
0
                (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
98
0
                (http::header::CACHE_CONTROL, "no-cache"),
99
0
            ],
100
0
            Body::new(SseBody {
101
0
                event_stream: SyncWrapper::new(self.stream),
102
0
            }),
103
0
        )
104
0
            .into_response()
105
0
    }
106
}
107
108
pin_project! {
109
    struct SseBody<S> {
110
        #[pin]
111
        event_stream: SyncWrapper<S>,
112
    }
113
}
114
115
impl<S, E> HttpBody for SseBody<S>
116
where
117
    S: Stream<Item = Result<Event, E>>,
118
{
119
    type Data = Bytes;
120
    type Error = E;
121
122
0
    fn poll_frame(
123
0
        self: Pin<&mut Self>,
124
0
        cx: &mut Context<'_>,
125
0
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
126
0
        let this = self.project();
127
128
0
        match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
129
0
            Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
130
0
            Some(Err(error)) => Poll::Ready(Some(Err(error))),
131
0
            None => Poll::Ready(None),
132
        }
133
0
    }
134
}
135
136
/// The state of an event's buffer.
137
///
138
/// This type allows creating events in a `const` context
139
/// by using a finalized buffer.
140
///
141
/// While the buffer is active, more bytes can be written to it.
142
/// Once finalized, it's immutable and cheap to clone.
143
/// The buffer is active during the event building, but eventually
144
/// becomes finalized to send http body frames as [`Bytes`].
145
#[derive(Debug, Clone)]
146
enum Buffer {
147
    Active(BytesMut),
148
    Finalized(Bytes),
149
}
150
151
impl Buffer {
152
    /// Returns a mutable reference to the internal buffer.
153
    ///
154
    /// If the buffer was finalized, this method creates
155
    /// a new active buffer with the previous contents.
156
0
    fn as_mut(&mut self) -> &mut BytesMut {
157
0
        match self {
158
0
            Buffer::Active(bytes_mut) => bytes_mut,
159
0
            Buffer::Finalized(bytes) => {
160
0
                *self = Buffer::Active(BytesMut::from(mem::take(bytes)));
161
0
                match self {
162
0
                    Buffer::Active(bytes_mut) => bytes_mut,
163
0
                    Buffer::Finalized(_) => unreachable!(),
164
                }
165
            }
166
        }
167
0
    }
168
}
169
170
/// Server-sent event
171
#[derive(Debug, Clone)]
172
#[must_use]
173
pub struct Event {
174
    buffer: Buffer,
175
    flags: EventFlags,
176
}
177
178
/// Expose [`Event`] as a [`std::fmt::Write`]
179
/// such that any form of data can be written as data safely.
180
///
181
/// This also ensures that newline characters `\r` and `\n`
182
/// correctly trigger a split with a new `data: ` prefix.
183
///
184
/// # Panics
185
///
186
/// Panics if any `data` has already been written prior to the first write
187
/// of this [`EventDataWriter`] instance.
188
#[derive(Debug)]
189
#[must_use]
190
pub struct EventDataWriter {
191
    event: Event,
192
193
    // Indicates if _this_ EventDataWriter has written data,
194
    // this does not say anything about whether or not `event` contains
195
    // data or not.
196
    data_written: bool,
197
}
198
199
impl Event {
200
    /// Default keep-alive event
201
    pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
202
203
0
    const fn finalized(bytes: Bytes) -> Self {
204
0
        Self {
205
0
            buffer: Buffer::Finalized(bytes),
206
0
            flags: EventFlags::from_bits(0),
207
0
        }
208
0
    }
209
210
    /// Use this [`Event`] as a [`EventDataWriter`] to write custom data.
211
    ///
212
    /// - [`Self::data`] can be used as a shortcut to write `str` data
213
    /// - [`Self::json_data`] can be used as a shortcut to write `json` data
214
    ///
215
    /// Turn it into an [`Event`] again using [`EventDataWriter::into_event`].
216
0
    pub fn into_data_writer(self) -> EventDataWriter {
217
0
        EventDataWriter {
218
0
            event: self,
219
0
            data_written: false,
220
0
        }
221
0
    }
222
223
    /// Set the event's data data field(s) (`data: <content>`)
224
    ///
225
    /// Newlines in `data` will automatically be broken across `data: ` fields.
226
    ///
227
    /// This corresponds to [`MessageEvent`'s data field].
228
    ///
229
    /// Note that events with an empty data field will be ignored by the browser.
230
    ///
231
    /// # Panics
232
    ///
233
    /// Panics if any `data` has already been written before.
234
    ///
235
    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
236
0
    pub fn data<T>(self, data: T) -> Self
237
0
    where
238
0
        T: AsRef<str>,
239
    {
240
0
        let mut writer = self.into_data_writer();
241
0
        let _ = writer.write_str(data.as_ref());
242
0
        writer.into_event()
243
0
    }
244
245
    /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
246
    ///
247
    /// This corresponds to [`MessageEvent`'s data field].
248
    ///
249
    /// # Panics
250
    ///
251
    /// Panics if any `data` has already been written before.
252
    ///
253
    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
254
    #[cfg(feature = "json")]
255
    pub fn json_data<T>(self, data: T) -> Result<Self, axum_core::Error>
256
    where
257
        T: serde_core::Serialize,
258
    {
259
        struct JsonWriter<'a>(&'a mut EventDataWriter);
260
        impl std::io::Write for JsonWriter<'_> {
261
            #[inline]
262
            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
263
                Ok(self.0.write_buf(buf))
264
            }
265
            fn flush(&mut self) -> std::io::Result<()> {
266
                Ok(())
267
            }
268
        }
269
270
        let mut writer = self.into_data_writer();
271
272
        let json_writer = JsonWriter(&mut writer);
273
        serde_json::to_writer(json_writer, &data).map_err(axum_core::Error::new)?;
274
275
        Ok(writer.into_event())
276
    }
277
278
    /// Set the event's comment field (`:<comment-text>`).
279
    ///
280
    /// This field will be ignored by most SSE clients.
281
    ///
282
    /// Unlike other functions, this function can be called multiple times to add many comments.
283
    ///
284
    /// # Panics
285
    ///
286
    /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
287
    /// comments.
288
0
    pub fn comment<T>(mut self, comment: T) -> Event
289
0
    where
290
0
        T: AsRef<str>,
291
    {
292
0
        self.field("", comment.as_ref());
293
0
        self
294
0
    }
295
296
    /// Set the event's name field (`event:<event-name>`).
297
    ///
298
    /// This corresponds to the `type` parameter given when calling `addEventListener` on an
299
    /// [`EventSource`]. For example, `.event("update")` should correspond to
300
    /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
301
    /// [`message` event] instead.
302
    ///
303
    /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
304
    /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
305
    ///
306
    /// # Panics
307
    ///
308
    /// - Panics if `event` contains any newlines or carriage returns.
309
    /// - Panics if this function has already been called on this event.
310
0
    pub fn event<T>(mut self, event: T) -> Event
311
0
    where
312
0
        T: AsRef<str>,
313
    {
314
0
        if self.flags.contains(EventFlags::HAS_EVENT) {
315
0
            panic!("Called `Event::event` multiple times");
316
0
        }
317
0
        self.flags.insert(EventFlags::HAS_EVENT);
318
319
0
        self.field("event", event.as_ref());
320
321
0
        self
322
0
    }
323
324
    /// Set the event's retry timeout field (`retry: <timeout>`).
325
    ///
326
    /// This sets how long clients will wait before reconnecting if they are disconnected from the
327
    /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
328
    /// wish, such as if they implement exponential backoff.
329
    ///
330
    /// # Panics
331
    ///
332
    /// Panics if this function has already been called on this event.
333
0
    pub fn retry(mut self, duration: Duration) -> Event {
334
0
        if self.flags.contains(EventFlags::HAS_RETRY) {
335
0
            panic!("Called `Event::retry` multiple times");
336
0
        }
337
0
        self.flags.insert(EventFlags::HAS_RETRY);
338
339
0
        let buffer = self.buffer.as_mut();
340
0
        buffer.extend_from_slice(b"retry: ");
341
342
0
        let secs = duration.as_secs();
343
0
        let millis = duration.subsec_millis();
344
345
0
        if secs > 0 {
346
            // format seconds
347
0
            buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
348
349
            // pad milliseconds
350
0
            if millis < 10 {
351
0
                buffer.extend_from_slice(b"00");
352
0
            } else if millis < 100 {
353
0
                buffer.extend_from_slice(b"0");
354
0
            }
355
0
        }
356
357
        // format milliseconds
358
0
        buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
359
360
0
        buffer.put_u8(b'\n');
361
362
0
        self
363
0
    }
364
365
    /// Set the event's identifier field (`id:<identifier>`).
366
    ///
367
    /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
368
    /// the browser will set that field to the last known message ID, starting with the empty
369
    /// string.
370
    ///
371
    /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
372
    ///
373
    /// # Panics
374
    ///
375
    /// - Panics if `id` contains any newlines, carriage returns or null characters.
376
    /// - Panics if this function has already been called on this event.
377
0
    pub fn id<T>(mut self, id: T) -> Event
378
0
    where
379
0
        T: AsRef<str>,
380
    {
381
0
        if self.flags.contains(EventFlags::HAS_ID) {
382
0
            panic!("Called `Event::id` multiple times");
383
0
        }
384
0
        self.flags.insert(EventFlags::HAS_ID);
385
386
0
        let id = id.as_ref().as_bytes();
387
0
        assert_eq!(
388
0
            memchr::memchr(b'\0', id),
389
            None,
390
            "Event ID cannot contain null characters",
391
        );
392
393
0
        self.field("id", id);
394
0
        self
395
0
    }
396
397
0
    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
398
0
        let value = value.as_ref();
399
0
        assert_eq!(
400
0
            memchr::memchr2(b'\r', b'\n', value),
401
            None,
402
            "SSE field value cannot contain newlines or carriage returns",
403
        );
404
405
0
        let buffer = self.buffer.as_mut();
406
0
        buffer.extend_from_slice(name.as_bytes());
407
0
        buffer.put_u8(b':');
408
0
        buffer.put_u8(b' ');
409
0
        buffer.extend_from_slice(value);
410
0
        buffer.put_u8(b'\n');
411
0
    }
412
413
0
    fn finalize(self) -> Bytes {
414
0
        match self.buffer {
415
0
            Buffer::Finalized(bytes) => bytes,
416
0
            Buffer::Active(mut bytes_mut) => {
417
0
                bytes_mut.put_u8(b'\n');
418
0
                bytes_mut.freeze()
419
            }
420
        }
421
0
    }
422
}
423
424
impl EventDataWriter {
425
    /// Consume the [`EventDataWriter`] and return the [`Event`] once again.
426
    ///
427
    /// In case any data was written by this instance
428
    /// it will also write the trailing `\n` character.
429
0
    pub fn into_event(self) -> Event {
430
0
        let mut event = self.event;
431
0
        if self.data_written {
432
0
            let _ = event.buffer.as_mut().write_char('\n');
433
0
        }
434
0
        event
435
0
    }
436
}
437
438
impl EventDataWriter {
439
    // Assumption: underlying writer never returns an error:
440
    // <https://docs.rs/bytes/latest/src/bytes/buf/writer.rs.html#79-82>
441
0
    fn write_buf(&mut self, buf: &[u8]) -> usize {
442
0
        if buf.is_empty() {
443
0
            return 0;
444
0
        }
445
446
0
        let buffer = self.event.buffer.as_mut();
447
448
0
        if !std::mem::replace(&mut self.data_written, true) {
449
0
            if self.event.flags.contains(EventFlags::HAS_DATA) {
450
0
                panic!("Called `Event::data*` multiple times");
451
0
            }
452
453
0
            let _ = buffer.write_str("data: ");
454
0
            self.event.flags.insert(EventFlags::HAS_DATA);
455
0
        }
456
457
0
        let mut writer = buffer.writer();
458
459
0
        let mut last_split = 0;
460
0
        for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
461
0
            let _ = writer.write_all(&buf[last_split..=delimiter]);
462
0
            let _ = writer.write_all(b"data: ");
463
0
            last_split = delimiter + 1;
464
0
        }
465
0
        let _ = writer.write_all(&buf[last_split..]);
466
467
0
        buf.len()
468
0
    }
469
}
470
471
impl fmt::Write for EventDataWriter {
472
0
    fn write_str(&mut self, s: &str) -> fmt::Result {
473
0
        let _ = self.write_buf(s.as_bytes());
474
0
        Ok(())
475
0
    }
476
}
477
478
impl Default for Event {
479
0
    fn default() -> Self {
480
0
        Self {
481
0
            buffer: Buffer::Active(BytesMut::new()),
482
0
            flags: EventFlags::from_bits(0),
483
0
        }
484
0
    }
485
}
486
487
#[derive(Debug, Copy, Clone, PartialEq)]
488
struct EventFlags(u8);
489
490
impl EventFlags {
491
    const HAS_DATA: Self = Self::from_bits(0b0001);
492
    const HAS_EVENT: Self = Self::from_bits(0b0010);
493
    const HAS_RETRY: Self = Self::from_bits(0b0100);
494
    const HAS_ID: Self = Self::from_bits(0b1000);
495
496
0
    const fn bits(&self) -> u8 {
497
0
        self.0
498
0
    }
499
500
0
    const fn from_bits(bits: u8) -> Self {
501
0
        Self(bits)
502
0
    }
503
504
0
    const fn contains(&self, other: Self) -> bool {
505
0
        self.bits() & other.bits() == other.bits()
506
0
    }
507
508
0
    fn insert(&mut self, other: Self) {
509
0
        *self = Self::from_bits(self.bits() | other.bits());
510
0
    }
511
}
512
513
/// Configure the interval between keep-alive messages, the content
514
/// of each message, and the associated stream.
515
#[derive(Debug, Clone)]
516
#[must_use]
517
pub struct KeepAlive {
518
    event: Event,
519
    max_interval: Duration,
520
}
521
522
impl KeepAlive {
523
    /// Create a new `KeepAlive`.
524
0
    pub fn new() -> Self {
525
0
        Self {
526
0
            event: Event::DEFAULT_KEEP_ALIVE,
527
0
            max_interval: Duration::from_secs(15),
528
0
        }
529
0
    }
530
531
    /// Customize the interval between keep-alive messages.
532
    ///
533
    /// Default is 15 seconds.
534
0
    pub fn interval(mut self, time: Duration) -> Self {
535
0
        self.max_interval = time;
536
0
        self
537
0
    }
538
539
    /// Customize the text of the keep-alive message.
540
    ///
541
    /// Default is an empty comment.
542
    ///
543
    /// # Panics
544
    ///
545
    /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
546
    /// comments.
547
0
    pub fn text<I>(self, text: I) -> Self
548
0
    where
549
0
        I: AsRef<str>,
550
    {
551
0
        self.event(Event::default().comment(text))
552
0
    }
553
554
    /// Customize the event of the keep-alive message.
555
    ///
556
    /// Default is an empty comment.
557
    ///
558
    /// # Panics
559
    ///
560
    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
561
    /// comments.
562
0
    pub fn event(mut self, event: Event) -> Self {
563
0
        self.event = Event::finalized(event.finalize());
564
0
        self
565
0
    }
566
}
567
568
impl Default for KeepAlive {
569
0
    fn default() -> Self {
570
0
        Self::new()
571
0
    }
572
}
573
574
#[cfg(feature = "tokio")]
575
pin_project! {
576
    /// A wrapper around a stream that produces keep-alive events
577
    #[derive(Debug)]
578
    pub struct KeepAliveStream<S> {
579
        #[pin]
580
        alive_timer: tokio::time::Sleep,
581
        #[pin]
582
        inner: S,
583
        keep_alive: KeepAlive,
584
    }
585
}
586
587
#[cfg(feature = "tokio")]
588
impl<S> KeepAliveStream<S> {
589
    fn new(keep_alive: KeepAlive, inner: S) -> Self {
590
        Self {
591
            alive_timer: tokio::time::sleep(keep_alive.max_interval),
592
            inner,
593
            keep_alive,
594
        }
595
    }
596
597
    fn reset(self: Pin<&mut Self>) {
598
        let this = self.project();
599
        this.alive_timer
600
            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
601
    }
602
}
603
604
#[cfg(feature = "tokio")]
605
impl<S, E> Stream for KeepAliveStream<S>
606
where
607
    S: Stream<Item = Result<Event, E>>,
608
{
609
    type Item = Result<Event, E>;
610
611
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
612
        use std::future::Future;
613
614
        let mut this = self.as_mut().project();
615
616
        match this.inner.as_mut().poll_next(cx) {
617
            Poll::Ready(Some(Ok(event))) => {
618
                self.reset();
619
620
                Poll::Ready(Some(Ok(event)))
621
            }
622
            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
623
            Poll::Ready(None) => Poll::Ready(None),
624
            Poll::Pending => {
625
                ready!(this.alive_timer.poll(cx));
626
627
                let event = this.keep_alive.event.clone();
628
629
                self.reset();
630
631
                Poll::Ready(Some(Ok(event)))
632
            }
633
        }
634
    }
635
}
636
637
#[cfg(test)]
638
mod tests {
639
    use super::*;
640
    use crate::{routing::get, test_helpers::*, Router};
641
    use futures_util::stream;
642
    use serde_json::value::RawValue;
643
    use std::{collections::HashMap, convert::Infallible};
644
    use tokio_stream::StreamExt as _;
645
646
    #[test]
647
    fn leading_space_is_not_stripped() {
648
        let no_leading_space = Event::default().data("\tfoobar");
649
        assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
650
651
        let leading_space = Event::default().data(" foobar");
652
        assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
653
    }
654
655
    #[test]
656
    fn write_data_writer_str() {
657
        // also confirm that nop writers do nothing :)
658
        let mut writer = Event::default()
659
            .into_data_writer()
660
            .into_event()
661
            .into_data_writer();
662
        writer.write_str("").unwrap();
663
        let mut writer = writer.into_event().into_data_writer();
664
665
        writer.write_str("").unwrap();
666
        writer.write_str("moon ").unwrap();
667
        writer.write_str("star\nsun").unwrap();
668
        writer.write_str("").unwrap();
669
        writer.write_str("set").unwrap();
670
        writer.write_str("").unwrap();
671
        writer.write_str(" bye\r").unwrap();
672
673
        let event = writer.into_event();
674
675
        assert_eq!(
676
            &*event.finalize(),
677
            b"data: moon star\ndata: sunset bye\rdata: \n\n"
678
        );
679
    }
680
681
    #[test]
682
    fn valid_json_raw_value_chars_handled() {
683
        let json_string = "{\r\"foo\":  \n\r\r   \"bar\\n\"\n}";
684
        let json_raw_value_event = Event::default()
685
            .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
686
            .unwrap();
687
        assert_eq!(
688
            &*json_raw_value_event.finalize(),
689
            b"data: {\rdata: \"foo\":  \ndata: \rdata: \rdata:    \"bar\\n\"\ndata: }\n\n"
690
        );
691
    }
692
693
    #[crate::test]
694
    async fn basic() {
695
        let app = Router::new().route(
696
            "/",
697
            get(|| async {
698
                let stream = stream::iter(vec![
699
                    Event::default().data("one").comment("this is a comment"),
700
                    Event::default()
701
                        .json_data(serde_json::json!({ "foo": "bar" }))
702
                        .unwrap(),
703
                    Event::default()
704
                        .event("three")
705
                        .retry(Duration::from_secs(30))
706
                        .id("unique-id"),
707
                ])
708
                .map(Ok::<_, Infallible>);
709
                Sse::new(stream)
710
            }),
711
        );
712
713
        let client = TestClient::new(app);
714
        let mut stream = client.get("/").await;
715
716
        assert_eq!(stream.headers()["content-type"], "text/event-stream");
717
        assert_eq!(stream.headers()["cache-control"], "no-cache");
718
719
        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
720
        assert_eq!(event_fields.get("data").unwrap(), "one");
721
        assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
722
723
        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
724
        assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
725
        assert!(!event_fields.contains_key("comment"));
726
727
        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
728
        assert_eq!(event_fields.get("event").unwrap(), "three");
729
        assert_eq!(event_fields.get("retry").unwrap(), "30000");
730
        assert_eq!(event_fields.get("id").unwrap(), "unique-id");
731
        assert!(!event_fields.contains_key("comment"));
732
733
        assert!(stream.chunk_text().await.is_none());
734
    }
735
736
    #[tokio::test(start_paused = true)]
737
    async fn keep_alive() {
738
        const DELAY: Duration = Duration::from_secs(5);
739
740
        let app = Router::new().route(
741
            "/",
742
            get(|| async {
743
                let stream = stream::repeat_with(|| Event::default().data("msg"))
744
                    .map(Ok::<_, Infallible>)
745
                    .throttle(DELAY);
746
747
                Sse::new(stream).keep_alive(
748
                    KeepAlive::new()
749
                        .interval(Duration::from_secs(1))
750
                        .text("keep-alive-text"),
751
                )
752
            }),
753
        );
754
755
        let client = TestClient::new(app);
756
        let mut stream = client.get("/").await;
757
758
        for _ in 0..5 {
759
            // first message should be an event
760
            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
761
            assert_eq!(event_fields.get("data").unwrap(), "msg");
762
763
            // then 4 seconds of keep-alive messages
764
            for _ in 0..4 {
765
                tokio::time::sleep(Duration::from_secs(1)).await;
766
                let event_fields = parse_event(&stream.chunk_text().await.unwrap());
767
                assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
768
            }
769
        }
770
    }
771
772
    #[tokio::test(start_paused = true)]
773
    async fn keep_alive_ends_when_the_stream_ends() {
774
        const DELAY: Duration = Duration::from_secs(5);
775
776
        let app = Router::new().route(
777
            "/",
778
            get(|| async {
779
                let stream = stream::repeat_with(|| Event::default().data("msg"))
780
                    .map(Ok::<_, Infallible>)
781
                    .throttle(DELAY)
782
                    .take(2);
783
784
                Sse::new(stream).keep_alive(
785
                    KeepAlive::new()
786
                        .interval(Duration::from_secs(1))
787
                        .text("keep-alive-text"),
788
                )
789
            }),
790
        );
791
792
        let client = TestClient::new(app);
793
        let mut stream = client.get("/").await;
794
795
        // first message should be an event
796
        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
797
        assert_eq!(event_fields.get("data").unwrap(), "msg");
798
799
        // then 4 seconds of keep-alive messages
800
        for _ in 0..4 {
801
            tokio::time::sleep(Duration::from_secs(1)).await;
802
            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
803
            assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
804
        }
805
806
        // then the last event
807
        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
808
        assert_eq!(event_fields.get("data").unwrap(), "msg");
809
810
        // then no more events or keep-alive messages
811
        assert!(stream.chunk_text().await.is_none());
812
    }
813
814
    fn parse_event(payload: &str) -> HashMap<String, String> {
815
        let mut fields = HashMap::new();
816
817
        let mut lines = payload.lines().peekable();
818
        while let Some(line) = lines.next() {
819
            if line.is_empty() {
820
                assert!(lines.next().is_none());
821
                break;
822
            }
823
824
            let (mut key, value) = line.split_once(':').unwrap();
825
            let value = value.trim();
826
            if key.is_empty() {
827
                key = "comment";
828
            }
829
            fields.insert(key.to_owned(), value.to_owned());
830
        }
831
832
        fields
833
    }
834
}