Skip to content

Commit 1eb33b9

Browse files
Make generated rust async futures Send (#1405)
The futures for generated rust async bindings were not always Send, because the LoweredParams held a non-Send `*mut u8`. To fix this, explicitly implement Send for LoweredParams. This makes the futures compatible with more of the Rust ecosystem, eg. an axum webserver where all handlers must be Send. This is safe, at least for now, because the generated code is single threaded. To make LoweredParams Send, change the type from a tuple to a struct which we can explicitly mark as Send. This also requires moving it up a little bit in the generated code outside of the trait impl. Add a test that asserts the generated binding is Send.
1 parent 65ee505 commit 1eb33b9

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

crates/rust/src/interface.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,32 @@ pub mod vtable{ordinal} {{
812812
let sig = self
813813
.resolve
814814
.wasm_signature(AbiVariant::GuestImportAsync, func);
815+
816+
// Generate `type ParamsLower`
817+
//
818+
uwrite!(
819+
self.src,
820+
"
821+
#[derive(Copy, Clone)]
822+
struct ParamsLower(
823+
"
824+
);
825+
let mut params_lower = sig.params.as_slice();
826+
if sig.retptr {
827+
params_lower = &params_lower[..params_lower.len() - 1];
828+
}
829+
for ty in params_lower {
830+
self.src.push_str(wasm_type(*ty));
831+
self.src.push_str(", ");
832+
}
833+
uwriteln!(
834+
self.src,
835+
"
836+
);
837+
unsafe impl Send for ParamsLower {{}}
838+
"
839+
);
840+
815841
uwriteln!(
816842
self.src,
817843
"
@@ -844,16 +870,7 @@ unsafe impl<'a> _Subtask for _MySubtask<'a> {{
844870
}
845871

846872
// Generate `type ParamsLower`
847-
uwrite!(self.src, "type ParamsLower = (");
848-
let mut params_lower = sig.params.as_slice();
849-
if sig.retptr {
850-
params_lower = &params_lower[..params_lower.len() - 1];
851-
}
852-
for ty in params_lower {
853-
self.src.push_str(wasm_type(*ty));
854-
self.src.push_str(", ");
855-
}
856-
uwriteln!(self.src, ");");
873+
uwrite!(self.src, "type ParamsLower = ParamsLower;");
857874

858875
// Generate `const ABI_LAYOUT`
859876
let mut heap_types = Vec::new();
@@ -962,7 +979,7 @@ unsafe fn call_import(&self, _params: Self::ParamsLower, _results: *mut u8) -> u
962979
lowers.push(start);
963980
param_lowers.push(name);
964981
}
965-
lowers.push("(_ptr,)".to_string());
982+
lowers.push("ParamsLower(_ptr,)".to_string());
966983
} else {
967984
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
968985
let mut results = Vec::new();
@@ -974,7 +991,7 @@ unsafe fn call_import(&self, _params: Self::ParamsLower, _results: *mut u8) -> u
974991
for result in results.iter_mut() {
975992
result.push_str(",");
976993
}
977-
let result = format!("({})", results.join(" "));
994+
let result = format!("ParamsLower({})", results.join(" "));
978995
lowers.push(format!("unsafe {{ {} {result} }}", String::from(f.src)));
979996
}
980997

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
include!(env!("BINDINGS"));
2+
3+
use std::future::Future;
4+
5+
use crate::a::b::i::*;
6+
7+
// Explicitly require Send.
8+
#[allow(dead_code)]
9+
fn require_send<T: Send>(_t: &T) {}
10+
11+
// This is the type of block_on with a Send requirement added.
12+
pub fn block_on_require_send<T: 'static>(future: impl Future<Output = T> + Send + 'static) -> T {
13+
require_send(&future);
14+
wit_bindgen::block_on(future)
15+
}
16+
17+
fn main() {
18+
block_on_require_send(async {
19+
one_argument("hello".into()).await;
20+
});
21+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
include!(env!("BINDINGS"));
2+
3+
struct Component;
4+
5+
export!(Component);
6+
7+
impl crate::exports::a::b::i::Guest for Component {
8+
async fn one_argument(x: String) {
9+
assert_eq!(&x, "hello");
10+
}
11+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package a:b;
2+
3+
interface i {
4+
one-argument: async func(x: string);
5+
}
6+
7+
world test {
8+
export i;
9+
}
10+
11+
world runner {
12+
import i;
13+
}

0 commit comments

Comments
 (0)