Skip to content
This repository was archived by the owner on Nov 1, 2021. It is now read-only.

Commit 4a58385

Browse files
committed
[OpenMP] Codegen support for 'target teams' on the host.
This patch adds support for codegen of 'target teams' on the host. This combined directive has two captured statements, one for the 'teams' region, and the other for the 'parallel'. This target teams region is offloaded using the __tgt_target_teams() call. The patch sets the number of teams as an argument to this call. Reviewers: ABataev Differential Revision: https://reviews.llvm.org/D29084 git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@293005 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent ad4e2ce commit 4a58385

8 files changed

+1423
-41
lines changed

lib/Basic/OpenMPKinds.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,11 @@ void clang::getOpenMPCaptureRegions(
875875
case OMPD_parallel_sections:
876876
CaptureRegions.push_back(OMPD_parallel);
877877
break;
878-
case OMPD_teams:
879878
case OMPD_target_teams:
879+
CaptureRegions.push_back(OMPD_target);
880+
CaptureRegions.push_back(OMPD_teams);
881+
break;
882+
case OMPD_teams:
880883
case OMPD_simd:
881884
case OMPD_for:
882885
case OMPD_for_simd:

lib/CodeGen/CGOpenMPRuntime.cpp

+53-22
Original file line numberDiff line numberDiff line change
@@ -4911,18 +4911,28 @@ emitNumTeamsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
49114911
"teams directive expected to be "
49124912
"emitted only for the host!");
49134913

4914+
auto &Bld = CGF.Builder;
4915+
4916+
// If the target directive is combined with a teams directive:
4917+
// Return the value in the num_teams clause, if any.
4918+
// Otherwise, return 0 to denote the runtime default.
4919+
if (isOpenMPTeamsDirective(D.getDirectiveKind())) {
4920+
if (const auto *NumTeamsClause = D.getSingleClause<OMPNumTeamsClause>()) {
4921+
CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF);
4922+
auto NumTeams = CGF.EmitScalarExpr(NumTeamsClause->getNumTeams(),
4923+
/*IgnoreResultAssign*/ true);
4924+
return Bld.CreateIntCast(NumTeams, CGF.Int32Ty,
4925+
/*IsSigned=*/true);
4926+
}
4927+
4928+
// The default value is 0.
4929+
return Bld.getInt32(0);
4930+
}
4931+
49144932
// If the target directive is combined with a parallel directive but not a
49154933
// teams directive, start one team.
4916-
if (isOpenMPParallelDirective(D.getDirectiveKind()) &&
4917-
!isOpenMPTeamsDirective(D.getDirectiveKind()))
4918-
return CGF.Builder.getInt32(1);
4919-
4920-
// FIXME: For the moment we do not support combined directives with target and
4921-
// teams, so we do not expect to get any num_teams clause in the provided
4922-
// directive. Once we support that, this assertion can be replaced by the
4923-
// actual emission of the clause expression.
4924-
assert(D.getSingleClause<OMPNumTeamsClause>() == nullptr &&
4925-
"Not expecting clause in directive.");
4934+
if (isOpenMPParallelDirective(D.getDirectiveKind()))
4935+
return Bld.getInt32(1);
49264936

49274937
// If the current target region has a teams region enclosed, we need to get
49284938
// the number of teams to pass to the runtime function call. This is done
@@ -4940,13 +4950,13 @@ emitNumTeamsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
49404950
CGOpenMPInnerExprInfo CGInfo(CGF, CS);
49414951
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
49424952
llvm::Value *NumTeams = CGF.EmitScalarExpr(NTE->getNumTeams());
4943-
return CGF.Builder.CreateIntCast(NumTeams, CGF.Int32Ty,
4944-
/*IsSigned=*/true);
4953+
return Bld.CreateIntCast(NumTeams, CGF.Int32Ty,
4954+
/*IsSigned=*/true);
49454955
}
49464956

