Skip to content

Commit d609437

Browse files
changpengrampitec
andauthored
AMDGPU: Support v_wmma_f32_16x16x128_f8f6f4 on gfx1250 (#149684)
Co-authored-by: Stanislav Mekhanoshin <[email protected]>
1 parent 0fa515f commit d609437

31 files changed

+1696
-34
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
705705
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
706706
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
707707
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
708+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
708709
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
709710
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
710711
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")

clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
855855
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
856856
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
857857
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
858+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
858859
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
859860
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
860861
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
11181119
ArgsForMatchingMatrixTypes = {4, 1};
11191120
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
11201121
break;
1122+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
1123+
ArgsForMatchingMatrixTypes = {5, 1, 3};
1124+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
1125+
break;
11211126
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
11221127
ArgsForMatchingMatrixTypes = {3, 0, 1};
11231128
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;

clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
157157
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
158158
}
159159

160+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
161+
// CHECK-GFX1250-NEXT: entry:
162+
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
163+
// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
164+
// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
165+
// CHECK-GFX1250-NEXT: ret void
166+
//
167+
void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
168+
{
169+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
170+
}
171+
160172
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
161173
// CHECK-GFX1250-NEXT: entry:
162174
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
114114
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
115115
}
116116

117+
void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
118+
{
119+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
120+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
121+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
122+
}
123+
117124
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
118125
{
119126
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
37173717
IntrWillReturn, IntrNoCallback, IntrNoFree]
37183718
>;
37193719

3720+
class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
3721+
Intrinsic<
3722+
[llvm_anyfloat_ty], // %D
3723+
[
3724+
llvm_i32_ty, // matrix_a_fmt
3725+
llvm_anyint_ty, // %A
3726+
llvm_i32_ty, // matrix_b_fmt
3727+
llvm_anyint_ty, // %B
3728+
llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
3729+
LLVMMatchType<0>, // %C
3730+
],
3731+
[IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
3732+
>;
3733+
37203734
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
37213735
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
37223736
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
37413755
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
37423756
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
37433757
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
3758+
def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
37443759
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
37453760
}
37463761

llvm/lib/IR/Verifier.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6668,6 +6668,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
66686668
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
66696669
break;
66706670
}
6671+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
6672+
Value *Src0 = Call.getArgOperand(1);
6673+
Value *Src1 = Call.getArgOperand(3);
6674+
6675+
unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
6676+
unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
6677+
Check(FmtA <= 4, "invalid value for matrix format", Call,
6678+
Call.getArgOperand(0));
6679+
Check(FmtB <= 4, "invalid value for matrix format", Call,
6680+
Call.getArgOperand(2));
6681+
6682+
// AMDGPU::MatrixFMT values
6683+
auto getFormatNumRegs = [](unsigned FormatVal) {
6684+
switch (FormatVal) {
6685+
case 0:
6686+
case 1:
6687+
return 16u;
6688+
case 2:
6689+
case 3:
6690+
return 12u;
6691+
case 4:
6692+
return 8u;
6693+
default:
6694+
llvm_unreachable("invalid format value");
6695+
}
6696+
};
6697+
6698+
auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
6699+
if (!Ty || !Ty->getElementType()->isIntegerTy(32))
6700+
return false;
6701+
unsigned NumElts = Ty->getNumElements();
6702+
return NumElts == 16 || NumElts == 12 || NumElts == 8;
6703+
};
6704+
6705+
auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
6706+
auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
6707+
Check(isValidSrcASrcBVector(Src0Ty),
6708+
"operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
6709+
Check(isValidSrcASrcBVector(Src1Ty),
6710+
"operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
6711+
6712+
// Permit excess registers for the format.
6713+
Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
6714+
"invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
6715+
Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
6716+
"invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
6717+
break;
6718+
}
66716719
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
66726720
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
66736721
Value *V = Call.getArgOperand(0);

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
16941694
NewII->takeName(&II);
16951695
return IC.replaceInstUsesWith(II, NewII);
16961696
}
1697+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
1698+
Value *Src0 = II.getArgOperand(1);
1699+
Value *Src1 = II.getArgOperand(3);
1700+
unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1701+
uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
1702+
auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
1703+
auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
1704+
1705+
bool MadeChange = false;
1706+
unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
1707+
unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
1708+
1709+
// Depending on the used format, fewer registers are required so shrink the
1710+
// vector type.
1711+
if (Src0Ty->getNumElements() > Src0NumElts) {
1712+
Src0 = IC.Builder.CreateExtractVector(
1713+
FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
1714+
IC.Builder.getInt64(0));
1715+
MadeChange = true;
1716+
}
1717+
1718+
if (Src1Ty->getNumElements() > Src1NumElts) {
1719+
Src1 = IC.Builder.CreateExtractVector(
1720+
FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
1721+
IC.Builder.getInt64(0));
1722+
MadeChange = true;
1723+
}
1724+
1725+
if (!MadeChange)
1726+
return std::nullopt;
1727+
1728+
SmallVector<Value *, 13> Args(II.args());
1729+
Args[1] = Src0;
1730+
Args[3] = Src1;
1731+
1732+
CallInst *NewII = IC.Builder.CreateIntrinsic(
1733+
IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
1734+
Args, &II);
1735+
NewII->takeName(&II);
1736+
return IC.replaceInstUsesWith(II, NewII);
1737+
}
16971738
}
16981739
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
16991740
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
47144714
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
47154715
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
47164716
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
4717+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
47174718
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
47184719
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
47194720
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
176176
ImmTyWaitVAVDst,
177177
ImmTyWaitVMVSrc,
178178
ImmTyBitOp3,
179+
ImmTyMatrixAFMT,
180+
ImmTyMatrixBFMT,
179181
ImmTyMatrixAReuse,
180182
ImmTyMatrixBReuse,
181183
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
423425
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
424426
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
425427
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
428+
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
429+
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
426430
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
427431
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
428432
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11741178
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
11751179
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
11761180
case ImmTyBitOp3: OS << "BitOp3"; break;
1181+
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
1182+
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
11771183
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
11781184
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
11791185
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17141720
ParseStatus parseIndexKey8bit(OperandVector &Operands);
17151721
ParseStatus parseIndexKey16bit(OperandVector &Operands);
17161722
ParseStatus parseIndexKey32bit(OperandVector &Operands);
1723+
ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
1724+
AMDGPUOperand::ImmTy Type);
1725+
ParseStatus parseMatrixAFMT(OperandVector &Operands);
1726+
ParseStatus parseMatrixBFMT(OperandVector &Operands);
17171727

