1
- use std:: { collections:: HashMap , net:: SocketAddr , sync:: Arc } ;
1
+ use std:: { collections:: HashMap , io , net:: SocketAddr , sync:: Arc , time :: Duration } ;
2
2
3
3
use axum:: {
4
4
Json , Router ,
5
5
extract:: { Query , State } ,
6
6
http:: StatusCode ,
7
7
response:: {
8
8
Response ,
9
- sse:: { Event , Sse } ,
9
+ sse:: { Event , KeepAlive , Sse } ,
10
10
} ,
11
11
routing:: { get, post} ,
12
12
} ;
13
- use futures:: { Sink , SinkExt , Stream , StreamExt } ;
14
- use tokio:: io;
13
+ use futures:: { Sink , SinkExt , Stream } ;
15
14
use tokio_stream:: wrappers:: ReceiverStream ;
16
15
use tokio_util:: sync:: { CancellationToken , PollSender } ;
17
16
use tracing:: Instrument ;
@@ -26,28 +25,33 @@ type TxStore =
26
25
Arc < tokio:: sync:: RwLock < HashMap < SessionId , tokio:: sync:: mpsc:: Sender < ClientJsonRpcMessage > > > > ;
27
26
pub type TransportReceiver = ReceiverStream < RxJsonRpcMessage < RoleServer > > ;
28
27
28
+ const DEFAULT_AUTO_PING_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
29
+
29
30
#[ derive( Clone ) ]
30
31
struct App {
31
32
txs : TxStore ,
32
33
transport_tx : tokio:: sync:: mpsc:: UnboundedSender < SseServerTransport > ,
33
34
post_path : Arc < str > ,
35
+ sse_ping_interval : Duration ,
34
36
}
35
37
36
38
impl App {
37
39
pub fn new (
38
40
post_path : String ,
41
+ sse_ping_interval : Duration ,
39
42
) -> (
40
43
Self ,
41
44
tokio:: sync:: mpsc:: UnboundedReceiver < SseServerTransport > ,
42
45
) {
43
- let ( transport_tx, tranport_rx ) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
46
+ let ( transport_tx, transport_rx ) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
44
47
(
45
48
Self {
46
49
txs : Default :: default ( ) ,
47
50
transport_tx,
48
51
post_path : post_path. into ( ) ,
52
+ sse_ping_interval,
49
53
} ,
50
- tranport_rx ,
54
+ transport_rx ,
51
55
)
52
56
}
53
57
}
@@ -87,7 +91,7 @@ async fn sse_handler(
87
91
) -> Result < Sse < impl Stream < Item = Result < Event , io:: Error > > > , Response < String > > {
88
92
let session = session_id ( ) ;
89
93
tracing:: info!( %session, "sse connection" ) ;
90
- use tokio_stream:: wrappers:: ReceiverStream ;
94
+ use tokio_stream:: { StreamExt , wrappers:: ReceiverStream } ;
91
95
use tokio_util:: sync:: PollSender ;
92
96
let ( from_client_tx, from_client_rx) = tokio:: sync:: mpsc:: channel ( 64 ) ;
93
97
let ( to_client_tx, to_client_rx) = tokio:: sync:: mpsc:: channel ( 64 ) ;
@@ -108,11 +112,12 @@ async fn sse_handler(
108
112
if transport_send_result. is_err ( ) {
109
113
tracing:: warn!( "send transport out error" ) ;
110
114
let mut response =
111
- Response :: new ( "fail to send out trasnport , it seems server is closed" . to_string ( ) ) ;
115
+ Response :: new ( "fail to send out transport , it seems server is closed" . to_string ( ) ) ;
112
116
* response. status_mut ( ) = StatusCode :: INTERNAL_SERVER_ERROR ;
113
117
return Err ( response) ;
114
118
}
115
119
let post_path = app. post_path . as_ref ( ) ;
120
+ let ping_interval = app. sse_ping_interval ;
116
121
let stream = futures:: stream:: once ( futures:: future:: ok (
117
122
Event :: default ( )
118
123
. event ( "endpoint" )
@@ -124,7 +129,7 @@ async fn sse_handler(
124
129
Err ( e) => Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ,
125
130
}
126
131
} ) ) ;
127
- Ok ( Sse :: new ( stream) )
132
+ Ok ( Sse :: new ( stream) . keep_alive ( KeepAlive :: new ( ) . interval ( ping_interval ) ) )
128
133
}
129
134
130
135
pub struct SseServerTransport {
@@ -190,6 +195,7 @@ impl Stream for SseServerTransport {
190
195
mut self : std:: pin:: Pin < & mut Self > ,
191
196
cx : & mut std:: task:: Context < ' _ > ,
192
197
) -> std:: task:: Poll < Option < Self :: Item > > {
198
+ use futures:: StreamExt ;
193
199
self . stream . poll_next_unpin ( cx)
194
200
}
195
201
}
@@ -200,6 +206,7 @@ pub struct SseServerConfig {
200
206
pub sse_path : String ,
201
207
pub post_path : String ,
202
208
pub ct : CancellationToken ,
209
+ pub sse_keep_alive : Option < Duration > ,
203
210
}
204
211
205
212
#[ derive( Debug ) ]
@@ -215,6 +222,7 @@ impl SseServer {
215
222
sse_path : "/sse" . to_string ( ) ,
216
223
post_path : "/message" . to_string ( ) ,
217
224
ct : CancellationToken :: new ( ) ,
225
+ sse_keep_alive : None ,
218
226
} )
219
227
. await
220
228
}
@@ -240,7 +248,10 @@ impl SseServer {
240
248
/// Warning: This function creates a new SseServer instance with the provided configuration.
241
249
/// `App.post_path` may be incorrect if using `Router` as an embedded router.
242
250
pub fn new ( config : SseServerConfig ) -> ( SseServer , Router ) {
243
- let ( app, transport_rx) = App :: new ( config. post_path . clone ( ) ) ;
251
+ let ( app, transport_rx) = App :: new (
252
+ config. post_path . clone ( ) ,
253
+ config. sse_keep_alive . unwrap_or ( DEFAULT_AUTO_PING_INTERVAL ) ,
254
+ ) ;
244
255
let router = Router :: new ( )
245
256
. route ( & config. sse_path , get ( sse_handler) )
246
257
. route ( & config. post_path , post ( post_event_handler) )
0 commit comments