Skip to content

Commit c104f3b

Browse files
committed
Lower return types for gen fn to impl Iterator
1 parent bc0d10d commit c104f3b

File tree

7 files changed

+167
-80
lines changed

7 files changed

+167
-80
lines changed

compiler/rustc_ast_lowering/src/item.rs

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::FnReturnTransformation;
2+
13
use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound};
24
use super::ResolverAstLoweringExt;
35
use super::{AstOwner, ImplTraitContext, ImplTraitPosition};
@@ -207,13 +209,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
207209
// only cares about the input argument patterns in the function
208210
// declaration (decl), not the return types.
209211
let asyncness = header.asyncness;
210-
let body_id =
211-
this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref());
212+
let genness = header.genness;
213+
let body_id = this.lower_maybe_coroutine_body(
214+
span,
215+
hir_id,
216+
decl,
217+
asyncness,
218+
genness,
219+
body.as_deref(),
220+
);
212221

213222
let itctx = ImplTraitContext::Universal;
214223
let (generics, decl) =
215224
this.lower_generics(generics, header.constness, id, &itctx, |this| {
216-
let ret_id = asyncness.opt_return_id();
225+
let ret_id = asyncness
226+
.opt_return_id()
227+
.map(|(node_id, span)| {
228+
crate::FnReturnTransformation::Async(node_id, span)
229+
})
230+
.or_else(|| match genness {
231+
Gen::Yes { span, closure_id: _, return_impl_trait_id } => {
232+
Some(crate::FnReturnTransformation::Iterator(
233+
return_impl_trait_id,
234+
span,
235+
))
236+
}
237+
_ => None,
238+
});
217239
this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id)
218240
});
219241
let sig = hir::FnSig {
@@ -732,20 +754,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
732754
sig,
733755
i.id,
734756
FnDeclKind::Trait,
735-
asyncness.opt_return_id(),
757+
asyncness
758+
.opt_return_id()
759+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
736760
);
737761
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false)
738762
}
739763
AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => {
740764
let asyncness = sig.header.asyncness;
741-
let body_id =
742-
self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body));
765+
let genness = sig.header.genness;
766+
let body_id = self.lower_maybe_coroutine_body(
767+
i.span,
768+
hir_id,
769+
&sig.decl,
770+
asyncness,
771+
genness,
772+
Some(body),
773+
);
743774
let (generics, sig) = self.lower_method_sig(
744775
generics,
745776
sig,
746777
i.id,
747778
FnDeclKind::Trait,
748-
asyncness.opt_return_id(),
779+
asyncness
780+
.opt_return_id()
781+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
749782
);
750783
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true)
751784
}
@@ -835,19 +868,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
835868
),
836869
AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => {
837870
let asyncness = sig.header.asyncness;
838-
let body_id = self.lower_maybe_async_body(
871+
let genness = sig.header.genness;
872+
let body_id = self.lower_maybe_coroutine_body(
839873
i.span,
840874
hir_id,
841875
&sig.decl,
842876
asyncness,
877+
genness,
843878
body.as_deref(),
844879
);
845880
let (generics, sig) = self.lower_method_sig(
846881
generics,
847882
sig,
848883
i.id,
849884
if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent },
850-
asyncness.opt_return_id(),
885+
asyncness
886+
.opt_return_id()
887+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
851888
);
852889

853890
(generics, hir::ImplItemKind::Fn(sig, body_id))
@@ -1011,16 +1048,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
10111048
})
10121049
}
10131050

