1- use std:: time:: Duration ;
1+ use std:: { sync :: Arc , time:: Duration } ;
22
33use crate :: {
44 error:: { ReceiveError , SendError , TypedReceiveError } ,
55 message:: { IncomingMessage , OutgoingMessage , PayloadTypeName , TypedIncomingMessage } ,
6- traits:: { CommunicationBackend , CryptoProvider , SessionRepository } ,
6+ traits:: {
7+ CommunicationBackend , CommunicationBackendReceiver , CryptoProvider , SessionRepository ,
8+ } ,
79} ;
810
911pub struct IpcClient < Crypto , Com , Ses >
@@ -17,45 +19,111 @@ where
1719 sessions : Ses ,
1820}
1921
22+ /// A subscription to receive messages over IPC.
23+ /// The subcription will start buffering messages after its creation and return them
24+ /// when receive() is called. Messages received before the subscription was created will not be
25+ /// returned.
26+ pub struct IpcClientSubscription < Crypto , Com , Ses >
27+ where
28+ Crypto : CryptoProvider < Com , Ses > ,
29+ Com : CommunicationBackend ,
30+ Ses : SessionRepository < Session = Crypto :: Session > ,
31+ {
32+ receiver : Com :: Receiver ,
33+ client : Arc < IpcClient < Crypto , Com , Ses > > ,
34+ topic : Option < String > ,
35+ }
36+
37+ /// A subscription to receive messages over IPC.
38+ /// The subcription will start buffering messages after its creation and return them
39+ /// when receive() is called. Messages received before the subscription was created will not be
40+ /// returned.
41+ pub struct IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
42+ where
43+ Crypto : CryptoProvider < Com , Ses > ,
44+ Com : CommunicationBackend ,
45+ Ses : SessionRepository < Session = Crypto :: Session > ,
46+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
47+ {
48+ receiver : Com :: Receiver ,
49+ client : Arc < IpcClient < Crypto , Com , Ses > > ,
50+ _payload : std:: marker:: PhantomData < Payload > ,
51+ }
52+
2053impl < Crypto , Com , Ses > IpcClient < Crypto , Com , Ses >
2154where
2255 Crypto : CryptoProvider < Com , Ses > ,
2356 Com : CommunicationBackend ,
2457 Ses : SessionRepository < Session = Crypto :: Session > ,
2558{
26- pub fn new ( crypto : Crypto , communication : Com , sessions : Ses ) -> Self {
27- Self {
59+ pub fn new ( crypto : Crypto , communication : Com , sessions : Ses ) -> Arc < Self > {
60+ Arc :: new ( Self {
2861 crypto,
2962 communication,
3063 sessions,
31- }
64+ } )
3265 }
3366
3467 /// Send a message
3568 pub async fn send (
36- & self ,
69+ self : & Arc < Self > ,
3770 message : OutgoingMessage ,
3871 ) -> Result < ( ) , SendError < Crypto :: SendError , Com :: SendError > > {
3972 self . crypto
4073 . send ( & self . communication , & self . sessions , message)
4174 . await
4275 }
4376
77+ /// Create a subscription to receive messages, optionally filtered by topic.
78+ /// Setting the topic to `None` will receive all messages.
79+ pub async fn subscribe (
80+ self : & Arc < Self > ,
81+ topic : Option < String > ,
82+ ) -> IpcClientSubscription < Crypto , Com , Ses > {
83+ IpcClientSubscription {
84+ receiver : self . communication . subscribe ( ) . await ,
85+ client : self . clone ( ) ,
86+ topic,
87+ }
88+ }
89+
90+ /// Create a subscription to receive messages that can be deserialized into the provided payload
91+ /// type.
92+ pub async fn subscribe_typed < Payload > (
93+ self : & Arc < Self > ,
94+ ) -> IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
95+ where
96+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
97+ {
98+ IpcClientTypedSubscription {
99+ receiver : self . communication . subscribe ( ) . await ,
100+ client : self . clone ( ) ,
101+ _payload : std:: marker:: PhantomData ,
102+ }
103+ }
104+
44105 /// Receive a message, optionally filtering by topic.
45106 /// Setting the topic to `None` will receive all messages.
46107 /// Setting the timeout to `None` will wait indefinitely.
47- pub async fn receive (
108+ async fn receive (
48109 & self ,
49- topic : Option < String > ,
110+ receiver : & Com :: Receiver ,
111+ topic : & Option < String > ,
50112 timeout : Option < Duration > ,
51- ) -> Result < IncomingMessage , ReceiveError < Crypto :: ReceiveError , Com :: ReceiveError > > {
113+ ) -> Result <
114+ IncomingMessage ,
115+ ReceiveError <
116+ Crypto :: ReceiveError ,
117+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
118+ > ,
119+ > {
52120 let receive_loop = async {
53121 loop {
54122 let received = self
55123 . crypto
56- . receive ( & self . communication , & self . sessions )
124+ . receive ( receiver , & self . communication , & self . sessions )
57125 . await ?;
58- if topic. is_none ( ) || received. topic == topic {
126+ if topic. is_none ( ) || & received. topic == topic {
59127 return Ok ( received) ;
60128 }
61129 }
@@ -72,26 +140,75 @@ where
72140
73141 /// Receive a message, skipping any messages that cannot be deserialized into the expected
74142 /// payload type.
75- pub async fn receive_typed < Payload > (
143+ async fn receive_typed < Payload > (
76144 & self ,
145+ receiver : & Com :: Receiver ,
77146 timeout : Option < Duration > ,
78147 ) -> Result <
79148 TypedIncomingMessage < Payload > ,
80149 TypedReceiveError <
81150 <Payload as TryFrom < Vec < u8 > > >:: Error ,
82151 Crypto :: ReceiveError ,
83- Com :: ReceiveError ,
152+ < Com :: Receiver as CommunicationBackendReceiver > :: ReceiveError ,
84153 > ,
85154 >
86155 where
87156 Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
88157 {
89158 let topic = Some ( Payload :: name ( ) ) ;
90- let received = self . receive ( topic, timeout) . await ?;
159+ let received = self . receive ( receiver , & topic, timeout) . await ?;
91160 received. try_into ( ) . map_err ( TypedReceiveError :: Typing )
92161 }
93162}
94163
164+ impl < Crypto , Com , Ses > IpcClientSubscription < Crypto , Com , Ses >
165+ where
166+ Crypto : CryptoProvider < Com , Ses > ,
167+ Com : CommunicationBackend ,
168+ Ses : SessionRepository < Session = Crypto :: Session > ,
169+ {
170+ /// Receive a message, optionally filtering by topic.
171+ /// Setting the timeout to `None` will wait indefinitely.
172+ pub async fn receive (
173+ & self ,
174+ timeout : Option < Duration > ,
175+ ) -> Result <
176+ IncomingMessage ,
177+ ReceiveError <
178+ Crypto :: ReceiveError ,
179+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
180+ > ,
181+ > {
182+ self . client
183+ . receive ( & self . receiver , & self . topic , timeout)
184+ . await
185+ }
186+ }
187+
188+ impl < Crypto , Com , Ses , Payload > IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
189+ where
190+ Crypto : CryptoProvider < Com , Ses > ,
191+ Com : CommunicationBackend ,
192+ Ses : SessionRepository < Session = Crypto :: Session > ,
193+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
194+ {
195+ /// Receive a message.
196+ /// Setting the timeout to `None` will wait indefinitely.
197+ pub async fn receive (
198+ & self ,
199+ timeout : Option < Duration > ,
200+ ) -> Result <
201+ TypedIncomingMessage < Payload > ,
202+ TypedReceiveError <
203+ <Payload as TryFrom < Vec < u8 > > >:: Error ,
204+ Crypto :: ReceiveError ,
205+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
206+ > ,
207+ > {
208+ self . client . receive_typed ( & self . receiver , timeout) . await
209+ }
210+ }
211+
95212#[ cfg( test) ]
96213mod tests {
97214 use std:: collections:: HashMap ;
@@ -121,6 +238,7 @@ mod tests {
121238
122239 async fn receive (
123240 & self ,
241+ _receiver : & <TestCommunicationBackend as CommunicationBackend >:: Receiver ,
124242 _communication : & TestCommunicationBackend ,
125243 _sessions : & TestSessionRepository ,
126244 ) -> Result < IncomingMessage , ReceiveError < String , TestCommunicationBackendReceiveError > >
@@ -176,7 +294,8 @@ mod tests {
176294 let session_map = TestSessionRepository :: new ( HashMap :: new ( ) ) ;
177295 let client = IpcClient :: new ( crypto_provider, communication_provider, session_map) ;
178296
179- let error = client. receive ( None , None ) . await . unwrap_err ( ) ;
297+ let subscription = client. subscribe ( None ) . await ;
298+ let error = subscription. receive ( None ) . await . unwrap_err ( ) ;
180299
181300 assert_eq ! ( error, ReceiveError :: Crypto ( "Crypto error" . to_string( ) ) ) ;
182301 }
@@ -212,8 +331,9 @@ mod tests {
212331 let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
213332 let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
214333
215- communication_provider. push_incoming ( message. clone ( ) ) . await ;
216- let received_message = client. receive ( None , None ) . await . unwrap ( ) ;
334+ let subscription = & client. subscribe ( None ) . await ;
335+ communication_provider. push_incoming ( message. clone ( ) ) ;
336+ let received_message = subscription. receive ( None ) . await . unwrap ( ) ;
217337
218338 assert_eq ! ( received_message, message) ;
219339 }
@@ -237,20 +357,12 @@ mod tests {
237357 let communication_provider = TestCommunicationBackend :: new ( ) ;
238358 let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
239359 let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
240- communication_provider
241- . push_incoming ( non_matching_message. clone ( ) )
242- . await ;
243- communication_provider
244- . push_incoming ( non_matching_message. clone ( ) )
245- . await ;
246- communication_provider
247- . push_incoming ( matching_message. clone ( ) )
248- . await ;
249-
250- let received_message: IncomingMessage = client
251- . receive ( Some ( "matching_topic" . to_owned ( ) ) , None )
252- . await
253- . unwrap ( ) ;
360+ let subscription = client. subscribe ( Some ( "matching_topic" . to_owned ( ) ) ) . await ;
361+ communication_provider. push_incoming ( non_matching_message. clone ( ) ) ;
362+ communication_provider. push_incoming ( non_matching_message. clone ( ) ) ;
363+ communication_provider. push_incoming ( matching_message. clone ( ) ) ;
364+
365+ let received_message: IncomingMessage = subscription. receive ( None ) . await . unwrap ( ) ;
254366
255367 assert_eq ! ( received_message, matching_message) ;
256368 }
@@ -302,18 +414,12 @@ mod tests {
302414 let communication_provider = TestCommunicationBackend :: new ( ) ;
303415 let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
304416 let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
305- communication_provider
306- . push_incoming ( unrelated. clone ( ) )
307- . await ;
308- communication_provider
309- . push_incoming ( unrelated. clone ( ) )
310- . await ;
311- communication_provider
312- . push_incoming ( typed_message. clone ( ) . try_into ( ) . unwrap ( ) )
313- . await ;
314-
315- let received_message: TypedIncomingMessage < TestPayload > =
316- client. receive_typed ( None ) . await . unwrap ( ) ;
417+ let subscription = client. subscribe_typed :: < TestPayload > ( ) . await ;
418+ communication_provider. push_incoming ( unrelated. clone ( ) ) ;
419+ communication_provider. push_incoming ( unrelated. clone ( ) ) ;
420+ communication_provider. push_incoming ( typed_message. clone ( ) . try_into ( ) . unwrap ( ) ) ;
421+
422+ let received_message = subscription. receive ( None ) . await . unwrap ( ) ;
317423
318424 assert_eq ! ( received_message, typed_message) ;
319425 }
@@ -358,11 +464,10 @@ mod tests {
358464 let communication_provider = TestCommunicationBackend :: new ( ) ;
359465 let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
360466 let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
361- communication_provider
362- . push_incoming ( non_deserializable_message. clone ( ) )
363- . await ;
467+ let subscription = client. subscribe_typed :: < TestPayload > ( ) . await ;
468+ communication_provider. push_incoming ( non_deserializable_message. clone ( ) ) ;
364469
365- let result: Result < TypedIncomingMessage < TestPayload > , _ > = client . receive_typed ( None ) . await ;
470+ let result = subscription . receive ( None ) . await ;
366471
367472 assert ! ( matches!(
368473 result,
0 commit comments