Skip to content

Commit ee012e0

Browse files
authored
Merge pull request #20061 from ChayimFriedman2/wrap-ret-ty
fix: In "Wrap return type" assist, don't wrap exit points if they already have the right type
2 parents 5cda2dd + 78427be commit ee012e0

File tree

2 files changed

+133
-34
lines changed

2 files changed

+133
-34
lines changed

crates/hir/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,10 +1727,10 @@ impl Adt {
17271727
pub fn ty_with_args<'db>(
17281728
self,
17291729
db: &'db dyn HirDatabase,
1730-
args: impl Iterator<Item = Type<'db>>,
1730+
args: impl IntoIterator<Item = Type<'db>>,
17311731
) -> Type<'db> {
17321732
let id = AdtId::from(self);
1733-
let mut it = args.map(|t| t.ty);
1733+
let mut it = args.into_iter().map(|t| t.ty);
17341734
let ty = TyBuilder::def_ty(db, id.into(), None)
17351735
.fill(|x| {
17361736
let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner));

crates/ide-assists/src/handlers/wrap_return_type.rs

Lines changed: 131 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,16 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
5656
};
5757

5858
let type_ref = &ret_type.ty()?;
59-
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
59+
let ty = ctx.sema.resolve_type(type_ref)?;
60+
let ty_adt = ty.as_adt();
6061
let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate());
6162

