@@ -5,7 +5,7 @@ use rustc_hir as hir;
5
5
use rustc_hir:: def:: { DefKind , Res } ;
6
6
use rustc_lint_defs:: Applicability ;
7
7
use rustc_lint_defs:: builtin:: BARE_TRAIT_OBJECTS ;
8
- use rustc_span:: Span ;
8
+ use rustc_span:: { Span , sym } ;
9
9
use rustc_trait_selection:: error_reporting:: traits:: suggestions:: NextTypeParamName ;
10
10
11
11
use super :: HirTyLowerer ;
@@ -181,14 +181,14 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
181
181
/// Make sure that we are in the condition to suggest `impl Trait`.
182
182
fn maybe_suggest_impl_trait ( & self , self_ty : & hir:: Ty < ' _ > , diag : & mut Diag < ' _ > ) -> bool {
183
183
let tcx = self . tcx ( ) ;
184
- let parent_id = tcx. hir ( ) . get_parent_item ( self_ty. hir_id ) . def_id ;
184
+ let parent_node = tcx. hir_node_by_def_id ( tcx . hir ( ) . get_parent_item ( self_ty. hir_id ) . def_id ) ;
185
185
// FIXME: If `type_alias_impl_trait` is enabled, also look for `Trait0<Ty = Trait1>`
186
186
// and suggest `Trait0<Ty = impl Trait1>`.
187
187
// Functions are found in three different contexts.
188
188
// 1. Independent functions
189
189
// 2. Functions inside trait blocks
190
190
// 3. Functions inside impl blocks
191
- let ( sig, generics) = match tcx . hir_node_by_def_id ( parent_id ) {
191
+ let ( sig, generics) = match parent_node {
192
192
hir:: Node :: Item ( hir:: Item { kind : hir:: ItemKind :: Fn ( sig, generics, _) , .. } ) => {
193
193
( sig, generics)
194
194
}
@@ -223,10 +223,17 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
223
223
tcx. parent_hir_node( self_ty. hir_id) ,
224
224
hir:: Node :: Ty ( hir:: Ty { kind: hir:: TyKind :: Ref ( ..) , .. } )
225
225
) ;
226
+ let is_non_trait_object = |ty : & ' tcx hir:: Ty < ' _ > | {
227
+ if sig. header . is_async ( ) {
228
+ Self :: get_future_inner_return_ty ( ty) . map_or ( false , |ty| ty. hir_id == self_ty. hir_id )
229
+ } else {
230
+ ty. peel_refs ( ) . hir_id == self_ty. hir_id
231
+ }
232
+ } ;
226
233
227
234
// Suggestions for function return type.
228
235
if let hir:: FnRetTy :: Return ( ty) = sig. decl . output
229
- && ty . peel_refs ( ) . hir_id == self_ty . hir_id
236
+ && is_non_trait_object ( ty )
230
237
{
231
238
let pre = if !is_dyn_compatible {
232
239
format ! ( "`{trait_name}` is dyn-incompatible, " )
@@ -311,10 +318,21 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
311
318
}
312
319
313
320
fn maybe_suggest_assoc_ty_bound ( & self , self_ty : & hir:: Ty < ' _ > , diag : & mut Diag < ' _ > ) {
314
- let mut parents = self . tcx ( ) . hir ( ) . parent_iter ( self_ty. hir_id ) ;
321
+ let mut parents = self . tcx ( ) . hir ( ) . parent_iter ( self_ty. hir_id ) . peekable ( ) ;
322
+ let is_async_fn = if let Some ( ( _, parent) ) = parents. peek ( )
323
+ && let Some ( sig) = parent. fn_sig ( )
324
+ && sig. header . is_async ( )
325
+ && let hir:: FnRetTy :: Return ( ty) = sig. decl . output
326
+ && Self :: get_future_inner_return_ty ( ty) . is_some ( )
327
+ {
328
+ true
329
+ } else {
330
+ false
331
+ } ;
315
332
316
333
if let Some ( ( _, hir:: Node :: AssocItemConstraint ( constraint) ) ) = parents. next ( )
317
334
&& let Some ( obj_ty) = constraint. ty ( )
335
+ && !is_async_fn
318
336
{
319
337
if let Some ( ( _, hir:: Node :: TraitRef ( ..) ) ) = parents. next ( )
320
338
&& let Some ( ( _, hir:: Node :: Ty ( ty) ) ) = parents. next ( )
@@ -343,4 +361,34 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
343
361
) ;
344
362
}
345
363
}
364
+
365
+ /// From the [`hir::Ty`] of an async function's lowered return type,
366
+ /// retrieve the `hir::Ty` representing the type the user originally wrote.
367
+ ///
368
+ /// e.g. given the function:
369
+ ///
370
+ /// ```
371
+ /// async fn foo() -> i32 { 2 }
372
+ /// ```
373
+ ///
374
+ /// this function, given the lowered return type of `foo`, an [`OpaqueDef`] that implements `Future<Output=i32>`,
375
+ /// returns the `i32`.
376
+ ///
377
+ /// [`OpaqueDef`]: hir::TyKind::OpaqueDef
378
+ pub fn get_future_inner_return_ty < ' a > ( hir_ty : & ' a hir:: Ty < ' a > ) -> Option < & ' a hir:: Ty < ' a > > {
379
+ let hir:: TyKind :: OpaqueDef ( opaque_ty, _) = hir_ty. kind else {
380
+ return None ;
381
+ } ;
382
+ if let hir:: OpaqueTy { bounds : [ hir:: GenericBound :: Trait ( trait_ref) ] , .. } = opaque_ty
383
+ && let Some ( segment) = trait_ref. trait_ref . path . segments . last ( )
384
+ && let Some ( args) = segment. args
385
+ && let [ constraint] = args. constraints
386
+ && constraint. ident . name == sym:: Output
387
+ && let Some ( ty) = constraint. ty ( )
388
+ {
389
+ Some ( ty)
390
+ } else {
391
+ None
392
+ }
393
+ }
346
394
}
0 commit comments