Skip to content

Commit ae7eff0

Browse files
pnkfelixcelinval
authored andcommitted
Desugars contract into the internal AST extensions
Check ensures on early return due to Try / Yeet Expand these two expressions to include a call to contract checking
1 parent 38eff16 commit ae7eff0

File tree

12 files changed

+453
-84
lines changed

12 files changed

+453
-84
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+31-17
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
314314
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
315315
}
316316
ExprKind::Ret(e) => {
317-
let mut e = e.as_ref().map(|x| self.lower_expr(x));
318-
if let Some(Some((span, fresh_ident))) = self
319-
.contract
320-
.as_ref()
321-
.map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
322-
{
323-
let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2);
324-
let args = if let Some(e) = e {
325-
std::slice::from_ref(e)
326-
} else {
327-
std::slice::from_ref(self.expr_unit(span))
328-
};
329-
e = Some(self.expr_call(span, checker_fn, args));
330-
}
331-
hir::ExprKind::Ret(e)
317+
let expr = e.as_ref().map(|x| self.lower_expr(x));
318+
self.checked_return(expr)
332319
}
333320
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
334321
ExprKind::Become(sub_expr) => {
@@ -395,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
395382
})
396383
}
397384

385+
/// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause.
386+
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
387+
let checked_ret = if let Some(Some((span, fresh_ident))) =
388+
self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
389+
{
390+
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span));
391+
Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2))
392+
} else {
393+
opt_expr
394+
};
395+
hir::ExprKind::Ret(checked_ret)
396+
}
397+
398+
/// Wraps an expression with a call to the ensures check before it gets returned.
399+
pub(crate) fn inject_ensures_check(
400+
&mut self,
401+
expr: &'hir hir::Expr<'hir>,
402+
span: Span,
403+
check_ident: Ident,
404+
check_hir_id: HirId,
405+
) -> &'hir hir::Expr<'hir> {
406+
let checker_fn = self.expr_ident(span, check_ident, check_hir_id);
407+
let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None);
408+
self.expr_call(span, checker_fn, std::slice::from_ref(expr))
409+
}
410+
398411
pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
399412
self.with_new_scopes(c.value.span, |this| {
400413
let def_id = this.local_def_id(c.id);
@@ -1983,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
19831996
),
19841997
))
19851998
} else {
1986-
self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr))))
1999+
let ret_expr = self.checked_return(Some(from_residual_expr));
2000+
self.arena.alloc(self.expr(try_span, ret_expr))
19872001
};
19882002
self.lower_attrs(ret_expr.hir_id, &attrs);
19892003

@@ -2032,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
20322046
let target_id = Ok(catch_id);
20332047
hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr))
20342048
} else {
2035-
hir::ExprKind::Ret(Some(from_yeet_expr))
2049+
self.checked_return(Some(from_yeet_expr))
20362050
}
20372051
}
20382052

compiler/rustc_ast_lowering/src/item.rs

+48-67
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
215215
if let Some(contract) = contract {
216216
let requires = contract.requires.clone();
217217
let ensures = contract.ensures.clone();
218-
let ensures = if let Some(ens) = ensures {
218+
let ensures = ensures.map(|ens| {
219219
// FIXME: this needs to be a fresh (or illegal) identifier to prevent
220220
// accidental capture of a parameter or global variable.
221221
let checker_ident: Ident =
@@ -226,13 +226,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
226226
hir::BindingMode::NONE,
227227
);
228228

229-
Some(crate::FnContractLoweringEnsures {
229+
crate::FnContractLoweringEnsures {
230230
expr: ens,
231231
fresh_ident: (checker_ident, checker_pat, checker_hir_id),
232-
})
233-
} else {
234-
None
235-
};
232+
}
233+
});
236234

237235
// Note: `with_new_scopes` will reinstall the outer
238236
// item's contract (if any) after its callback finishes.
@@ -1095,73 +1093,56 @@ impl<'hir> LoweringContext<'_, 'hir> {
10951093

