Skip to content

Commit 923e87a

Browse files
authored
feat: Sse server auto ping (#74)
1. auto ping in sse stream every second to make cursor happy 2. configurable sse keep alive --------- Co-authored-by: = <=>
1 parent 588a013 commit 923e87a

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

crates/rmcp/src/transport/sse_server.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
1+
use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
22

33
use axum::{
44
Json, Router,
55
extract::{Query, State},
66
http::StatusCode,
77
response::{
88
Response,
9-
sse::{Event, Sse},
9+
sse::{Event, KeepAlive, Sse},
1010
},
1111
routing::{get, post},
1212
};
13-
use futures::{Sink, SinkExt, Stream, StreamExt};
14-
use tokio::io;
13+
use futures::{Sink, SinkExt, Stream};
1514
use tokio_stream::wrappers::ReceiverStream;
1615
use tokio_util::sync::{CancellationToken, PollSender};
1716
use tracing::Instrument;
@@ -26,28 +25,33 @@ type TxStore =
2625
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::mpsc::Sender<ClientJsonRpcMessage>>>>;
2726
pub type TransportReceiver = ReceiverStream<RxJsonRpcMessage<RoleServer>>;
2827

28+
const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15);
29+
2930
#[derive(Clone)]
3031
struct App {
3132
txs: TxStore,
3233
transport_tx: tokio::sync::mpsc::UnboundedSender<SseServerTransport>,
3334
post_path: Arc<str>,
35+
sse_ping_interval: Duration,
3436
}
3537

3638
impl App {
3739
pub fn new(
3840
post_path: String,
41+
sse_ping_interval: Duration,
3942
) -> (
4043
Self,
4144
tokio::sync::mpsc::UnboundedReceiver<SseServerTransport>,
4245
) {
43-
let (transport_tx, tranport_rx) = tokio::sync::mpsc::unbounded_channel();
46+
let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel();
4447
(
4548
Self {
4649
txs: Default::default(),
4750
transport_tx,
4851
post_path: post_path.into(),
52+
sse_ping_interval,
4953
},
50-
tranport_rx,
54+
transport_rx,
5155
)
5256
}
5357
}
@@ -87,7 +91,7 @@ async fn sse_handler(
8791
) -> Result<Sse<impl Stream<Item = Result<Event, io::Error>>>, Response<String>> {
8892
let session = session_id();
8993
tracing::info!(%session, "sse connection");
90-
use tokio_stream::wrappers::ReceiverStream;
94+
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
9195
use tokio_util::sync::PollSender;
9296
let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64);
9397
let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64);
@@ -108,11 +112,12 @@ async fn sse_handler(
108112
if transport_send_result.is_err() {
109113
tracing::warn!("send transport out error");
110114
let mut response =
111-
Response::new("fail to send out trasnport, it seems server is closed".to_string());
115+
Response::new("fail to send out transport, it seems server is closed".to_string());
112116
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
113117
return Err(response);
114118
}
115119
let post_path = app.post_path.as_ref();
120+
let ping_interval = app.sse_ping_interval;
116121
let stream = futures::stream::once(futures::future::ok(
117122
Event::default()
118123
.event("endpoint")
@@ -124,7 +129,7 @@ async fn sse_handler(
124129
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)),
125130
}
126131
}));
127-
Ok(Sse::new(stream))
132+
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(ping_interval)))
128133
}
129134

130135
pub struct SseServerTransport {
@@ -190,6 +195,7 @@ impl Stream for SseServerTransport {
190195
mut self: std::pin::Pin<&mut Self>,
191196
cx: &mut std::task::Context<'_>,
192197
) -> std::task::Poll<Option<Self::Item>> {
198+
use futures::StreamExt;
193199
self.stream.poll_next_unpin(cx)
194200
}
195201
}
@@ -200,6 +206,7 @@ pub struct SseServerConfig {
200206
pub sse_path: String,
201207
pub post_path: String,
202208
pub ct: CancellationToken,
209+
pub sse_keep_alive: Option<Duration>,
203210
}
204211

205212
#[derive(Debug)]
@@ -215,6 +222,7 @@ impl SseServer {
215222
sse_path: "/sse".to_string(),
216223
post_path: "/message".to_string(),
217224
ct: CancellationToken::new(),
225+
sse_keep_alive: None,
218226
})
219227
.await
220228
}
@@ -240,7 +248,10 @@ impl SseServer {
240248
/// Warning: This function creates a new SseServer instance with the provided configuration.
241249
/// `App.post_path` may be incorrect if using `Router` as an embedded router.
242250
pub fn new(config: SseServerConfig) -> (SseServer, Router) {
243-
let (app, transport_rx) = App::new(config.post_path.clone());
251+
let (app, transport_rx) = App::new(
252+
config.post_path.clone(),
253+
config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL),
254+
);
244255
let router = Router::new()
245256
.route(&config.sse_path, get(sse_handler))
246257
.route(&config.post_path, post(post_event_handler))

examples/servers/src/axum_router.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ async fn main() -> anyhow::Result<()> {
2424
sse_path: "/sse".to_string(),
2525
post_path: "/message".to_string(),
2626
ct: tokio_util::sync::CancellationToken::new(),
27+
sse_keep_alive: None,
2728
};
2829

2930
let (sse_server, router) = SseServer::new(config);

0 commit comments

Comments
 (0)