Skip to content

MetalPerformancePrimitives iOS xcode26.0 b5

Rolf Bjarne Kvinge edited this page Aug 6, 2025 · 2 revisions

#MetalPerformancePrimitives.framework https://github.com/dotnet/macios/issues/23418

diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h	2025-07-11 22:46:13
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/MPPTensorOpsMatMul2d.h	2025-07-26 22:15:31
@@ -113,7 +113,7 @@
 //
 // Above matrix multiplication implementation will do edge checking for all
 // thread groups against extents of original tensor although for large enough
-// matrices most of thread groups will be working on "inside" tiles, requring no
+// matrices most of thread groups will be working on "inside" tiles, requiring no
 // bounds check. In high performance code we can avoid edge checking for inside
 // thread groups and get better performance
 //
@@ -228,7 +228,7 @@
 // number of threads in opscope with which op was created. Note that
 // cooperative_tensor created from an op is only valid for threads that are part
 // of opscope on which op was created. Though the layout of cooperative_tensor
-// is implemtation defined, we provide accessor functions as shown in the
+// is implementation defined, we provide accessor functions as shown in the
 // example below
 //
 // kernel void simpleMatMulCooperative(tensor<device half,  dextents<int32_t,
@@ -347,6 +347,7 @@
 
 #include "__impl/MPPTensorOpsBase.h"
 #include "__impl/MPPTensorOpsTypes.h"
+#include "__impl/MPPTensorOpsUtility.h"
 #include <metal_numeric>
 
 #pragma METAL internals : enable
@@ -438,10 +439,10 @@
                              RunArgs...>(left, right, destination);
   }
 
-  template <typename ElementType, typename CoordType, typename... CoopArgs>
+  template <typename LeftOperand, typename RightOperand, typename ElementType, typename CoordType, typename... CoopArgs>
   using cooperative_tensor_destination_t =
       __mutmul2d_detail::__cooperative_tensor_destination_t<
-          Descriptor, Scope, ElementType, CoordType, CoopArgs...>;
+          Descriptor, Scope, LeftOperand, RightOperand, ElementType, CoordType, CoopArgs...>;
 
   template <typename LeftOperandType, typename RightOperandType,
             typename ElementType, typename CoordType = int,
@@ -451,38 +452,100 @@
                 __tensor_ops_detail::__is_thread_addrspace_v<ElementType> &&
                 __tensor_ops_detail::__is_integral_v<CoordType>>,
             typename... CoopArgs>
-  INLINE cooperative_tensor_destination_t<ElementType, CoordType, CoopArgs...>
+  INLINE cooperative_tensor_destination_t<LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>
   get_destination_cooperative_tensor() thread const
   {
 
     return __mutmul2d_detail::__get_destination_cooperative_tensor<
-        Descriptor, Scope, ElementType, CoordType, LeftOperandType,
-        RightOperandType, CoopArgs...>();
+        Descriptor, Scope, LeftOperandType, RightOperandType, ElementType,
+        CoordType, CoopArgs...>();
   }
+
+  template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+            typename CoordType, typename... CoopArgs>
+  using cooperative_tensor_row_reduction_destination_t =
+      __mutmul2d_detail::__cooperative_tensor_row_reduction_destination_t<
+          Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>;
+
+  template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+            typename CoordType = int,
+            typename U = __tensor_ops_detail::__enable_if_t<
+                __tensor_ops_detail::__is_integral_v<CoordType>>,
+            typename... CoopArgs>
+  INLINE cooperative_tensor_row_reduction_destination_t<LeftOperandType, RightOperandType,
+                                                        ElementType, CoordType, CoopArgs...>
+  get_row_reduction_destination_cooperative_tensor() thread const
+  {
+    return __mutmul2d_detail::__get_row_reduction_destination_cooperative_tensor<
+        Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>();
+  }
+
+
+  template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+            typename CoordType, typename... CoopArgs>
+  using cooperative_tensor_column_reduction_destination_t =
+      __mutmul2d_detail::__cooperative_tensor_column_reduction_destination_t<
+          Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>;
+
+  template <typename LeftOperandType, typename RightOperandType, typename ElementType,
+            typename CoordType = int, typename U = __tensor_ops_detail::__enable_if_t<
+                __tensor_ops_detail::__is_integral_v<CoordType>>,
+            typename... CoopArgs>
+  INLINE cooperative_tensor_column_reduction_destination_t<LeftOperandType, RightOperandType,
+                                                           ElementType, CoordType, CoopArgs...>
+  get_column_reduction_destination_cooperative_tensor() thread const
+  {
+    return __mutmul2d_detail::__get_column_reduction_destination_cooperative_tensor<
+        Descriptor, Scope, LeftOperandType, RightOperandType, ElementType, CoordType, CoopArgs...>();
+  }
 };
 
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
 inline void reduce_rows(
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &source,
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &destination,
+    thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &source,
+    thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destination,
     reduction_operation op = reduction_operation::sum,
     ElementType identity =
         reduction_operation_identity<ElementType>::sum_identity)
 {
-  __mutmul2d_detail::__reduce_rows<ElementType, Extents, Layout>(
+  __mutmul2d_detail::__reduce_rows<ElementType, SrcExtents, DstExtents, SrcLayout, DstLayout>(
       source, destination, identity, op);
 }
 
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
 inline void reduce_columns(
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &source,
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &destination,
+    thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &source,
+    thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destination,
     reduction_operation op = reduction_operation::sum,
     ElementType identity =
         reduction_operation_identity<ElementType>::sum_identity)
 {
-  __mutmul2d_detail::__reduce_columns<ElementType, Extents, Layout>(
+  __mutmul2d_detail::__reduce_columns<ElementType, SrcExtents, DstExtents, SrcLayout, DstLayout>(
       source, destination, identity, op);
+}
+
+// Returns whether the iterators are compatible between a source and destination cooperative tensor.
+//
+// Use this to check whether map_iterator will be return a valid iterator. For example:
+//
+//     if (is_iterator_compatible(sourceCT, destCT)) {
+//         for (auto it = sourceCT.begin(); it != sourceCT.end(); it++) {
+//             auto dst_it = destCT.map_iterator(sourceCT)
+//
+//             *it += *dst_it;
+//         }
+//     }
+//     else {
+//          // Fall back to storing sourceCT to threadgroup memory and access via
+//          // destCT's multidimensional indices
+//     }
+template <class SrcElementType, class DstElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
+inline bool is_iterator_compatible(
+    const thread metal::cooperative_tensor<SrcElementType, SrcExtents, SrcLayout> &source,
+    const thread metal::cooperative_tensor<DstElementType, DstExtents, DstLayout> &destination)
+{
+  return __mutmul2d_detail::__is_iterator_compatible<SrcElementType, DstElementType, SrcExtents, DstExtents,
+      SrcLayout, DstLayout>(source, destination);
 }
 
 } // namespace tensor_ops
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h	2025-07-11 23:47:32
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsConvolution2dImpl.h	2025-07-26 06:00:32
@@ -4395,6 +4395,9 @@
   using a_elem_type = typename a_type::element_type;
   using w_elem_type = typename w_type::element_type;
 
+  using a_value_type = typename a_type::value_type;
+  using w_value_type = typename w_type::value_type;
+
   static size_t thread_storage_size()
   {
     metal::execution_threads t = scope();
@@ -4402,9 +4405,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
     return __tensorops_impl_conv2d_cooperative_destination_data_size(
         descriptor, d_data_type, a_data_type, w_data_type, threads);
   }
@@ -4414,7 +4417,7 @@
     return alignof(element_t);
   };
 
