Skip to content

Commit 6ac8221

Browse files
[ConstantTime][Clang] Add __builtin_ct_select for constant-time selection
1 parent 80a83ad commit 6ac8221

File tree

5 files changed

+1200
-1
lines changed

5 files changed

+1200
-1
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5278,3 +5278,11 @@ def CountedByRef : Builtin {
52785278
let Attributes = [NoThrow, CustomTypeChecking];
52795279
let Prototype = "int(...)";
52805280
}
5281+
5282+
// Constant-time select builtin
5283+
def CtSelect : Builtin {
5284+
let Spellings = ["__builtin_ct_select"];
5285+
let Attributes = [NoThrow, Const, UnevaluatedArguments,
5286+
ConstIgnoringExceptions, CustomTypeChecking];
5287+
let Prototype = "void(...)";
5288+
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
#include "TargetInfo.h"
2727
#include "clang/AST/OSLog.h"
2828
#include "clang/AST/StmtVisitor.h"
29+
#include "clang/Basic/DiagnosticFrontend.h"
30+
#include "clang/Basic/DiagnosticSema.h"
2931
#include "clang/Basic/TargetInfo.h"
30-
#include "clang/Frontend/FrontendDiagnostic.h"
3132
#include "llvm/IR/InlineAsm.h"
3233
#include "llvm/IR/Instruction.h"
3334
#include "llvm/IR/Intrinsics.h"
@@ -6450,6 +6451,40 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
64506451
auto Str = CGM.GetAddrOfConstantCString(Name, "");
64516452
return RValue::get(Str.getPointer());
64526453
}
6454+
case Builtin::BI__builtin_ct_select: {
6455+
if (E->getNumArgs() != 3) {
6456+
CGM.getDiags().Report(E->getBeginLoc(),
6457+
E->getNumArgs() > 3
6458+
? diag::err_typecheck_call_too_many_args
6459+
: diag::err_typecheck_call_too_few_args);
6460+
return GetUndefRValue(E->getType());
6461+
}
6462+
6463+
auto *Cond = EmitScalarExpr(E->getArg(0));
6464+
auto *A = EmitScalarExpr(E->getArg(1));
6465+
auto *B = EmitScalarExpr(E->getArg(2));
6466+
6467+
// Verify types match
6468+
if (A->getType() != B->getType()) {
6469+
CGM.getDiags().Report(E->getBeginLoc(),
6470+
diag::err_typecheck_convert_incompatible);
6471+
return GetUndefRValue(E->getType());
6472+
}
6473+
6474+
// Verify condition is integer type
6475+
if (!Cond->getType()->isIntegerTy()) {
6476+
CGM.getDiags().Report(E->getBeginLoc(), diag::err_typecheck_expect_int);
6477+
return GetUndefRValue(E->getType());
6478+
}
6479+
6480+
if (Cond->getType()->getIntegerBitWidth() != 1)
6481+
Cond = Builder.CreateICmpNE(
6482+
Cond, llvm::ConstantInt::get(Cond->getType(), 0), "cond.bool");
6483+
6484+
llvm::Function *Fn =
6485+
CGM.getIntrinsic(llvm::Intrinsic::ct_select, {A->getType()});
6486+
return RValue::get(Builder.CreateCall(Fn, {Cond, A, B}));
6487+
}
64536488
}
64546489

64556490
// If this is an alias for a lib function (e.g. __builtin_sin), emit