49474957
// If we have an enclosed teams directive but no num_teams clause we use
49484958
// the default value 0.
4949-
return CGF.Builder.getInt32(0);
4959+
return Bld.getInt32(0);
49504960
}
49514961

49524962
// No teams associated with the directive.
@@ -4986,9 +4996,20 @@ emitNumThreadsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
49864996
//
49874997
// If this is not a teams directive return nullptr.
49884998

4989-
if (isOpenMPParallelDirective(D.getDirectiveKind())) {
4999+
if (isOpenMPTeamsDirective(D.getDirectiveKind()) ||
5000+
isOpenMPParallelDirective(D.getDirectiveKind())) {
49905001
llvm::Value *DefaultThreadLimitVal = Bld.getInt32(0);
49915002
llvm::Value *NumThreadsVal = nullptr;
5003+
llvm::Value *ThreadLimitVal = nullptr;
5004+
5005+
if (const auto *ThreadLimitClause =
5006+
D.getSingleClause<OMPThreadLimitClause>()) {
5007+
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
5008+
auto ThreadLimit = CGF.EmitScalarExpr(ThreadLimitClause->getThreadLimit(),
5009+
/*IgnoreResultAssign*/ true);
5010+
ThreadLimitVal = Bld.CreateIntCast(ThreadLimit, CGF.Int32Ty,
5011+
/*IsSigned=*/true);
5012+
}
49925013

49935014
if (const auto *NumThreadsClause =
49945015
D.getSingleClause<OMPNumThreadsClause>()) {
@@ -5000,15 +5021,21 @@ emitNumThreadsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
50005021
Bld.CreateIntCast(NumThreads, CGF.Int32Ty, /*IsSigned=*/true);
50015022
}
50025023

5003-
return NumThreadsVal ? NumThreadsVal : DefaultThreadLimitVal;
5004-
}
5024+
// Select the lesser of thread_limit and num_threads.
5025+
if (NumThreadsVal)
5026+
ThreadLimitVal = ThreadLimitVal
5027+
? Bld.CreateSelect(Bld.CreateICmpSLT(NumThreadsVal,
5028+
ThreadLimitVal),
5029+
NumThreadsVal, ThreadLimitVal)
5030+
: NumThreadsVal;
50055031

5006-
// FIXME: For the moment we do not support combined directives with target and
5007-
// teams, so we do not expect to get any thread_limit clause in the provided
5008-
// directive. Once we support that, this assertion can be replaced by the
5009-
// actual emission of the clause expression.
5010-
assert(D.getSingleClause<OMPThreadLimitClause>() == nullptr &&
5011-
"Not expecting clause in directive.");
5032+
// Set default value passed to the runtime if either teams or a target
5033+
// parallel type directive is found but no clause is specified.
5034+
if (!ThreadLimitVal)
5035+
ThreadLimitVal = DefaultThreadLimitVal;
5036+
5037+
return ThreadLimitVal;
5038+
}
50125039

50135040
// If the current target region has a teams region enclosed, we need to get
50145041
// the thread limit to pass to the runtime function call. This is done
@@ -6217,6 +6244,10 @@ void CGOpenMPRuntime::scanForTargetRegionsFunctions(const Stmt *S,
62176244
CodeGenFunction::EmitOMPTargetParallelDeviceFunction(
62186245
CGM, ParentName, cast<OMPTargetParallelDirective>(*S));
62196246
break;
6247+
case Stmt::OMPTargetTeamsDirectiveClass:
6248+
CodeGenFunction::EmitOMPTargetTeamsDeviceFunction(
6249+
CGM, ParentName, cast<OMPTargetTeamsDirective>(*S));
6250+
break;
62206251
default:
62216252
llvm_unreachable("Unknown target directive for OpenMP device codegen.");
62226253
}

lib/CodeGen/CGStmtOpenMP.cpp

+52-13
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ class OMPParallelScope final : public OMPLexicalScope {
9898
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
9999
};
100100

