Skip to content

Commit d29178c

Browse files
Do check_coroutine_obligations once per typeck root
1 parent aa1653e commit d29178c

File tree

8 files changed

+118
-99
lines changed

8 files changed

+118
-99
lines changed

compiler/rustc_hir_analysis/src/check/check.rs

+11-36
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ use rustc_hir as hir;
1010
use rustc_hir::def::{CtorKind, DefKind};
1111
use rustc_hir::Node;
1212
use rustc_infer::infer::{RegionVariableOrigin, TyCtxtInferExt};
13-
use rustc_infer::traits::{Obligation, TraitEngineExt as _};
13+
use rustc_infer::traits::Obligation;
1414
use rustc_lint_defs::builtin::REPR_TRANSPARENT_EXTERNAL_PRIVATE_FIELDS;
1515
use rustc_middle::middle::stability::EvalResult;
16-
use rustc_middle::traits::ObligationCauseCode;
1716
use rustc_middle::ty::fold::BottomUpFolder;
1817
use rustc_middle::ty::layout::{LayoutError, MAX_SIMD_LANES};
1918
use rustc_middle::ty::util::{Discr, InspectCoroutineFields, IntTypeExt};
@@ -26,7 +25,7 @@ use rustc_target::abi::FieldIdx;
2625
use rustc_trait_selection::traits::error_reporting::on_unimplemented::OnUnimplementedDirective;
2726
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
2827
use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt as _;
29-
use rustc_trait_selection::traits::{self, TraitEngine, TraitEngineExt as _};
28+
use rustc_trait_selection::traits::{self};
3029
use rustc_type_ir::fold::TypeFoldable;
3130

3231
use std::cell::LazyCell;
@@ -1541,55 +1540,31 @@ fn opaque_type_cycle_error(
15411540
err.emit()
15421541
}
15431542

1544-
// FIXME(@lcnr): This should not be computed per coroutine, but instead once for
1545-
// each typeck root.
15461543
pub(super) fn check_coroutine_obligations(
15471544
tcx: TyCtxt<'_>,
15481545
def_id: LocalDefId,
15491546
) -> Result<(), ErrorGuaranteed> {
1550-
debug_assert!(tcx.is_coroutine(def_id.to_def_id()));
1547+
debug_assert!(!tcx.is_typeck_child(def_id.to_def_id()));
15511548

1552-
let typeck = tcx.typeck(def_id);
1553-
let param_env = tcx.param_env(typeck.hir_owner.def_id);
1549+
let typeck_results = tcx.typeck(def_id);
1550+
let param_env = tcx.param_env(def_id);
15541551

1555-
let coroutine_stalled_predicates = &typeck.coroutine_stalled_predicates[&def_id];
1556-
debug!(?coroutine_stalled_predicates);
1552+
debug!(?typeck_results.coroutine_stalled_predicates);
15571553

15581554
let infcx = tcx
15591555
.infer_ctxt()
15601556
// typeck writeback gives us predicates with their regions erased.
15611557
// As borrowck already has checked lifetimes, we do not need to do it again.
15621558
.ignoring_regions()
1563-
// Bind opaque types to type checking root, as they should have been checked by borrowck,
1564-
// but may show up in some cases, like when (root) obligations are stalled in the new solver.
1565-
.with_opaque_type_inference(typeck.hir_owner.def_id)
1559+
.with_opaque_type_inference(def_id)
15661560
.build();
15671561

1568-
let mut fulfillment_cx = <dyn TraitEngine<'_>>::new(&infcx);
1569-
for (predicate, cause) in coroutine_stalled_predicates {
1570-
let obligation = Obligation::new(tcx, cause.clone(), param_env, *predicate);
1571-
fulfillment_cx.register_predicate_obligation(&infcx, obligation);
1572-
}
1573-
1574-
if (tcx.features().unsized_locals || tcx.features().unsized_fn_params)
1575-
&& let Some(coroutine) = tcx.mir_coroutine_witnesses(def_id)
1576-
{
1577-
for field_ty in coroutine.field_tys.iter() {
1578-
fulfillment_cx.register_bound(
1579-
&infcx,
1580-
param_env,
1581-
field_ty.ty,
1582-
tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)),
1583-
ObligationCause::new(
1584-
field_ty.source_info.span,
1585-
def_id,
1586-
ObligationCauseCode::SizedCoroutineInterior(def_id),
1587-
),
1588-
);
1589-
}
1562+
let ocx = ObligationCtxt::new(&infcx);
1563+
for (predicate, cause) in &typeck_results.coroutine_stalled_predicates {
1564+
ocx.register_obligation(Obligation::new(tcx, cause.clone(), param_env, *predicate));
15901565
}
15911566