-  static uint16_t size(const_thread_storage_t storage)
+  static uint16_t get_capacity(const_thread_storage_t storage)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
@@ -4429,9 +4432,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
 
     __tensorops_impl_conv2d_cooperative_destination_tensor_init(
         this_, descriptor, d_data_type, a_data_type, w_data_type, threads);
@@ -4444,9 +4447,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
 
     __tensorops_impl_conv2d_cooperative_destination_tensor_copy(
         this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4460,9 +4463,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
 
     __tensorops_impl_conv2d_cooperative_destination_tensor_move(
         this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4476,9 +4479,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
 
     __tensorops_impl_conv2d_cooperative_destination_tensor_copy(
         this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4492,9 +4495,9 @@
     __tensor_ops_detail::__tensor_ops_datatype d_data_type =
         __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
     __tensor_ops_detail::__tensor_ops_datatype a_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<a_value_type>::value;
     __tensor_ops_detail::__tensor_ops_datatype w_data_type =
-        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_elem_type>::value;
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<w_value_type>::value;
 
     __tensorops_impl_conv2d_cooperative_destination_tensor_move(
         this_, other, descriptor, d_data_type, a_data_type, w_data_type,
@@ -4693,8 +4696,8 @@
                     "Unsupported type");
   };
 
-  static thread element_t *get_pointer_to(const_thread_storage_t storage,
-                                          index_t idx)
+  static thread element_t *get_element_pointer(const_thread_storage_t storage,
+                                               index_t idx)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
@@ -4716,9 +4719,14 @@
     return (thread element_t *)
         __tensorops_impl_conv2d_cooperative_destination_tensor_elements(
             (thread_storage_t)storage, idx, dataType, threads);
-  };
+  }
 
-  static bool mask(const_thread_storage_t storage, index_t idx)
+  static index_t get_element_index(thread_storage_t storage,
+                                   const thread element_type *) {
+    // TODO
+  }
+
+  static bool is_valid_element(const_thread_storage_t storage, index_t idx)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
@@ -4744,7 +4752,7 @@
 
   template <typename index_t, __tensor_ops_detail::__rank_t rank>
   static metal::array<index_t, rank>