1014-
fn lower_maybe_async_body(
1051+
/// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or
1052+
/// `gen {}` block as appropriate.
1053+
fn lower_maybe_coroutine_body(
10151054
&mut self,
10161055
span: Span,
10171056
fn_id: hir::HirId,
10181057
decl: &FnDecl,
10191058
asyncness: Async,
1059+
genness: Gen,
10201060
body: Option<&Block>,
10211061
) -> hir::BodyId {
1022-
let (closure_id, body) = match (asyncness, body) {
1023-
(Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
1062+
let (closure_id, body) = match (asyncness, genness, body) {
1063+
// FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error
1064+
// for now since it's not supported.
1065+
(Async::Yes { closure_id, .. }, _, Some(body))
1066+
| (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
10241067
_ => return self.lower_fn_body_block(span, decl, body),
10251068
};
10261069

@@ -1163,44 +1206,55 @@ impl<'hir> LoweringContext<'_, 'hir> {
11631206
parameters.push(new_parameter);
11641207
}
11651208

1166-
let async_expr = this.make_async_expr(
1167-
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1168-
closure_id,
1169-
None,
1170-
body.span,
1171-
hir::CoroutineSource::Fn,
1172-
|this| {
1173-
// Create a block from the user's function body:
1174-
let user_body = this.lower_block_expr(body);
1209+
let mkbody = |this: &mut LoweringContext<'_, 'hir>| {
1210+
// Create a block from the user's function body:
1211+
let user_body = this.lower_block_expr(body);
11751212

1176-
// Transform into `drop-temps { <user-body> }`, an expression:
1177-
let desugared_span =
1178-
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1179-
let user_body =
1180-
this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
1213+
// Transform into `drop-temps { <user-body> }`, an expression:
1214+
let desugared_span =
1215+
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1216+
let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
11811217

1182-
// As noted above, create the final block like
1183-
//
1184-
// ```
1185-
// {
1186-
// let $param_pattern = $raw_param;
1187-
// ...
1188-
// drop-temps { <user-body> }
1189-
// }
1190-
// ```
1191-
let body = this.block_all(
1192-
desugared_span,
1193-
this.arena.alloc_from_iter(statements),
1194-
Some(user_body),
1195-
);
1218+
// As noted above, create the final block like
1219+
//
1220+
// ```
1221+
// {
1222+
// let $param_pattern = $raw_param;
1223+
// ...
1224+
// drop-temps { <user-body> }
1225+
// }
1226+
// ```
1227+
let body = this.block_all(
1228+
desugared_span,
1229+
this.arena.alloc_from_iter(statements),
1230+
Some(user_body),
1231+
);
11961232

1197-
this.expr_block(body)
1198-
},
1199-
);
1233+
this.expr_block(body)
1234+
};
1235+
let coroutine_expr = match (asyncness, genness) {
1236+
(Async::Yes { .. }, _) => this.make_async_expr(
1237+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1238+
closure_id,
1239+
None,
1240+
body.span,
1241+
hir::CoroutineSource::Fn,
1242+
mkbody,
1243+
),
1244+
(_, Gen::Yes { .. }) => this.make_gen_expr(
1245+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1246+
closure_id,
1247+
None,
1248+
body.span,
1249+
hir::CoroutineSource::Fn,
1250+
mkbody,
1251+
),
1252+
_ => unreachable!("we must have either an async fn or a gen fn"),
1253+
};
12001254

12011255
let hir_id = this.lower_node_id(closure_id);
12021256
this.maybe_forward_track_caller(body.span, fn_id, hir_id);
1203-
let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) };
1257+
let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) };
12041258

12051259
(this.arena.alloc_from_iter(parameters), expr)
12061260
})
@@ -1212,13 +1266,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
12121266
sig: &FnSig,
12131267
id: NodeId,
12141268
kind: FnDeclKind,
1215-
is_async: Option<(NodeId, Span)>,
1269+
transform_return_type: Option<FnReturnTransformation>,
12161270
) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
12171271
let header = self.lower_fn_header(sig.header);
12181272
let itctx = ImplTraitContext::Universal;
12191273
let (generics, decl) =
12201274
self.lower_generics(generics, sig.header.constness, id, &itctx, |this| {
1221-
this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async)
1275+
this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type)
12221276
});
12231277
(generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
12241278
}

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,21 @@ enum ParenthesizedGenericArgs {
493493
Err,
494494
}
495495

496+
/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl`
497+
#[derive(Debug)]
498+
enum FnReturnTransformation {
499+
/// Replaces a return type `T` with `impl Future<Output = T>`.
500+
///
501+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
502+
/// `async` keyword.
503+
Async(NodeId, Span),
504+
/// Replaces a return type `T` with `impl Iterator<Item = T>`.
505+
///
506+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
507+
/// `gen` keyword.
508+
Iterator(NodeId, Span),
509+
}
510+
496511
impl<'a, 'hir> LoweringContext<'a, 'hir> {
497512
fn create_def(
498513
&mut self,
@@ -1778,21 +1793,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
17781793
}))
17791794
}
17801795

1781-
// Lowers a function declaration.
1782-
//
1783-
// `decl`: the unlowered (AST) function declaration.
1784-
// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`.
1785-
// `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future<Output = T>` in the
1786-
// return type. This is used for `async fn` declarations. The `NodeId` is the ID of the
1787-
// return type `impl Trait` item, and the `Span` points to the `async` keyword.
1796+
/// Lowers a function declaration.
1797+
///
1798+
/// `decl`: the unlowered (AST) function declaration.
1799+
///
1800+
/// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given
1801+
/// `NodeId`.
1802+
///
1803+
/// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is
1804+
/// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details.
17881805
#[instrument(level = "debug", skip(self))]
17891806
fn lower_fn_decl(
17901807
&mut self,
17911808
decl: &FnDecl,
17921809
fn_node_id: NodeId,
17931810
fn_span: Span,
17941811
kind: FnDeclKind,
1795-
make_ret_async: Option<(NodeId, Span)>,
1812+
transform_return_type: Option<FnReturnTransformation>,
17961813
) -> &'hir hir::FnDecl<'hir> {
17971814
let c_variadic = decl.c_variadic();
17981815

@@ -1821,11 +1838,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18211838
self.lower_ty_direct(&param.ty, &itctx)
18221839
}));
18231840

