-
Notifications
You must be signed in to change notification settings - Fork 545
MetalPerformancePrimitives iOS xcode26.0 b5
Rolf Bjarne Kvinge edited this page Aug 6, 2025
·
2 revisions
#MetalPerformancePrimitives.framework https://github.com/dotnet/macios/issues/23418
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h 2025-07-11 22:46:13
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h 2025-07-26 22:15:31
@@ -113,7 +113,7 @@
//
// Above matrix multiplication implementation will do edge checking for all
// thread groups against extents of original tensor although for large enough
-// matrices most of thread groups will be working on "inside" tiles, requring no
+// matrices most of thread groups will be working on "inside" tiles, requiring no
// bounds check. In high performance code we can avoid edge checking for inside
// thread groups and get better performance
//
@@ -228,7 +228,7 @@
// number of threads in opscope with which op was created. Note that
// cooperative_tensor created from an op is only valid for threads that are part
// of opscope on which op was created. Though the layout of cooperative_tensor
-// is implemtation defined, we provide accessor functions as shown in the
+// is implementation defined, we provide accessor functions as shown in the
// example below
//
// kernel void simpleMatMulCooperative(tensor<device half, dextents<int32_t,
@@ -347,6 +347,7 @@
#include "__impl/MPPTensorOpsBase.h"
#include "__impl/MPPTensorOpsTypes.h"
+#include "__impl/MPPTensorOpsUtility.h"
#include <metal_numeric>
#pragma METAL internals : enable
@@ -438,10 +439,10 @@
RunArgs...>(left, right, destination);
}
- template <typename ElementType, typename CoordType, typename... CoopArgs>
+ template <typename LeftOperand, typename RightOperand, typename ElementType, typename CoordType, typename... CoopArgs>
using cooperative_tensor_destination_t =
__mutmul2d_detail::__cooperative_tensor_destination_t<
- Descriptor, Scope, ElementType, CoordType, CoopArgs...>;
+ Descriptor, Scope, LeftOperand, RightOperand, ElementType, CoordType, CoopArgs...>;
template <typename LeftOperandType, typename RightOperandType,
typename ElementType, typename CoordType = int,
@@ -451,38 +452,100 @@
__tensor_ops_detail::__is_thread_addrspace_v<ElementType> &&
__tensor_ops_detail::__is_integral_v<CoordType>>,
typename... CoopArgs>
- INLINE cooperative_tensor_destination_t<ElementType, CoordType, CoopArgs...>
+ INLINE cooperative_tensor_destination_t<LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>
get_destination_cooperative_tensor() thread const
{
return __mutmul2d_detail::__get_destination_cooperative_tensor<
- Descriptor, Scope, ElementType, CoordType, LeftOperandType,
- RightOperandType, CoopArgs...>();
+ Descriptor, Scope, LeftOperandType, RightOperandType, ElementType,
+ CoordType, CoopArgs...>();
}
+
+ template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+ typename CoordType, typename... CoopArgs>
+ using cooperative_tensor_row_reduction_destination_t =
+ __mutmul2d_detail::__cooperative_tensor_row_reduction_destination_t<
+ Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>;
+
+ template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+ typename CoordType = int,
+ typename U = __tensor_ops_detail::__enable_if_t<
+ __tensor_ops_detail::__is_integral_v<CoordType>>,
+ typename... CoopArgs>
+ INLINE cooperative_tensor_row_reduction_destination_t<LeftOperandType, RightOperandType,
+ ElementType, CoordType, CoopArgs...>
+ get_row_reduction_destination_cooperative_tensor() thread const
+ {
+ return __mutmul2d_detail::__get_row_reduction_destination_cooperative_tensor<
+ Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>();
+ }
+
+
+ template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+ typename CoordType, typename... CoopArgs>
+ using cooperative_tensor_column_reduction_destination_t =
+ __mutmul2d_detail::__cooperative_tensor_column_reduction_destination_t<
+ Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>;
+
+ template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+ typename CoordType = int, typename U = __tensor_ops_detail::__enable_if_t<
+ __tensor_ops_detail::__is_integral_v<CoordType>>,
+ typename... CoopArgs>
+ INLINE cooperative_tensor_column_reduction_destination_t<LeftOperandType, RightOperandType,
+ ElementType, CoordType, CoopArgs...>
+ get_column_reduction_destination_cooperative_tensor() thread const
+ {
+ return __mutmul2d_detail::__get_column_reduction_destination_cooperative_tensor<
+ Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>();
+ }
};
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
inline void reduce_rows(
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &source,
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &destination,
+ thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &source,
+ thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destination,
reduction_operation op = reduction_operation::sum,
ElementType identity =
reduction_operation_identity<ElementType>::sum_identity)
{
- __mutmul2d_detail::__reduce_rows<ElementType, Extents, Layout>(
+ __mutmul2d_detail::__reduce_rows<ElementType, SrcExtents, DstExtents, SrcLayout, DstLayout>(
source, destination, identity, op);
}
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
inline void reduce_columns(
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &source,
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &destination,
+ thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &source,
+ thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destination,
reduction_operation op = reduction_operation::sum,
ElementType identity =
reduction_operation_identity<ElementType>::sum_identity)
{
- __mutmul2d_detail::__reduce_columns<ElementType, Extents, Layout>(
+ __mutmul2d_detail::__reduce_columns<ElementType, SrcExtents, DstExtents, SrcLayout, DstLayout>(
source, destination, identity, op);
+}
+
+// Returns whether the iterators are compatible between a source and destination cooperative tensor.
+//
+// Use this to check whether map_iterator will be return a valid iterator. For example:
+//
+// if (is_iterator_compatible(sourceCT, destCT)) {
+// for (auto it = sourceCT.begin(); it != sourceCT.end(); it++) {
+// auto dst_it = destCT.map_iterator(sourceCT)
+//
+// *it += *dst_it;
+// }
+// }
+// else {
+// // Fall back to storing sourceCT to threadgroup memory and access via
+// // destCT's multidimensional indices
+// }
+template <class SrcElementType, class DstElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
+inline bool is_iterator_compatible(
+ const thread metal::cooperative_tensor<SrcElementType, SrcExtents, SrcLayout> &source,
+ const thread metal::cooperative_tensor<DstElementType, DstExtents, DstLayout> &destination)
+{
+ return __mutmul2d_detail::__is_iterator_compatible<SrcElementType, DstElementType, SrcExtents, DstExtents,
+ SrcLayout, DstLayout>(source, destination);
}
} // namespace tensor_ops
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h 2025-07-11 23:47:32
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h 2025-07-26 06:00:32
@@ -4395,6 +4395,9 @@
using a_elem_type = typename a_type::element_type;
using w_elem_type = typename w_type::element_type;
+ using a_value_type = typename a_type::value_type;
+ using w_value_type = typename w_type::value_type;
+
static size_t thread_storage_size()
{
metal::execution_threads t = scope();
@@ -4402,9 +4405,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
return __tensorops_impl_conv2d_cooperative_destination_data_size(
descriptor, d_data_type, a_data_type, w_data_type, threads);
}
@@ -4414,7 +4417,7 @@
return alignof(element_t);
};
- static uint16_t size(const_thread_storage_t storage)
+ static uint16_t get_capacity(const_thread_storage_t storage)
{
metal::execution_threads t = scope();
int threads = t.size();
@@ -4429,9 +4432,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
__tensorops_impl_conv2d_cooperative_destination_tensor_init(
this_, descriptor, d_data_type, a_data_type, w_data_type, threads);
@@ -4444,9 +4447,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
__tensorops_impl_conv2d_cooperative_destination_tensor_copy(
this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4460,9 +4463,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
__tensorops_impl_conv2d_cooperative_destination_tensor_move(
this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4476,9 +4479,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
__tensorops_impl_conv2d_cooperative_destination_tensor_copy(
this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4492,9 +4495,9 @@
__tensor_ops_detail::__tensor_ops_datatype d_data_type =
__tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
__tensor_ops_detail::__tensor_ops_datatype a_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
__tensor_ops_detail::__tensor_ops_datatype w_data_type =
- __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
__tensorops_impl_conv2d_cooperative_destination_tensor_move(
this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4693,8 +4696,8 @@
"Unsupported type");
};
- static thread element_t *get_pointer_to(const_thread_storage_t storage,
- index_t idx)
+ static thread element_t *get_element_pointer(const_thread_storage_t storage,
+ index_t idx)
{
metal::execution_threads t = scope();
int threads = t.size();
@@ -4716,9 +4719,14 @@
return (thread element_t *)
__tensorops_impl_conv2d_cooperative_destination_tensor_elements(
(thread_storage_t)storage, idx, dataType, threads);
- };
+ }
- static bool mask(const_thread_storage_t storage, index_t idx)
+ static index_t get_element_index(thread_storage_t storage,
+ const thread element_type *) {
+ // TODO
+ }
+
+ static bool is_valid_element(const_thread_storage_t storage, index_t idx)
{
metal::execution_threads t = scope();
int threads = t.size();
@@ -4744,7 +4752,7 @@
template <typename index_t, __tensor_ops_detail::__rank_t rank>
static metal::array<index_t, rank>
- multidimensional_indices(const_thread_storage_t storage, index_t idx)
+ get_multidimensional_index(const_thread_storage_t storage, index_t idx)
{
metal::execution_threads t = scope();
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h 2025-07-11 23:40:46
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h 2025-07-25 22:24:17
@@ -27,131 +27,455 @@
using __reduction_operation = reduction_operation;
+constexpr bool matmul2d_descriptor_is_equal(matmul2d_descriptor a, matmul2d_descriptor b) {
+ return a.m == b.m &&
+ a.n == b.n &&
+ a.k == b.k &&
+ a.transpose_left == b.transpose_left &&
+ a.transpose_right == b.transpose_right &&
+ a.relaxed_precision == b.relaxed_precision &&
+ a.matmul_mode == b.matmul_mode;
+}
+
extern "C" EXTERNALLY_DEFINED_ATTR size_t
__tensorops_impl_matmul2d_op_cooperative_destination_data_size(
- const __matmul2d_descriptor descriptor, const int threads);
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_num_elements(
- const __matmul2d_descriptor descriptor, const int threads);
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR thread void *
-__tensorops_impl_matmul2d_op_cooperative_destination_tensor_elements(
- __tensor_ops_detail::__thread_void_t, uint16_t,
+__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_pointer(
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ uint16_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
__tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR thread uint16_t
+__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_index(
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
- const __matmul2d_descriptor descriptor,
- __tensor_ops_detail::__thread_void_t, uint16_t,
- __tensor_ops_detail::__tensor_ops_datatype, thread void *,
- __tensor_ops_detail::__tensor_ops_datatype, const int);
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ uint16_t,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_init(
- __tensor_ops_detail::__thread_void_t, __matmul2d_descriptor,
- __tensor_ops_detail::__tensor_ops_datatype, const int);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR bool
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_is_valid_element(
- const __matmul2d_descriptor descriptor,
- __tensor_ops_detail::__thread_void_t, uint16_t,
- __tensor_ops_detail::__tensor_ops_datatype, const int);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ uint16_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f16(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f16(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_i32(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_i32(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f32(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f32(
- thread __matmul2d_descriptor &desc, thread void *storage,
- const thread void *source,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
- int sourceRank, int threads);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f16(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f16(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_i32(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_i32(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f32(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f32(
- thread __matmul2d_descriptor &desc, const thread void *storage,
- const thread void *destination,
- __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
int threads);
+extern "C" EXTERNALLY_DEFINED_ATTR size_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_data_size(
+ __matmul2d_descriptor,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
+extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_num_elements(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
+extern "C" EXTERNALLY_DEFINED_ATTR thread void *
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_pointer(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ uint16_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR thread uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_index(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+ __matmul2d_descriptor,
+ int,
+ __tensor_ops_detail::__const_thread_void_t,
+ uint16_t,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_init(
+ __tensor_ops_detail::__thread_void_t,
+ __matmul2d_descriptor,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
+extern "C" EXTERNALLY_DEFINED_ATTR bool
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_is_valid_element(
+ __matmul2d_descriptor descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ int,
+ uint16_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ int);
+extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_map_index(
+ __tensor_ops_detail::__const_thread_void_t,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __matmul2d_descriptor,
+ int,
+ int,
+ uint16_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR bool
+__tensorops_impl_matmul2d_op_cooperative_destination_is_iterator_compatible(
+ __matmul2d_descriptor,
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f16(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f16(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_i32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_i32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f16(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f16(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_i32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_i32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f32(
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+ int,
+ int,
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f16(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, half identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ half,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f32(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, float identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ float,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_i32(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, int identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ int,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f16(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, half identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ half,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f32(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, float identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ float,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_i32(
- thread __matmul2d_descriptor &desc, const thread void *src,
- thread void *dst, int identity, __reduction_operation op);
+ __matmul2d_descriptor,
+ __tensor_ops_detail::__const_thread_void_t,
+ __tensor_ops_detail::__thread_void_t,
+ int,
+ __reduction_operation,
+ __tensor_ops_detail::__tensor_ops_datatype,
+ __tensor_ops_detail::__tensor_ops_datatype);
extern "C" EXTERNALLY_DEFINED_ATTR void
__tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f16(
@@ -2318,7 +2642,8 @@
template <__matmul2d_descriptor descriptor,
__matmul2d_cooperative_operand_index operand_index, typename scope,
- typename element_type, typename coord_type, typename... args>
+ typename left_operand, typename right_operand, typename element_type,
+ typename coord_type, typename... args>
struct __operand_layout
{
@@ -2327,12 +2652,9 @@
"only destination can be cooperative tensor");
static_assert(__tensor_ops_detail::__is_same_v<element_type, float> ||
__tensor_ops_detail::__is_same_v<element_type, half> ||
-#if __HAVE_BFLOAT__
- __tensor_ops_detail::__is_same_v<element_type, bfloat> ||
-#endif
__tensor_ops_detail::__is_same_v<element_type, int32_t>,
"cooperative tensor data type can only be one of "
- "float/half/bfloat/int32_t");
+ "float/half/int32_t");
static constant constexpr __tensor_ops_detail::__rank_t rank = 2;
using element_t = element_type;
@@ -2342,11 +2664,20 @@
using const_thread_storage_t = const thread void *;
using index_t = uint16_t;
using operand_layout_t =
- __operand_layout<descriptor, operand_index, scope, element_t, coord_t>;
+ __operand_layout<descriptor, operand_index, scope, left_operand, right_operand, element_t, coord_t>;
using cooperative_tensor_t =
metal::cooperative_tensor<element_t, extent_t, operand_layout_t>;
using scope_t = scope;
+ using left_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<left_operand>>;
+ using right_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<right_operand>>;
+
+ using left_elem_t = typename left_t::element_type;
+ using right_elem_t = typename right_t::element_type;
+
+ using left_value_t = typename left_t::value_type;
+ using right_value_t = typename right_t::value_type;
+
static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
"scope should be of type __tensorops_scope");
@@ -2367,7 +2698,7 @@
{
thread element_t *this_e = (thread element_t *)(this_);
thread element_t *other_e = (thread element_t *)(other);
- for (size_t i = 0, e = size(this_); i != e; ++i)
+ for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
{
other_e[i] = this_e[i];
}
@@ -2385,7 +2716,7 @@
{
thread element_t *this_e = (thread element_t *)(this_);
thread element_t *other_e = (thread element_t *)(other);
- for (size_t i = 0, e = size(this_); i != e; ++i)
+ for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
{
other_e[i] = this_e[i];
}
@@ -2405,8 +2736,16 @@
{
metal::execution_threads t = scope();
int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype elemDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
+
return __tensorops_impl_matmul2d_op_cooperative_destination_data_size(
- descriptor, threads);
+ descriptor, leftDataType, rightDataType, elemDataType, threads);
}
template <class ElemType, class Extents, class Descriptor, class... Tags>
@@ -2438,15 +2777,22 @@
const thread void *source = (const thread void *)(&sourceT);
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
{
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f16(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f16(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
"Unsupported address space");
@@ -2455,11 +2801,13 @@
{
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_i32(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_i32(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
"Unsupported address space");
@@ -2468,11 +2816,13 @@
{
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f32(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
sourcePtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f32(
- desc, storage, source, sourceDescType, sourceRank, threads);
+ desc, storage, source, sourceDescType, sourceRank, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
"Unsupported address space");
@@ -2510,16 +2860,23 @@
const thread void *destination = (const thread void *)(&destinationT);
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
{
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f16(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f16(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
"Unsupported address space");
@@ -2529,11 +2886,13 @@
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_i32(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_i32(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
"Unsupported address space");
@@ -2543,11 +2902,13 @@
if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f32(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
destinationPtrType>)
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f32(
- desc, storage, destination, destinationDescType, threads);
+ desc, storage, destination, destinationDescType, leftDataType,
+ rightDataType, threads);
else
static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
"Unsupported address space");
@@ -2557,107 +2918,108 @@
"Unsupported type");
};
- static uint16_t size(const_thread_storage_t storage)
+ static uint16_t get_capacity(const_thread_storage_t storage)
{
metal::execution_threads t = scope();
int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
return __tensorops_impl_matmul2d_op_cooperative_destination_tensor_num_elements(
- descriptor, threads);
+ descriptor, storage, leftDataType, rightDataType, threads);
}
- static thread element_t *get_pointer_to(const_thread_storage_t storage,
- index_t idx)
+ static thread element_t *get_element_pointer(const_thread_storage_t storage,
+ index_t idx)
{
- __tensor_ops_detail::__tensor_ops_datatype dataType;
- if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
- else
- static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
- "unsupported data type");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
return (thread element_t *)
- __tensorops_impl_matmul2d_op_cooperative_destination_tensor_elements(
- (thread_storage_t)storage, idx, dataType);
- };
+ __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_pointer(
+ descriptor, (thread_storage_t)storage, idx, leftDataType,
+ rightDataType, dataType);
+ }
- static bool mask(const_thread_storage_t storage, index_t idx)
+ static index_t get_element_index(const_thread_storage_t storage,
+ const thread element_type *element)
{
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ return (index_t)
+ __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_index(
+ descriptor, (thread_storage_t)storage, element, leftDataType,
+ rightDataType, dataType);
+ }
+
+ static bool is_valid_element(const_thread_storage_t storage, index_t idx)
+ {
metal::execution_threads t = scope();
int threads = t.size();
- __tensor_ops_detail::__tensor_ops_datatype dataType;
- if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
- else
- static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
- "unsupported data type");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
return __tensorops_impl_matmul2d_op_cooperative_destination_tensor_is_valid_element(
descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
- dataType, threads);
+ leftDataType, rightDataType, dataType, threads);
}
template <typename index_t, __tensor_ops_detail::__rank_t rank = 2>
static metal::array<index_t, rank>
- multidimensional_indices(const_thread_storage_t storage, index_t idx)
+ get_multidimensional_index(const_thread_storage_t storage, index_t idx)
{
metal::execution_threads t = scope();
int threads = t.size();
- __tensor_ops_detail::__tensor_ops_datatype dataType;
- if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
- else
- static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
- "unsupported data type");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
if constexpr (__tensor_ops_detail::__is_same_v<coord_t, ushort>)
{
ushort coords[2];
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
- descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
- dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
- threads);
+ descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
+ threads, leftDataType, rightDataType, dataType);
return {coords[0], coords[1]};
}
else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, short>)
{
short coords[2];
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
- descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
- dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
- threads);
+ descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
+ threads, leftDataType, rightDataType, dataType);
return {coords[0], coords[1]};
}
else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, uint>)
{
uint coords[2];
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
- descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
- dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
- threads);
+ descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
+ threads, leftDataType, rightDataType, dataType);
;
return {coords[0], coords[1]};
}
@@ -2665,67 +3027,540 @@
{
int coords[2];
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
- descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
- dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
- threads);
+ descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
+ threads, leftDataType, rightDataType, dataType);
return {coords[0], coords[1]};
}
+ else {
+ static_assert(__tensor_ops_detail::__assert_false_v<coord_t>,
+ "unsupported coordinate data type");
+ }
}
static void construct(thread_storage_t storage)
{
metal::execution_threads t = scope();
int threads = t.size();
- __tensor_ops_detail::__tensor_ops_datatype dataType;
- if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
- else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
- dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
- else
- static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
- "unsupported data type");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype elemDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
+
__tensorops_impl_matmul2d_op_cooperative_destination_tensor_init(
- (__tensor_ops_detail::__thread_void_t)storage, descriptor, dataType,
- threads);
+ descriptor, (__tensor_ops_detail::__thread_void_t)storage, leftDataType,
+ rightDataType, elemDataType, threads);
}
};
template <__matmul2d_descriptor descriptor,
__matmul2d_cooperative_operand_index operand_index, typename scope,
- typename element_type, typename coord_type, typename... args>
+ typename left_operand, typename right_operand, typename element_type, typename coord_type, typename... args>
using __cooperative_tensor_t =
- typename __operand_layout<descriptor, operand_index, scope, element_type,
+ typename __operand_layout<descriptor, operand_index, scope, left_operand, right_operand, element_type,
coord_type, args...>::cooperative_tensor_t;
-template <__matmul2d_descriptor descriptor, typename scope,
+template <__matmul2d_descriptor descriptor, typename scope, typename left_operand, typename right_operand,
typename element_type, typename coord_type, typename... args>
using __cooperative_tensor_destination_t =
__cooperative_tensor_t<descriptor,
matmul2d_cooperative_operand_index::destination,
- scope, element_type, coord_type, args...>;
+ scope, left_operand, right_operand, element_type, coord_type, args...>;
template <__matmul2d_descriptor descriptor, typename scope,
- typename element_type, typename coord_type, typename left_operand,
- typename right_operand, typename... args>
-__cooperative_tensor_destination_t<descriptor, scope, element_type, coord_type,
+ typename left_operand, typename right_operand, typename element_type, typename coord_type, typename... args>
+__cooperative_tensor_destination_t<descriptor, scope, left_operand, right_operand, element_type, coord_type,
args...>
__get_destination_cooperative_tensor()
{
static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
"scope should be of type __tensorops_scope");
- return __cooperative_tensor_destination_t<descriptor, scope, element_type,
+ return __cooperative_tensor_destination_t<descriptor, scope, left_operand, right_operand, element_type,
coord_type, args...>();
}
+template <__matmul2d_descriptor descriptor, int reduction_dim, typename scope,
+ typename left_operand, typename right_operand,
+ typename element_type, typename coord_type, typename... args>
+struct __reduction_operand_layout
+{
+ static_assert(__tensor_ops_detail::__is_same_v<element_type, float> ||
+ __tensor_ops_detail::__is_same_v<element_type, half> ||
+ __tensor_ops_detail::__is_same_v<element_type, int32_t>,
+ "cooperative tensor data type can only be one of "
+ "float/half/int32_t");
+
+ static constant constexpr __tensor_ops_detail::__rank_t rank = 1;
+ using element_t = element_type;
+ using coord_t = coord_type;
+ using extent_t = metal::dextents<coord_t, rank>;
+ using thread_storage_t = thread void *;
+ using const_thread_storage_t = const thread void *;
+ using index_t = uint16_t;
+ using operand_layout_t =
+ __reduction_operand_layout<descriptor, reduction_dim, scope, left_operand,
+ right_operand, element_t, coord_t>;
+ using cooperative_tensor_t =
+ metal::cooperative_tensor<element_t, extent_t, operand_layout_t>;
+ using scope_t = scope;
+
+ using left_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<left_operand>>;
+ using right_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<right_operand>>;
+
+ using left_elem_t = typename left_t::element_type;
+ using right_elem_t = typename right_t::element_type;
+
+ using left_value_t = typename left_t::value_type;
+ using right_value_t = typename right_t::value_type;
+
+ static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+ "scope should be of type __tensorops_scope");
+ static_assert(reduction_dim == 0 || reduction_dim == 1, "Reduction dimension must be 0 or 1");
+
+ static constexpr constant bool is_matmul2d_reduction_cooperative_destination_layout =
+ true;
+ static constexpr constant int __reduction_dim = reduction_dim;
+
+ static constexpr constant __matmul2d_descriptor matmul2d_desc = descriptor;
+
+ // Returns the alignment of the storage allocated in each thread
+ // for this cooperative_tensor.
+ static constexpr size_t thread_storage_align()
+ {
+ return alignof(element_t);
+ };
+
+ // Copy-constructs from the cooperative_tensor `other`.
+ static void copy_construct(thread void *this_, thread void *other)
+ {
+ thread element_t *this_e = (thread element_t *)(this_);
+ thread element_t *other_e = (thread element_t *)(other);
+ for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
+ {
+ other_e[i] = this_e[i];
+ }
+ };
+
+ // Move-constructs from the cooperative_tensor `other`.
+ static void move_construct(thread void *this_, thread void *other)
+ {
+ thread element_t *this_e = (thread element_t *)(this_);
+ thread element_t *other_e = this_e;
+ };
+
+ // Copy-assigns from the cooperative_tensor `other`.
+ static void copy_assign(thread void *this_, thread void *other)
+ {
+ thread element_t *this_e = (thread element_t *)(this_);
+ thread element_t *other_e = (thread element_t *)(other);
+ for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
+ {
+ other_e[i] = this_e[i];
+ }
+ };
+
+ // Move-assigns from the cooperative_tensor `other`.
+ static void move_assign(thread void *this_, thread void *other)
+ {
+ thread element_t *this_e = (thread element_t *)(this_);
+ thread element_t *other_e = this_e;
+ };
+
+ // Destroys the per-thread object.
+ static void destroy(thread void *) {};
+
+ static size_t thread_storage_size()
+ {
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_data_size(
+ descriptor, reduction_dim, leftDataType, rightDataType, elementDataType, threads);
+ }
+
+ template <class ElemType, class Extents, class Descriptor, class... Tags>
+ static void load(thread_storage_t storage,
+ const thread metal::tensor<ElemType, Extents, Descriptor,
+ Tags...> &sourceT)
+ {
+ using elem_t = __tensor_ops_detail::__remove_addrspace_t<ElemType>;
+
+ static_assert(__tensor_ops_detail::__is_same_v<elem_t, element_t>,
+ "Source tensor datatype does not match cooperative tensor");
+ static_assert(Extents::rank() == 1,
+ "Source tensor must be rank 1");
+
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __matmul2d_descriptor desc = descriptor;
+
+ using tensorType = metal::tensor<ElemType, Extents, Descriptor, Tags...>;
+
+ using sourcePtrType = typename tensorType::data_handle_type;
+
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType =
+ __tensor_ops_detail::__tensor_type_to_tensor_descriptor_type<
+ tensorType>();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+ const thread void *source = (const thread void *)(&sourceT);
+
+ if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f16(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f16(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+ "Unsupported address space");
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, int32_t>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_i32(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_i32(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+ "Unsupported address space");
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, float>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f32(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ sourcePtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f32(
+ desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+ "Unsupported address space");
+ }
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<elem_t>,
+ "Unsupported type");
+ };
+
+ template <class ElemType, class Extents, class Descriptor, class... Tags>
+ static void store(const_thread_storage_t storage,
+ const thread metal::tensor<ElemType, Extents, Descriptor,
+ Tags...> &destinationT)
+ {
+ using elem_t = __tensor_ops_detail::__remove_addrspace_t<ElemType>;
+
+ static_assert(__tensor_ops_detail::__is_same_v<elem_t, element_t>,
+ "Tensor datatype does not match cooperative tensor");
+ static_assert(Extents::rank() == 1,
+ "Tensor must be rank 1");
+
+ __matmul2d_descriptor desc = descriptor;
+
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ using tensorType = metal::tensor<ElemType, Extents, Descriptor, Tags...>;
+
+ using destinationPtrType = typename tensorType::data_handle_type;
+
+ __tensor_ops_detail::__tensor_ops_tensor_descriptor_type
+ destinationDescType =
+ __tensor_ops_detail::__tensor_type_to_tensor_descriptor_type<
+ tensorType>();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+ const thread void *destination = (const thread void *)(&destinationT);
+
+ if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f16(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f16(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+ "Unsupported address space");
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, int32_t>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_i32(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_i32(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+ "Unsupported address space");
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, float>)
+ {
+ if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f32(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+ destinationPtrType>)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f32(
+ desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+ "Unsupported address space");
+ }
+ else
+ static_assert(__tensor_ops_detail::__assert_false_v<elem_t>,
+ "Unsupported type");
+ };
+
+ static uint16_t get_capacity(const_thread_storage_t storage)
+ {
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+ return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_num_elements(
+ descriptor, storage, reduction_dim, leftDataType, rightDataType, threads);
+ }
+
+ static thread element_t *get_element_pointer(const_thread_storage_t storage,
+ index_t idx)
+ {
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ return (thread element_t *)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_pointer(
+ descriptor, (thread_storage_t)storage, idx, leftDataType, rightDataType, dataType);
+ }
+
+ static index_t get_element_index(const_thread_storage_t storage,
+ const thread element_type *element)
+ {
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ return (index_t)
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_index(
+ descriptor, (thread_storage_t)storage, element, leftDataType, rightDataType, dataType);
+ }
+
+ static bool is_valid_element(const_thread_storage_t storage, index_t idx)
+ {
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_is_valid_element(
+ descriptor, (__tensor_ops_detail::__thread_void_t)storage, reduction_dim, idx,
+ leftDataType, rightDataType, dataType, threads);
+ }
+
+ template <typename index_t, __tensor_ops_detail::__rank_t rank = 1>
+ static metal::array<index_t, rank>
+ get_multidimensional_index(const_thread_storage_t storage, index_t idx)
+ {
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+ if constexpr (__tensor_ops_detail::__is_same_v<coord_t, ushort>)
+ {
+ ushort coords[1];
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+ descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
+ threads, leftDataType, rightDataType, elementDataType);
+ return { coords[0] };
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, short>)
+ {
+ short coords[1];
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+ descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
+ threads, leftDataType, rightDataType, elementDataType);
+ return { coords[0] };
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, uint>)
+ {
+ uint coords[1];
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+ descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
+ threads, leftDataType, rightDataType, elementDataType);
+ ;
+ return { coords[0] };
+ }
+ else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, int>)
+ {
+ int coords[1];
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+ descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+ coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
+ threads, leftDataType, rightDataType, elementDataType);
+ return { coords[0] };
+ }
+ else {
+ static_assert(__tensor_ops_detail::__assert_false_v<coord_t>,
+ "unsupported coordinate data type");
+ }
+ }
+
+ static void construct(thread_storage_t storage)
+ {
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+ __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_t>();
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+ __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_init(
+ (__tensor_ops_detail::__thread_void_t)storage, descriptor,
+ reduction_dim, leftDataType, rightDataType, elementDataType, threads);
+ }
+
+ template <class FromIterator, class ToIterator>
+ static uint16_t map_index(const thread void *from_storage, uint16_t from_idx,
+ const thread void *to_storage)
+ {
+ using sourceLayout = typename FromIterator::layout;
+ using destLayout = typename ToIterator::layout;
+
+ static_assert(sourceLayout::is_matmul2d_cooperative_destination_layout,
+ "Source must be a matmul2d destination cooperative tensor");
+ static_assert(destLayout::is_matmul2d_reduction_cooperative_destination_layout,
+ "Destination must be a matmul2d reduction destination cooperative tensor");
+ static_assert(__tensor_ops_detail::__is_same_v<typename sourceLayout::scope_t, metal::execution_simdgroup>,
+ "map_index requires a single SIMD group");
+ static_assert(__tensor_ops_detail::__is_same_v<typename destLayout::scope_t, metal::execution_simdgroup>,
+ "map_index requires a single SIMD group");
+
+ metal::execution_threads t = scope();
+ int threads = t.size();
+
+ constexpr __matmul2d_descriptor sourceDesc = sourceLayout::matmul2d_desc;
+ constexpr __matmul2d_descriptor destDesc = destLayout::matmul2d_desc;
+
+ static_assert(reduction_dim == 0 || sourceDesc.n == destDesc.n, "Source and destination must have matching N dimension if reduction_dim = 1");
+ static_assert(reduction_dim == 1 || sourceDesc.m == destDesc.m, "Source and destination must have matching M dimension if reduction_dim = 0");
+
+ static_assert(__tensor_ops_detail::__is_same_v<typename sourceLayout::element_t, typename destLayout::element_t>, "Source and destination element types must match");
+
+ __tensor_ops_detail::__tensor_ops_datatype srcLeftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename sourceLayout::left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype srcRightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename sourceLayout::right_value_t>::value;
+
+ return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_map_index(
+ from_storage, sourceDesc,
+ to_storage, destDesc,
+ reduction_dim, threads, from_idx,
+ srcLeftDataType, srcRightDataType);
+ }
+};
+
template <__matmul2d_descriptor descriptor, typename scope,
typename left_operand, typename right_operand,
+ typename element_type, typename coord_type, typename... args>
+using __cooperative_tensor_row_reduction_destination_t =
+ typename __reduction_operand_layout<descriptor, 0, scope, left_operand, right_operand,
+ element_type, coord_type, args...>::cooperative_tensor_t;
+
+template <__matmul2d_descriptor descriptor, typename scope,
+ typename left_operand, typename right_operand,
+ typename element_type, typename coord_type, typename... args>
+using __cooperative_tensor_column_reduction_destination_t =
+ typename __reduction_operand_layout<descriptor, 1, scope, left_operand, right_operand,
+ element_type, coord_type, args...>::cooperative_tensor_t;
+
+template <__matmul2d_descriptor descriptor, typename scope,
+ typename left_operand, typename right_operand,
+ typename element_type, typename coord_type, typename... args>
+__cooperative_tensor_row_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+ element_type, coord_type, args...>
+__get_row_reduction_destination_cooperative_tensor()
+{
+ static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+ "scope should be of type __tensorops_scope");
+ return __cooperative_tensor_row_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+ element_type, coord_type, args...>();
+}
+
+template <__matmul2d_descriptor descriptor, typename scope,
+ typename left_operand, typename right_operand,
+ typename element_type, typename coord_type, typename... args>
+__cooperative_tensor_column_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+ element_type, coord_type, args...>
+__get_column_reduction_destination_cooperative_tensor()
+{
+ static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+ "scope should be of type __tensorops_scope");
+ return __cooperative_tensor_column_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+ element_type, coord_type, args...>();
+}
+
+template <__matmul2d_descriptor descriptor, typename scope,
+ typename left_operand, typename right_operand,
typename destination_operand, typename... args>
void __run(thread left_operand &leftIn, thread right_operand &rightIn,
thread destination_operand &destinationT)
@@ -4566,7 +5401,7 @@
}
else
{
- thread void *destination = (thread void *)&destinationT[0];
+ thread void *destination = (thread void *)&destinationT[__tensor_ops_detail::__tensor_ops_reserved_index];
if constexpr (__tensor_ops_detail::__is_same_v<leftValueType, half> &&
__tensor_ops_detail::__is_same_v<rightValueType, half> &&
@@ -5054,72 +5889,159 @@
}
}
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
inline void __reduce_rows(
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &sourceT,
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &destT,
+ thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &sourceT,
+ thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destT,
ElementType identity = (ElementType)0,
__reduction_operation op = reduction_operation::sum)
{
- static_assert(Layout::is_matmul2d_cooperative_destination_layout,
- "Source and destination must be matmul2d cooperative "
- "destination tensors");
- static_assert(__tensor_ops_detail::__is_same_v<typename Layout::scope_t,
+ static_assert(SrcLayout::is_matmul2d_cooperative_destination_layout,
+ "Source must be matmul2d cooperative destination tensor");
+ static_assert(DstLayout::is_matmul2d_reduction_cooperative_destination_layout,
+ "Destination must be matmul2d row reduction cooperative destination tensor");
+ static_assert(DstLayout::__reduction_dim == 0,
+ "Destination must be matmul2d row reduction cooperative destination tensor");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t,
metal::execution_simdgroup>,
"reduce_rows requires a single SIMD group");
- static_assert(Extents::rank() == 2, "Rank must be 2");
+ static_assert(__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t,
+ metal::execution_simdgroup>,
+ "reduce_rows requires a single SIMD group");
+ static_assert(SrcExtents::rank() == 2, "Source rank must be 2");
+ static_assert(DstExtents::rank() == 1, "Destination rank must be 1");
- thread void *src = (thread void *)&sourceT[0];
- thread void *dst = (thread void *)&destT[0];
+ constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+ constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
- __matmul2d_descriptor desc = Layout::matmul2d_desc;
+ static_assert(matmul2d_descriptor_is_equal(sourceDesc, destDesc), "Source and destination matmul2d descriptor must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::left_t, typename DstLayout::left_t>, "Source and destination operand types must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::right_t, typename DstLayout::right_t>, "Source and destination operand types must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::element_t, typename DstLayout::element_t>, "Source and destination element types must match");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+
+ thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+ thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+ __matmul2d_descriptor desc = SrcLayout::matmul2d_desc;
+
if constexpr (__tensor_ops_detail::__is_same_v<ElementType, half>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f16(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, int32_t>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_i32(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, float>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f32(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else
static_assert(__tensor_ops_detail::__assert_false_v<ElementType>,
"Unsupported type");
}
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
inline void __reduce_columns(
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &sourceT,
- thread metal::cooperative_tensor<ElementType, Extents, Layout> &destT,
+ thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &sourceT,
+ thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destT,
ElementType identity = (ElementType)0,
__reduction_operation op = reduction_operation::sum)
{
- static_assert(Layout::is_matmul2d_cooperative_destination_layout,
- "Source and destination must be matmul2d cooperative "
- "destination tensors");
- static_assert(__tensor_ops_detail::__is_same_v<typename Layout::scope_t,
+ static_assert(SrcLayout::is_matmul2d_cooperative_destination_layout,
+ "Source must be matmul2d cooperative destination tensor");
+ static_assert(DstLayout::is_matmul2d_reduction_cooperative_destination_layout,
+ "Destination must be matmul2d column reduction cooperative destination tensor");
+ static_assert(DstLayout::__reduction_dim == 1,
+ "Destination must be matmul2d column reduction cooperative destination tensor");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t,
metal::execution_simdgroup>,
"reduce_columns requires a single SIMD group");
- static_assert(Extents::rank() == 2, "Rank must be 2");
+ static_assert(__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t,
+ metal::execution_simdgroup>,
+ "reduce_columns requires a single SIMD group");
+ static_assert(SrcExtents::rank() == 2, "Source rank must be 2");
+ static_assert(DstExtents::rank() == 1, "Destination rank must be 1");
- thread void *src = (thread void *)&sourceT[0];
- thread void *dst = (thread void *)&destT[0];
+ constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+ constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
- __matmul2d_descriptor desc = Layout::matmul2d_desc;
+ static_assert(matmul2d_descriptor_is_equal(sourceDesc, destDesc), "Source and destination matmul2d descriptor must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::left_t, typename DstLayout::left_t>, "Source and destination operand types must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::right_t, typename DstLayout::right_t>, "Source and destination operand types must match");
+ static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::element_t, typename DstLayout::element_t>, "Source and destination element types must match");
+ __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+
+ thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+ thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+ __matmul2d_descriptor desc = SrcLayout::matmul2d_desc;
+
if constexpr (__tensor_ops_detail::__is_same_v<ElementType, half>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f16(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, int32_t>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_i32(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, float>)
__tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f32(
- desc, src, dst, identity, op);
+ desc, src, dst, identity, op, leftDataType, rightDataType);
else
static_assert(__tensor_ops_detail::__assert_false_v<ElementType>,
"Unsupported type");
+}
+
+template <class SrcElementType, class DstElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
+inline bool __is_iterator_compatible(
+ thread metal::cooperative_tensor<SrcElementType, SrcExtents, SrcLayout> &sourceT,
+ thread metal::cooperative_tensor<DstElementType, DstExtents, DstLayout> &destT)
+{
+ if (!SrcLayout::is_matmul2d_cooperative_destination_layout ||
+ !DstLayout::is_matmul2d_reduction_cooperative_destination_layout ||
+ !__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t, metal::execution_simdgroup> ||
+ !__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t, metal::execution_simdgroup> ||
+ !__tensor_ops_detail::__is_same_v<SrcElementType, DstElementType> ||
+ SrcExtents::rank() != 2 || DstExtents::rank() != 1)
+ {
+ return false;
+ }
+
+ constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+ constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
+
+ constexpr int reduction_dim = DstLayout::__reduction_dim;
+
+ if ((reduction_dim == 0 && sourceDesc.m != destDesc.m) ||
+ (reduction_dim == 1 && sourceDesc.n == destDesc.n))
+ {
+ return false;
+ }
+
+ thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+ thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+ __tensor_ops_detail::__tensor_ops_datatype srcLeftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype srcRightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype srcElemDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<SrcElementType>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dstLeftDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename DstLayout::left_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dstRightDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<typename DstLayout::right_value_t>::value;
+ __tensor_ops_detail::__tensor_ops_datatype dstElemDataType =
+ __tensor_ops_detail::__type_to_tensor_ops_datatype<DstElementType>::value;
+
+ return __tensorops_impl_matmul2d_op_cooperative_destination_is_iterator_compatible(
+ sourceDesc, destDesc, src, dst, srcLeftDataType, srcRightDataType,
+ srcElemDataType, dstLeftDataType, dstRightDataType, dstElemDataType);
}
#undef EXTERNALLY_DEFINED_ATTR
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h 2025-07-11 22:52:25
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h 2025-07-26 22:15:31
@@ -23,10 +23,7 @@
template <typename T>
constexpr inline int __get_rank()
{
- if constexpr (__is_cooperative_tensor_type_v<T>)
- return T::rank();
- else // tensor
- return T::get_rank();
+ return T::get_rank();
}
using __rank_t = ushort;
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h 2025-07-11 23:30:03
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h 2025-07-26 06:00:32
@@ -33,7 +33,6 @@
template <typename T>
struct __type_to_tensor_ops_datatype
{
- static constant __tensor_ops_datatype value = __tensor_ops_datatype_invalid;
};
template <>
struct __type_to_tensor_ops_datatype<float>