10961094
// { body }
10971095
// ==>
1098-
// { rustc_contract_requires(PRECOND); { body } }
1099-
let result: hir::Expr<'hir> = if let Some(contract) = opt_contract {
1100-
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
1101-
this.expr(contract.span, hir::ExprKind::Tup(&[]))
1102-
};
1103-
1104-
let precond: hir::Stmt<'hir> = if let Some(req) = contract.requires {
1105-
let lowered_req = this.lower_expr_mut(&req);
1106-
let precond = this.expr_call_lang_item_fn_mut(
1107-
req.span,
1108-
hir::LangItem::ContractCheckRequires,
1109-
&*arena_vec![this; lowered_req],
1110-
);
1111-
this.stmt_expr(req.span, precond)
1112-
} else {
1113-
let u = lit_unit(this);
1114-
this.stmt_expr(contract.span, u)
1115-
};
1096+
// { contract_requires(PRECOND); { body } }
1097+
let Some(contract) = opt_contract else { return (params, result) };
1098+
let result_ref = this.arena.alloc(result);
1099+
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
1100+
this.expr(contract.span, hir::ExprKind::Tup(&[]))
1101+
};
11161102

1117-
let (postcond_checker, result) = if let Some(ens) = contract.ensures {
1118-
let crate::FnContractLoweringEnsures { expr: ens, fresh_ident } = ens;
1119-
let lowered_ens: hir::Expr<'hir> = this.lower_expr_mut(&ens);
1120-
let postcond_checker = this.expr_call_lang_item_fn(
1121-
ens.span,
1122-
hir::LangItem::ContractBuildCheckEnsures,
1123-
&*arena_vec![this; lowered_ens],
1124-
);
1125-
let checker_binding_pat = fresh_ident.1;
1126-
(
1127-
this.stmt_let_pat(
1128-
None,
1129-
ens.span,
1130-
Some(postcond_checker),
1131-
this.arena.alloc(checker_binding_pat),
1132-
hir::LocalSource::Contract,
1133-
),
1134-
{
1135-
let checker_fn =
1136-
this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2);
1137-
let span = this.mark_span_with_reason(
1138-
DesugaringKind::Contract,
1139-
ens.span,
1140-
None,
1141-
);
1142-
this.expr_call_mut(
1143-
span,
1144-
checker_fn,
1145-
std::slice::from_ref(this.arena.alloc(result)),
1146-
)
1147-
},
1148-
)
1149-
} else {
1150-
let u = lit_unit(this);
1151-
(this.stmt_expr(contract.span, u), result)
1152-
};
1103+
let precond: hir::Stmt<'hir> = if let Some(req) = contract.requires {
1104+
let lowered_req = this.lower_expr_mut(&req);
1105+
let precond = this.expr_call_lang_item_fn_mut(
1106+
req.span,
1107+
hir::LangItem::ContractCheckRequires,
1108+
&*arena_vec![this; lowered_req],
1109+
);
1110+
this.stmt_expr(req.span, precond)
1111+
} else {
1112+
let u = lit_unit(this);
1113+
this.stmt_expr(contract.span, u)
1114+
};
11531115