1592-
let errors = fulfillment_cx.select_all_or_error(&infcx);
1567+
let errors = ocx.select_all_or_error();
15931568
debug!(?errors);
15941569
if !errors.is_empty() {
15951570
return Err(infcx.err_ctxt().report_fulfillment_errors(errors));

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
575575
obligations
576576
.extend(self.fulfillment_cx.borrow_mut().drain_unstalled_obligations(&self.infcx));
577577

578-
let obligations = obligations.into_iter().map(|o| (o.predicate, o.cause)).collect();
579-
debug!(?obligations);
580-
self.typeck_results
581-
.borrow_mut()
582-
.coroutine_stalled_predicates
583-
.insert(expr_def_id, obligations);
578+
let obligations = obligations.into_iter().map(|o| (o.predicate, o.cause));
579+
self.typeck_results.borrow_mut().coroutine_stalled_predicates.extend(obligations);
584580
}
585581
}
586582

compiler/rustc_hir_typeck/src/writeback.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -550,15 +550,10 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
550550
fn visit_coroutine_interior(&mut self) {
551551
let fcx_typeck_results = self.fcx.typeck_results.borrow();
552552
assert_eq!(fcx_typeck_results.hir_owner, self.typeck_results.hir_owner);
553-
self.tcx().with_stable_hashing_context(move |ref hcx| {
554-
for (&expr_def_id, predicates) in
555-
fcx_typeck_results.coroutine_stalled_predicates.to_sorted(hcx, false).into_iter()
556-
{
557-
let predicates =
558-
self.resolve(predicates.clone(), &self.fcx.tcx.def_span(expr_def_id));
559-
self.typeck_results.coroutine_stalled_predicates.insert(expr_def_id, predicates);
560-
}
561-
})
553+
for (predicate, cause) in &fcx_typeck_results.coroutine_stalled_predicates {
554+
let (predicate, cause) = self.resolve((*predicate, cause.clone()), &cause.span);
555+
self.typeck_results.coroutine_stalled_predicates.insert((predicate, cause));
556+
}
562557
}
563558

564559
#[instrument(skip(self), level = "debug")]

compiler/rustc_interface/src/passes.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,9 @@ fn run_required_analyses(tcx: TyCtxt<'_>) {
759759
tcx.hir().par_body_owners(|def_id| {
760760
if tcx.is_coroutine(def_id.to_def_id()) {
761761
tcx.ensure().mir_coroutine_witnesses(def_id);
762-
tcx.ensure().check_coroutine_obligations(def_id);
762+
tcx.ensure().check_coroutine_obligations(
763+
tcx.typeck_root_def_id(def_id.to_def_id()).expect_local(),
764+
);
763765
}
764766
});
765767
sess.time("layout_testing", || layout_test::test_layout(tcx));

compiler/rustc_middle/src/ty/typeck_results.rs

+3-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ use crate::{
77
GenericArgs, GenericArgsRef, Ty, UserArgs,
88
},
99
};
10-
use rustc_data_structures::{
11-
fx::FxIndexMap,
12-
unord::{ExtendUnord, UnordItems, UnordSet},
13-
};
10+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
11+
use rustc_data_structures::unord::{ExtendUnord, UnordItems, UnordSet};
1412
use rustc_errors::ErrorGuaranteed;
1513
use rustc_hir::{
1614
self as hir,
@@ -201,8 +199,7 @@ pub struct TypeckResults<'tcx> {
201199

202200
/// Stores the predicates that apply on coroutine witness types.
203201
/// formatting modified file tests/ui/coroutine/retain-resume-ref.rs
204-
pub coroutine_stalled_predicates:
205-
LocalDefIdMap<Vec<(ty::Predicate<'tcx>, ObligationCause<'tcx>)>>,
202+
pub coroutine_stalled_predicates: FxIndexSet<(ty::Predicate<'tcx>, ObligationCause<'tcx>)>,
206203

207204
/// We sometimes treat byte string literals (which are of type `&[u8; N]`)
208205
/// as `&[u8]`, depending on the pattern in which they are used.

compiler/rustc_mir_transform/src/coroutine.rs

+40
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ use rustc_span::symbol::sym;
8080
use rustc_span::Span;
8181
use rustc_target::abi::{FieldIdx, VariantIdx};
8282
use rustc_target::spec::PanicStrategy;
83+
use rustc_trait_selection::infer::TyCtxtInferExt as _;
84+
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
85+
use rustc_trait_selection::traits::ObligationCtxt;
86+
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
8387
use std::{iter, ops};
8488

8589
pub struct StateTransform;
@@ -1584,10 +1588,46 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
15841588
let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
15851589

15861590
check_suspend_tys(tcx, &coroutine_layout, body);
1591+
check_field_tys_sized(tcx, &coroutine_layout, def_id);
15871592

15881593
Some(coroutine_layout)
15891594
}
15901595

