Skip to content

[SampleFDO][TypeProf] Support vtable type profiling in ext-binary and text format. #141649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions llvm/include/llvm/ProfileData/SampleProf.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ enum class sampleprof_error {
ostream_seek_unsupported,
uncompress_failed,
zlib_unavailable,
hash_mismatch
hash_mismatch,
illegal_line_offset,
duplicate_vtable_type,
};

inline std::error_code make_error_code(sampleprof_error E) {
Expand Down Expand Up @@ -89,6 +91,8 @@ struct is_error_code_enum<llvm::sampleprof_error> : std::true_type {};
namespace llvm {
namespace sampleprof {

constexpr char kBodySampleVTableProfPrefix[] = "vtables ";

enum SampleProfileFormat {
SPF_None = 0,
SPF_Text = 0x1,
Expand Down Expand Up @@ -202,6 +206,9 @@ enum class SecProfSummaryFlags : uint32_t {
/// SecFlagIsPreInlined means this profile contains ShouldBeInlined
/// contexts thus this is CS preinliner computed.
SecFlagIsPreInlined = (1 << 4),

/// SecFlagHasVTableTypeProf means this profile contains vtable type profiles.
SecFlagHasVTableTypeProf = (1 << 5),
};

enum class SecFuncMetadataFlags : uint32_t {
Expand Down Expand Up @@ -285,7 +292,7 @@ struct LineLocation {
void dump() const;

// Serialize the line location to the output stream using ULEB128 encoding.
void serialize(raw_ostream &OS);
void serialize(raw_ostream &OS) const;

bool operator<(const LineLocation &O) const {
return LineOffset < O.LineOffset ||
Expand All @@ -301,7 +308,7 @@ struct LineLocation {
}

uint64_t getHashCode() const {
return ((uint64_t) Discriminator << 32) | LineOffset;
return ((uint64_t)Discriminator << 32) | LineOffset;
}

uint32_t LineOffset;
Expand All @@ -316,16 +323,28 @@ struct LineLocationHash {

raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc);

/// Key represents the id of a vtable and value represents its count.
/// TODO: Rename class FunctionId to SymbolId in a separate PR.
using TypeCountMap = std::map<FunctionId, uint64_t>;

/// Write \p Map to the output stream. Keys are linearized using \p NameTable
/// and written as ULEB128. Values are written as ULEB128 as well.
std::error_code
serializeTypeMap(const TypeCountMap &Map,
const MapVector<FunctionId, uint32_t> &NameTable,
raw_ostream &OS);

/// Representation of a single sample record.
///
/// A sample record is represented by a positive integer value, which
/// indicates how frequently was the associated line location executed.
///
/// Additionally, if the associated location contains a function call,
/// the record will hold a list of all the possible called targets. For
/// direct calls, this will be the exact function being invoked. For
/// indirect calls (function pointers, virtual table dispatch), this
/// will be a list of one or more functions.
/// the record will hold a list of all the possible called targets and the types
/// for virtual table dispatches. For direct calls, this will be the exact
/// function being invoked. For indirect calls (function pointers, virtual table
/// dispatch), this will be a list of one or more functions. For virtual table
/// dispatches, this record will also hold the type of the object.
class SampleRecord {
public:
using CallTarget = std::pair<FunctionId, uint64_t>;
Expand Down Expand Up @@ -369,14 +388,12 @@ class SampleRecord {
/// Sample counts accumulate using saturating arithmetic, to avoid wrapping
/// around unsigned integers.
sampleprof_error addCalledTarget(FunctionId F, uint64_t S,
uint64_t Weight = 1) {
uint64_t &TargetSamples = CallTargets[F];
bool Overflowed;
TargetSamples =
SaturatingMultiplyAdd(S, Weight, TargetSamples, &Overflowed);
return Overflowed ? sampleprof_error::counter_overflow
: sampleprof_error::success;
}
uint64_t Weight = 1);

/// Add vtable type \p F with samples \p S.
/// Optionally scale sample count \p S by \p Weight.
sampleprof_error addVTableAccessCount(FunctionId F, uint64_t S,
uint64_t Weight = 1);

/// Remove called function from the call target map. Return the target sample
/// count of the called function.
Expand Down Expand Up @@ -433,9 +450,9 @@ class SampleRecord {
void dump() const;
/// Serialize the sample record to the output stream using ULEB128 encoding.
/// The \p NameTable is used to map function names to their IDs.
std::error_code
serialize(raw_ostream &OS,
const MapVector<FunctionId, uint32_t> &NameTable) const;
std::error_code serialize(raw_ostream &OS,
const MapVector<FunctionId, uint32_t> &NameTable,
bool SerializeVTableProf) const;

bool operator==(const SampleRecord &Other) const {
return NumSamples == Other.NumSamples && CallTargets == Other.CallTargets;
Expand Down Expand Up @@ -743,6 +760,7 @@ using BodySampleMap = std::map<LineLocation, SampleRecord>;
// memory, which is *very* significant for large profiles.
using FunctionSamplesMap = std::map<FunctionId, FunctionSamples>;
using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
using CallsiteTypeMap = std::map<LineLocation, TypeCountMap>;
using LocToLocMap =
std::unordered_map<LineLocation, LineLocation, LineLocationHash>;

Expand Down Expand Up @@ -925,6 +943,14 @@ class FunctionSamples {
return &Iter->second;
}

/// Returns the TypeCountMap for inlined callsites at the given \p Loc.
const TypeCountMap *findCallsiteTypeSamplesAt(const LineLocation &Loc) const {
auto Iter = VirtualCallsiteTypeCounts.find(mapIRLocToProfileLoc(Loc));
if (Iter == VirtualCallsiteTypeCounts.end())
return nullptr;
return &Iter->second;
}

/// Returns a pointer to FunctionSamples at the given callsite location
/// \p Loc with callee \p CalleeName. If no callsite can be found, relax
/// the restriction to return the FunctionSamples at callsite location
Expand Down Expand Up @@ -986,6 +1012,42 @@ class FunctionSamples {
return CallsiteSamples;
}

/// Return all the callsite type samples collected in the body of the
/// function.
const CallsiteTypeMap &getCallsiteTypeCounts() const {
return VirtualCallsiteTypeCounts;
}

/// Returns the type samples for the un-drifted location of \p Loc.
TypeCountMap &getTypeSamplesAt(const LineLocation &Loc) {
return VirtualCallsiteTypeCounts[mapIRLocToProfileLoc(Loc)];
}

/// Scale \p Other sample counts by \p Weight and add the scaled result to the
/// type samples for the undrifted location of \p Loc.
template <typename T>
sampleprof_error addCallsiteVTableTypeProfAt(const LineLocation &Loc,
const T &Other,
uint64_t Weight = 1) {
static_assert((std::is_same_v<typename T::key_type, StringRef> ||
std::is_same_v<typename T::key_type, FunctionId>) &&
std::is_same_v<typename T::mapped_type, uint64_t>,
"T must be a map with StringRef or FunctionId as key and "
"uint64_t as value");
TypeCountMap &TypeCounts = getTypeSamplesAt(Loc);
bool Overflowed = false;

for (const auto [Type, Count] : Other) {
FunctionId TypeId(Type);
bool RowOverflow = false;
TypeCounts[TypeId] = SaturatingMultiplyAdd(
Count, Weight, TypeCounts[TypeId], &RowOverflow);
Overflowed |= RowOverflow;
}
return Overflowed ? sampleprof_error::counter_overflow
: sampleprof_error::success;
}

/// Return the maximum of sample counts in a function body. When SkipCallSite
/// is false, which is the default, the return count includes samples in the
/// inlined functions. When SkipCallSite is true, the return count only
Expand Down Expand Up @@ -1040,6 +1102,10 @@ class FunctionSamples {
mergeSampleProfErrors(Result,
FSMap[Rec.first].merge(Rec.second, Weight));
}
for (const auto &[Loc, OtherTypeMap] : Other.getCallsiteTypeCounts())
mergeSampleProfErrors(
Result, addCallsiteVTableTypeProfAt(Loc, OtherTypeMap, Weight));

return Result;
}

Expand Down Expand Up @@ -1283,6 +1349,23 @@ class FunctionSamples {
/// collected in the call to baz() at line offset 8.
CallsiteSampleMap CallsiteSamples;

/// Map inlined virtual callsites to the vtable from which they are loaded.
///
/// Each entry is a mapping from the location to the list of vtables and their
/// sampled counts. For example, given:
///
/// void foo() {
/// ...
/// 5 inlined_vcall_bar();
/// ...
/// 5 inlined_vcall_baz();
/// ...
/// 200 inlined_vcall_qux();
/// }
/// This map will contain two entries. One with two types for line offset 5
/// and one with one type for line offset 200.
CallsiteTypeMap VirtualCallsiteTypeCounts;

/// IR to profile location map generated by stale profile matching.
///
/// Each entry is a mapping from the location on current build to the matched
Expand Down
11 changes: 11 additions & 0 deletions llvm/include/llvm/ProfileData/SampleProfReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,11 @@ class SampleProfileReaderBinary : public SampleProfileReader {
/// otherwise same as readStringFromTable, also return its hash value.
ErrorOr<std::pair<SampleContext, uint64_t>> readSampleContextFromTable();

std::error_code readBodySampleVTableProf(const LineLocation &Loc,
FunctionSamples &FProfile);
/// Read all callsites' vtable access counts for \p FProfile.
std::error_code readCallsiteVTableProf(FunctionSamples &FProfile);

/// Points to the current location in the buffer.
const uint8_t *Data = nullptr;

Expand All @@ -725,6 +730,12 @@ class SampleProfileReaderBinary : public SampleProfileReader {
/// to the start of MD5SampleContextTable.
const uint64_t *MD5SampleContextStart = nullptr;

/// Read bytes from the input buffer pointed by `Data` and decode them into
/// \p M. `Data` will be advanced to the end of the read bytes when this
/// function returns. Returns error if any.
std::error_code readVTableTypeCountMap(TypeCountMap &M);
bool ReadVTableProf = false;

private:
std::error_code readSummaryEntry(std::vector<ProfileSummaryEntry> &Entries);
virtual std::error_code verifySPMagic(uint64_t Magic) = 0;
Expand Down
18 changes: 15 additions & 3 deletions llvm/include/llvm/ProfileData/SampleProfWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
virtual MapVector<FunctionId, uint32_t> &getNameTable() { return NameTable; }
virtual std::error_code writeMagicIdent(SampleProfileFormat Format);
virtual std::error_code writeNameTable();

std::error_code writeHeader(const SampleProfileMap &ProfileMap) override;
std::error_code writeSummary();
virtual std::error_code writeContextIdx(const SampleContext &Context);
Expand All @@ -218,11 +219,23 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
std::set<FunctionId> &V);

MapVector<FunctionId, uint32_t> NameTable;

void addName(FunctionId FName);
/// void addTypeName(FunctionId TypeName);
virtual void addContext(const SampleContext &Context);
void addNames(const FunctionSamples &S);

/// Add the type names to NameTable.
void addTypeNames(const TypeCountMap &M);

/// Write \p CallsiteTypeMap to the output stream \p OS.
std::error_code
writeCallsiteVTableProf(const CallsiteTypeMap &CallsiteTypeMap,
raw_ostream &OS);

// TODO:This should be configurable by flag.
bool WriteVTableProf = false;

private:
friend ErrorOr<std::unique_ptr<SampleProfileWriter>>
SampleProfileWriter::create(std::unique_ptr<raw_ostream> &OS,
Expand Down Expand Up @@ -409,8 +422,7 @@ class SampleProfileWriterExtBinaryBase : public SampleProfileWriterBinary {

class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
public:
SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS)
: SampleProfileWriterExtBinaryBase(OS) {}
SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS);

private:
std::error_code writeDefaultLayout(const SampleProfileMap &ProfileMap);
Expand Down
Loading