Skip to content

Commit 7884a32

Browse files
authored
Merge pull request #100 from weiznich/async_to_sync_connection_wrapper
Introduce an `AsyncConnectionWrapper` type
2 parents 5ba4375 + 4954cff commit 7884a32

File tree

17 files changed

+884
-464
lines changed

17 files changed

+884
-464
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All user visible changes to this project will be documented in this file.
44
This project adheres to [Semantic Versioning](http://semver.org/), as described
55
for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md)
66

7+
## Unreleased
8+
9+
* Add a `AsyncConnectionWrapper` type to turn a `diesel_async::AsyncConnection` into a `diesel::Connection`. This might be used to execute migrations via `diesel_migrations`.
10+
711
## [0.3.2] - 2023-07-24
812

913
* Fix `TinyInt` serialization
@@ -52,5 +56,3 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
5256
[0.2.1]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.2.1
5357
[0.2.2]: https://github.com/weiznich/diesel_async/compare/v0.2.1...v0.2.2
5458
[0.3.0]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.3.0
55-
[0.3.1]: https://github.com/weiznich/diesel_async/compare/v0.3.0...v0.3.1
56-
[0.3.2]: https://github.com/weiznich/diesel_async/compare/v0.3.1...v0.3.2

Cargo.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
1313
rust-version = "1.65.0"
1414

