@@ -586,15 +586,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
586
586
RedOutVar RedOut)
587
587
: base(Identity, BinaryOp, Init), MRedOut(std::move(RedOut)){};
588
588
589
- // / Associates the reduction accessor to user's memory with \p CGH handler
590
- // / to keep the accessor alive until the command group finishes the work.
591
- // / This function does not do anything for USM reductions.
592
- void associateWithHandler (handler &CGH) {
593
- if constexpr (is_acc) {
594
- CGH.associateWithHandler (&MRedOut, access::target::device);
595
- }
596
- }
597
-
598
589
// / Creates and returns a local accessor with the \p Size elements.
599
590
// / By default the local accessor elements are of the same type as the
600
591
// / elements processed by the reduction, but may it be altered by specifying
@@ -632,7 +623,7 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
632
623
rw_accessor_type getWriteAccForPartialReds (size_t Size, handler &CGH) {
633
624
if constexpr (is_rw_acc) {
634
625
if (Size == 1 ) {
635
- associateWithHandler (CGH );
626
+ CGH. associateWithHandler (&MRedOut, access::target::device );
636
627
return MRedOut;
637
628
}
638
629
}
@@ -808,7 +799,7 @@ class reduction_impl
808
799
reduction_impl (RedOutVar &Acc, handler &CGH, bool InitializeToIdentity)
809
800
: algo(reducer_type::getIdentity(), BinaryOperation(),
810
801
InitializeToIdentity, Acc) {
811
- algo:: associateWithHandler (CGH);
802
+ associateWithHandler (CGH, &Acc, access::target::device );
812
803
if (Acc.size () != 1 )
813
804
throw sycl::runtime_error (errc::invalid,
814
805
" Reduction variable must be a scalar." ,
@@ -838,7 +829,7 @@ class reduction_impl
838
829
reduction_impl (RedOutVar &Acc, handler &CGH, const T &Identity,
839
830
BinaryOperation BOp, bool InitializeToIdentity)
840
831
: algo(chooseIdentity(Identity), BOp, InitializeToIdentity, Acc) {
841
- algo:: associateWithHandler (CGH);
832
+ associateWithHandler (CGH, &Acc, access::target::device );
842
833
if (Acc.size () != 1 )
843
834
throw sycl::runtime_error (errc::invalid,
844
835
" Reduction variable must be a scalar." ,
@@ -1561,7 +1552,7 @@ template <typename KernelName, class Reduction>
1561
1552
std::enable_if_t <!Reduction::is_usm>
1562
1553
reduSaveFinalResultToUserMem (handler &CGH, Reduction &Redu) {
1563
1554
auto InAcc = Redu.getReadAccToPreviousPartialReds (CGH);
1564
- Redu. associateWithHandler (CGH);
1555
+ associateWithHandler (CGH, &Redu. getUserRedVar (), access::target::device );
1565
1556
CGH.copy (InAcc, Redu.getUserRedVar ());
1566
1557
}
1567
1558
@@ -2089,26 +2080,16 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
2089
2080
CGH, KernelFunc, Range, Redu, Out);
2090
2081
}
2091
2082
2092
- inline void associateReduAccsWithHandlerHelper (handler &) {}
2093
-
2094
- template <typename ReductionT>
2095
- void associateReduAccsWithHandlerHelper (handler &CGH, ReductionT &Redu) {
2096
- Redu.associateWithHandler (CGH);
2097
- }
2098
-
2099
- template <typename ReductionT, typename ... RestT,
2100
- enable_if_t <(sizeof ...(RestT) > 0 ), int > Z = 0 >
2101
- void associateReduAccsWithHandlerHelper (handler &CGH, ReductionT &Redu,
2102
- RestT &...Rest) {
2103
- Redu.associateWithHandler (CGH);
2104
- associateReduAccsWithHandlerHelper (CGH, Rest...);
2105
- }
2106
-
2107
2083
template <typename ... Reductions, size_t ... Is>
2108
2084
void associateReduAccsWithHandler (handler &CGH,
2109
2085
std::tuple<Reductions...> &ReduTuple,
2110
2086
std::index_sequence<Is...>) {
2111
- associateReduAccsWithHandlerHelper (CGH, std::get<Is>(ReduTuple)...);
2087
+ auto ProcessOne = [&CGH](auto Redu) {
2088
+ if constexpr (decltype (Redu)::is_acc) {
2089
+ associateWithHandler (CGH, &Redu.getUserRedVar (), access::target::device);
2090
+ }
2091
+ };
2092
+ (ProcessOne (std::get<Is>(ReduTuple)), ...);
2112
2093
}
2113
2094
2114
2095
// / All scalar reductions are processed together; there is one loop of log2(N)
@@ -2378,7 +2359,8 @@ void reduSaveFinalResultToUserMemHelper(
2378
2359
event CopyEvent = withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
2379
2360
auto InAcc = Redu.getReadAccToPreviousPartialReds (CopyHandler);
2380
2361
auto OutAcc = Redu.getUserRedVar ();
2381
- Redu.associateWithHandler (CopyHandler);
2362
+ associateWithHandler (CopyHandler, &Redu.getUserRedVar (),
2363
+ access::target::device);
2382
2364
if (!Events.empty ())
2383
2365
CopyHandler.depends_on (Events.back ());
2384
2366
CopyHandler.copy (InAcc, OutAcc);
0 commit comments