1596+
fn check_field_tys_sized<'tcx>(
1597+
tcx: TyCtxt<'tcx>,
1598+
coroutine_layout: &CoroutineLayout<'tcx>,
1599+
def_id: LocalDefId,
1600+
) {
1601+
// No need to check if unsized_locals/unsized_fn_params is disabled,
1602+
// since we will error during typeck.
1603+
if !tcx.features().unsized_locals && !tcx.features().unsized_fn_params {
1604+
return;
1605+
}
1606+
1607+
let infcx = tcx.infer_ctxt().ignoring_regions().build();
1608+
let param_env = tcx.param_env(def_id);
1609+
1610+
let ocx = ObligationCtxt::new(&infcx);
1611+
for field_ty in &coroutine_layout.field_tys {
1612+
ocx.register_bound(
1613+
ObligationCause::new(
1614+
field_ty.source_info.span,
1615+
def_id,
1616+
ObligationCauseCode::SizedCoroutineInterior(def_id),
1617+
),
1618+
param_env,
1619+
field_ty.ty,
1620+
tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)),
1621+
);
1622+
}
1623+
1624+
let errors = ocx.select_all_or_error();
1625+
debug!(?errors);
1626+
if !errors.is_empty() {
1627+
infcx.err_ctxt().report_fulfillment_errors(errors);
1628+
}
1629+
}
1630+
15911631
impl<'tcx> MirPass<'tcx> for StateTransform {
15921632
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
15931633
let Some(old_yield_ty) = body.yield_ty() else {

tests/ui/coroutine/clone-impl.rs

+19-5
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66

77
struct NonClone;
88

9-
fn main() {
9+
fn test1() {
1010
let copyable: u32 = 123;
11-
let clonable_0: Vec<u32> = Vec::new();
12-
let clonable_1: Vec<u32> = Vec::new();
13-
let non_clonable: NonClone = NonClone;
14-
1511
let gen_copy_0 = move || {
1612
yield;
1713
drop(copyable);
1814
};
1915
check_copy(&gen_copy_0);
2016
check_clone(&gen_copy_0);
17+
}
18+
19+
fn test2() {
20+
let copyable: u32 = 123;
2121
let gen_copy_1 = move || {
2222
/*
2323
let v = vec!['a'];
@@ -33,6 +33,10 @@ fn main() {
3333
};
3434
check_copy(&gen_copy_1);
3535
check_clone(&gen_copy_1);
36+
}
37+
38+
fn test3() {
39+
let clonable_0: Vec<u32> = Vec::new();
3640
let gen_clone_0 = move || {
3741
let v = vec!['a'];
3842
yield;
@@ -43,6 +47,10 @@ fn main() {
4347
//~^ ERROR the trait bound `Vec<u32>: Copy` is not satisfied
4448
//~| ERROR the trait bound `Vec<char>: Copy` is not satisfied
4549
check_clone(&gen_clone_0);
50+
}
51+
52+
fn test4() {
53+
let clonable_1: Vec<u32> = Vec::new();
4654
let gen_clone_1 = move || {
4755
let v = vec!['a'];
4856
/*
@@ -59,6 +67,10 @@ fn main() {
5967
//~^ ERROR the trait bound `Vec<u32>: Copy` is not satisfied
6068
//~| ERROR the trait bound `Vec<char>: Copy` is not satisfied
6169
check_clone(&gen_clone_1);
70+
}
71+
72+
fn test5() {
73+
let non_clonable: NonClone = NonClone;
6274
let gen_non_clone = move || {
6375
yield;
6476
drop(non_clonable);
@@ -71,3 +83,5 @@ fn main() {
7183

7284
fn check_copy<T: Copy>(_x: &T) {}
7385
fn check_clone<T: Clone>(_x: &T) {}
86+
87+
fn main() {}

0 commit comments

Comments
 (0)