|
| 1 | +//! This module contains an wrapper type |
| 2 | +//! that provides a [`diesel::Connection`] |
| 3 | +//! implementation for types that implement |
| 4 | +//! [`crate::AsyncConnection`]. Using this type |
| 5 | +//! might be useful for the following usecases: |
| 6 | +//! |
| 7 | +//! * Executing migrations on application startup |
| 8 | +//! * Using a pure rust diesel connection implementation |
| 9 | +//! as replacement for the existing connection |
| 10 | +//! implementations provided by diesel |
| 11 | +
|
| 12 | +use futures_util::Future; |
| 13 | +use futures_util::Stream; |
| 14 | +use futures_util::StreamExt; |
| 15 | +use std::pin::Pin; |
| 16 | + |
| 17 | +/// This is a helper trait that allows to customize the |
| 18 | +/// async runtime used to execute futures as part of the |
| 19 | +/// [`AsyncConnectionWrapper`] type. By default a |
| 20 | +/// tokio runtime is used. |
| 21 | +pub trait BlockOn { |
| 22 | + /// This function should allow to execute a |
| 23 | + /// given future to get the result |
| 24 | + fn block_on<F>(&self, f: F) -> F::Output |
| 25 | + where |
| 26 | + F: Future; |
| 27 | + |
| 28 | + /// This function should be used to construct |
| 29 | + /// a new runtime instance |
| 30 | + fn get_runtime() -> Self; |
| 31 | +} |
| 32 | + |
| 33 | +/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to |
| 34 | +/// provide a sync [`diesel::Connection`] implementation. |
| 35 | +/// |
| 36 | +/// Internally this wrapper type will use `block_on` to wait for |
| 37 | +/// the execution of futures from the inner connection. This implies you |
| 38 | +/// cannot use functions of this type in a scope with an already existing |
| 39 | +/// tokio runtime. If you are in a situation where you want to use this |
| 40 | +/// connection wrapper in the scope of an existing tokio runtime (for example |
| 41 | +/// for running migrations via `diesel_migration`) you need to wrap |
| 42 | +/// the relevant code block into a `tokio::task::spawn_blocking` task. |
| 43 | +/// |
| 44 | +/// # Examples |
| 45 | +/// |
| 46 | +/// ```rust |
| 47 | +/// # include!("doctest_setup.rs"); |
| 48 | +/// use schema::users; |
| 49 | +/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; |
| 50 | +/// # |
| 51 | +/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { |
| 52 | +/// use diesel::prelude::{RunQueryDsl, Connection}; |
| 53 | +/// # let database_url = database_url(); |
| 54 | +/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?; |
| 55 | +/// |
| 56 | +/// let all_users = users::table.load::<(i32, String)>(&mut conn)?; |
| 57 | +/// # assert_eq!(all_users.len(), 0); |
| 58 | +/// # Ok(()) |
| 59 | +/// # } |
| 60 | +/// ``` |
| 61 | +/// |
| 62 | +/// If you are in the scope of an existing tokio runtime you need to use |
| 63 | +/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks |
| 64 | +/// ```rust |
| 65 | +/// # include!("doctest_setup.rs"); |
| 66 | +/// use schema::users; |
| 67 | +/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; |
| 68 | +/// |
| 69 | +/// async fn some_async_fn() { |
| 70 | +/// # let database_url = database_url(); |
| 71 | +/// // need to use `spawn_blocking` to execute |
| 72 | +/// // a blocking task in the scope of an existing runtime |
| 73 | +/// let res = tokio::task::spawn_blocking(move || { |
| 74 | +/// use diesel::prelude::{RunQueryDsl, Connection}; |
| 75 | +/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?; |
| 76 | +/// |
| 77 | +/// let all_users = users::table.load::<(i32, String)>(&mut conn)?; |
| 78 | +/// # assert_eq!(all_users.len(), 0); |
| 79 | +/// Ok::<_, Box<dyn std::error::Error + Send + Sync>>(()) |
| 80 | +/// }).await; |
| 81 | +/// |
| 82 | +/// # res.unwrap().unwrap(); |
| 83 | +/// } |
| 84 | +/// |
| 85 | +/// # #[tokio::main] |
| 86 | +/// # async fn main() { |
| 87 | +/// # some_async_fn().await; |
| 88 | +/// # } |
| 89 | +/// ``` |
| 90 | +#[cfg(feature = "tokio")] |
| 91 | +pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> = |
| 92 | + self::implementation::AsyncConnectionWrapper<C, B>; |
| 93 | + |
| 94 | +/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to |
| 95 | +/// provide a sync [`diesel::Connection`] implementation. |
| 96 | +/// |
| 97 | +/// Internally this wrapper type will use `block_on` to wait for |
| 98 | +/// the execution of futures from the inner connection. |
| 99 | +#[cfg(not(feature = "tokio"))] |
| 100 | +pub use self::implementation::AsyncConnectionWrapper; |
| 101 | + |
| 102 | +mod implementation { |
| 103 | + use super::*; |
| 104 | + |
| 105 | + pub struct AsyncConnectionWrapper<C, B> { |
| 106 | + inner: C, |
| 107 | + runtime: B, |
| 108 | + } |
| 109 | + |
| 110 | + impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B> |
| 111 | + where |
| 112 | + C: crate::SimpleAsyncConnection, |
| 113 | + B: BlockOn, |
| 114 | + { |
| 115 | + fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> { |
| 116 | + let f = self.inner.batch_execute(query); |
| 117 | + self.runtime.block_on(f) |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {} |
| 122 | + |
| 123 | + impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B> |
| 124 | + where |
| 125 | + C: crate::AsyncConnection, |
| 126 | + B: BlockOn + Send, |
| 127 | + { |
| 128 | + type Backend = C::Backend; |
| 129 | + |
| 130 | + type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper; |
| 131 | + |
| 132 | + fn establish(database_url: &str) -> diesel::ConnectionResult<Self> { |
| 133 | + let runtime = B::get_runtime(); |
| 134 | + let f = C::establish(database_url); |
| 135 | + let inner = runtime.block_on(f)?; |
| 136 | + Ok(Self { inner, runtime }) |
| 137 | + } |
| 138 | + |
| 139 | + fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize> |
| 140 | + where |
| 141 | + T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId, |
| 142 | + { |
| 143 | + let f = self.inner.execute_returning_count(source); |
| 144 | + self.runtime.block_on(f) |
| 145 | + } |
| 146 | + |
| 147 | + fn transaction_state( |
| 148 | + &mut self, |
| 149 | + ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{ |
| 150 | + self.inner.transaction_state() |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B> |
| 155 | + where |
| 156 | + C: crate::AsyncConnection, |
| 157 | + B: BlockOn + Send, |
| 158 | + { |
| 159 | + type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B> |
| 160 | + where |
| 161 | + Self: 'conn; |
| 162 | + |
| 163 | + type Row<'conn, 'query> = C::Row<'conn, 'query> |
| 164 | + where |
| 165 | + Self: 'conn; |
| 166 | + |
| 167 | + fn load<'conn, 'query, T>( |
| 168 | + &'conn mut self, |
| 169 | + source: T, |
| 170 | + ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>> |
| 171 | + where |
| 172 | + T: diesel::query_builder::Query |
| 173 | + + diesel::query_builder::QueryFragment<Self::Backend> |
| 174 | + + diesel::query_builder::QueryId |
| 175 | + + 'query, |
| 176 | + Self::Backend: diesel::expression::QueryMetadata<T::SqlType>, |
| 177 | + { |
| 178 | + let f = self.inner.load(source); |
| 179 | + let stream = self.runtime.block_on(f)?; |
| 180 | + |
| 181 | + Ok(AsyncCursorWrapper { |
| 182 | + stream: Box::pin(stream), |
| 183 | + runtime: &self.runtime, |
| 184 | + }) |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + pub struct AsyncCursorWrapper<'a, S, B> { |
| 189 | + stream: Pin<Box<S>>, |
| 190 | + runtime: &'a B, |
| 191 | + } |
| 192 | + |
| 193 | + impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B> |
| 194 | + where |
| 195 | + S: Stream, |
| 196 | + B: BlockOn, |
| 197 | + { |
| 198 | + type Item = S::Item; |
| 199 | + |
| 200 | + fn next(&mut self) -> Option<Self::Item> { |
| 201 | + let f = self.stream.next(); |
| 202 | + self.runtime.block_on(f) |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + pub struct AsyncConnectionWrapperTransactionManagerWrapper; |
| 207 | + |
| 208 | + impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>> |
| 209 | + for AsyncConnectionWrapperTransactionManagerWrapper |
| 210 | + where |
| 211 | + C: crate::AsyncConnection, |
| 212 | + B: BlockOn + Send, |
| 213 | + { |
| 214 | + type TransactionStateData = |
| 215 | + <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData; |
| 216 | + |
| 217 | + fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> { |
| 218 | + let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction( |
| 219 | + &mut conn.inner, |
| 220 | + ); |
| 221 | + conn.runtime.block_on(f) |
| 222 | + } |
| 223 | + |
| 224 | + fn rollback_transaction( |
| 225 | + conn: &mut AsyncConnectionWrapper<C, B>, |
| 226 | + ) -> diesel::QueryResult<()> { |
| 227 | + let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction( |
| 228 | + &mut conn.inner, |
| 229 | + ); |
| 230 | + conn.runtime.block_on(f) |
| 231 | + } |
| 232 | + |
| 233 | + fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> { |
| 234 | + let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction( |
| 235 | + &mut conn.inner, |
| 236 | + ); |
| 237 | + conn.runtime.block_on(f) |
| 238 | + } |
| 239 | + |
| 240 | + fn transaction_manager_status_mut( |
| 241 | + conn: &mut AsyncConnectionWrapper<C, B>, |
| 242 | + ) -> &mut diesel::connection::TransactionManagerStatus { |
| 243 | + <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut( |
| 244 | + &mut conn.inner, |
| 245 | + ) |
| 246 | + } |
| 247 | + |
| 248 | + fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool { |
| 249 | + <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager( |
| 250 | + &mut conn.inner, |
| 251 | + ) |
| 252 | + } |
| 253 | + } |
| 254 | + |
| 255 | + #[cfg(feature = "r2d2")] |
| 256 | + impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B> |
| 257 | + where |
| 258 | + B: BlockOn, |
| 259 | + Self: diesel::Connection, |
| 260 | + C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend> |
| 261 | + + crate::pooled_connection::PoolableConnection, |
| 262 | + { |
| 263 | + fn ping(&mut self) -> diesel::QueryResult<()> { |
| 264 | + diesel::Connection::execute_returning_count(self, &C::make_ping_query()).map(|_| ()) |
| 265 | + } |
| 266 | + |
| 267 | + fn is_broken(&mut self) -> bool { |
| 268 | + <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager( |
| 269 | + &mut self.inner, |
| 270 | + ) |
| 271 | + } |
| 272 | + } |
| 273 | + |
| 274 | + #[cfg(feature = "tokio")] |
| 275 | + pub struct Tokio { |
| 276 | + handle: Option<tokio::runtime::Handle>, |
| 277 | + runtime: Option<tokio::runtime::Runtime>, |
| 278 | + } |
| 279 | + |
| 280 | + #[cfg(feature = "tokio")] |
| 281 | + impl BlockOn for Tokio { |
| 282 | + fn block_on<F>(&self, f: F) -> F::Output |
| 283 | + where |
| 284 | + F: Future, |
| 285 | + { |
| 286 | + if let Some(handle) = &self.handle { |
| 287 | + handle.block_on(f) |
| 288 | + } else if let Some(runtime) = &self.runtime { |
| 289 | + runtime.block_on(f) |
| 290 | + } else { |
| 291 | + unreachable!() |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + fn get_runtime() -> Self { |
| 296 | + if let Ok(handle) = tokio::runtime::Handle::try_current() { |
| 297 | + Self { |
| 298 | + handle: Some(handle), |
| 299 | + runtime: None, |
| 300 | + } |
| 301 | + } else { |
| 302 | + let runtime = tokio::runtime::Builder::new_current_thread() |
| 303 | + .enable_io() |
| 304 | + .build() |
| 305 | + .unwrap(); |
| 306 | + Self { |
| 307 | + handle: None, |
| 308 | + runtime: Some(runtime), |
| 309 | + } |
| 310 | + } |
| 311 | + } |
| 312 | + } |
| 313 | +} |
0 commit comments