1154-
let block = this.block_all(
1155-
contract.span,
1156-
arena_vec![this; precond, postcond_checker],
1157-
Some(this.arena.alloc(result)),
1116+
let (postcond_checker, result) = if let Some(ens) = contract.ensures {
1117+
let crate::FnContractLoweringEnsures { expr: ens, fresh_ident } = ens;
1118+
let lowered_ens: hir::Expr<'hir> = this.lower_expr_mut(&ens);
1119+
let postcond_checker = this.expr_call_lang_item_fn(
1120+
ens.span,
1121+
hir::LangItem::ContractBuildCheckEnsures,
1122+
&*arena_vec![this; lowered_ens],
11581123
);
1159-
this.expr_block(block)
1124+
let checker_binding_pat = fresh_ident.1;
1125+
(
1126+
this.stmt_let_pat(
1127+
None,
1128+
ens.span,
1129+
Some(postcond_checker),
1130+
this.arena.alloc(checker_binding_pat),
1131+
hir::LocalSource::Contract,
1132+
),
1133+
this.inject_ensures_check(result_ref, ens.span, fresh_ident.0, fresh_ident.2),
1134+
)
11601135
} else {
1161-
result
1136+
let u = lit_unit(this);
1137+
(this.stmt_expr(contract.span, u), &*result_ref)
11621138
};
11631139

1164-
(params, result)
1140+
let block = this.block_all(
1141+
contract.span,
1142+
arena_vec![this; precond, postcond_checker],
1143+
Some(result),
1144+
);
1145+
(params, this.expr_block(block))
11651146
})
11661147
}
11671148

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#![allow(unused_imports, unused_variables)]
2+
3+
use rustc_ast::token;
4+
use rustc_ast::tokenstream::{DelimSpacing, DelimSpan, Spacing, TokenStream, TokenTree};
5+
use rustc_errors::ErrorGuaranteed;
6+
use rustc_expand::base::{AttrProcMacro, ExtCtxt};
7+
use rustc_span::Span;
8+
use rustc_span::symbol::{Ident, Symbol, kw, sym};
9+
10+
pub(crate) struct ExpandRequires;
11+
12+
pub(crate) struct ExpandEnsures;
13+
14+
impl AttrProcMacro for ExpandRequires {
15+
fn expand<'cx>(
16+
&self,
17+
ecx: &'cx mut ExtCtxt<'_>,
18+
span: Span,
19+
annotation: TokenStream,
20+
annotated: TokenStream,
21+
) -> Result<TokenStream, ErrorGuaranteed> {
22+
expand_requires_tts(ecx, span, annotation, annotated)
23+
}
24+
}
25+
26+
impl AttrProcMacro for ExpandEnsures {
27+
fn expand<'cx>(
28+
&self,
29+
ecx: &'cx mut ExtCtxt<'_>,
30+
span: Span,
31+
annotation: TokenStream,
32+
annotated: TokenStream,
33+
) -> Result<TokenStream, ErrorGuaranteed> {
34+
expand_ensures_tts(ecx, span, annotation, annotated)
35+
}
36+
}
37+
38+
fn expand_injecting_circa_where_clause(
39+
_ecx: &mut ExtCtxt<'_>,
40+
attr_span: Span,
41+
annotated: TokenStream,
42+
inject: impl FnOnce(&mut Vec<TokenTree>) -> Result<(), ErrorGuaranteed>,
43+
) -> Result<TokenStream, ErrorGuaranteed> {
44+
let mut new_tts = Vec::with_capacity(annotated.len());
45+
let mut cursor = annotated.into_trees();
46+
47+
// Find the `fn name<G,...>(x:X,...)` and inject the AST contract forms right after
48+
// the formal parameters (and return type if any).
49+
while let Some(tt) = cursor.next_ref() {
50+
new_tts.push(tt.clone());
51+
if let TokenTree::Token(tok, _) = tt
52+
&& tok.is_ident_named(kw::Fn)
53+
{
54+
break;
55+
}
56+
}
57+
58+
// Found the `fn` keyword, now find the formal parameters.
59+
//
60+
// FIXME: can this fail if you have parentheticals in a generics list, like `fn foo<F: Fn(X) -> Y>` ?
61+
while let Some(tt) = cursor.next_ref() {
62+
new_tts.push(tt.clone());
63+
64+
if let TokenTree::Delimited(_, _, token::Delimiter::Parenthesis, _) = tt {
65+
break;
66+
}
67+
if let TokenTree::Token(token::Token { kind: token::TokenKind::Semi, .. }, _) = tt {
68+
panic!("contract attribute applied to fn without parameter list.");
69+
}
70+
}
71+
72+
// There *might* be a return type declaration (and figuring out where that ends would require
73+
// parsing an arbitrary type expression, e.g. `-> Foo<args ...>`
74+
//
75+
// Instead of trying to figure that out, scan ahead and look for the first occurence of a
76+
// `where`, a `{ ... }`, or a `;`.
77+
//
78+
// FIXME: this might still fall into a trap for something like `-> Ctor<T, const { 0 }>`. I
79+
// *think* such cases must be under a Delimited (e.g. `[T; { N }]` or have the braced form
80+
// prefixed by e.g. `const`, so we should still be able to filter them out without having to
81+
// parse the type expression itself. But rather than try to fix things with hacks like that,
82+
// time might be better spent extending the attribute expander to suport tt-annotation atop
83+
// ast-annotated, which would be an elegant way to sidestep all of this.
84+
let mut opt_next_tt = cursor.next_ref();
85+
while let Some(next_tt) = opt_next_tt {
86+
if let TokenTree::Token(tok, _) = next_tt
87+
&& tok.is_ident_named(kw::Where)
88+
{
89+
break;
90+
}
91+
if let TokenTree::Delimited(_, _, token::Delimiter::Brace, _) = next_tt {
92+
break;
93+
}
94+
if let TokenTree::Token(token::Token { kind: token::TokenKind::Semi, .. }, _) = next_tt {
95+
break;
96+
}
97+
98+
// for anything else, transcribe the tt and keep looking.
99+
new_tts.push(next_tt.clone());
100+
opt_next_tt = cursor.next_ref();
101+
continue;
102+
}
103+
104+
// At this point, we've transcribed everything from the `fn` through the formal parameter list
105+
// and return type declaration, (if any), but `tt` itself has *not* been transcribed.
106+
//
107+
// Now inject the AST contract form.
108+
//
109+
// FIXME: this kind of manual token tree munging does not have significant precedent among
110+
// rustc builtin macros, probably because most builtin macros use direct AST manipulation to
111+
// accomplish similar goals. But since our attributes need to take arbitrary expressions, and
112+
// our attribute infrastructure does not yet support mixing a token-tree annotation with an AST
113+
// annotated, we end up doing token tree manipulation.
114+
inject(&mut new_tts)?;
115+
116+
// Above we injected the internal AST requires/ensures contruct. Now copy over all the other
117+
// token trees.
118+
if let Some(tt) = opt_next_tt {
119+
new_tts.push(tt.clone());
120+
}
121+
while let Some(tt) = cursor.next_ref() {
122+
new_tts.push(tt.clone());
123+
}
124+
125+
Ok(TokenStream::new(new_tts))
126+
}
127+
128+
fn expand_requires_tts(
129+
_ecx: &mut ExtCtxt<'_>,
130+
attr_span: Span,
131+
annotation: TokenStream,
132+
annotated: TokenStream,
133+
) -> Result<TokenStream, ErrorGuaranteed> {
134+
expand_injecting_circa_where_clause(_ecx, attr_span, annotated, |new_tts| {
135+
new_tts.push(TokenTree::Token(
136+
token::Token::from_ast_ident(Ident::new(kw::RustcContractRequires, attr_span)),
137+
Spacing::Joint,
138+
));
139+
new_tts.push(TokenTree::Token(
140+
token::Token::new(token::TokenKind::OrOr, attr_span),
141+
Spacing::Alone,
142+
));
143+
new_tts.push(TokenTree::Delimited(
144+
DelimSpan::from_single(attr_span),
145+
DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
146+
token::Delimiter::Parenthesis,
147+
annotation,
148+
));
149+
Ok(())
150+
})
151+
}
152+
153+
fn expand_ensures_tts(
154+
_ecx: &mut ExtCtxt<'_>,
155+
attr_span: Span,
156+
annotation: TokenStream,
157+
annotated: TokenStream,
158+
) -> Result<TokenStream, ErrorGuaranteed> {
159+
expand_injecting_circa_where_clause(_ecx, attr_span, annotated, |new_tts| {
160+
new_tts.push(TokenTree::Token(
161+
token::Token::from_ast_ident(Ident::new(kw::RustcContractEnsures, attr_span)),
162+
Spacing::Joint,
163+
));
164+
new_tts.push(TokenTree::Delimited(
165+
DelimSpan::from_single(attr_span),
166+
DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
167+
token::Delimiter::Parenthesis,
168+
annotation,
169+
));
170+
Ok(())
171+
})
172+
}

0 commit comments

Comments
 (0)