17181728
ParseStatus parseDfmtNfmt(int64_t &Format);
17191729
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
18491859
const unsigned CPol);
18501860
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
18511861
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
1862+
bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
18521863
unsigned getConstantBusLimit(unsigned Opcode) const;
18531864
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
18541865
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5409,6 +5420,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
54095420
return true;
54105421
}
54115422

5423+
bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
5424+
const OperandVector &Operands) {
5425+
unsigned Opc = Inst.getOpcode();
5426+
const MCRegisterInfo *TRI = getContext().getRegisterInfo();
5427+
const MCInstrDesc &Desc = MII.get(Opc);
5428+
5429+
auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
5430+
int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
5431+
if (FmtIdx == -1)
5432+
return true;
5433+
unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
5434+
int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
5435+
unsigned RegSize =
5436+
TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
5437+
5438+
if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
5439+
return true;
5440+
5441+
static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
5442+
"MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
5443+
"MATRIX_FMT_FP4"};
5444+
5445+
Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
5446+
"wrong register tuple size for " + Twine(FmtNames[Fmt]));
5447+
return false;
5448+
};
5449+
5450+
return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
5451+
validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
5452+
}
5453+
54125454
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
54135455
const SMLoc &IDLoc,
54145456
const OperandVector &Operands) {
@@ -5542,6 +5584,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
55425584
if (!validateTFE(Inst, Operands)) {
55435585
return false;
55445586
}
5587+
if (!validateWMMA(Inst, Operands)) {
5588+
return false;
5589+
}
55455590

55465591
return true;
55475592
}
@@ -7215,6 +7260,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
72157260
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
72167261
}
72177262

7263+
ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
7264+
StringRef Name,
7265+
AMDGPUOperand::ImmTy Type) {
7266+
return parseStringOrIntWithPrefix(Operands, Name,
7267+
{"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
7268+
"MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
7269+
"MATRIX_FMT_FP4"},
7270+
Type);
7271+
}
7272+
7273+
ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
7274+
return tryParseMatrixFMT(Operands, "matrix_a_fmt",
7275+
AMDGPUOperand::ImmTyMatrixAFMT);
7276+
}
7277+
7278+
ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
7279+
return tryParseMatrixFMT(Operands, "matrix_b_fmt",
7280+
AMDGPUOperand::ImmTyMatrixBFMT);
7281+
}
7282+
72187283
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
72197284
// values to live in a joint format operand in the MCInst encoding.
72207285
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9316,6 +9381,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
93169381
DefaultVal);
93179382
}
93189383

9384+
int MatrixAFMTIdx =
9385+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
9386+
if (MatrixAFMTIdx != -1) {
9387+
addOptionalImmOperand(Inst, Operands, OptIdx,
9388+
AMDGPUOperand::ImmTyMatrixAFMT, 0);
9389+
}
9390+
9391+
int MatrixBFMTIdx =
9392+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
9393+
if (MatrixBFMTIdx != -1) {
9394+
addOptionalImmOperand(Inst, Operands, OptIdx,
9395+
AMDGPUOperand::ImmTyMatrixBFMT, 0);
9396+
}
9397+
93199398
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
93209399
addOptionalImmOperand(Inst, Operands, OptIdx,
93219400
AMDGPUOperand::ImmTyMatrixAReuse, 0);

0 commit comments

Comments
 (0)