Skip to content

Commit f5e374a

Browse files
Unsafe fixes, smol executor added
1 parent aef441b commit f5e374a

File tree

13 files changed

+280
-62
lines changed

13 files changed

+280
-62
lines changed

Cargo.lock

Lines changed: 67 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ _unstable-all-types = [
7777
# Base runtime features without TLS
7878
runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor"]
7979
runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"]
80+
runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol"]
8081
runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio"]
8182

8283
# TLS features
@@ -98,12 +99,16 @@ runtime-async-global-executor-rustls = ["runtime-async-global-executor", "tls-ru
9899
runtime-async-std-native-tls = ["runtime-async-std", "tls-native-tls"]
99100
runtime-async-std-rustls = ["runtime-async-std", "tls-rustls-ring"]
100101

102+
runtime-smol-native-tls = ["runtime-smol", "tls-native-tls"]
103+
runtime-smol-rustls = ["runtime-smol", "tls-rustls-ring"]
104+
101105
runtime-tokio-native-tls = ["runtime-tokio", "tls-native-tls"]
102106
runtime-tokio-rustls = ["runtime-tokio", "tls-rustls-ring"]
103107

104108
# for conditional compilation
105109
_rt-async-global-executor = []
106110
_rt-async-std = []
111+
_rt-smol = []
107112
_rt-tokio = []
108113
_sqlite = []
109114

@@ -166,6 +171,10 @@ features = ["async-io"]
166171
[workspace.dependencies.async-std]
167172
version = "1.12"
168173

174+
[workspace.dependencies.smol]
175+
version = "2.0"
176+
default-features = false
177+
169178
[workspace.dependencies.tokio]
170179
version = "1"
171180
features = ["time", "net", "sync", "fs", "io-util", "rt"]

sqlx-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ _rt-async-global-executor = [
2525
"async-net",
2626
]
2727
_rt-async-std = ["async-std", "async-io-std"]
28+
_rt-smol = ["smol"]
2829
_rt-tokio = ["tokio", "tokio-stream"]
2930
_tls-native-tls = ["native-tls"]
3031
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
@@ -40,6 +41,7 @@ offline = ["serde", "either/serde"]
4041
# Runtimes
4142
async-global-executor = { workspace = true, optional = true }
4243
async-std = { workspace = true, optional = true }
44+
smol = { workspace = true, optional = true }
4345
tokio = { workspace = true, optional = true }
4446

4547
# TLS

sqlx-core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
// The only unsafe code in SQLx is that necessary to interact with native APIs like with SQLite,
1919
// and that can live in its own separate driver crate.
2020
// temporary
21-
// #![forbid(unsafe_code)]
21+
#![forbid(unsafe_code)]
2222
// Allows an API be documented as only available in some specific platforms.
2323
// <https://doc.rust-lang.org/unstable-book/language-features/doc-cfg.html>
2424
#![cfg_attr(docsrs, feature(doc_cfg))]

sqlx-core/src/rt/mod.rs

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ pub mod rt_async_global_executor;
1010
#[cfg(feature = "_rt-async-std")]
1111
pub mod rt_async_std;
1212

13+
#[cfg(feature = "_rt-smol")]
14+
pub mod rt_smol;
15+
1316
#[cfg(feature = "_rt-tokio")]
1417
pub mod rt_tokio;
1518

@@ -22,6 +25,8 @@ pub enum JoinHandle<T> {
2225
AsyncGlobalExecutor(rt_async_global_executor::JoinHandle<T>),
2326
#[cfg(feature = "_rt-async-std")]
2427
AsyncStd(async_std::task::JoinHandle<T>),
28+
#[cfg(feature = "_rt-smol")]
29+
Smol(rt_smol::JoinHandle<T>),
2530
#[cfg(feature = "_rt-tokio")]
2631
Tokio(tokio::task::JoinHandle<T>),
2732
// `PhantomData<T>` requires `T: Unpin`
@@ -41,14 +46,23 @@ pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, T
4146
return rt_async_global_executor::timeout(duration, f).await;
4247
}
4348

49+
#[cfg(feature = "_rt-smol")]
50+
{
51+
return rt_smol::timeout(duration, f).await;
52+
}
53+
4454
#[cfg(feature = "_rt-async-std")]
4555
{
4656
return async_std::future::timeout(duration, f)
4757
.await
4858
.map_err(|_| TimeoutError);
4959
}
5060

51-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
61+
#[cfg(not(all(
62+
feature = "_rt-async-global-executor",
63+
feature = "_rt-async-std",
64+
feature = "_rt-smol"
65+
)))]
5266
#[allow(unreachable_code)]
5367
missing_rt((duration, f))
5468
}
@@ -64,12 +78,21 @@ pub async fn sleep(duration: Duration) {
6478
return rt_async_global_executor::sleep(duration).await;
6579
}
6680

81+
#[cfg(feature = "_rt-smol")]
82+
{
83+
return rt_smol::sleep(duration).await;
84+
}
85+
6786
#[cfg(feature = "_rt-async-std")]
6887
{
6988
return async_std::task::sleep(duration).await;
7089
}
7190

72-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
91+
#[cfg(not(all(
92+
feature = "_rt-async-global-executor",
93+
feature = "_rt-async-std",
94+
feature = "_rt-smol"
95+
)))]
7396
#[allow(unreachable_code)]
7497
missing_rt(duration)
7598
}
@@ -97,7 +120,11 @@ where
97120
return JoinHandle::AsyncStd(async_std::task::spawn(fut));
98121
}
99122