-  multidimensional_indices(const_thread_storage_t storage, index_t idx)
+  get_multidimensional_index(const_thread_storage_t storage, index_t idx)
   {
 
     metal::execution_threads t = scope();
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h	2025-07-11 23:40:46
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsMatMul2dImpl.h	2025-07-25 22:24:17
@@ -27,131 +27,455 @@
 
 using __reduction_operation = reduction_operation;
 
+constexpr bool matmul2d_descriptor_is_equal(matmul2d_descriptor a, matmul2d_descriptor b) {
+  return a.m == b.m &&
+         a.n == b.n &&
+         a.k == b.k &&
+         a.transpose_left == b.transpose_left &&
+         a.transpose_right == b.transpose_right &&
+         a.relaxed_precision == b.relaxed_precision &&
+         a.matmul_mode == b.matmul_mode;
+}
+
 extern "C" EXTERNALLY_DEFINED_ATTR size_t
 __tensorops_impl_matmul2d_op_cooperative_destination_data_size(
-    const __matmul2d_descriptor descriptor, const int threads);
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_num_elements(
-    const __matmul2d_descriptor descriptor, const int threads);
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR thread void *
-__tensorops_impl_matmul2d_op_cooperative_destination_tensor_elements(
-    __tensor_ops_detail::__thread_void_t, uint16_t,
+__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_pointer(
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    uint16_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR thread uint16_t
+__tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_index(
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
-    const __matmul2d_descriptor descriptor,
-    __tensor_ops_detail::__thread_void_t, uint16_t,
-    __tensor_ops_detail::__tensor_ops_datatype, thread void *,
-    __tensor_ops_detail::__tensor_ops_datatype, const int);
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    uint16_t,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_init(
-    __tensor_ops_detail::__thread_void_t, __matmul2d_descriptor,
-    __tensor_ops_detail::__tensor_ops_datatype, const int);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR bool
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_is_valid_element(
-    const __matmul2d_descriptor descriptor,
-    __tensor_ops_detail::__thread_void_t, uint16_t,
-    __tensor_ops_detail::__tensor_ops_datatype, const int);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    uint16_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f16(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f16(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_i32(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_i32(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f32(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f32(
-    thread __matmul2d_descriptor &desc, thread void *storage,
-    const thread void *source,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType,
-    int sourceRank, int threads);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
 
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f16(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f16(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_i32(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_i32(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f32(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f32(
-    thread __matmul2d_descriptor &desc, const thread void *storage,
-    const thread void *destination,
-    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type destDescType,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
     int threads);
 
+extern "C" EXTERNALLY_DEFINED_ATTR size_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_data_size(
+    __matmul2d_descriptor,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
+extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_num_elements(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
+extern "C" EXTERNALLY_DEFINED_ATTR thread void *
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_pointer(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    uint16_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR thread uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_index(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+    __matmul2d_descriptor,
+    int,
+    __tensor_ops_detail::__const_thread_void_t,
+    uint16_t,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_init(
+    __tensor_ops_detail::__thread_void_t,
+    __matmul2d_descriptor,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
+extern "C" EXTERNALLY_DEFINED_ATTR bool
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_is_valid_element(
+    __matmul2d_descriptor descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    int,
+    uint16_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    int);
+extern "C" EXTERNALLY_DEFINED_ATTR uint16_t
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_map_index(
+    __tensor_ops_detail::__const_thread_void_t,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __matmul2d_descriptor,
+    int,
+    int,
+    uint16_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
+extern "C" EXTERNALLY_DEFINED_ATTR bool
+__tensorops_impl_matmul2d_op_cooperative_destination_is_iterator_compatible(
+    __matmul2d_descriptor,
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f16(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f16(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_i32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_i32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f16(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f16(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_i32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_i32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+extern "C" EXTERNALLY_DEFINED_ATTR void
+__tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f32(
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type,
+    int,
+    int,
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType,
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType);
+
+extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f16(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, half identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    half,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f32(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, float identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    float,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_i32(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, int identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    int,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f16(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, half identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    half,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f32(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, float identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    float,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_i32(
-    thread __matmul2d_descriptor &desc, const thread void *src,
-    thread void *dst, int identity, __reduction_operation op);
+    __matmul2d_descriptor,
+    __tensor_ops_detail::__const_thread_void_t,
+    __tensor_ops_detail::__thread_void_t,
+    int,
+    __reduction_operation,
+    __tensor_ops_detail::__tensor_ops_datatype,
+    __tensor_ops_detail::__tensor_ops_datatype);
 
 extern "C" EXTERNALLY_DEFINED_ATTR void
 __tensorops_impl_matmul2d_op_run_dv_f16_dv_f16_dv_f16(
@@ -2318,7 +2642,8 @@
 
 template <__matmul2d_descriptor descriptor,
           __matmul2d_cooperative_operand_index operand_index, typename scope,
-          typename element_type, typename coord_type, typename... args>
+          typename left_operand, typename right_operand, typename element_type,
+          typename coord_type, typename... args>
 struct __operand_layout
 {
 
@@ -2327,12 +2652,9 @@
                 "only destination can be cooperative tensor");
   static_assert(__tensor_ops_detail::__is_same_v<element_type, float> ||
                     __tensor_ops_detail::__is_same_v<element_type, half> ||
-#if __HAVE_BFLOAT__
-                    __tensor_ops_detail::__is_same_v<element_type, bfloat> ||
-#endif
                     __tensor_ops_detail::__is_same_v<element_type, int32_t>,
                 "cooperative tensor data type can only be one of "
-                "float/half/bfloat/int32_t");
+                "float/half/int32_t");
 
   static constant constexpr __tensor_ops_detail::__rank_t rank = 2;
   using element_t = element_type;
@@ -2342,11 +2664,20 @@
   using const_thread_storage_t = const thread void *;
   using index_t = uint16_t;
   using operand_layout_t =
-      __operand_layout<descriptor, operand_index, scope, element_t, coord_t>;
+      __operand_layout<descriptor, operand_index, scope, left_operand, right_operand, element_t, coord_t>;
   using cooperative_tensor_t =
       metal::cooperative_tensor<element_t, extent_t, operand_layout_t>;
   using scope_t = scope;
 
+  using left_t  = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<left_operand>>;
+  using right_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<right_operand>>;
+
+  using left_elem_t  = typename left_t::element_type;
+  using right_elem_t = typename right_t::element_type;
+
+  using left_value_t  = typename left_t::value_type;
+  using right_value_t = typename right_t::value_type;
+
   static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
                 "scope should be of type __tensorops_scope");
 
@@ -2367,7 +2698,7 @@
   {
     thread element_t *this_e = (thread element_t *)(this_);
     thread element_t *other_e = (thread element_t *)(other);
-    for (size_t i = 0, e = size(this_); i != e; ++i)
+    for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
     {
       other_e[i] = this_e[i];
     }
@@ -2385,7 +2716,7 @@
   {
     thread element_t *this_e = (thread element_t *)(this_);
     thread element_t *other_e = (thread element_t *)(other);
-    for (size_t i = 0, e = size(this_); i != e; ++i)
+    for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
     {
       other_e[i] = this_e[i];
     }
@@ -2405,8 +2736,16 @@
   {
     metal::execution_threads t = scope();
     int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype elemDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
+
     return __tensorops_impl_matmul2d_op_cooperative_destination_data_size(
-        descriptor, threads);
+        descriptor, leftDataType, rightDataType, elemDataType, threads);
   }
 
   template <class ElemType, class Extents, class Descriptor, class... Tags>
@@ -2438,15 +2777,22 @@
 
     const thread void *source = (const thread void *)(&sourceT);
 
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
     if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
     {
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f16(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f16(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
                       "Unsupported address space");
@@ -2455,11 +2801,13 @@
     {
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_i32(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_i32(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
                       "Unsupported address space");
@@ -2468,11 +2816,13 @@
     {
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_dv_f32(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              sourcePtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_load_tg_f32(
-            desc, storage, source, sourceDescType, sourceRank, threads);
+            desc, storage, source, sourceDescType, sourceRank, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
                       "Unsupported address space");
@@ -2510,16 +2860,23 @@
 
     const thread void *destination = (const thread void *)(&destinationT);
 
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
     if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
     {
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
                         destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f16(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f16(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
                       "Unsupported address space");
@@ -2529,11 +2886,13 @@
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
                         destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_i32(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_i32(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
                       "Unsupported address space");
@@ -2543,11 +2902,13 @@
       if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
                         destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_dv_f32(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
                              destinationPtrType>)
         __tensorops_impl_matmul2d_op_cooperative_destination_tensor_store_tg_f32(
-            desc, storage, destination, destinationDescType, threads);
+            desc, storage, destination, destinationDescType, leftDataType,
+            rightDataType, threads);
       else
         static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
                       "Unsupported address space");
@@ -2557,107 +2918,108 @@
                     "Unsupported type");
   };
 
-  static uint16_t size(const_thread_storage_t storage)
+  static uint16_t get_capacity(const_thread_storage_t storage)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
     return __tensorops_impl_matmul2d_op_cooperative_destination_tensor_num_elements(
-        descriptor, threads);
+        descriptor, storage, leftDataType, rightDataType, threads);
   }
 
-  static thread element_t *get_pointer_to(const_thread_storage_t storage,
-                                          index_t idx)
+  static thread element_t *get_element_pointer(const_thread_storage_t storage,
+                                               index_t idx)
   {
-    __tensor_ops_detail::__tensor_ops_datatype dataType;
-    if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
-    else
-      static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
-                    "unsupported data type");
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
 
     return (thread element_t *)
-        __tensorops_impl_matmul2d_op_cooperative_destination_tensor_elements(
-            (thread_storage_t)storage, idx, dataType);
-  };
+        __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_pointer(
+            descriptor, (thread_storage_t)storage, idx, leftDataType,
+            rightDataType, dataType);
+  }
 
-  static bool mask(const_thread_storage_t storage, index_t idx)
+  static index_t get_element_index(const_thread_storage_t storage,
+                                   const thread element_type *element)
   {
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    return (index_t)
+        __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_element_index(
+            descriptor, (thread_storage_t)storage, element, leftDataType,
+            rightDataType, dataType);
+  }
+
+  static bool is_valid_element(const_thread_storage_t storage, index_t idx)
+  {
     metal::execution_threads t = scope();
     int threads = t.size();
-    __tensor_ops_detail::__tensor_ops_datatype dataType;
-    if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
-    else
-      static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
-                    "unsupported data type");
 
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
     return __tensorops_impl_matmul2d_op_cooperative_destination_tensor_is_valid_element(
         descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
-        dataType, threads);
+        leftDataType, rightDataType, dataType, threads);
   }
 
   template <typename index_t, __tensor_ops_detail::__rank_t rank = 2>
   static metal::array<index_t, rank>
-  multidimensional_indices(const_thread_storage_t storage, index_t idx)
+  get_multidimensional_index(const_thread_storage_t storage, index_t idx)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
-    __tensor_ops_detail::__tensor_ops_datatype dataType;
-    if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
-    else
-      static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
-                    "unsupported data type");
 
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
     if constexpr (__tensor_ops_detail::__is_same_v<coord_t, ushort>)
     {
       ushort coords[2];
       __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
-          descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
-          dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
-          threads);
+          descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
+          threads, leftDataType, rightDataType, dataType);
       return {coords[0], coords[1]};
     }
     else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, short>)
     {
       short coords[2];
       __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
-          descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
-          dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
-          threads);
+          descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
+          threads, leftDataType, rightDataType, dataType);
       return {coords[0], coords[1]};
     }
     else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, uint>)
     {
       uint coords[2];
       __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
-          descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
-          dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
-          threads);
+          descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
+          threads, leftDataType, rightDataType, dataType);
       ;
       return {coords[0], coords[1]};
     }
@@ -2665,67 +3027,540 @@
     {
       int coords[2];
       __tensorops_impl_matmul2d_op_cooperative_destination_tensor_get_coordinate(
-          descriptor, (__tensor_ops_detail::__thread_void_t)storage, idx,
-          dataType, coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
-          threads);
+          descriptor, (__tensor_ops_detail::__const_thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
+          threads, leftDataType, rightDataType, dataType);
       return {coords[0], coords[1]};
     }
+    else {
+        static_assert(__tensor_ops_detail::__assert_false_v<coord_t>,
+                    "unsupported coordinate data type");
+    }
   }
 
   static void construct(thread_storage_t storage)
   {
     metal::execution_threads t = scope();
     int threads = t.size();
-    __tensor_ops_detail::__tensor_ops_datatype dataType;
-    if constexpr (__tensor_ops_detail::__is_same_v<element_t, float>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float32;
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, half>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_float16;
-#if __HAVE_BFLOAT__
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, bfloat>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_bfloat16;
-#endif
-    else if constexpr (__tensor_ops_detail::__is_same_v<element_t, int32_t>)
-      dataType = __tensor_ops_detail::__tensor_ops_datatype_int32;
-    else
-      static_assert(__tensor_ops_detail::__assert_false_v<element_t>,
-                    "unsupported data type");
 
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype elemDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<element_t>::value;
+
     __tensorops_impl_matmul2d_op_cooperative_destination_tensor_init(
-        (__tensor_ops_detail::__thread_void_t)storage, descriptor, dataType,
-        threads);
+        descriptor, (__tensor_ops_detail::__thread_void_t)storage, leftDataType,
+        rightDataType, elemDataType, threads);
   }
 };
 
 template <__matmul2d_descriptor descriptor,
           __matmul2d_cooperative_operand_index operand_index, typename scope,
-          typename element_type, typename coord_type, typename... args>
+          typename left_operand, typename right_operand, typename element_type, typename coord_type, typename... args>
 using __cooperative_tensor_t =
-    typename __operand_layout<descriptor, operand_index, scope, element_type,
+    typename __operand_layout<descriptor, operand_index, scope, left_operand, right_operand, element_type,
                               coord_type, args...>::cooperative_tensor_t;
 
-template <__matmul2d_descriptor descriptor, typename scope,
+template <__matmul2d_descriptor descriptor, typename scope, typename left_operand, typename right_operand,
           typename element_type, typename coord_type, typename... args>
 using __cooperative_tensor_destination_t =
     __cooperative_tensor_t<descriptor,
                            matmul2d_cooperative_operand_index::destination,
-                           scope, element_type, coord_type, args...>;
+                           scope, left_operand, right_operand, element_type, coord_type, args...>;
 
 template <__matmul2d_descriptor descriptor, typename scope,
-          typename element_type, typename coord_type, typename left_operand,
-          typename right_operand, typename... args>
-__cooperative_tensor_destination_t<descriptor, scope, element_type, coord_type,
+          typename left_operand, typename right_operand, typename element_type, typename coord_type,  typename... args>
+__cooperative_tensor_destination_t<descriptor, scope, left_operand, right_operand, element_type, coord_type,
                                    args...>
 __get_destination_cooperative_tensor()
 {
   static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
                 "scope should be of type __tensorops_scope");
-  return __cooperative_tensor_destination_t<descriptor, scope, element_type,
+  return __cooperative_tensor_destination_t<descriptor, scope, left_operand, right_operand, element_type,
                                             coord_type, args...>();
 }
 
+template <__matmul2d_descriptor descriptor, int reduction_dim, typename scope,
+          typename left_operand, typename right_operand,
+          typename element_type, typename coord_type, typename... args>
+struct __reduction_operand_layout
+{
+  static_assert(__tensor_ops_detail::__is_same_v<element_type, float> ||
+                    __tensor_ops_detail::__is_same_v<element_type, half> ||
+                    __tensor_ops_detail::__is_same_v<element_type, int32_t>,
+                "cooperative tensor data type can only be one of "
+                "float/half/int32_t");
+
+  static constant constexpr __tensor_ops_detail::__rank_t rank = 1;
+  using element_t = element_type;
+  using coord_t = coord_type;
+  using extent_t = metal::dextents<coord_t, rank>;
+  using thread_storage_t = thread void *;
+  using const_thread_storage_t = const thread void *;
+  using index_t = uint16_t;
+  using operand_layout_t =
+      __reduction_operand_layout<descriptor, reduction_dim, scope, left_operand,
+                                 right_operand, element_t, coord_t>;
+  using cooperative_tensor_t =
+      metal::cooperative_tensor<element_t, extent_t, operand_layout_t>;
+  using scope_t = scope;
+
+  using left_t  = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<left_operand>>;
+  using right_t = __tensor_ops_detail::__remove_addrspace_t<__tensor_ops_detail::__remove_reference_t<right_operand>>;
+
+  using left_elem_t  = typename left_t::element_type;
+  using right_elem_t = typename right_t::element_type;
+
+  using left_value_t  = typename left_t::value_type;
+  using right_value_t = typename right_t::value_type;
+
+  static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+                "scope should be of type __tensorops_scope");
+  static_assert(reduction_dim == 0 || reduction_dim == 1, "Reduction dimension must be 0 or 1");
+
+  static constexpr constant bool is_matmul2d_reduction_cooperative_destination_layout =
+      true;
+  static constexpr constant int __reduction_dim = reduction_dim;
+
+  static constexpr constant __matmul2d_descriptor matmul2d_desc = descriptor;
+
+  // Returns the alignment of the storage allocated in each thread
+  // for this cooperative_tensor.
+  static constexpr size_t thread_storage_align()
+  {
+    return alignof(element_t);
+  };
+
+  // Copy-constructs from the cooperative_tensor `other`.
+  static void copy_construct(thread void *this_, thread void *other)
+  {
+    thread element_t *this_e = (thread element_t *)(this_);
+    thread element_t *other_e = (thread element_t *)(other);
+    for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
+    {
+      other_e[i] = this_e[i];
+    }
+  };
+
+  // Move-constructs from the cooperative_tensor `other`.
+  static void move_construct(thread void *this_, thread void *other)
+  {
+    thread element_t *this_e = (thread element_t *)(this_);
+    thread element_t *other_e = this_e;
+  };
+
+  // Copy-assigns from the cooperative_tensor `other`.
+  static void copy_assign(thread void *this_, thread void *other)
+  {
+    thread element_t *this_e = (thread element_t *)(this_);
+    thread element_t *other_e = (thread element_t *)(other);
+    for (size_t i = 0, e = get_capacity(this_); i != e; ++i)
+    {
+      other_e[i] = this_e[i];
+    }
+  };
+
+  // Move-assigns from the cooperative_tensor `other`.
+  static void move_assign(thread void *this_, thread void *other)
+  {
+    thread element_t *this_e = (thread element_t *)(this_);
+    thread element_t *other_e = this_e;
+  };
+
+  // Destroys the per-thread object.
+  static void destroy(thread void *) {};
+
+  static size_t thread_storage_size()
+  {
+    metal::execution_threads t = scope();
+    int threads = t.size();
+    
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_data_size(
+        descriptor, reduction_dim, leftDataType, rightDataType, elementDataType, threads);
+  }
+
+  template <class ElemType, class Extents, class Descriptor, class... Tags>
+  static void load(thread_storage_t storage,
+                   const thread metal::tensor<ElemType, Extents, Descriptor,
+                                              Tags...> &sourceT)
+  {
+    using elem_t = __tensor_ops_detail::__remove_addrspace_t<ElemType>;
+
+    static_assert(__tensor_ops_detail::__is_same_v<elem_t, element_t>,
+                  "Source tensor datatype does not match cooperative tensor");
+    static_assert(Extents::rank() == 1,
+                  "Source tensor must be rank 1");
+
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    __matmul2d_descriptor desc = descriptor;
+
+    using tensorType = metal::tensor<ElemType, Extents, Descriptor, Tags...>;
+
+    using sourcePtrType = typename tensorType::data_handle_type;
+
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type sourceDescType =
+        __tensor_ops_detail::__tensor_type_to_tensor_descriptor_type<
+            tensorType>();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+    const thread void *source = (const thread void *)(&sourceT);
+
+    if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f16(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f16(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+                      "Unsupported address space");
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, int32_t>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_i32(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_i32(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+                      "Unsupported address space");
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, float>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_dv_f32(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             sourcePtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_load_tg_f32(
+            desc, storage, source, sourceDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<sourcePtrType>,
+                      "Unsupported address space");
+    }
+    else
+      static_assert(__tensor_ops_detail::__assert_false_v<elem_t>,
+                    "Unsupported type");
+  };
+
+  template <class ElemType, class Extents, class Descriptor, class... Tags>
+  static void store(const_thread_storage_t storage,
+                    const thread metal::tensor<ElemType, Extents, Descriptor,
+                                               Tags...> &destinationT)
+  {
+    using elem_t = __tensor_ops_detail::__remove_addrspace_t<ElemType>;
+
+    static_assert(__tensor_ops_detail::__is_same_v<elem_t, element_t>,
+                  "Tensor datatype does not match cooperative tensor");
+    static_assert(Extents::rank() == 1,
+                  "Tensor must be rank 1");
+
+    __matmul2d_descriptor desc = descriptor;
+
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    using tensorType = metal::tensor<ElemType, Extents, Descriptor, Tags...>;
+
+    using destinationPtrType = typename tensorType::data_handle_type;
+
+    __tensor_ops_detail::__tensor_ops_tensor_descriptor_type
+        destinationDescType =
+            __tensor_ops_detail::__tensor_type_to_tensor_descriptor_type<
+                tensorType>();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+    const thread void *destination = (const thread void *)(&destinationT);
+
+    if constexpr (__tensor_ops_detail::__is_same_v<elem_t, half>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+                        destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f16(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f16(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+                      "Unsupported address space");
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, int32_t>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+                        destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_i32(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_i32(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+                      "Unsupported address space");
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<elem_t, float>)
+    {
+      if constexpr (__tensor_ops_detail::__is_device_addrspace_v<
+                        destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_dv_f32(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else if constexpr (__tensor_ops_detail::__is_threadgroup_addrspace_v<
+                             destinationPtrType>)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_store_tg_f32(
+            desc, storage, destination, destinationDescType, reduction_dim, threads, leftDataType, rightDataType);
+      else
+        static_assert(__tensor_ops_detail::__assert_false_v<destinationPtrType>,
+                      "Unsupported address space");
+    }
+    else
+      static_assert(__tensor_ops_detail::__assert_false_v<elem_t>,
+                    "Unsupported type");
+  };
+
+  static uint16_t get_capacity(const_thread_storage_t storage)
+  {
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+    return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_num_elements(
+        descriptor, storage, reduction_dim, leftDataType, rightDataType, threads);
+  }
+
+  static thread element_t *get_element_pointer(const_thread_storage_t storage,
+                                               index_t idx)
+  {
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    return (thread element_t *)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_pointer(
+            descriptor, (thread_storage_t)storage, idx, leftDataType, rightDataType, dataType);
+  }
+
+  static index_t get_element_index(const_thread_storage_t storage,
+                                   const thread element_type *element)
+  {
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    return (index_t)
+        __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_element_index(
+            descriptor, (thread_storage_t)storage, element, leftDataType, rightDataType, dataType);
+  }
+
+  static bool is_valid_element(const_thread_storage_t storage, index_t idx)
+  {
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype dataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_is_valid_element(
+        descriptor, (__tensor_ops_detail::__thread_void_t)storage, reduction_dim, idx,
+        leftDataType, rightDataType, dataType, threads);
+  }
+
+  template <typename index_t, __tensor_ops_detail::__rank_t rank = 1>
+  static metal::array<index_t, rank>
+  get_multidimensional_index(const_thread_storage_t storage, index_t idx)
+  {
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_type>();
+
+    if constexpr (__tensor_ops_detail::__is_same_v<coord_t, ushort>)
+    {
+      ushort coords[1];
+      __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+          descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_uint16,
+          threads, leftDataType, rightDataType, elementDataType);
+      return { coords[0] };
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, short>)
+    {
+      short coords[1];
+      __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+          descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_int16,
+          threads, leftDataType, rightDataType, elementDataType);
+      return { coords[0] };
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, uint>)
+    {
+      uint coords[1];
+      __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+          descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_uint32,
+          threads, leftDataType, rightDataType, elementDataType);
+      ;
+      return { coords[0] };
+    }
+    else if constexpr (__tensor_ops_detail::__is_same_v<coord_t, int>)
+    {
+      int coords[1];
+      __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_get_coordinate(
+          descriptor, reduction_dim, (__tensor_ops_detail::__thread_void_t)storage, idx,
+          coords, __tensor_ops_detail::__tensor_ops_datatype_int32,
+          threads, leftDataType, rightDataType, elementDataType);
+      return { coords[0] };
+    }
+    else {
+        static_assert(__tensor_ops_detail::__assert_false_v<coord_t>,
+                    "unsupported coordinate data type");
+    }
+  }
+
+  static void construct(thread_storage_t storage)
+  {
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    __tensor_ops_detail::__tensor_ops_datatype elementDataType =
+        __tensor_ops_detail::__element_type_to_tensor_ops_datatype<element_t>();
+    __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<right_value_t>::value;
+
+    __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_init(
+        (__tensor_ops_detail::__thread_void_t)storage, descriptor, 
+        reduction_dim, leftDataType, rightDataType, elementDataType, threads);
+  }
+
+  template <class FromIterator, class ToIterator>
+  static uint16_t map_index(const thread void *from_storage, uint16_t from_idx,
+                            const thread void *to_storage)
+  {
+    using sourceLayout = typename FromIterator::layout;
+    using destLayout = typename ToIterator::layout;
+
+    static_assert(sourceLayout::is_matmul2d_cooperative_destination_layout,
+                  "Source must be a matmul2d destination cooperative tensor");
+    static_assert(destLayout::is_matmul2d_reduction_cooperative_destination_layout,
+                  "Destination must be a matmul2d reduction destination cooperative tensor");
+    static_assert(__tensor_ops_detail::__is_same_v<typename sourceLayout::scope_t, metal::execution_simdgroup>,
+                  "map_index requires a single SIMD group");
+    static_assert(__tensor_ops_detail::__is_same_v<typename destLayout::scope_t, metal::execution_simdgroup>,
+                  "map_index requires a single SIMD group");
+
+    metal::execution_threads t = scope();
+    int threads = t.size();
+
+    constexpr __matmul2d_descriptor sourceDesc = sourceLayout::matmul2d_desc;
+    constexpr __matmul2d_descriptor destDesc = destLayout::matmul2d_desc;
+
+    static_assert(reduction_dim == 0 || sourceDesc.n == destDesc.n, "Source and destination must have matching N dimension if reduction_dim = 1");
+    static_assert(reduction_dim == 1 || sourceDesc.m == destDesc.m, "Source and destination must have matching M dimension if reduction_dim = 0");
+    
+    static_assert(__tensor_ops_detail::__is_same_v<typename sourceLayout::element_t, typename destLayout::element_t>, "Source and destination element types must match");
+
+    __tensor_ops_detail::__tensor_ops_datatype srcLeftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename sourceLayout::left_value_t>::value;
+    __tensor_ops_detail::__tensor_ops_datatype srcRightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename sourceLayout::right_value_t>::value;
+
+    return __tensorops_impl_matmul2d_op_cooperative_reduction_destination_tensor_map_index(
+        from_storage, sourceDesc,
+        to_storage, destDesc,
+        reduction_dim, threads, from_idx,
+        srcLeftDataType, srcRightDataType);
+  }
+};
+
 template <__matmul2d_descriptor descriptor, typename scope,
           typename left_operand, typename right_operand,
+          typename element_type, typename coord_type, typename... args>
+using __cooperative_tensor_row_reduction_destination_t =
+    typename __reduction_operand_layout<descriptor, 0, scope, left_operand, right_operand,
+                              element_type, coord_type, args...>::cooperative_tensor_t;
+
+template <__matmul2d_descriptor descriptor, typename scope,
+          typename left_operand, typename right_operand,
+          typename element_type, typename coord_type, typename... args>
+using __cooperative_tensor_column_reduction_destination_t =
+    typename __reduction_operand_layout<descriptor, 1, scope, left_operand, right_operand,
+                              element_type, coord_type, args...>::cooperative_tensor_t;
+
+template <__matmul2d_descriptor descriptor, typename scope,
+          typename left_operand, typename right_operand,
+          typename element_type, typename coord_type, typename... args>
+__cooperative_tensor_row_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+                                                 element_type, coord_type, args...>
+__get_row_reduction_destination_cooperative_tensor()
+{
+  static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+                "scope should be of type __tensorops_scope");
+  return __cooperative_tensor_row_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+                                                          element_type, coord_type, args...>();
+}
+
+template <__matmul2d_descriptor descriptor, typename scope,
+          typename left_operand, typename right_operand,
+          typename element_type, typename coord_type, typename... args>
+__cooperative_tensor_column_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+                                                    element_type, coord_type, args...>
+__get_column_reduction_destination_cooperative_tensor()
+{
+  static_assert(__tensor_ops_detail::__is_tensorops_execution_scope_v<scope>,
+                "scope should be of type __tensorops_scope");
+  return __cooperative_tensor_column_reduction_destination_t<descriptor, scope, left_operand, right_operand,
+                                                             element_type, coord_type, args...>();
+}
+
+template <__matmul2d_descriptor descriptor, typename scope,
+          typename left_operand, typename right_operand,
           typename destination_operand, typename... args>
 void __run(thread left_operand &leftIn, thread right_operand &rightIn,
            thread destination_operand &destinationT)
@@ -4566,7 +5401,7 @@
     }
     else
     {
-      thread void *destination = (thread void *)&destinationT[0];
+      thread void *destination = (thread void *)&destinationT[__tensor_ops_detail::__tensor_ops_reserved_index];
 
       if constexpr (__tensor_ops_detail::__is_same_v<leftValueType, half> &&
                     __tensor_ops_detail::__is_same_v<rightValueType, half> &&
@@ -5054,72 +5889,159 @@
   }
 }
 
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
 inline void __reduce_rows(
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &sourceT,
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &destT,
+    thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &sourceT,
+    thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destT,
     ElementType identity = (ElementType)0,
     __reduction_operation op = reduction_operation::sum)
 {
-  static_assert(Layout::is_matmul2d_cooperative_destination_layout,
-                "Source and destination must be matmul2d cooperative "
-                "destination tensors");
-  static_assert(__tensor_ops_detail::__is_same_v<typename Layout::scope_t,
+  static_assert(SrcLayout::is_matmul2d_cooperative_destination_layout,
+                "Source must be matmul2d cooperative destination tensor");
+  static_assert(DstLayout::is_matmul2d_reduction_cooperative_destination_layout,
+                "Destination must be matmul2d row reduction cooperative destination tensor");
+  static_assert(DstLayout::__reduction_dim == 0,
+                "Destination must be matmul2d row reduction cooperative destination tensor");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t,
                                                  metal::execution_simdgroup>,
                 "reduce_rows requires a single SIMD group");
-  static_assert(Extents::rank() == 2, "Rank must be 2");
+  static_assert(__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t,
+                                                 metal::execution_simdgroup>,
+                "reduce_rows requires a single SIMD group");
+  static_assert(SrcExtents::rank() == 2, "Source rank must be 2");
+  static_assert(DstExtents::rank() == 1, "Destination rank must be 1");
 
-  thread void *src = (thread void *)&sourceT[0];
-  thread void *dst = (thread void *)&destT[0];
+  constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+  constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
 
-  __matmul2d_descriptor desc = Layout::matmul2d_desc;
+  static_assert(matmul2d_descriptor_is_equal(sourceDesc, destDesc), "Source and destination matmul2d descriptor must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::left_t, typename DstLayout::left_t>, "Source and destination operand types must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::right_t, typename DstLayout::right_t>, "Source and destination operand types must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::element_t, typename DstLayout::element_t>, "Source and destination element types must match");
 
+  __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+
+  thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+  thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+  __matmul2d_descriptor desc = SrcLayout::matmul2d_desc;
+
   if constexpr (__tensor_ops_detail::__is_same_v<ElementType, half>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f16(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, int32_t>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_i32(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, float>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_rows_f32(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else
     static_assert(__tensor_ops_detail::__assert_false_v<ElementType>,
                   "Unsupported type");
 }
 
-template <class ElementType, class Extents, class Layout>
+template <class ElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
 inline void __reduce_columns(
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &sourceT,
-    thread metal::cooperative_tensor<ElementType, Extents, Layout> &destT,
+    thread metal::cooperative_tensor<ElementType, SrcExtents, SrcLayout> &sourceT,
+    thread metal::cooperative_tensor<ElementType, DstExtents, DstLayout> &destT,
     ElementType identity = (ElementType)0,
     __reduction_operation op = reduction_operation::sum)
 {
-  static_assert(Layout::is_matmul2d_cooperative_destination_layout,
-                "Source and destination must be matmul2d cooperative "
-                "destination tensors");
-  static_assert(__tensor_ops_detail::__is_same_v<typename Layout::scope_t,
+  static_assert(SrcLayout::is_matmul2d_cooperative_destination_layout,
+                "Source must be matmul2d cooperative destination tensor");
+  static_assert(DstLayout::is_matmul2d_reduction_cooperative_destination_layout,
+                "Destination must be matmul2d column reduction cooperative destination tensor");
+  static_assert(DstLayout::__reduction_dim == 1,
+                "Destination must be matmul2d column reduction cooperative destination tensor");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t,
                                                  metal::execution_simdgroup>,
                 "reduce_columns requires a single SIMD group");
-  static_assert(Extents::rank() == 2, "Rank must be 2");
+  static_assert(__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t,
+                                                 metal::execution_simdgroup>,
+                "reduce_columns requires a single SIMD group");
+  static_assert(SrcExtents::rank() == 2, "Source rank must be 2");
+  static_assert(DstExtents::rank() == 1, "Destination rank must be 1");
 
-  thread void *src = (thread void *)&sourceT[0];
-  thread void *dst = (thread void *)&destT[0];
+  constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+  constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
 
-  __matmul2d_descriptor desc = Layout::matmul2d_desc;
+  static_assert(matmul2d_descriptor_is_equal(sourceDesc, destDesc), "Source and destination matmul2d descriptor must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::left_t, typename DstLayout::left_t>, "Source and destination operand types must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::right_t, typename DstLayout::right_t>, "Source and destination operand types must match");
+  static_assert(__tensor_ops_detail::__is_same_v<typename SrcLayout::element_t, typename DstLayout::element_t>, "Source and destination element types must match");
 
+  __tensor_ops_detail::__tensor_ops_datatype leftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype rightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+
+  thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+  thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+  __matmul2d_descriptor desc = SrcLayout::matmul2d_desc;
+
   if constexpr (__tensor_ops_detail::__is_same_v<ElementType, half>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f16(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, int32_t>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_i32(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else if constexpr (__tensor_ops_detail::__is_same_v<ElementType, float>)
     __tensorops_impl_matmul2d_op_cooperative_destination_reduce_columns_f32(
-        desc, src, dst, identity, op);
+        desc, src, dst, identity, op, leftDataType, rightDataType);
   else
     static_assert(__tensor_ops_detail::__assert_false_v<ElementType>,
                   "Unsupported type");
+}
+
+template <class SrcElementType, class DstElementType, class SrcExtents, class DstExtents, class SrcLayout, class DstLayout>
+inline bool __is_iterator_compatible(
+    thread metal::cooperative_tensor<SrcElementType, SrcExtents, SrcLayout> &sourceT,
+    thread metal::cooperative_tensor<DstElementType, DstExtents, DstLayout> &destT)
+{
+  if (!SrcLayout::is_matmul2d_cooperative_destination_layout ||
+      !DstLayout::is_matmul2d_reduction_cooperative_destination_layout ||
+      !__tensor_ops_detail::__is_same_v<typename SrcLayout::scope_t, metal::execution_simdgroup> ||
+      !__tensor_ops_detail::__is_same_v<typename DstLayout::scope_t, metal::execution_simdgroup> ||
+      !__tensor_ops_detail::__is_same_v<SrcElementType, DstElementType> ||
+      SrcExtents::rank() != 2 || DstExtents::rank() != 1)
+  {
+    return false;
+  }
+
+  constexpr __matmul2d_descriptor sourceDesc = SrcLayout::matmul2d_desc;
+  constexpr __matmul2d_descriptor destDesc = DstLayout::matmul2d_desc;
+
+  constexpr int reduction_dim = DstLayout::__reduction_dim;
+
+  if ((reduction_dim == 0 && sourceDesc.m != destDesc.m) ||
+      (reduction_dim == 1 && sourceDesc.n == destDesc.n))
+  {
+    return false;
+  }
+
+  thread void *src = (thread void *)&sourceT[__tensor_ops_detail::__tensor_ops_reserved_index];
+  thread void *dst = (thread void *)&destT[__tensor_ops_detail::__tensor_ops_reserved_index];
+
+  __tensor_ops_detail::__tensor_ops_datatype srcLeftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::left_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype srcRightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename SrcLayout::right_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype srcElemDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<SrcElementType>::value;
+  __tensor_ops_detail::__tensor_ops_datatype dstLeftDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename DstLayout::left_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype dstRightDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<typename DstLayout::right_value_t>::value;
+  __tensor_ops_detail::__tensor_ops_datatype dstElemDataType =
+        __tensor_ops_detail::__type_to_tensor_ops_datatype<DstElementType>::value;
+
+  return __tensorops_impl_matmul2d_op_cooperative_destination_is_iterator_compatible(
+         sourceDesc, destDesc, src, dst, srcLeftDataType, srcRightDataType,
+        srcElemDataType, dstLeftDataType, dstRightDataType, dstElemDataType);
 }
 
 #undef EXTERNALLY_DEFINED_ATTR
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h	2025-07-11 22:52:25
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsTypes.h	2025-07-26 22:15:31
@@ -23,10 +23,7 @@
 template <typename T>
 constexpr inline int __get_rank()
 {
-  if constexpr (__is_cooperative_tensor_type_v<T>)
-    return T::rank();
-  else // tensor
-    return T::get_rank();
+  return T::get_rank();
 }
 
 using __rank_t = ushort;
diff -ruN /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h
--- /Applications/Xcode_26.0.0-beta4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h	2025-07-11 23:30:03
+++ /Applications/Xcode_26.0.0-beta5.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/MetalPerformancePrimitives.framework/Headers/__impl/MPPTensorOpsUtility.h	2025-07-26 06:00:32
@@ -33,7 +33,6 @@
 template <typename T>
 struct __type_to_tensor_ops_datatype
 {
-  static constant __tensor_ops_datatype value = __tensor_ops_datatype_invalid;
 };
 template <>
 struct __type_to_tensor_ops_datatype<float>
Clone this wiki locally