@@ -56,15 +56,16 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
56
56
} ;
57
57
58
58
let type_ref = & ret_type. ty ( ) ?;
59
- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
59
+ let ty = ctx. sema . resolve_type ( type_ref) ?;
60
+ let ty_adt = ty. as_adt ( ) ;
60
61
let famous_defs = FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) ;
61
62
62
63
for kind in WrapperKind :: ALL {
63
64
let Some ( core_wrapper) = kind. core_type ( & famous_defs) else {
64
65
continue ;
65
66
} ;
66
67
67
- if matches ! ( ty , Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_wrapper) {
68
+ if matches ! ( ty_adt , Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_wrapper) {
68
69
// The return type is already wrapped
69
70
cov_mark:: hit!( wrap_return_type_simple_return_type_already_wrapped) ;
70
71
continue ;
@@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
78
79
|builder| {
79
80
let mut editor = builder. make_editor ( & parent) ;
80
81
let make = SyntaxFactory :: with_mappings ( ) ;
81
- let alias = wrapper_alias ( ctx, & make, & core_wrapper, type_ref, kind. symbol ( ) ) ;
82
- let new_return_ty = alias. unwrap_or_else ( || match kind {
83
- WrapperKind :: Option => make. ty_option ( type_ref. clone ( ) ) ,
84
- WrapperKind :: Result => make. ty_result ( type_ref. clone ( ) , make. ty_infer ( ) . into ( ) ) ,
82
+ let alias = wrapper_alias ( ctx, & make, core_wrapper, type_ref, & ty, kind. symbol ( ) ) ;
83
+ let ( ast_new_return_ty, semantic_new_return_ty) = alias. unwrap_or_else ( || {
84
+ let ( ast_ty, ty_constructor) = match kind {
85
+ WrapperKind :: Option => {
86
+ ( make. ty_option ( type_ref. clone ( ) ) , famous_defs. core_option_Option ( ) )
87
+ }
88
+ WrapperKind :: Result => (
89
+ make. ty_result ( type_ref. clone ( ) , make. ty_infer ( ) . into ( ) ) ,
90
+ famous_defs. core_result_Result ( ) ,
91
+ ) ,
92
+ } ;
93
+ let semantic_ty = ty_constructor
94
+ . map ( |ty_constructor| {
95
+ hir:: Adt :: from ( ty_constructor) . ty_with_args ( ctx. db ( ) , [ ty. clone ( ) ] )
96
+ } )
97
+ . unwrap_or_else ( || ty. clone ( ) ) ;
98
+ ( ast_ty, semantic_ty)
85
99
} ) ;
86
100
87
101
let mut exprs_to_wrap = Vec :: new ( ) ;
@@ -96,19 +110,30 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
96
110
for_each_tail_expr ( & body_expr, tail_cb) ;
97
111
98
112
for ret_expr_arg in exprs_to_wrap {
113
+ if let Some ( ty) = ctx. sema . type_of_expr ( & ret_expr_arg) {
114
+ if ty. adjusted ( ) . could_unify_with ( ctx. db ( ) , & semantic_new_return_ty) {
115
+ // The type is already correct, don't wrap it.
116
+ // We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
117
+ // enum matches it's okay for us, as we don't trigger the assist if the return type
118
+ // is already `Option`/`Result`, so mismatched exact type is more likely a mistake
119
+ // than something intended.
120
+ continue ;
121
+ }
122
+ }
123
+
99
124
let happy_wrapped = make. expr_call (
100
125
make. expr_path ( make. ident_path ( kind. happy_ident ( ) ) ) ,
101
126
make. arg_list ( iter:: once ( ret_expr_arg. clone ( ) ) ) ,
102
127
) ;
103
128
editor. replace ( ret_expr_arg. syntax ( ) , happy_wrapped. syntax ( ) ) ;
104
129
}
105
130
106
- editor. replace ( type_ref. syntax ( ) , new_return_ty . syntax ( ) ) ;
131
+ editor. replace ( type_ref. syntax ( ) , ast_new_return_ty . syntax ( ) ) ;
107
132
108
133
if let WrapperKind :: Result = kind {
109
134
// Add a placeholder snippet at the first generic argument that doesn't equal the return type.
110
135
// This is normally the error type, but that may not be the case when we inserted a type alias.
111
- let args = new_return_ty
136
+ let args = ast_new_return_ty
112
137
. path ( )
113
138
. unwrap ( )
114
139
. segment ( )
@@ -188,35 +213,36 @@ impl WrapperKind {
188
213
}
189
214
190
215
// Try to find an wrapper type alias in the current scope (shadowing the default).
191
- fn wrapper_alias (
192
- ctx : & AssistContext < ' _ > ,
216
+ fn wrapper_alias < ' db > (
217
+ ctx : & AssistContext < ' db > ,
193
218
make : & SyntaxFactory ,
194
- core_wrapper : & hir:: Enum ,
195
- ret_type : & ast:: Type ,
219
+ core_wrapper : hir:: Enum ,
220
+ ast_ret_type : & ast:: Type ,
221
+ semantic_ret_type : & hir:: Type < ' db > ,
196
222
wrapper : hir:: Symbol ,
197
- ) -> Option < ast:: PathType > {
223
+ ) -> Option < ( ast:: PathType , hir :: Type < ' db > ) > {
198
224
let wrapper_path = hir:: ModPath :: from_segments (
199
225
hir:: PathKind :: Plain ,
200
226
iter:: once ( hir:: Name :: new_symbol_root ( wrapper) ) ,
201
227
) ;
202
228
203
- ctx. sema . resolve_mod_path ( ret_type . syntax ( ) , & wrapper_path) . and_then ( |def| {
229
+ ctx. sema . resolve_mod_path ( ast_ret_type . syntax ( ) , & wrapper_path) . and_then ( |def| {
204
230
def. filter_map ( |def| match def. into_module_def ( ) {
205
231
hir:: ModuleDef :: TypeAlias ( alias) => {
206
232
let enum_ty = alias. ty ( ctx. db ( ) ) . as_adt ( ) ?. as_enum ( ) ?;
207
- ( & enum_ty == core_wrapper) . then_some ( alias)
233
+ ( enum_ty == core_wrapper) . then_some ( ( alias, enum_ty ) )
208
234
}
209
235
_ => None ,
210
236
} )
211
- . find_map ( |alias| {
237
+ . find_map ( |( alias, enum_ty ) | {
212
238
let mut inserted_ret_type = false ;
213
239
let generic_args =
214
240
alias. source ( ctx. db ( ) ) ?. value . generic_param_list ( ) ?. generic_params ( ) . map ( |param| {
215
241
match param {
216
242
// Replace the very first type parameter with the function's return type.
217
243
ast:: GenericParam :: TypeParam ( _) if !inserted_ret_type => {
218
244
inserted_ret_type = true ;
219
- make. type_arg ( ret_type . clone ( ) ) . into ( )
245
+ make. type_arg ( ast_ret_type . clone ( ) ) . into ( )
220
246
}
221
247
ast:: GenericParam :: LifetimeParam ( _) => {
222
248
make. lifetime_arg ( make. lifetime ( "'_" ) ) . into ( )
@@ -231,7 +257,10 @@ fn wrapper_alias(
231
257
make. path_segment_generics ( make. name_ref ( name. as_str ( ) ) , generic_arg_list) ,
232
258
) ;
233
259
234
- Some ( make. ty_path ( path) )
260
+ let new_ty =
261
+ hir:: Adt :: from ( enum_ty) . ty_with_args ( ctx. db ( ) , [ semantic_ret_type. clone ( ) ] ) ;
262
+
263
+ Some ( ( make. ty_path ( path) , new_ty) )
235
264
} )
236
265
} )
237
266
}
@@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
605
634
check_assist_by_label (
606
635
wrap_return_type,
607
636
r#"
608
- //- minicore: option
637
+ //- minicore: option, future
638
+ struct F(i32);
639
+ impl core::future::Future for F {
640
+ type Output = i32;
641
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
642
+ }
609
643
async fn foo() -> i$032 {
610
644
if true {
611
645
if false {
612
- 1 .await
646
+ F(1) .await
613
647
} else {
614
- 2 .await
648
+ F(2) .await
615
649
}
616
650
} else {
617
- 24i32.await
651
+ F( 24i32) .await
618
652
}
619
653
}
620
654
"# ,
621
655
r#"
656
+ struct F(i32);
657
+ impl core::future::Future for F {
658
+ type Output = i32;
659
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
660
+ }
622
661
async fn foo() -> Option<i32> {
623
662
if true {
624
663
if false {
625
- Some(1 .await)
664
+ Some(F(1) .await)
626
665
} else {
627
- Some(2 .await)
666
+ Some(F(2) .await)
628
667
}
629
668
} else {
630
- Some(24i32.await)
669
+ Some(F( 24i32) .await)
631
670
}
632
671
}
633
672
"# ,
@@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
1666
1705
check_assist_by_label (
1667
1706
wrap_return_type,
1668
1707
r#"
1669
- //- minicore: result
1708
+ //- minicore: result, future
1709
+ struct F(i32);
1710
+ impl core::future::Future for F {
1711
+ type Output = i32;
1712
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1713
+ }
1670
1714
async fn foo() -> i$032 {
1671
1715
if true {
1672
1716
if false {
1673
- 1 .await
1717
+ F(1) .await
1674
1718
} else {
1675
- 2 .await
1719
+ F(2) .await
1676
1720
}
1677
1721
} else {
1678
- 24i32.await
1722
+ F( 24i32) .await
1679
1723
}
1680
1724
}
1681
1725
"# ,
1682
1726
r#"
1727
+ struct F(i32);
1728
+ impl core::future::Future for F {
1729
+ type Output = i32;
1730
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1731
+ }
1683
1732
async fn foo() -> Result<i32, ${0:_}> {
1684
1733
if true {
1685
1734
if false {
1686
- Ok(1 .await)
1735
+ Ok(F(1) .await)
1687
1736
} else {
1688
- Ok(2 .await)
1737
+ Ok(F(2) .await)
1689
1738
}
1690
1739
} else {
1691
- Ok(24i32.await)
1740
+ Ok(F( 24i32) .await)
1692
1741
}
1693
1742
}
1694
1743
"# ,
@@ -2455,6 +2504,56 @@ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
2455
2504
2456
2505
fn foo() -> Result<i32, ${0:_}> {
2457
2506
Ok(0)
2507
+ }
2508
+ "# ,
2509
+ WrapperKind :: Result . label ( ) ,
2510
+ ) ;
2511
+ }
2512
+
2513
+ #[ test]
2514
+ fn already_wrapped ( ) {
2515
+ check_assist_by_label (
2516
+ wrap_return_type,
2517
+ r#"
2518
+ //- minicore: option
2519
+ fn foo() -> i32$0 {
2520
+ if false {
2521
+ 0
2522
+ } else {
2523
+ Some(1)
2524
+ }
2525
+ }
2526
+ "# ,
2527
+ r#"
2528
+ fn foo() -> Option<i32> {
2529
+ if false {
2530
+ Some(0)
2531
+ } else {
2532
+ Some(1)
2533
+ }
2534
+ }
2535
+ "# ,
2536
+ WrapperKind :: Option . label ( ) ,
2537
+ ) ;
2538
+ check_assist_by_label (
2539
+ wrap_return_type,
2540
+ r#"
2541
+ //- minicore: result
2542
+ fn foo() -> i32$0 {
2543
+ if false {
2544
+ 0
2545
+ } else {
2546
+ Ok(1)
2547
+ }
2548
+ }
2549
+ "# ,
2550
+ r#"
2551
+ fn foo() -> Result<i32, ${0:_}> {
2552
+ if false {
2553
+ Ok(0)
2554
+ } else {
2555
+ Ok(1)
2556
+ }
2458
2557
}
2459
2558
"# ,
2460
2559
WrapperKind :: Result . label ( ) ,
0 commit comments