Skip to content
This repository was archived by the owner on Nov 1, 2021. It is now read-only.

Commit 440f379

Browse files
committed
[OpenMP] Support for the num_threads-clause on 'target parallel'.
The num_threads-clause on the combined directive applies to the 'parallel' region of this construct. We modify the NumThreadsClause class to capture the clause expression within the 'target' region. The offload runtime call for 'target parallel' is changed to __tgt_target_teams() with 1 team and the number of threads set by this clause or a default if none. Reviewers: ABataev Differential Revision: https://reviews.llvm.org/D29082 git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@292997 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 0e1da5f commit 440f379

12 files changed

+555
-52
lines changed

include/clang/AST/OpenMPClause.h

+15-6
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class OMPFinalClause : public OMPClause {
345345
/// In this example directive '#pragma omp parallel' has simple 'num_threads'
346346
/// clause with number of threads '6'.
347347
///
348-
class OMPNumThreadsClause : public OMPClause {
348+
class OMPNumThreadsClause : public OMPClause, public OMPClauseWithPreInit {
349349
friend class OMPClauseReader;
350350
/// \brief Location of '('.
351351
SourceLocation LParenLoc;
@@ -360,20 +360,29 @@ class OMPNumThreadsClause : public OMPClause {
360360
/// \brief Build 'num_threads' clause with condition \a NumThreads.
361361
///
362362
/// \param NumThreads Number of threads for the construct.
363+
/// \param HelperNumThreads Helper Number of threads for the construct.
364+
/// \param CaptureRegion Innermost OpenMP region where expressions in this
365+
/// clause must be captured.
363366
/// \param StartLoc Starting location of the clause.
364367
/// \param LParenLoc Location of '('.
365368
/// \param EndLoc Ending location of the clause.
366369
///
367-
OMPNumThreadsClause(Expr *NumThreads, SourceLocation StartLoc,
368-
SourceLocation LParenLoc, SourceLocation EndLoc)
369-
: OMPClause(OMPC_num_threads, StartLoc, EndLoc), LParenLoc(LParenLoc),
370-
NumThreads(NumThreads) {}
370+
OMPNumThreadsClause(Expr *NumThreads, Stmt *HelperNumThreads,
371+
OpenMPDirectiveKind CaptureRegion,
372+
SourceLocation StartLoc, SourceLocation LParenLoc,
373+
SourceLocation EndLoc)
374+
: OMPClause(OMPC_num_threads, StartLoc, EndLoc),
375+
OMPClauseWithPreInit(this), LParenLoc(LParenLoc),
376+
NumThreads(NumThreads) {
377+
setPreInitStmt(HelperNumThreads, CaptureRegion);
378+
}
371379

372380
/// \brief Build an empty clause.
373381
///
374382
OMPNumThreadsClause()
375383
: OMPClause(OMPC_num_threads, SourceLocation(), SourceLocation()),
376-
LParenLoc(SourceLocation()), NumThreads(nullptr) {}
384+
OMPClauseWithPreInit(this), LParenLoc(SourceLocation()),
385+
NumThreads(nullptr) {}
377386

378387
/// \brief Sets the location of '('.
379388
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

include/clang/AST/RecursiveASTVisitor.h

+1
Original file line numberDiff line numberDiff line change
@@ -2725,6 +2725,7 @@ bool RecursiveASTVisitor<Derived>::VisitOMPFinalClause(OMPFinalClause *C) {
27252725
template <typename Derived>
27262726
bool
27272727
RecursiveASTVisitor<Derived>::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
2728+
TRY_TO(VisitOMPClauseWithPreInit(C));
27282729
TRY_TO(TraverseStmt(C->getNumThreads()));
27292730
return true;
27302731
}

lib/AST/OpenMPClause.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
5050
return static_cast<const OMPLinearClause *>(C);
5151
case OMPC_if:
5252
return static_cast<const OMPIfClause *>(C);
53+
case OMPC_num_threads:
54+
return static_cast<const OMPNumThreadsClause *>(C);
5355
case OMPC_default:
5456
case OMPC_proc_bind:
5557
case OMPC_final:
56-
case OMPC_num_threads:
5758
case OMPC_safelen:
5859
case OMPC_simdlen:
5960
case OMPC_collapse:

lib/AST/StmtProfile.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ void OMPClauseProfiler::VisitOMPFinalClause(const OMPFinalClause *C) {
294294
}
295295

296296
void OMPClauseProfiler::VisitOMPNumThreadsClause(const OMPNumThreadsClause *C) {
297+
VistOMPClauseWithPreInit(C);
297298
if (C->getNumThreads())
298299
Profiler->VisitStmt(C->getNumThreads());
299300
}

lib/CodeGen/CGOpenMPRuntime.cpp

+98-25
Original file line numberDiff line numberDiff line change
@@ -4894,19 +4894,29 @@ static const Stmt *ignoreCompoundStmts(const Stmt *Body) {
48944894
return Body;
48954895
}
48964896

4897-
/// \brief Emit the num_teams clause of an enclosed teams directive at the
4898-
/// target region scope. If there is no teams directive associated with the
4899-
/// target directive, or if there is no num_teams clause associated with the
4900-
/// enclosed teams directive, return nullptr.
4897+
/// Emit the number of teams for a target directive. Inspect the num_teams
4898+
/// clause associated with a teams construct combined or closely nested
4899+
/// with the target directive.
4900+
///
4901+
/// Emit a team of size one for directives such as 'target parallel' that
4902+
/// have no associated teams construct.
4903+
///
4904+
/// Otherwise, return nullptr.
49014905
static llvm::Value *
4902-
emitNumTeamsClauseForTargetDirective(CGOpenMPRuntime &OMPRuntime,
4903-
CodeGenFunction &CGF,
4904-
const OMPExecutableDirective &D) {
4906+
emitNumTeamsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
4907+
CodeGenFunction &CGF,
4908+
const OMPExecutableDirective &D) {
49054909

49064910
assert(!CGF.getLangOpts().OpenMPIsDevice && "Clauses associated with the "
49074911
"teams directive expected to be "
49084912
"emitted only for the host!");
49094913

4914+
// If the target directive is combined with a parallel directive but not a
4915+
// teams directive, start one team.
4916+
if (isOpenMPParallelDirective(D.getDirectiveKind()) &&
4917+
!isOpenMPTeamsDirective(D.getDirectiveKind()))
4918+
return CGF.Builder.getInt32(1);
4919+
49104920
// FIXME: For the moment we do not support combined directives with target and
49114921
// teams, so we do not expect to get any num_teams clause in the provided
49124922
// directive. Once we support that, this assertion can be replaced by the
@@ -4943,19 +4953,56 @@ emitNumTeamsClauseForTargetDirective(CGOpenMPRuntime &OMPRuntime,
49434953
return nullptr;
49444954
}
49454955

4946-
/// \brief Emit the thread_limit clause of an enclosed teams directive at the
4947-
/// target region scope. If there is no teams directive associated with the
4948-
/// target directive, or if there is no thread_limit clause associated with the
4949-
/// enclosed teams directive, return nullptr.
4956+
/// Emit the number of threads for a target directive. Inspect the
4957+
/// thread_limit clause associated with a teams construct combined or closely
4958+
/// nested with the target directive.
4959+
///
4960+
/// Emit the num_threads clause for directives such as 'target parallel' that
4961+
/// have no associated teams construct.
4962+
///
4963+
/// Otherwise, return nullptr.
49504964
static llvm::Value *
4951-
emitThreadLimitClauseForTargetDirective(CGOpenMPRuntime &OMPRuntime,
4952-
CodeGenFunction &CGF,
4953-
const OMPExecutableDirective &D) {
4965+
emitNumThreadsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
4966+
CodeGenFunction &CGF,
4967+
const OMPExecutableDirective &D) {
49544968

49554969
assert(!CGF.getLangOpts().OpenMPIsDevice && "Clauses associated with the "
49564970
"teams directive expected to be "
49574971
"emitted only for the host!");
49584972

4973+
auto &Bld = CGF.Builder;
4974+
4975+
//
4976+
// If the target directive is combined with a teams directive:
4977+
// Return the value in the thread_limit clause, if any.
4978+
//
4979+
// If the target directive is combined with a parallel directive:
4980+
// Return the value in the num_threads clause, if any.
4981+
//
4982+
// If both clauses are set, select the minimum of the two.
4983+
//
4984+
// If neither teams or parallel combined directives set the number of threads
4985+
// in a team, return 0 to denote the runtime default.
4986+
//
4987+
// If this is not a teams directive return nullptr.
4988+
4989+
if (isOpenMPParallelDirective(D.getDirectiveKind())) {
4990+
llvm::Value *DefaultThreadLimitVal = Bld.getInt32(0);
4991+
llvm::Value *NumThreadsVal = nullptr;
4992+
4993+
if (const auto *NumThreadsClause =
4994+
D.getSingleClause<OMPNumThreadsClause>()) {
4995+
CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);
4996+
llvm::Value *NumThreads =
4997+
CGF.EmitScalarExpr(NumThreadsClause->getNumThreads(),
4998+
/*IgnoreResultAssign*/ true);
4999+
NumThreadsVal =
5000+
Bld.CreateIntCast(NumThreads, CGF.Int32Ty, /*IsSigned=*/true);
5001+
}
5002+
5003+
return NumThreadsVal ? NumThreadsVal : DefaultThreadLimitVal;
5004+
}
5005+
49595006
// FIXME: For the moment we do not support combined directives with target and
49605007
// teams, so we do not expect to get any thread_limit clause in the provided
49615008
// directive. Once we support that, this assertion can be replaced by the
@@ -6041,24 +6088,50 @@ void CGOpenMPRuntime::emitTargetCall(CodeGenFunction &CGF,
60416088
// Return value of the runtime offloading call.
60426089
llvm::Value *Return;
60436090

6044-
auto *NumTeams = emitNumTeamsClauseForTargetDirective(RT, CGF, D);
6045-
auto *ThreadLimit = emitThreadLimitClauseForTargetDirective(RT, CGF, D);
6091+
auto *NumTeams = emitNumTeamsForTargetDirective(RT, CGF, D);
6092+
auto *NumThreads = emitNumThreadsForTargetDirective(RT, CGF, D);
60466093

6047-
// If we have NumTeams defined this means that we have an enclosed teams
6048-
// region. Therefore we also expect to have ThreadLimit defined. These two
6049-
// values should be defined in the presence of a teams directive, regardless
6050-
// of having any clauses associated. If the user is using teams but no
6051-
// clauses, these two values will be the default that should be passed to
6052-
// the runtime library - a 32-bit integer with the value zero.
6094+
// The target region is an outlined function launched by the runtime
6095+
// via calls __tgt_target() or __tgt_target_teams().
6096+
//
6097+
// __tgt_target() launches a target region with one team and one thread,
6098+
// executing a serial region. This master thread may in turn launch
6099+
// more threads within its team upon encountering a parallel region,
6100+
// however, no additional teams can be launched on the device.
6101+
//
6102+
// __tgt_target_teams() launches a target region with one or more teams,
6103+
// each with one or more threads. This call is required for target
6104+
// constructs such as:
6105+
// 'target teams'
6106+
// 'target' / 'teams'
6107+
// 'target teams distribute parallel for'
6108+
// 'target parallel'
6109+
// and so on.
6110+
//
6111+
// Note that on the host and CPU targets, the runtime implementation of
6112+
// these calls simply call the outlined function without forking threads.
6113+
// The outlined functions themselves have runtime calls to
6114+
// __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
6115+
// the compiler in emitTeamsCall() and emitParallelCall().
6116+
//
6117+
// In contrast, on the NVPTX target, the implementation of
6118+
// __tgt_target_teams() launches a GPU kernel with the requested number
6119+
// of teams and threads so no additional calls to the runtime are required.
60536120
if (NumTeams) {
6054-
assert(ThreadLimit && "Thread limit expression should be available along "
6055-
"with number of teams.");
6121+
// If we have NumTeams defined this means that we have an enclosed teams
6122+
// region. Therefore we also expect to have NumThreads defined. These two
6123+
// values should be defined in the presence of a teams directive,
6124+
// regardless of having any clauses associated. If the user is using teams
6125+
// but no clauses, these two values will be the default that should be
6126+
// passed to the runtime library - a 32-bit integer with the value zero.
6127+
assert(NumThreads && "Thread limit expression should be available along "
6128+
"with number of teams.");
60566129
llvm::Value *OffloadingArgs[] = {
60576130
DeviceID, OutlinedFnID,
60586131
PointerNum, Info.BasePointersArray,
60596132
Info.PointersArray, Info.SizesArray,
60606133
Info.MapTypesArray, NumTeams,
6061-
ThreadLimit};
6134+
NumThreads};
60626135
Return = CGF.EmitRuntimeCall(
60636136
RT.createRuntimeFunction(OMPRTL__tgt_target_teams), OffloadingArgs);
60646137
} else {

lib/Sema/SemaOpenMP.cpp

+78-7
Original file line numberDiff line numberDiff line change
@@ -6635,10 +6635,9 @@ OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,
66356635
// the region in which to capture expressions associated with a clause.
66366636
// A return value of OMPD_unknown signifies that the expression should not
66376637
// be captured.
6638-
static OpenMPDirectiveKind
6639-
getOpenMPCaptureRegionForClause(OpenMPDirectiveKind DKind,
6640-
OpenMPClauseKind CKind,
6641-
OpenMPDirectiveKind NameModifier) {
6638+
static OpenMPDirectiveKind getOpenMPCaptureRegionForClause(
6639+
OpenMPDirectiveKind DKind, OpenMPClauseKind CKind,
6640+
OpenMPDirectiveKind NameModifier = OMPD_unknown) {
66426641
OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
66436642

66446643
switch (CKind) {
@@ -6708,6 +6707,69 @@ getOpenMPCaptureRegionForClause(OpenMPDirectiveKind DKind,
67086707
llvm_unreachable("Unknown OpenMP directive");
67096708
}
67106709
break;
6710+
case OMPC_num_threads:
6711+
switch (DKind) {
6712+
case OMPD_target_parallel:
6713+
CaptureRegion = OMPD_target;
6714+
break;
6715+
case OMPD_cancel:
6716+
case OMPD_parallel:
6717+
case OMPD_parallel_sections:
6718+
case OMPD_parallel_for:
6719+
case OMPD_parallel_for_simd:
6720+
case OMPD_target:
6721+
case OMPD_target_simd:
6722+
case OMPD_target_parallel_for:
6723+
case OMPD_target_parallel_for_simd:
6724+
case OMPD_target_teams:
6725+
case OMPD_target_teams_distribute:
6726+
case OMPD_target_teams_distribute_simd:
6727+
case OMPD_target_teams_distribute_parallel_for:
6728+
case OMPD_target_teams_distribute_parallel_for_simd:
6729+
case OMPD_teams_distribute_parallel_for:
6730+
case OMPD_teams_distribute_parallel_for_simd:
6731+
case OMPD_distribute_parallel_for:
6732+
case OMPD_distribute_parallel_for_simd:
6733+
case OMPD_task:
6734+
case OMPD_taskloop:
6735+
case OMPD_taskloop_simd:
6736+
case OMPD_target_data:
6737+
case OMPD_target_enter_data:
6738+
case OMPD_target_exit_data:
6739+
case OMPD_target_update:
6740+
// Do not capture num_threads-clause expressions.
6741+
break;
6742+
case OMPD_threadprivate:
6743+
case OMPD_taskyield:
6744+
case OMPD_barrier:
6745+
case OMPD_taskwait:
6746+
case OMPD_cancellation_point:
6747+
case OMPD_flush:
6748+
case OMPD_declare_reduction:
6749+
case OMPD_declare_simd:
6750+
case OMPD_declare_target:
6751+
case OMPD_end_declare_target:
6752+
case OMPD_teams:
6753+
case OMPD_simd:
6754+
case OMPD_for:
6755+
case OMPD_for_simd:
6756+
case OMPD_sections:
6757+
case OMPD_section:
6758+
case OMPD_single:
6759+
case OMPD_master:
6760+
case OMPD_critical:
6761+
case OMPD_taskgroup:
6762+
case OMPD_distribute:
6763+
case OMPD_ordered:
6764+
case OMPD_atomic:
6765+
case OMPD_distribute_simd:
6766+
case OMPD_teams_distribute:
6767+
case OMPD_teams_distribute_simd:
6768+
llvm_unreachable("Unexpected OpenMP directive with num_threads-clause");
6769+
case OMPD_unknown:
6770+
llvm_unreachable("Unknown OpenMP directive");
6771+
}
6772+
break;
67116773
case OMPC_schedule:
67126774
case OMPC_dist_schedule:
67136775
case OMPC_firstprivate:
@@ -6717,7 +6779,6 @@ getOpenMPCaptureRegionForClause(OpenMPDirectiveKind DKind,
67176779
case OMPC_default:
67186780
case OMPC_proc_bind:
67196781
case OMPC_final:
6720-
case OMPC_num_threads:
67216782
case OMPC_safelen:
67226783
case OMPC_simdlen:
67236784
case OMPC_collapse:
@@ -6887,15 +6948,25 @@ OMPClause *Sema::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
68876948
SourceLocation LParenLoc,
68886949
SourceLocation EndLoc) {
68896950
Expr *ValExpr = NumThreads;
6951+
Stmt *HelperValStmt = nullptr;
6952+
OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
68906953

68916954
// OpenMP [2.5, Restrictions]
68926955
// The num_threads expression must evaluate to a positive integer value.
68936956
if (!IsNonNegativeIntegerValue(ValExpr, *this, OMPC_num_threads,
68946957
/*StrictlyPositive=*/true))
68956958
return nullptr;
68966959

6897-
return new (Context)
6898-
OMPNumThreadsClause(ValExpr, StartLoc, LParenLoc, EndLoc);
6960+
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
6961+
CaptureRegion = getOpenMPCaptureRegionForClause(DKind, OMPC_num_threads);
6962+
if (CaptureRegion != OMPD_unknown) {
6963+
llvm::MapVector<Expr *, DeclRefExpr *> Captures;
6964+
ValExpr = tryBuildCapture(*this, ValExpr, Captures).get();
6965+
HelperValStmt = buildPreInits(Context, Captures);
6966+
}
6967+
6968+
return new (Context) OMPNumThreadsClause(
6969+
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
68996970
}
69006971

69016972
ExprResult Sema::VerifyPositiveIntegerConstantInClause(Expr *E,

lib/Serialization/ASTReaderStmt.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,7 @@ void OMPClauseReader::VisitOMPFinalClause(OMPFinalClause *C) {
19521952
}
19531953

19541954
void OMPClauseReader::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
1955+
VisitOMPClauseWithPreInit(C);
19551956
C->setNumThreads(Reader->Record.readSubExpr());
19561957
C->setLParenLoc(Reader->ReadSourceLocation());
19571958
}

lib/Serialization/ASTWriterStmt.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,7 @@ void OMPClauseWriter::VisitOMPFinalClause(OMPFinalClause *C) {
18181818
}
18191819

18201820
void OMPClauseWriter::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
1821+
VisitOMPClauseWithPreInit(C);
18211822
Record.AddStmt(C->getNumThreads());
18221823
Record.AddSourceLocation(C->getLParenLoc());
18231824
}

0 commit comments

Comments
 (0)