Skip to content

[clang][bytecode] Make union activation more granular #148835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 141 additions & 49 deletions clang/lib/AST/ByteCode/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@ using APSInt = llvm::APSInt;
namespace clang {
namespace interp {

static bool refersToUnion(const Expr *E) {
for (;;) {
if (const auto *ME = dyn_cast<MemberExpr>(E)) {
if (const auto *FD = dyn_cast<FieldDecl>(ME->getMemberDecl());
FD && FD->getParent()->isUnion())
return true;
E = ME->getBase();
continue;
}

if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) {
E = ASE->getBase()->IgnoreImplicit();
continue;
}

if (const auto *ICE = dyn_cast<ImplicitCastExpr>(E);
ICE && (ICE->getCastKind() == CK_NoOp ||
ICE->getCastKind() == CK_DerivedToBase ||
ICE->getCastKind() == CK_UncheckedDerivedToBase)) {
E = ICE->getSubExpr();
continue;
}

break;
}
return false;
}

static std::optional<bool> getBoolValue(const Expr *E) {
if (const auto *CE = dyn_cast_if_present<ConstantExpr>(E);
CE && CE->hasAPValueResult() &&
Expand Down Expand Up @@ -880,22 +908,11 @@ bool Compiler<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
return this->VisitPointerArithBinOp(BO);
}

// Assignments require us to evalute the RHS first.
if (BO->getOpcode() == BO_Assign) {

if (!visit(RHS) || !visit(LHS))
return false;

// We don't support assignments in C.
if (!Ctx.getLangOpts().CPlusPlus && !this->emitInvalid(BO))
return false;
if (BO->getOpcode() == BO_Assign)
return this->visitAssignment(LHS, RHS, BO);

if (!this->emitFlip(*LT, *RT, BO))
return false;
} else {
if (!visit(LHS) || !visit(RHS))
return false;
}
if (!visit(LHS) || !visit(RHS))
return false;

// For languages such as C, cast the result of one
// of our comparision opcodes to T (which is usually int).
Expand Down Expand Up @@ -946,22 +963,6 @@ bool Compiler<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
if (BO->getType()->isFloatingType())
return Discard(this->emitDivf(getFPOptions(BO), BO));
return Discard(this->emitDiv(*T, BO));
case BO_Assign:
if (DiscardResult)
return LHS->refersToBitField() ? this->emitStoreBitFieldPop(*T, BO)
: this->emitStorePop(*T, BO);
if (LHS->refersToBitField()) {
if (!this->emitStoreBitField(*T, BO))
return false;
} else {
if (!this->emitStore(*T, BO))
return false;
}
// Assignments aren't necessarily lvalues in C.
// Load from them in that case.
if (!BO->isLValue())
return this->emitLoadPop(*T, BO);
return true;
case BO_And:
return Discard(this->emitBitAnd(*T, BO));
case BO_Or:
Expand Down Expand Up @@ -1790,26 +1791,37 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
return this->delegate(Inits[0]);

auto initPrimitiveField = [=](const Record::Field *FieldToInit,
const Expr *Init, PrimType T) -> bool {
const Expr *Init, PrimType T,
bool Activate = false) -> bool {
InitStackScope<Emitter> ISS(this, isa<CXXDefaultInitExpr>(Init));
InitLinkScope<Emitter> ILS(this, InitLink::Field(FieldToInit->Offset));
if (!this->visit(Init))
return false;

if (FieldToInit->isBitField())
bool BitField = FieldToInit->isBitField();
if (BitField && Activate)
return this->emitInitBitFieldActivate(T, FieldToInit, E);
if (BitField)
return this->emitInitBitField(T, FieldToInit, E);
if (Activate)
return this->emitInitFieldActivate(T, FieldToInit->Offset, E);
return this->emitInitField(T, FieldToInit->Offset, E);
};

auto initCompositeField = [=](const Record::Field *FieldToInit,
const Expr *Init) -> bool {
const Expr *Init,
bool Activate = false) -> bool {
InitStackScope<Emitter> ISS(this, isa<CXXDefaultInitExpr>(Init));
InitLinkScope<Emitter> ILS(this, InitLink::Field(FieldToInit->Offset));

// Non-primitive case. Get a pointer to the field-to-initialize
// on the stack and recurse into visitInitializer().
if (!this->emitGetPtrField(FieldToInit->Offset, Init))
return false;

if (Activate && !this->emitActivate(E))
return false;

if (!this->visitInitializer(Init))
return false;
return this->emitPopPtr(E);
Expand All @@ -1829,10 +1841,10 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,

const Record::Field *FieldToInit = R->getField(FToInit);
if (std::optional<PrimType> T = classify(Init)) {
if (!initPrimitiveField(FieldToInit, Init, *T))
if (!initPrimitiveField(FieldToInit, Init, *T, /*Activate=*/true))
return false;
} else {
if (!initCompositeField(FieldToInit, Init))
if (!initCompositeField(FieldToInit, Init, /*Activate=*/true))
return false;
}
}
Expand Down Expand Up @@ -2023,7 +2035,8 @@ bool Compiler<Emitter>::visitArrayElemInit(unsigned ElemIndex, const Expr *Init,

template <class Emitter>
bool Compiler<Emitter>::visitCallArgs(ArrayRef<const Expr *> Args,
const FunctionDecl *FuncDecl) {
const FunctionDecl *FuncDecl,
bool Activate) {
assert(VarScope->getKind() == ScopeKind::Call);
llvm::BitVector NonNullArgs = collectNonNullArgs(FuncDecl, Args);

Expand All @@ -2046,6 +2059,11 @@ bool Compiler<Emitter>::visitCallArgs(ArrayRef<const Expr *> Args,
return false;
}

if (ArgIndex == 1 && Activate) {
if (!this->emitActivate(Arg))
return false;
}

if (FuncDecl && NonNullArgs[ArgIndex]) {
PrimType ArgT = classify(Arg).value_or(PT_Ptr);
if (ArgT == PT_Ptr) {
Expand Down Expand Up @@ -4227,10 +4245,13 @@ bool Compiler<Emitter>::visitZeroRecordInitializer(const Record *R,
PrimType T = classifyPrim(D->getType());
if (!this->visitZeroInitializer(T, QT, E))
return false;
if (R->isUnion()) {
if (!this->emitInitFieldActivate(T, Field.Offset, E))
return false;
break;
}
if (!this->emitInitField(T, Field.Offset, E))
return false;
if (R->isUnion())
break;
continue;
}

Expand All @@ -4256,13 +4277,15 @@ bool Compiler<Emitter>::visitZeroRecordInitializer(const Record *R,
} else
return false;

if (!this->emitFinishInitPop(E))
return false;

// C++11 [dcl.init]p5: If T is a (possibly cv-qualified) union type, the
// object's first non-static named data member is zero-initialized
if (R->isUnion())
if (R->isUnion()) {
if (!this->emitFinishInitActivatePop(E))
return false;
break;
}
if (!this->emitFinishInitPop(E))
return false;
}

for (const Record::Base &B : R->bases()) {
Expand Down Expand Up @@ -4325,6 +4348,59 @@ bool Compiler<Emitter>::visitZeroArrayInitializer(QualType T, const Expr *E) {
return false;
}

template <class Emitter>
bool Compiler<Emitter>::visitAssignment(const Expr *LHS, const Expr *RHS,
const Expr *E) {
if (!classify(E->getType()))
return false;

if (!this->visit(RHS))
return false;
if (!this->visit(LHS))
return false;

// We don't support assignments in C.
if (!Ctx.getLangOpts().CPlusPlus && !this->emitInvalid(E))
return false;

PrimType RHT = classifyPrim(RHS);
bool Activates = refersToUnion(LHS);
bool BitField = LHS->refersToBitField();

if (!this->emitFlip(PT_Ptr, RHT, E))
return false;

if (DiscardResult) {
if (BitField && Activates)
return this->emitStoreBitFieldActivatePop(RHT, E);
if (BitField)
return this->emitStoreBitFieldPop(RHT, E);
if (Activates)
return this->emitStoreActivatePop(RHT, E);
// Otherwise, regular non-activating store.
return this->emitStorePop(RHT, E);
}

auto maybeLoad = [&](bool Result) -> bool {
if (!Result)
return false;
// Assignments aren't necessarily lvalues in C.
// Load from them in that case.
if (!E->isLValue())
return this->emitLoadPop(RHT, E);
return true;
};

if (BitField && Activates)
return maybeLoad(this->emitStoreBitFieldActivate(RHT, E));
if (BitField)
return maybeLoad(this->emitStoreBitField(RHT, E));
if (Activates)
return maybeLoad(this->emitStoreActivate(RHT, E));
// Otherwise, regular non-activating store.
return maybeLoad(this->emitStore(RHT, E));
}

template <class Emitter>
template <typename T>
bool Compiler<Emitter>::emitConst(T Value, PrimType Ty, const Expr *E) {
Expand Down Expand Up @@ -5067,7 +5143,7 @@ bool Compiler<Emitter>::VisitCallExpr(const CallExpr *E) {
return false;
}

if (!this->visitCallArgs(Args, FuncDecl))
if (!this->visitCallArgs(Args, FuncDecl, IsAssignmentOperatorCall))
return false;

// Undo the argument reversal we did earlier.
Expand Down Expand Up @@ -5851,7 +5927,8 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
assert(!ReturnType);

auto emitFieldInitializer = [&](const Record::Field *F, unsigned FieldOffset,
const Expr *InitExpr) -> bool {
const Expr *InitExpr,
bool Activate = false) -> bool {
// We don't know what to do with these, so just return false.
if (InitExpr->getType().isNull())
return false;
Expand All @@ -5860,8 +5937,13 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
if (!this->visit(InitExpr))
return false;

if (F->isBitField())
bool BitField = F->isBitField();
if (BitField && Activate)
return this->emitInitThisBitFieldActivate(*T, F, FieldOffset, InitExpr);
if (BitField)
return this->emitInitThisBitField(*T, F, FieldOffset, InitExpr);
if (Activate)
return this->emitInitThisFieldActivate(*T, FieldOffset, InitExpr);
return this->emitInitThisField(*T, FieldOffset, InitExpr);
}
// Non-primitive case. Get a pointer to the field-to-initialize
Expand All @@ -5870,6 +5952,9 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
if (!this->emitGetPtrThisField(FieldOffset, InitExpr))
return false;

if (Activate && !this->emitActivate(InitExpr))
return false;

if (!this->visitInitializer(InitExpr))
return false;

Expand All @@ -5880,8 +5965,9 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
const Record *R = this->getRecord(RD);
if (!R)
return false;
bool IsUnion = R->isUnion();

if (R->isUnion() && Ctor->isCopyOrMoveConstructor()) {
if (IsUnion && Ctor->isCopyOrMoveConstructor()) {
if (R->getNumFields() == 0)
return this->emitRetVoid(Ctor);
// union copy and move ctors are special.
Expand All @@ -5908,7 +5994,7 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
if (const FieldDecl *Member = Init->getMember()) {
const Record::Field *F = R->getField(Member);

if (!emitFieldInitializer(F, F->Offset, InitExpr))
if (!emitFieldInitializer(F, F->Offset, InitExpr, IsUnion))
return false;
} else if (const Type *Base = Init->getBaseClass()) {
const auto *BaseDecl = Base->getAsCXXRecordDecl();
Expand All @@ -5928,11 +6014,15 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {
return false;
}

if (IsUnion && !this->emitActivate(InitExpr))
return false;

if (!this->visitInitializer(InitExpr))
return false;
if (!this->emitFinishInitPop(InitExpr))
return false;
} else if (const IndirectFieldDecl *IFD = Init->getIndirectMember()) {

assert(IFD->getChainingSize() >= 2);

unsigned NestedFieldOffset = 0;
Expand All @@ -5944,12 +6034,14 @@ bool Compiler<Emitter>::compileConstructor(const CXXConstructorDecl *Ctor) {

NestedField = FieldRecord->getField(FD);
assert(NestedField);
IsUnion = IsUnion || FieldRecord->isUnion();

NestedFieldOffset += NestedField->Offset;
}
assert(NestedField);

if (!emitFieldInitializer(NestedField, NestedFieldOffset, InitExpr))
if (!emitFieldInitializer(NestedField, NestedFieldOffset, InitExpr,
IsUnion))
return false;

// Mark all chain links as initialized.
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/AST/ByteCode/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
const Expr *E);
bool visitArrayElemInit(unsigned ElemIndex, const Expr *Init,
std::optional<PrimType> InitT);
bool visitCallArgs(ArrayRef<const Expr *> Args, const FunctionDecl *FuncDecl);
bool visitCallArgs(ArrayRef<const Expr *> Args, const FunctionDecl *FuncDecl,
bool Activate);

/// Creates a local primitive value.
unsigned allocateLocalPrimitive(DeclTy &&Decl, PrimType Ty, bool IsConst,
Expand Down Expand Up @@ -342,6 +343,7 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
bool visitZeroInitializer(PrimType T, QualType QT, const Expr *E);
bool visitZeroRecordInitializer(const Record *R, const Expr *E);
bool visitZeroArrayInitializer(QualType T, const Expr *E);
bool visitAssignment(const Expr *LHS, const Expr *RHS, const Expr *E);

/// Emits an APSInt constant.
bool emitConst(const llvm::APSInt &Value, PrimType Ty, const Expr *E);
Expand Down
Loading