Skip to content

[SampleFDO][TypeProf]Support vtable type profiling for ext-binary and text format #148002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions llvm/include/llvm/ProfileData/SampleProf.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ enum class sampleprof_error {
uncompress_failed,
zlib_unavailable,
hash_mismatch,
illegal_line_offset
illegal_line_offset,
duplicate_vtable_type
};

inline std::error_code make_error_code(sampleprof_error E) {
Expand Down Expand Up @@ -91,6 +92,8 @@ struct is_error_code_enum<llvm::sampleprof_error> : std::true_type {};
namespace llvm {
namespace sampleprof {

constexpr char kVTableProfPrefix[] = "vtables ";

enum SampleProfileFormat {
SPF_None = 0,
SPF_Text = 0x1,
Expand Down Expand Up @@ -204,6 +207,9 @@ enum class SecProfSummaryFlags : uint32_t {
/// SecFlagIsPreInlined means this profile contains ShouldBeInlined
/// contexts thus this is CS preinliner computed.
SecFlagIsPreInlined = (1 << 4),

/// SecFlagHasVTableTypeProf means this profile contains vtable type profiles.
SecFlagHasVTableTypeProf = (1 << 5),
};

enum class SecFuncMetadataFlags : uint32_t {
Expand Down Expand Up @@ -303,7 +309,7 @@ struct LineLocation {
}

uint64_t getHashCode() const {
return ((uint64_t) Discriminator << 32) | LineOffset;
return ((uint64_t)Discriminator << 32) | LineOffset;
}

uint32_t LineOffset;
Expand All @@ -318,16 +324,28 @@ struct LineLocationHash {

LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc);

/// Key represents the id of a vtable and value represents its count.
/// TODO: Rename class FunctionId to SymbolId in a separate PR.
using TypeCountMap = std::map<FunctionId, uint64_t>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add document on FunctionId/SymbolId which refers to vtable symbol. 'Type' refers to C++ polymorphic class types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


/// Write \p Map to the output stream. Keys are linearized using \p NameTable
/// and written as ULEB128. Values are written as ULEB128 as well.
std::error_code
serializeTypeMap(const TypeCountMap &Map,
const MapVector<FunctionId, uint32_t> &NameTable,
raw_ostream &OS);

/// Representation of a single sample record.
///
/// A sample record is represented by a positive integer value, which
/// indicates how frequently was the associated line location executed.
///
/// Additionally, if the associated location contains a function call,
/// the record will hold a list of all the possible called targets. For
/// direct calls, this will be the exact function being invoked. For
/// indirect calls (function pointers, virtual table dispatch), this
/// will be a list of one or more functions.
/// the record will hold a list of all the possible called targets and the types
/// for virtual table dispatches. For direct calls, this will be the exact
/// function being invoked. For indirect calls (function pointers, virtual table
/// dispatch), this will be a list of one or more functions. For virtual table
/// dispatches, this record will also hold the type of the object.
class SampleRecord {
public:
using CallTarget = std::pair<FunctionId, uint64_t>;
Expand Down Expand Up @@ -746,6 +764,7 @@ using BodySampleMap = std::map<LineLocation, SampleRecord>;
// memory, which is *very* significant for large profiles.
using FunctionSamplesMap = std::map<FunctionId, FunctionSamples>;
using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
using CallsiteTypeMap = std::map<LineLocation, TypeCountMap>;
using LocToLocMap =
std::unordered_map<LineLocation, LineLocation, LineLocationHash>;

Expand Down Expand Up @@ -928,6 +947,14 @@ class FunctionSamples {
return &Iter->second;
}

/// Returns the TypeCountMap for inlined callsites at the given \p Loc.
const TypeCountMap *findCallsiteTypeSamplesAt(const LineLocation &Loc) const {
auto Iter = VirtualCallsiteTypeCounts.find(mapIRLocToProfileLoc(Loc));
if (Iter == VirtualCallsiteTypeCounts.end())
return nullptr;
return &Iter->second;
}

/// Returns a pointer to FunctionSamples at the given callsite location
/// \p Loc with callee \p CalleeName. If no callsite can be found, relax
/// the restriction to return the FunctionSamples at callsite location
Expand Down Expand Up @@ -989,6 +1016,42 @@ class FunctionSamples {
return CallsiteSamples;
}

/// Return all the callsite type samples collected in the body of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : return vtable access samples for C++ types collected ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

/// function.
const CallsiteTypeMap &getCallsiteTypeCounts() const {
return VirtualCallsiteTypeCounts;
}

/// Returns the type samples for the un-drifted location of \p Loc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

TypeCountMap &getTypeSamplesAt(const LineLocation &Loc) {
return VirtualCallsiteTypeCounts[mapIRLocToProfileLoc(Loc)];
}

/// Scale \p Other sample counts by \p Weight and add the scaled result to the
/// type samples for the undrifted location of \p Loc.
template <typename T>
sampleprof_error addCallsiteVTableTypeProfAt(const LineLocation &Loc,
const T &Other,
uint64_t Weight = 1) {
static_assert((std::is_same_v<typename T::key_type, StringRef> ||
std::is_same_v<typename T::key_type, FunctionId>) &&
std::is_same_v<typename T::mapped_type, uint64_t>,
"T must be a map with StringRef or FunctionId as key and "
"uint64_t as value");
TypeCountMap &TypeCounts = getTypeSamplesAt(Loc);
bool Overflowed = false;

for (const auto [Type, Count] : Other) {
FunctionId TypeId(Type);
bool RowOverflow = false;
TypeCounts[TypeId] = SaturatingMultiplyAdd(
Count, Weight, TypeCounts[TypeId], &RowOverflow);
Overflowed |= RowOverflow;
}
return Overflowed ? sampleprof_error::counter_overflow
: sampleprof_error::success;
}

/// Return the maximum of sample counts in a function body. When SkipCallSite
/// is false, which is the default, the return count includes samples in the
/// inlined functions. When SkipCallSite is true, the return count only
Expand Down Expand Up @@ -1043,6 +1106,10 @@ class FunctionSamples {
mergeSampleProfErrors(Result,
FSMap[Rec.first].merge(Rec.second, Weight));
}
for (const auto &[Loc, OtherTypeMap] : Other.getCallsiteTypeCounts())
mergeSampleProfErrors(
Result, addCallsiteVTableTypeProfAt(Loc, OtherTypeMap, Weight));

return Result;
}

Expand Down Expand Up @@ -1286,6 +1353,23 @@ class FunctionSamples {
/// collected in the call to baz() at line offset 8.
CallsiteSampleMap CallsiteSamples;

/// Map virtual callsites to the vtable from which they are loaded.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: map a vcall site to the list of accessed vtables by the site. The vcallsite is referenced by its source location, and ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

///
/// Each entry is a mapping from the location to the list of vtables and their
/// sampled counts. For example, given:
///
/// void foo() {
/// ...
/// 5 inlined_vcall_bar();
/// ...
/// 5 inlined_vcall_baz();
/// ...
/// 200 inlined_vcall_qux();
/// }
/// This map will contain two entries. One with two types for line offset 5
/// and one with one type for line offset 200.
CallsiteTypeMap VirtualCallsiteTypeCounts;

/// IR to profile location map generated by stale profile matching.
///
/// Each entry is a mapping from the location on current build to the matched
Expand Down
12 changes: 12 additions & 0 deletions llvm/include/llvm/ProfileData/SampleProfReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,14 @@ class LLVM_ABI SampleProfileReaderBinary : public SampleProfileReader {
/// otherwise same as readStringFromTable, also return its hash value.
ErrorOr<std::pair<SampleContext, uint64_t>> readSampleContextFromTable();

/// Read all virtual functions' vtable access counts for \p FProfile.
std::error_code readCallsiteVTableProf(FunctionSamples &FProfile);

/// Read bytes from the input buffer pointed by `Data` and decode them into
/// \p M. `Data` will be advanced to the end of the read bytes when this
/// function returns. Returns error if any.
std::error_code readVTableTypeCountMap(TypeCountMap &M);

/// Points to the current location in the buffer.
const uint8_t *Data = nullptr;

Expand All @@ -727,6 +735,10 @@ class LLVM_ABI SampleProfileReaderBinary : public SampleProfileReader {
/// to the start of MD5SampleContextTable.
const uint64_t *MD5SampleContextStart = nullptr;

/// If true, the profile has vtable profiles and reader should decode them
/// to parse profiles correctly.
bool ReadVTableProf = false;

private:
std::error_code readSummaryEntry(std::vector<ProfileSummaryEntry> &Entries);
virtual std::error_code verifySPMagic(uint64_t Magic) = 0;
Expand Down
14 changes: 10 additions & 4 deletions llvm/include/llvm/ProfileData/SampleProfWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,20 @@ class LLVM_ABI SampleProfileWriterBinary : public SampleProfileWriter {
std::error_code writeBody(const FunctionSamples &S);
inline void stablizeNameTable(MapVector<FunctionId, uint32_t> &NameTable,
std::set<FunctionId> &V);

MapVector<FunctionId, uint32_t> NameTable;

void addName(FunctionId FName);
virtual void addContext(const SampleContext &Context);
void addNames(const FunctionSamples &S);

/// Write \p CallsiteTypeMap to the output stream \p OS.
std::error_code
writeCallsiteVTableProf(const CallsiteTypeMap &CallsiteTypeMap,
raw_ostream &OS);

bool WriteVTableProf = false;

private:
LLVM_ABI friend ErrorOr<std::unique_ptr<SampleProfileWriter>>
SampleProfileWriter::create(std::unique_ptr<raw_ostream> &OS,
Expand Down Expand Up @@ -412,8 +419,7 @@ class LLVM_ABI SampleProfileWriterExtBinaryBase
class LLVM_ABI SampleProfileWriterExtBinary
: public SampleProfileWriterExtBinaryBase {
public:
SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS)
: SampleProfileWriterExtBinaryBase(OS) {}
SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS);

private:
std::error_code writeDefaultLayout(const SampleProfileMap &ProfileMap);
Expand Down
43 changes: 43 additions & 0 deletions llvm/lib/ProfileData/SampleProf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ bool FunctionSamples::ProfileIsPreInlined = false;
bool FunctionSamples::UseMD5 = false;
bool FunctionSamples::HasUniqSuffix = true;
bool FunctionSamples::ProfileIsFS = false;

std::error_code
serializeTypeMap(const TypeCountMap &Map,
const MapVector<FunctionId, uint32_t> &NameTable,
raw_ostream &OS) {
encodeULEB128(Map.size(), OS);
for (const auto &[TypeName, SampleCount] : Map) {
if (auto NameIndexIter = NameTable.find(TypeName);
NameIndexIter != NameTable.end()) {
encodeULEB128(NameIndexIter->second, OS);
} else {
// If the type is not in the name table, we cannot serialize it.
return sampleprof_error::truncated_name_table;
}
encodeULEB128(SampleCount, OS);
}
return sampleprof_error::success;
}
} // namespace sampleprof
} // namespace llvm

Expand Down Expand Up @@ -93,6 +111,8 @@ class SampleProfErrorCategoryType : public std::error_category {
return "Function hash mismatch";
case sampleprof_error::illegal_line_offset:
return "Illegal line offset in sample profile data";
case sampleprof_error::duplicate_vtable_type:
return "Duplicate vtable type in one map";
}
llvm_unreachable("A value of sampleprof_error has no message.");
}
Expand Down Expand Up @@ -126,6 +146,7 @@ sampleprof_error SampleRecord::merge(const SampleRecord &Other,
for (const auto &I : Other.getCallTargets()) {
mergeSampleProfErrors(Result, addCalledTarget(I.first, I.second, Weight));
}

return Result;
}

Expand Down Expand Up @@ -178,6 +199,17 @@ raw_ostream &llvm::sampleprof::operator<<(raw_ostream &OS,
return OS;
}

static void printTypeCountMap(raw_ostream &OS, LineLocation Loc,
const TypeCountMap &TypeCountMap) {
if (TypeCountMap.empty()) {
return;
}
OS << Loc << ": vtables: ";
for (const auto &[Type, Count] : TypeCountMap)
OS << Type << ":" << Count << " ";
OS << "\n";
}

/// Print the samples collected for a function on stream \p OS.
void FunctionSamples::print(raw_ostream &OS, unsigned Indent) const {
if (getFunctionHash())
Expand All @@ -192,7 +224,13 @@ void FunctionSamples::print(raw_ostream &OS, unsigned Indent) const {
SampleSorter<LineLocation, SampleRecord> SortedBodySamples(BodySamples);
for (const auto &SI : SortedBodySamples.get()) {
OS.indent(Indent + 2);
const auto &Loc = SI->first;
OS << SI->first << ": " << SI->second;
if (const TypeCountMap *TypeCountMap =
this->findCallsiteTypeSamplesAt(Loc)) {
OS.indent(Indent + 2);
printTypeCountMap(OS, Loc, *TypeCountMap);
}
}
OS.indent(Indent);
OS << "}\n";
Expand All @@ -214,6 +252,11 @@ void FunctionSamples::print(raw_ostream &OS, unsigned Indent) const {
OS << Loc << ": inlined callee: " << FuncSample.getFunction() << ": ";
FuncSample.print(OS, Indent + 4);
}
auto TypeSamplesIter = VirtualCallsiteTypeCounts.find(Loc);
if (TypeSamplesIter != VirtualCallsiteTypeCounts.end()) {
OS.indent(Indent + 2);
printTypeCountMap(OS, Loc, TypeSamplesIter->second);
}
}
OS.indent(Indent);
OS << "}\n";
Expand Down
Loading
Loading