1824-
let output = if let Some((ret_id, _span)) = make_ret_async {
1825-
let fn_def_id = self.local_def_id(fn_node_id);
1826-
self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span)
1827-
} else {
1828-
match &decl.output {
1841+
let output = match transform_return_type {
1842+
Some(transform) => {
1843+
let fn_def_id = self.local_def_id(fn_node_id);
1844+
self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span)
1845+
}
1846+
None => match &decl.output {
18291847
FnRetTy::Ty(ty) => {
18301848
let context = if kind.return_impl_trait_allowed() {
18311849
let fn_def_id = self.local_def_id(fn_node_id);
@@ -1849,7 +1867,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18491867
hir::FnRetTy::Return(self.lower_ty(ty, &context))
18501868
}
18511869
FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)),
1852-
}
1870+
},
18531871
};
18541872

18551873
self.arena.alloc(hir::FnDecl {
@@ -1888,17 +1906,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18881906
// `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition)
18891907
// `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created
18901908
#[instrument(level = "debug", skip(self))]
1891-
fn lower_async_fn_ret_ty(
1909+
fn lower_coroutine_fn_ret_ty(
18921910
&mut self,
18931911
output: &FnRetTy,
18941912
fn_def_id: LocalDefId,
1895-
opaque_ty_node_id: NodeId,
1913+
transform: FnReturnTransformation,
18961914
fn_kind: FnDeclKind,
18971915
fn_span: Span,
18981916
) -> hir::FnRetTy<'hir> {
18991917
let span = self.lower_span(fn_span);
19001918
let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
19011919

1920+
let opaque_ty_node_id = match transform {
1921+
FnReturnTransformation::Async(opaque_ty_node_id, _)
1922+
| FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id,
1923+
};
1924+
19021925
let captured_lifetimes: Vec<_> = self
19031926
.resolver
19041927
.take_extra_lifetime_params(opaque_ty_node_id)
@@ -1914,8 +1937,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19141937
span,
19151938
opaque_ty_span,
19161939
|this| {
1917-
let future_bound = this.lower_async_fn_output_type_to_future_bound(
1940+
let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
19181941
output,
1942+
transform,
19191943
span,
19201944
ImplTraitContext::ReturnPositionOpaqueTy {
19211945
origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
@@ -1931,9 +1955,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19311955
}
19321956

19331957
/// Transforms `-> T` into `Future<Output = T>`.
1934-
fn lower_async_fn_output_type_to_future_bound(
1958+
fn lower_coroutine_fn_output_type_to_future_bound(
19351959
&mut self,
19361960
output: &FnRetTy,
1961+
transform: FnReturnTransformation,
19371962
span: Span,
19381963
nested_impl_trait_context: ImplTraitContext,
19391964
) -> hir::GenericBound<'hir> {
@@ -1948,17 +1973,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19481973
FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
19491974
};
19501975

1951-
// "<Output = T>"
1976+
// "<Output|Item = T>"
1977+
let (symbol, lang_item) = match transform {
1978+
FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
1979+
FnReturnTransformation::Iterator(..) => {
1980+
(hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator)
1981+
}
1982+
};
1983+
19521984
let future_args = self.arena.alloc(hir::GenericArgs {
19531985
args: &[],
1954-
bindings: arena_vec![self; self.output_ty_binding(span, output_ty)],
1986+
bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
19551987
parenthesized: hir::GenericArgsParentheses::No,
19561988
span_ext: DUMMY_SP,
19571989
});
19581990

19591991
hir::GenericBound::LangItemTrait(
1960-
// ::std::future::Future<future_params>
1961-
hir::LangItem::Future,
1992+
lang_item,
19621993
self.lower_span(span),
19631994
self.next_id(),
19641995
future_args,

0 commit comments

Comments
 (0)