6263
for kind in WrapperKind::ALL {
6364
let Some(core_wrapper) = kind.core_type(&famous_defs) else {
6465
continue;
6566
};
6667

67-
if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
68+
if matches!(ty_adt, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
6869
// The return type is already wrapped
6970
cov_mark::hit!(wrap_return_type_simple_return_type_already_wrapped);
7071
continue;
@@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
7879
|builder| {
7980
let mut editor = builder.make_editor(&parent);
8081
let make = SyntaxFactory::with_mappings();
81-
let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
82-
let new_return_ty = alias.unwrap_or_else(|| match kind {
83-
WrapperKind::Option => make.ty_option(type_ref.clone()),
84-
WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
82+
let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol());
83+
let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| {
84+
let (ast_ty, ty_constructor) = match kind {
85+
WrapperKind::Option => {
86+
(make.ty_option(type_ref.clone()), famous_defs.core_option_Option())
87+
}
88+
WrapperKind::Result => (
89+
make.ty_result(type_ref.clone(), make.ty_infer().into()),
90+
famous_defs.core_result_Result(),
91+
),
92+
};
93+
let semantic_ty = ty_constructor
94+
.map(|ty_constructor| {
95+
hir::Adt::from(ty_constructor).ty_with_args(ctx.db(), [ty.clone()])
96+
})
97+
.unwrap_or_else(|| ty.clone());
98+
(ast_ty, semantic_ty)
8599
});
86100

87101
let mut exprs_to_wrap = Vec::new();
@@ -96,19 +110,30 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
96110
for_each_tail_expr(&body_expr, tail_cb);
97111

98112
for ret_expr_arg in exprs_to_wrap {
113+
if let Some(ty) = ctx.sema.type_of_expr(&ret_expr_arg) {
114+
if ty.adjusted().could_unify_with(ctx.db(), &semantic_new_return_ty) {
115+
// The type is already correct, don't wrap it.
116+
// We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
117+
// enum matches it's okay for us, as we don't trigger the assist if the return type
118+
// is already `Option`/`Result`, so mismatched exact type is more likely a mistake
119+
// than something intended.
120+
continue;
121+
}
122+
}
123+
99124
let happy_wrapped = make.expr_call(
100125
make.expr_path(make.ident_path(kind.happy_ident())),
101126
make.arg_list(iter::once(ret_expr_arg.clone())),
102127
);
103128
editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
104129
}
105130

106-
editor.replace(type_ref.syntax(), new_return_ty.syntax());
131+
editor.replace(type_ref.syntax(), ast_new_return_ty.syntax());
107132

108133
if let WrapperKind::Result = kind {
109134
// Add a placeholder snippet at the first generic argument that doesn't equal the return type.
110135
// This is normally the error type, but that may not be the case when we inserted a type alias.
111-
let args = new_return_ty
136+
let args = ast_new_return_ty
112137
.path()
113138
.unwrap()
114139
.segment()
@@ -188,35 +213,36 @@ impl WrapperKind {
188213
}
189214

190215
// Try to find an wrapper type alias in the current scope (shadowing the default).
191-
fn wrapper_alias(
192-
ctx: &AssistContext<'_>,
216+
fn wrapper_alias<'db>(
217+
ctx: &AssistContext<'db>,
193218
make: &SyntaxFactory,
194-
core_wrapper: &hir::Enum,
195-
ret_type: &ast::Type,
219+
core_wrapper: hir::Enum,
220+
ast_ret_type: &ast::Type,
221+
semantic_ret_type: &hir::Type<'db>,
196222
wrapper: hir::Symbol,
197-
) -> Option<ast::PathType> {
223+
) -> Option<(ast::PathType, hir::Type<'db>)> {
198224
let wrapper_path = hir::ModPath::from_segments(
199225
hir::PathKind::Plain,
200226
iter::once(hir::Name::new_symbol_root(wrapper)),
201227
);
202228

203-
ctx.sema.resolve_mod_path(ret_type.syntax(), &wrapper_path).and_then(|def| {
229+
ctx.sema.resolve_mod_path(ast_ret_type.syntax(), &wrapper_path).and_then(|def| {
204230
def.filter_map(|def| match def.into_module_def() {
205231
hir::ModuleDef::TypeAlias(alias) => {
206232
let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
207-
(&enum_ty == core_wrapper).then_some(alias)
233+
(enum_ty == core_wrapper).then_some((alias, enum_ty))
208234
}
209235
_ => None,
210236
})
211-
.find_map(|alias| {
237+
.find_map(|(alias, enum_ty)| {
212238
let mut inserted_ret_type = false;
213239
let generic_args =
214240
alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
215241
match param {
216242
// Replace the very first type parameter with the function's return type.
217243
ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
218244
inserted_ret_type = true;
219-
make.type_arg(ret_type.clone()).into()
245+
make.type_arg(ast_ret_type.clone()).into()
220246
}
221247
ast::GenericParam::LifetimeParam(_) => {
222248
make.lifetime_arg(make.lifetime("'_")).into()
@@ -231,7 +257,10 @@ fn wrapper_alias(
231257
make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
232258
);
233259

234-
Some(make.ty_path(path))
260+
let new_ty =
261+
hir::Adt::from(enum_ty).ty_with_args(ctx.db(), [semantic_ret_type.clone()]);
262+
263+
Some((make.ty_path(path), new_ty))
235264
})
236265
})
237266
}
@@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
605634
check_assist_by_label(
606635
wrap_return_type,
607636
r#"
608-
//- minicore: option
637+
//- minicore: option, future
638+
struct F(i32);
639+
impl core::future::Future for F {
640+
type Output = i32;
641+
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
642+
}
609643
async fn foo() -> i$032 {
610644
if true {
611645
if false {
612-
1.await
646+
F(1).await
613647
} else {
614-
2.await
648+
F(2).await
615649
}
616650
} else {
617-
24i32.await
651+
F(24i32).await
618652
}
619653
}
620654
"#,
621655
r#"
656+
struct F(i32);
657+
impl core::future::Future for F {
658+
type Output = i32;
659+
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
660+
}
622661
async fn foo() -> Option<i32> {
623662
if true {
624663
if false {
625-
Some(1.await)
664+
Some(F(1).await)
626665
} else {
627-
Some(2.await)
666+
Some(F(2).await)
628667
}
629668
} else {
630-
Some(24i32.await)
669+
Some(F(24i32).await)
631670
}
632671
}
633672
"#,
@@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
16661705
check_assist_by_label(
16671706
wrap_return_type,
16681707
r#"
1669-
//- minicore: result
1708+
//- minicore: result, future
1709+
struct F(i32);
1710+
impl core::future::Future for F {
1711+
type Output = i32;
1712+
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1713+
}
16701714
async fn foo() -> i$032 {
16711715
if true {
16721716
if false {
1673-
1.await
1717+
F(1).await
16741718
} else {
1675-
2.await
1719+
F(2).await
16761720
}
16771721
} else {
1678-
24i32.await
1722+
F(24i32).await
16791723
}
16801724
}
16811725
"#,
16821726
r#"
1727+
struct F(i32);
1728+
impl core::future::Future for F {
1729+
type Output = i32;
1730+
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1731+
}
16831732
async fn foo() -> Result<i32, ${0:_}> {
16841733
if true {
16851734
if false {
1686-
Ok(1.await)
1735+
Ok(F(1).await)
16871736
} else {
1688-
Ok(2.await)
1737+
Ok(F(2).await)
16891738
}
16901739
} else {
1691-
Ok(24i32.await)
1740+
Ok(F(24i32).await)
16921741
}
16931742
}
16941743
"#,
@@ -2455,6 +2504,56 @@ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
24552504
24562505
fn foo() -> Result<i32, ${0:_}> {
24572506
Ok(0)
2507+
}
2508+
"#,
2509+
WrapperKind::Result.label(),
2510+
);
2511+
}
2512+
2513+
#[test]
2514+
fn already_wrapped() {
2515+
check_assist_by_label(
2516+
wrap_return_type,
2517+
r#"
2518+
//- minicore: option
2519+
fn foo() -> i32$0 {
2520+
if false {
2521+
0
2522+
} else {
2523+
Some(1)
2524+
}
2525+
}
2526+
"#,
2527+
r#"
2528+
fn foo() -> Option<i32> {
2529+
if false {
2530+
Some(0)
2531+
} else {
2532+
Some(1)
2533+
}
2534+
}
2535+
"#,
2536+
WrapperKind::Option.label(),
2537+
);
2538+
check_assist_by_label(
2539+
wrap_return_type,
2540+
r#"
2541+
//- minicore: result
2542+
fn foo() -> i32$0 {
2543+
if false {
2544+
0
2545+
} else {
2546+
Ok(1)
2547+
}
2548+
}
2549+
"#,
2550+
r#"
2551+
fn foo() -> Result<i32, ${0:_}> {
2552+
if false {
2553+
Ok(0)
2554+
} else {
2555+
Ok(1)
2556+
}
24582557
}
24592558
"#,
24602559
WrapperKind::Result.label(),

0 commit comments

Comments
 (0)