1515
[dependencies]
16-
diesel = { version = "~2.1.0", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]}
16+
diesel = { version = "~2.1.1", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]}
1717
async-trait = "0.1.66"
1818
futures-channel = { version = "0.3.17", default-features = false, features = ["std", "sink"], optional = true }
1919
futures-util = { version = "0.3.17", default-features = false, features = ["std", "sink"] }
20-
tokio-postgres = { version = "0.7.2", optional = true}
20+
tokio-postgres = { version = "0.7.10", optional = true}
2121
tokio = { version = "1.26", optional = true}
2222
mysql_async = { version = ">=0.30.0,<0.33", optional = true}
2323
mysql_common = {version = ">=0.29.0,<0.31.0", optional = true}
@@ -31,12 +31,14 @@ scoped-futures = {version = "0.1", features = ["std"]}
3131
tokio = {version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"]}
3232
cfg-if = "1"
3333
chrono = "0.4"
34-
diesel = { version = "2.0.0", default-features = false, features = ["chrono"]}
34+
diesel = { version = "2.1.0", default-features = false, features = ["chrono"]}
3535

3636
[features]
3737
default = []
38-
mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel"]
38+
mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel", "tokio"]
3939
postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/rt"]
40+
async-connection-wrapper = []
41+
r2d2 = ["diesel/r2d2"]
4042

4143
[[test]]
4244
name = "integration_tests"
@@ -54,3 +56,4 @@ members = [
5456
".",
5557
"examples/postgres/pooled-with-rustls"
5658
]
59+

src/async_connection_wrapper.rs

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
//! This module contains an wrapper type
2+
//! that provides a [`diesel::Connection`]
3+
//! implementation for types that implement
4+
//! [`crate::AsyncConnection`]. Using this type
5+
//! might be useful for the following usecases:
6+
//!
7+
//! * Executing migrations on application startup
8+
//! * Using a pure rust diesel connection implementation
9+
//! as replacement for the existing connection
10+
//! implementations provided by diesel
11+
12+
use futures_util::Future;
13+
use futures_util::Stream;
14+
use futures_util::StreamExt;
15+
use std::pin::Pin;
16+
17+
/// This is a helper trait that allows to customize the
18+
/// async runtime used to execute futures as part of the
19+
/// [`AsyncConnectionWrapper`] type. By default a
20+
/// tokio runtime is used.
21+
pub trait BlockOn {
22+
/// This function should allow to execute a
23+
/// given future to get the result
24+
fn block_on<F>(&self, f: F) -> F::Output
25+
where
26+
F: Future;
27+
28+
/// This function should be used to construct
29+
/// a new runtime instance
30+
fn get_runtime() -> Self;
31+
}
32+
33+
/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
34+
/// provide a sync [`diesel::Connection`] implementation.
35+
///
36+
/// Internally this wrapper type will use `block_on` to wait for
37+
/// the execution of futures from the inner connection. This implies you
38+
/// cannot use functions of this type in a scope with an already existing
39+
/// tokio runtime. If you are in a situation where you want to use this
40+
/// connection wrapper in the scope of an existing tokio runtime (for example
41+
/// for running migrations via `diesel_migration`) you need to wrap
42+
/// the relevant code block into a `tokio::task::spawn_blocking` task.
43+
///
44+
/// # Examples
45+
///
46+
/// ```rust
47+
/// # include!("doctest_setup.rs");
48+
/// use schema::users;
49+
/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
50+
/// #
51+
/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
52+
/// use diesel::prelude::{RunQueryDsl, Connection};
53+
/// # let database_url = database_url();
54+
/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
55+
///
56+
/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
57+
/// # assert_eq!(all_users.len(), 0);
58+
/// # Ok(())
59+
/// # }
60+
/// ```
61+
///
62+
/// If you are in the scope of an existing tokio runtime you need to use
63+
/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks
64+
/// ```rust
65+
/// # include!("doctest_setup.rs");
66+
/// use schema::users;
67+
/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
68+
///
69+
/// async fn some_async_fn() {
70+
/// # let database_url = database_url();
71+
/// // need to use `spawn_blocking` to execute
72+
/// // a blocking task in the scope of an existing runtime
73+
/// let res = tokio::task::spawn_blocking(move || {
74+
/// use diesel::prelude::{RunQueryDsl, Connection};
75+
/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
76+
///
77+
/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
78+
/// # assert_eq!(all_users.len(), 0);
79+
/// Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
80+
/// }).await;
81+
///
82+
/// # res.unwrap().unwrap();
83+
/// }
84+
///
85+
/// # #[tokio::main]
86+
/// # async fn main() {
87+
/// # some_async_fn().await;
88+
/// # }
89+
/// ```
90+
#[cfg(feature = "tokio")]
91+
pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92+
self::implementation::AsyncConnectionWrapper<C, B>;
93+
94+
/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
95+
/// provide a sync [`diesel::Connection`] implementation.
96+
///
97+
/// Internally this wrapper type will use `block_on` to wait for
98+
/// the execution of futures from the inner connection.
99+
#[cfg(not(feature = "tokio"))]
100+
pub use self::implementation::AsyncConnectionWrapper;
101+
102+
mod implementation {
103+
use super::*;
104+
105+
pub struct AsyncConnectionWrapper<C, B> {
106+
inner: C,
107+
runtime: B,
108+
}
109+
110+
impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
111+
where
112+
C: crate::SimpleAsyncConnection,
113+
B: BlockOn,
114+
{
115+
fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
116+
let f = self.inner.batch_execute(query);
117+
self.runtime.block_on(f)
118+
}
119+
}
120+
121+
impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
122+
123+
impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
124+
where
125+
C: crate::AsyncConnection,
126+
B: BlockOn + Send,
127+
{
128+
type Backend = C::Backend;
129+
130+
type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
131+
132+
fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
133+
let runtime = B::get_runtime();
134+
let f = C::establish(database_url);
135+
let inner = runtime.block_on(f)?;
136+
Ok(Self { inner, runtime })
137+
}
138+
139+
fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
140+
where
141+
T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
142+
{
143+
let f = self.inner.execute_returning_count(source);
144+
self.runtime.block_on(f)
145+
}
146+
147+
fn transaction_state(
148+
&mut self,
149+
) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
150+
self.inner.transaction_state()
151+
}
152+
}
153+
154+
impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
155+
where
156+
C: crate::AsyncConnection,
157+
B: BlockOn + Send,
158+
{
159+
type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
160+
where
161+
Self: 'conn;
162+
163+
type Row<'conn, 'query> = C::Row<'conn, 'query>
164+
where
165+
Self: 'conn;
166+
167+
fn load<'conn, 'query, T>(
168+
&'conn mut self,
169+
source: T,
170+
) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
171+
where
172+
T: diesel::query_builder::Query
173+
+ diesel::query_builder::QueryFragment<Self::Backend>
174+
+ diesel::query_builder::QueryId
175+
+ 'query,
176+
Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
177+
{
178+
let f = self.inner.load(source);
179+
let stream = self.runtime.block_on(f)?;
180+
181+
Ok(AsyncCursorWrapper {
182+
stream: Box::pin(stream),
183+
runtime: &self.runtime,
184+
})
185+
}
186+
}
187+
188+
pub struct AsyncCursorWrapper<'a, S, B> {
189+
stream: Pin<Box<S>>,
190+
runtime: &'a B,
191+
}
192+
193+
impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B>
194+
where
195+
S: Stream,
196+
B: BlockOn,
197+
{
198+
type Item = S::Item;
199+
200+
fn next(&mut self) -> Option<Self::Item> {
201+
let f = self.stream.next();
202+
self.runtime.block_on(f)
203+
}
204+
}
205+
206+
pub struct AsyncConnectionWrapperTransactionManagerWrapper;
207+
208+
impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
209+
for AsyncConnectionWrapperTransactionManagerWrapper
210+
where
211+
C: crate::AsyncConnection,
212+
B: BlockOn + Send,
213+
{
214+
type TransactionStateData =
215+
<C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
216+
217+
fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
218+
let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
219+
&mut conn.inner,
220+
);
221+
conn.runtime.block_on(f)
222+
}
223+
224+
fn rollback_transaction(
225+
conn: &mut AsyncConnectionWrapper<C, B>,
226+
) -> diesel::QueryResult<()> {
227+
let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
228+
&mut conn.inner,
229+
);
230+
conn.runtime.block_on(f)
231+
}
232+
233+
fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
234+
let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
235+
&mut conn.inner,
236+
);
237+
conn.runtime.block_on(f)
238+
}
239+
240+
fn transaction_manager_status_mut(
241+
conn: &mut AsyncConnectionWrapper<C, B>,
242+
) -> &mut diesel::connection::TransactionManagerStatus {
243+
<C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
244+
&mut conn.inner,
245+
)
246+
}
247+
248+
fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
249+
<C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
250+
&mut conn.inner,
251+
)
252+
}
253+
}
254+
255+
#[cfg(feature = "r2d2")]
256+
impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
257+
where
258+
B: BlockOn,
259+
Self: diesel::Connection,
260+
C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
261+
+ crate::pooled_connection::PoolableConnection,
262+
{
263+
fn ping(&mut self) -> diesel::QueryResult<()> {
264+
diesel::Connection::execute_returning_count(self, &C::make_ping_query()).map(|_| ())
265+
}
266+
267+
fn is_broken(&mut self) -> bool {
268+
<C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
269+
&mut self.inner,
270+
)
271+
}
272+
}
273+
274+
#[cfg(feature = "tokio")]
275+
pub struct Tokio {
276+
handle: Option<tokio::runtime::Handle>,
277+
runtime: Option<tokio::runtime::Runtime>,
278+
}
279+
280+
#[cfg(feature = "tokio")]
281+
impl BlockOn for Tokio {
282+
fn block_on<F>(&self, f: F) -> F::Output
283+
where
284+
F: Future,
285+
{
286+
if let Some(handle) = &self.handle {
287+
handle.block_on(f)
288+
} else if let Some(runtime) = &self.runtime {
289+
runtime.block_on(f)
290+
} else {
291+
unreachable!()
292+
}
293+
}
294+
295+
fn get_runtime() -> Self {
296+
if let Ok(handle) = tokio::runtime::Handle::try_current() {
297+
Self {
298+
handle: Some(handle),
299+
runtime: None,
300+
}
301+
} else {
302+
let runtime = tokio::runtime::Builder::new_current_thread()
303+
.enable_io()
304+
.build()
305+
.unwrap();
306+
Self {
307+
handle: None,
308+
runtime: Some(runtime),
309+
}
310+
}
311+
}
312+
}
313+
}

0 commit comments

Comments
 (0)