Skip to content
This repository was archived by the owner on Nov 1, 2021. It is now read-only.

Commit ec7d993

Browse files
committed
[OpenMP] Support for the if-clause on the combined directive 'target parallel'.
The if-clause on the combined directive potentially applies to both the 'target' and the 'parallel' regions. Codegen'ing the if-clause on the combined directive requires additional support because the expression in the clause must be captured by the 'target' capture statement but not the 'parallel' capture statement. Note that this situation arises for other clauses such as num_threads. The OMPIfClause class inherits OMPClauseWithPreInit to support capturing of expressions in the clause. A member CaptureRegion is added to OMPClauseWithPreInit to indicate which captured statement (in this case 'target' but not 'parallel') captures these expressions. To ensure correct codegen of captured expressions in the presence of combined 'target' directives, OMPParallelScope was added to 'parallel' codegen. Reviewers: ABataev Differential Revision: https://reviews.llvm.org/D28781 git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@292437 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent ea8cb15 commit ec7d993

File tree

10 files changed

+652
-40
lines changed

10 files changed

+652
-40
lines changed

include/clang/AST/OpenMPClause.h

+27-13
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,17 @@ class OMPClauseWithPreInit {
7676
friend class OMPClauseReader;
7777
/// Pre-initialization statement for the clause.
7878
Stmt *PreInit;
79+
/// Region that captures the associated stmt.
80+
OpenMPDirectiveKind CaptureRegion;
81+
7982
protected:
8083
/// Set pre-initialization statement for the clause.
81-
void setPreInitStmt(Stmt *S) { PreInit = S; }
82-
OMPClauseWithPreInit(const OMPClause *This) : PreInit(nullptr) {
84+
void setPreInitStmt(Stmt *S, OpenMPDirectiveKind ThisRegion = OMPD_unknown) {
85+
PreInit = S;
86+
CaptureRegion = ThisRegion;
87+
}
88+
OMPClauseWithPreInit(const OMPClause *This)
89+
: PreInit(nullptr), CaptureRegion(OMPD_unknown) {
8390
assert(get(This) && "get is not tuned for pre-init.");
8491
}
8592

@@ -88,6 +95,8 @@ class OMPClauseWithPreInit {
8895
const Stmt *getPreInitStmt() const { return PreInit; }
8996
/// Get pre-initialization statement for the clause.
9097
Stmt *getPreInitStmt() { return PreInit; }
98+
/// Get capture region for the stmt in the clause.
99+
OpenMPDirectiveKind getCaptureRegion() { return CaptureRegion; }
91100
static OMPClauseWithPreInit *get(OMPClause *C);
92101
static const OMPClauseWithPreInit *get(const OMPClause *C);
93102
};
@@ -194,7 +203,7 @@ template <class T> class OMPVarListClause : public OMPClause {
194203
/// In this example directive '#pragma omp parallel' has simple 'if' clause with
195204
/// condition 'a > 5' and directive name modifier 'parallel'.
196205
///
197-
class OMPIfClause : public OMPClause {
206+
class OMPIfClause : public OMPClause, public OMPClauseWithPreInit {
198207
friend class OMPClauseReader;
199208
/// \brief Location of '('.
200209
SourceLocation LParenLoc;
@@ -225,26 +234,31 @@ class OMPIfClause : public OMPClause {
225234
///
226235
/// \param NameModifier [OpenMP 4.1] Directive name modifier of clause.
227236
/// \param Cond Condition of the clause.
237+
/// \param HelperCond Helper condition for the clause.
238+
/// \param CaptureRegion Innermost OpenMP region where expressions in this
239+
/// clause must be captured.
228240
/// \param StartLoc Starting location of the clause.
229241
/// \param LParenLoc Location of '('.
230242
/// \param NameModifierLoc Location of directive name modifier.
231243
/// \param ColonLoc [OpenMP 4.1] Location of ':'.
232244
/// \param EndLoc Ending location of the clause.
233245
///
234-
OMPIfClause(OpenMPDirectiveKind NameModifier, Expr *Cond,
235-
SourceLocation StartLoc, SourceLocation LParenLoc,
236-
SourceLocation NameModifierLoc, SourceLocation ColonLoc,
237-
SourceLocation EndLoc)
238-
: OMPClause(OMPC_if, StartLoc, EndLoc), LParenLoc(LParenLoc),
239-
Condition(Cond), ColonLoc(ColonLoc), NameModifier(NameModifier),
240-
NameModifierLoc(NameModifierLoc) {}
246+
OMPIfClause(OpenMPDirectiveKind NameModifier, Expr *Cond, Stmt *HelperCond,
247+
OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc,
248+
SourceLocation LParenLoc, SourceLocation NameModifierLoc,
249+
SourceLocation ColonLoc, SourceLocation EndLoc)
250+
: OMPClause(OMPC_if, StartLoc, EndLoc), OMPClauseWithPreInit(this),
251+
LParenLoc(LParenLoc), Condition(Cond), ColonLoc(ColonLoc),
252+
NameModifier(NameModifier), NameModifierLoc(NameModifierLoc) {
253+
setPreInitStmt(HelperCond, CaptureRegion);
254+
}
241255

242256
/// \brief Build an empty clause.
243257
///
244258
OMPIfClause()
245-
: OMPClause(OMPC_if, SourceLocation(), SourceLocation()), LParenLoc(),
246-
Condition(nullptr), ColonLoc(), NameModifier(OMPD_unknown),
247-
NameModifierLoc() {}
259+
: OMPClause(OMPC_if, SourceLocation(), SourceLocation()),
260+
OMPClauseWithPreInit(this), LParenLoc(), Condition(nullptr), ColonLoc(),
261+
NameModifier(OMPD_unknown), NameModifierLoc() {}
248262

249263
/// \brief Sets the location of '('.
250264
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

include/clang/AST/RecursiveASTVisitor.h

+1
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,7 @@ bool RecursiveASTVisitor<Derived>::VisitOMPClauseWithPostUpdate(
27112711

27122712
template <typename Derived>
27132713
bool RecursiveASTVisitor<Derived>::VisitOMPIfClause(OMPIfClause *C) {
2714+
TRY_TO(VisitOMPClauseWithPreInit(C));
27142715
TRY_TO(TraverseStmt(C->getCondition()));
27152716
return true;
27162717
}

lib/AST/OpenMPClause.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
4848
return static_cast<const OMPReductionClause *>(C);
4949
case OMPC_linear:
5050
return static_cast<const OMPLinearClause *>(C);
51+
case OMPC_if:
52+
return static_cast<const OMPIfClause *>(C);
5153
case OMPC_default:
5254
case OMPC_proc_bind:
53-
case OMPC_if:
5455
case OMPC_final:
5556
case OMPC_num_threads:
5657
case OMPC_safelen:

lib/AST/StmtProfile.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ void OMPClauseProfiler::VistOMPClauseWithPostUpdate(
283283
}
284284

285285
void OMPClauseProfiler::VisitOMPIfClause(const OMPIfClause *C) {
286+
VistOMPClauseWithPreInit(C);
286287
if (C->getCondition())
287288
Profiler->VisitStmt(C->getCondition());
288289
}

lib/CodeGen/CGStmtOpenMP.cpp

+30-11
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using namespace CodeGen;
2626
namespace {
2727
/// Lexical scope for OpenMP executable constructs, that handles correct codegen
2828
/// for captured expressions.
29-
class OMPLexicalScope final : public CodeGenFunction::LexicalScope {
29+
class OMPLexicalScope : public CodeGenFunction::LexicalScope {
3030
void emitPreInitStmt(CodeGenFunction &CGF, const OMPExecutableDirective &S) {
3131
for (const auto *C : S.clauses()) {
3232
if (auto *CPI = OMPClauseWithPreInit::get(C)) {
@@ -54,10 +54,11 @@ class OMPLexicalScope final : public CodeGenFunction::LexicalScope {
5454

5555
public:
5656
OMPLexicalScope(CodeGenFunction &CGF, const OMPExecutableDirective &S,
57-
bool AsInlined = false)
57+
bool AsInlined = false, bool EmitPreInitStmt = true)
5858
: CodeGenFunction::LexicalScope(CGF, S.getSourceRange()),
5959
InlinedShareds(CGF) {
60-
emitPreInitStmt(CGF, S);
60+
if (EmitPreInitStmt)
61+
emitPreInitStmt(CGF, S);
6162
if (AsInlined) {
6263
if (S.hasAssociatedStmt()) {
6364
auto *CS = cast<CapturedStmt>(S.getAssociatedStmt());
@@ -81,6 +82,22 @@ class OMPLexicalScope final : public CodeGenFunction::LexicalScope {
8182
}
8283
};
8384

85+
/// Lexical scope for OpenMP parallel construct, that handles correct codegen
86+
/// for captured expressions.
87+
class OMPParallelScope final : public OMPLexicalScope {
88+
bool EmitPreInitStmt(const OMPExecutableDirective &S) {
89+
OpenMPDirectiveKind Kind = S.getDirectiveKind();
90+
return !isOpenMPTargetExecutionDirective(Kind) &&
91+
isOpenMPParallelDirective(Kind);
92+
}
93+
94+
public:
95+
OMPParallelScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
96+
: OMPLexicalScope(CGF, S,
97+
/*AsInlined=*/false,
98+
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
99+
};
100+
84101
/// Private scope for OpenMP loop-based directives, that supports capturing
85102
/// of used expression from loop statement.
86103
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
@@ -1237,7 +1254,7 @@ static void emitCommonOMPParallelDirective(CodeGenFunction &CGF,
12371254
}
12381255
}
12391256

1240-
OMPLexicalScope Scope(CGF, S);
1257+
OMPParallelScope Scope(CGF, S);
12411258
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
12421259
CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
12431260
CGF.CGM.getOpenMPRuntime().emitParallelCall(CGF, S.getLocStart(), OutlinedFn,
@@ -3409,17 +3426,17 @@ static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
34093426
CodeGenModule &CGM = CGF.CGM;
34103427
const CapturedStmt &CS = *cast<CapturedStmt>(S.getAssociatedStmt());
34113428

3412-
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
3413-
CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
3414-
34153429
llvm::Function *Fn = nullptr;
34163430
llvm::Constant *FnID = nullptr;
34173431

3418-
// Check if we have any if clause associated with the directive.
34193432
const Expr *IfCond = nullptr;
3420-
3421-
if (auto *C = S.getSingleClause<OMPIfClause>()) {
3422-
IfCond = C->getCondition();
3433+
// Check for the at most one if clause associated with the target region.
3434+
for (const auto *C : S.getClausesOfKind<OMPIfClause>()) {
3435+
if (C->getNameModifier() == OMPD_unknown ||
3436+
C->getNameModifier() == OMPD_target) {
3437+
IfCond = C->getCondition();
3438+
break;
3439+
}
34233440
}
34243441

34253442
// Check if we have any device clause associated with the directive.
@@ -3456,6 +3473,8 @@ static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
34563473
CGM.getOpenMPRuntime().emitTargetOutlinedFunction(S, ParentName, Fn, FnID,
34573474
IsOffloadEntry, CodeGen);
34583475
OMPLexicalScope Scope(CGF, S);
3476+
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
3477+
CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
34593478
CGM.getOpenMPRuntime().emitTargetCall(CGF, S, Fn, FnID, IfCond, Device,
34603479
CapturedVars);
34613480
}

0 commit comments

Comments
 (0)