101+
/// Lexical scope for OpenMP teams construct, that handles correct codegen
102+
/// for captured expressions.
103+
class OMPTeamsScope final : public OMPLexicalScope {
104+
bool EmitPreInitStmt(const OMPExecutableDirective &S) {
105+
OpenMPDirectiveKind Kind = S.getDirectiveKind();
106+
return !isOpenMPTargetExecutionDirective(Kind) &&
107+
isOpenMPTeamsDirective(Kind);
108+
}
109+
110+
public:
111+
OMPTeamsScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
112+
: OMPLexicalScope(CGF, S,
113+
/*AsInlined=*/false,
114+
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
115+
};
116+
101117
/// Private scope for OpenMP loop-based directives, that supports capturing
102118
/// of used expression from loop statement.
103119
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
@@ -2018,15 +2034,6 @@ void CodeGenFunction::EmitOMPTeamsDistributeParallelForDirective(
20182034
});
20192035
}
20202036

2021-
void CodeGenFunction::EmitOMPTargetTeamsDirective(
2022-
const OMPTargetTeamsDirective &S) {
2023-
CGM.getOpenMPRuntime().emitInlinedDirective(
2024-
*this, OMPD_target_teams, [&S](CodeGenFunction &CGF, PrePostActionTy &) {
2025-
CGF.EmitStmt(
2026-
cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
2027-
});
2028-
}
2029-
20302037
void CodeGenFunction::EmitOMPTargetTeamsDistributeDirective(
20312038
const OMPTargetTeamsDistributeDirective &S) {
20322039
CGM.getOpenMPRuntime().emitInlinedDirective(
@@ -3519,9 +3526,8 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
35193526
auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitTeamsOutlinedFunction(
35203527
S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);
35213528

3522-
const OMPTeamsDirective &TD = *dyn_cast<OMPTeamsDirective>(&S);
3523-
const OMPNumTeamsClause *NT = TD.getSingleClause<OMPNumTeamsClause>();
3524-
const OMPThreadLimitClause *TL = TD.getSingleClause<OMPThreadLimitClause>();
3529+
const OMPNumTeamsClause *NT = S.getSingleClause<OMPNumTeamsClause>();
3530+
const OMPThreadLimitClause *TL = S.getSingleClause<OMPThreadLimitClause>();
35253531
if (NT || TL) {
35263532
Expr *NumTeams = (NT) ? NT->getNumTeams() : nullptr;
35273533
Expr *ThreadLimit = (TL) ? TL->getThreadLimit() : nullptr;
@@ -3530,7 +3536,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
35303536
S.getLocStart());
35313537
}
35323538

3533-
OMPLexicalScope Scope(CGF, S);
3539+
OMPTeamsScope Scope(CGF, S);
35343540
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
35353541
CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
35363542
CGF.CGM.getOpenMPRuntime().emitTeamsCall(CGF, S, S.getLocStart(), OutlinedFn,
@@ -3549,6 +3555,39 @@ void CodeGenFunction::EmitOMPTeamsDirective(const OMPTeamsDirective &S) {
35493555
emitCommonOMPTeamsDirective(*this, S, OMPD_teams, CodeGen);
35503556
}
35513557

3558+
static void emitTargetTeamsRegion(CodeGenFunction &CGF, PrePostActionTy &Action,
3559+
const OMPTargetTeamsDirective &S) {
3560+
auto *CS = S.getCapturedStmt(OMPD_teams);
3561+
Action.Enter(CGF);
3562+
auto &&CodeGen = [CS](CodeGenFunction &CGF, PrePostActionTy &) {
3563+
// TODO: Add support for clauses.
3564+
CGF.EmitStmt(CS->getCapturedStmt());
3565+
};
3566+
emitCommonOMPTeamsDirective(CGF, S, OMPD_teams, CodeGen);
3567+
}
3568+
3569+
void CodeGenFunction::EmitOMPTargetTeamsDeviceFunction(
3570+
CodeGenModule &CGM, StringRef ParentName,
3571+
const OMPTargetTeamsDirective &S) {
3572+
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
3573+
emitTargetTeamsRegion(CGF, Action, S);
3574+
};
3575+
llvm::Function *Fn;
3576+
llvm::Constant *Addr;
3577+
// Emit target region as a standalone region.
3578+
CGM.getOpenMPRuntime().emitTargetOutlinedFunction(
3579+
S, ParentName, Fn, Addr, /*IsOffloadEntry=*/true, CodeGen);
3580+
assert(Fn && Addr && "Target device function emission failed.");
3581+
}
3582+
3583+
void CodeGenFunction::EmitOMPTargetTeamsDirective(
3584+
const OMPTargetTeamsDirective &S) {
3585+
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
3586+
emitTargetTeamsRegion(CGF, Action, S);
3587+
};
3588+
emitCommonOMPTargetDirective(*this, S, CodeGen);
3589+
}
3590+
35523591
void CodeGenFunction::EmitOMPCancellationPointDirective(
35533592
const OMPCancellationPointDirective &S) {
35543593
CGM.getOpenMPRuntime().emitCancellationPointCall(*this, S.getLocStart(),

lib/CodeGen/CodeGenFunction.h

+3
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,9 @@ class CodeGenFunction : public CodeGenTypeCache {
27112711
static void
27122712
EmitOMPTargetParallelDeviceFunction(CodeGenModule &CGM, StringRef ParentName,
27132713
const OMPTargetParallelDirective &S);
2714+
static void
2715+
EmitOMPTargetTeamsDeviceFunction(CodeGenModule &CGM, StringRef ParentName,
2716+
const OMPTargetTeamsDirective &S);
27142717
/// \brief Emit inner loop of the worksharing/simd construct.
27152718
///
27162719
/// \param S Directive, for which the inner loop must be emitted.

lib/Sema/SemaOpenMP.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -1594,8 +1594,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
15941594
case OMPD_parallel_for:
15951595
case OMPD_parallel_for_simd:
15961596
case OMPD_parallel_sections:
1597-
case OMPD_teams:
1598-
case OMPD_target_teams: {
1597+
case OMPD_teams: {
15991598
QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1);
16001599
QualType KmpInt32PtrTy =
16011600
Context.getPointerType(KmpInt32Ty).withConst().withRestrict();
@@ -1608,6 +1607,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
16081607
Params);
16091608
break;
16101609
}
1610+
case OMPD_target_teams:
16111611
case OMPD_target_parallel: {
16121612
Sema::CapturedParamNameType ParamsTarget[] = {
16131613
std::make_pair(StringRef(), QualType()) // __context with shared vars
@@ -1618,14 +1618,15 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
16181618
QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1);
16191619
QualType KmpInt32PtrTy =
16201620
Context.getPointerType(KmpInt32Ty).withConst().withRestrict();
1621-
Sema::CapturedParamNameType ParamsParallel[] = {
1621+
Sema::CapturedParamNameType ParamsTeamsOrParallel[] = {
16221622
std::make_pair(".global_tid.", KmpInt32PtrTy),
16231623
std::make_pair(".bound_tid.", KmpInt32PtrTy),
16241624
std::make_pair(StringRef(), QualType()) // __context with shared vars
16251625
};
1626-
// Start a captured region for 'parallel'.
1626+
// Start a captured region for 'teams' or 'parallel'. Both regions have
1627+
// the same implicit parameters.
16271628
ActOnCapturedRegionStart(DSAStack->getConstructLoc(), CurScope, CR_OpenMP,
1628-
ParamsParallel);
1629+
ParamsTeamsOrParallel);
16291630
break;
16301631
}
16311632
case OMPD_simd:

0 commit comments

Comments
 (0)