100-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
123+
#[cfg(not(all(
124+
feature = "_rt-async-global-executor",
125+
feature = "_rt-async-std",
126+
feature = "_rt-smol"
127+
)))]
101128
#[allow(unreachable_code)]
102129
missing_rt(fut)
103130
}
@@ -125,7 +152,18 @@ where
125152
return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f));
126153
}
127154

128-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
155+
#[cfg(feature = "_rt-smol")]
156+
{
157+
return JoinHandle::Smol(rt_smol::JoinHandle {
158+
task: Some(smol::unblock(f)),
159+
});
160+
}
161+
162+
#[cfg(not(all(
163+
feature = "_rt-async-global-executor",
164+
feature = "_rt-async-std",
165+
feature = "_rt-smol"
166+
)))]
129167
#[allow(unreachable_code)]
130168
missing_rt(f)
131169
}
@@ -146,7 +184,16 @@ pub async fn yield_now() {
146184
return async_std::task::yield_now().await;
147185
}
148186

149-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
187+
#[cfg(feature = "_rt-smol")]
188+
{
189+
return smol::future::yield_now().await;
190+
}
191+
192+
#[cfg(not(all(
193+
feature = "_rt-async-global-executor",
194+
feature = "_rt-async-std",
195+
feature = "_rt-smol"
196+
)))]
150197
#[allow(unreachable_code)]
151198
missing_rt(())
152199
}
@@ -155,11 +202,13 @@ pub async fn yield_now() {
155202
pub fn test_block_on<F: Future>(f: F) -> F::Output {
156203
#[cfg(feature = "_rt-tokio")]
157204
{
158-
return tokio::runtime::Builder::new_current_thread()
159-
.enable_all()
160-
.build()
161-
.expect("failed to start Tokio runtime")
162-
.block_on(f);
205+
if rt_tokio::available() {
206+
return tokio::runtime::Builder::new_current_thread()
207+
.enable_all()
208+
.build()
209+
.expect("failed to start Tokio runtime")
210+
.block_on(f);
211+
}
163212
}
164213

165214
#[cfg(feature = "_rt-async-global-executor")]
@@ -172,7 +221,16 @@ pub fn test_block_on<F: Future>(f: F) -> F::Output {
172221
return async_std::task::block_on(f);
173222
}
174223

175-
#[cfg(not(all(feature = "_rt-async-global-executor", feature = "_rt-async-std",)))]
224+
#[cfg(feature = "_rt-smol")]
225+
{
226+
return smol::block_on(f);
227+
}
228+
229+
#[cfg(not(all(
230+
feature = "_rt-async-global-executor",
231+
feature = "_rt-async-std",
232+
feature = "_rt-smol"
233+
)))]
176234
#[allow(unreachable_code)]
177235
missing_rt(f)
178236
}
@@ -183,7 +241,7 @@ pub fn missing_rt<T>(_unused: T) -> ! {
183241
panic!("this functionality requires a Tokio context")
184242
}
185243

186-
panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, or `runtime-tokio` feature must be enabled")
244+
panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, `runtime-smol`, or `runtime-tokio` feature must be enabled")
187245
}
188246

189247
impl<T: Send + 'static> Future for JoinHandle<T> {
@@ -196,6 +254,8 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
196254
Self::AsyncGlobalExecutor(handle) => Pin::new(handle).poll(cx),
197255
#[cfg(feature = "_rt-async-std")]
198256
Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
257+
#[cfg(feature = "_rt-smol")]
258+
Self::Smol(handle) => Pin::new(handle).poll(cx),
199259
#[cfg(feature = "_rt-tokio")]
200260
Self::Tokio(handle) => Pin::new(handle)
201261
.poll(cx)

sqlx-core/src/rt/rt_async_global_executor/socket.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ use async_io_global_executor::Async;
1111

1212
impl Socket for Async<TcpStream> {
1313
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
14-
unsafe { self.get_mut().read(buf.init_mut()) }
14+
self.get_ref().read(buf.init_mut())
1515
}
1616

1717
fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
18-
unsafe { self.get_mut().write(buf) }
18+
self.get_ref().write(buf)
1919
}
2020

2121
fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@@ -27,18 +27,18 @@ impl Socket for Async<TcpStream> {
2727
}
2828

2929
fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
30-
unsafe { Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) }
30+
Poll::Ready(self.get_ref().shutdown(Shutdown::Both))
3131
}
3232
}
3333

3434
#[cfg(unix)]
3535
impl Socket for Async<std::os::unix::net::UnixStream> {
3636
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
37-
unsafe { self.get_mut().read(buf.init_mut()) }
37+
self.get_ref().read(buf.init_mut())
3838
}
3939

4040
fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
41-
unsafe { self.get_mut().write(buf) }
41+
self.get_ref().write(buf)
4242
}
4343

4444
fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@@ -50,6 +50,6 @@ impl Socket for Async<std::os::unix::net::UnixStream> {
5050
}
5151

5252
fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53-
unsafe { Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) }
53+
Poll::Ready(self.get_ref().shutdown(Shutdown::Both))
5454
}
5555
}

0 commit comments

Comments
 (0)