clang/lib/Sema/SemaChecking.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3494,6 +3494,95 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
34943494
if (BuiltinCountedByRef(TheCall))
34953495
return ExprError();
34963496
break;
3497+
3498+
case Builtin::BI__builtin_ct_select: {
3499+
if (TheCall->getNumArgs() != 3) {
3500+
// Simple argument count check without complex diagnostics
3501+
if (TheCall->getNumArgs() < 3) {
3502+
return Diag(TheCall->getEndLoc(),
3503+
diag::err_typecheck_call_too_few_args_at_least)
3504+
<< 0 << 3 << TheCall->getNumArgs() << 0
3505+
<< TheCall->getCallee()->getSourceRange();
3506+
} else {
3507+
return Diag(TheCall->getEndLoc(),
3508+
diag::err_typecheck_call_too_many_args)
3509+
<< 0 << 3 << TheCall->getNumArgs() << 0
3510+
<< TheCall->getCallee()->getSourceRange();
3511+
}
3512+
}
3513+
auto *Cond = TheCall->getArg(0);
3514+
auto *A = TheCall->getArg(1);
3515+
auto *B = TheCall->getArg(2);
3516+
3517+
QualType CondTy = Cond->getType();
3518+
if (!CondTy->isIntegerType()) {
3519+
return Diag(Cond->getBeginLoc(), diag::err_typecheck_cond_expect_scalar)
3520+
<< CondTy << Cond->getSourceRange();
3521+
}
3522+
3523+
QualType ATy = A->getType();
3524+
QualType BTy = B->getType();
3525+
3526+
// check for scalar or vector scalar type
3527+
if ((!ATy->isScalarType() && !ATy->isVectorType()) ||
3528+
(!BTy->isScalarType() && !BTy->isVectorType())) {
3529+
return Diag(A->getBeginLoc(),
3530+
diag::err_typecheck_cond_incompatible_operands)
3531+
<< ATy << BTy << A->getSourceRange() << B->getSourceRange();
3532+
}
3533+
3534+
// Check if both operands have the same type or can be implicitly converted
3535+
QualType ResultTy;
3536+
if (Context.hasSameType(ATy, BTy)) {
3537+
ResultTy = ATy;
3538+
} else {
3539+
// Try to find a common type using the same logic as conditional
3540+
// expressions
3541+
ExprResult ARes = ExprResult(A);
3542+
ExprResult BRes = ExprResult(B);
3543+
3544+
// For arithmetic types, allow promotions within the same category only
3545+
if (ATy->isArithmeticType() && BTy->isArithmeticType()) {
3546+
// Check if both are integer types or both are floating types
3547+
bool AIsInteger = ATy->isIntegerType();
3548+
bool BIsInteger = BTy->isIntegerType();
3549+
bool AIsFloating = ATy->isFloatingType();
3550+
bool BIsFloating = BTy->isFloatingType();
3551+
3552+
if ((AIsInteger && BIsInteger) || (AIsFloating && BIsFloating)) {
3553+
// Both are in the same category, allow usual arithmetic conversions
3554+
ResultTy = UsualArithmeticConversions(
3555+
ARes, BRes, TheCall->getBeginLoc(), ArithConvKind::Conditional);
3556+
if (ARes.isInvalid() || BRes.isInvalid() || ResultTy.isNull()) {
3557+
return Diag(A->getBeginLoc(),
3558+
diag::err_typecheck_cond_incompatible_operands)
3559+
<< ATy << BTy << A->getSourceRange() << B->getSourceRange();
3560+
}
3561+
// Update the arguments with any necessary implicit casts
3562+
TheCall->setArg(1, ARes.get());
3563+
TheCall->setArg(2, BRes.get());
3564+
} else {
3565+
// Different categories (int vs float), not allowed
3566+
return Diag(A->getBeginLoc(),
3567+
diag::err_typecheck_cond_incompatible_operands)
3568+
<< ATy << BTy << A->getSourceRange() << B->getSourceRange();
3569+
}
3570+
} else {
3571+
// For non-arithmetic types, they must be exactly the same
3572+
return Diag(A->getBeginLoc(),
3573+
diag::err_typecheck_cond_incompatible_operands)
3574+
<< ATy << BTy << A->getSourceRange() << B->getSourceRange();
3575+
}
3576+
}
3577+
3578+
ExprResult CondRes = PerformContextuallyConvertToBool(Cond);
3579+
if (CondRes.isInvalid())
3580+
return ExprError();
3581+
3582+
TheCall->setArg(0, CondRes.get());
3583+
TheCall->setType(ResultTy);
3584+
return TheCall;
3585+
} break;
34973586
}
34983587

34993588
if (getLangOpts().HLSL && HLSL().CheckBuiltinFunctionCall(BuiltinID, TheCall))

0 commit comments

Comments
 (0)