Skip to content

Commit 8967bf3

Browse files
Update associateWithHandler uses
1 parent cf039fe commit 8967bf3

File tree

1 file changed

+12
-30
lines changed

1 file changed

+12
-30
lines changed

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
586586
RedOutVar RedOut)
587587
: base(Identity, BinaryOp, Init), MRedOut(std::move(RedOut)){};
588588

589-
/// Associates the reduction accessor to user's memory with \p CGH handler
590-
/// to keep the accessor alive until the command group finishes the work.
591-
/// This function does not do anything for USM reductions.
592-
void associateWithHandler(handler &CGH) {
593-
if constexpr (is_acc) {
594-
CGH.associateWithHandler(&MRedOut, access::target::device);
595-
}
596-
}
597-
598589
/// Creates and returns a local accessor with the \p Size elements.
599590
/// By default the local accessor elements are of the same type as the
600591
/// elements processed by the reduction, but may it be altered by specifying
@@ -632,7 +623,7 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
632623
rw_accessor_type getWriteAccForPartialReds(size_t Size, handler &CGH) {
633624
if constexpr (is_rw_acc) {
634625
if (Size == 1) {
635-
associateWithHandler(CGH);
626+
CGH.associateWithHandler(&MRedOut, access::target::device);
636627
return MRedOut;
637628
}
638629
}
@@ -808,7 +799,7 @@ class reduction_impl
808799
reduction_impl(RedOutVar &Acc, handler &CGH, bool InitializeToIdentity)
809800
: algo(reducer_type::getIdentity(), BinaryOperation(),
810801
InitializeToIdentity, Acc) {
811-
algo::associateWithHandler(CGH);
802+
associateWithHandler(CGH, &Acc, access::target::device);
812803
if (Acc.size() != 1)
813804
throw sycl::runtime_error(errc::invalid,
814805
"Reduction variable must be a scalar.",
@@ -838,7 +829,7 @@ class reduction_impl
838829
reduction_impl(RedOutVar &Acc, handler &CGH, const T &Identity,
839830
BinaryOperation BOp, bool InitializeToIdentity)
840831
: algo(chooseIdentity(Identity), BOp, InitializeToIdentity, Acc) {
841-
algo::associateWithHandler(CGH);
832+
associateWithHandler(CGH, &Acc, access::target::device);
842833
if (Acc.size() != 1)
843834
throw sycl::runtime_error(errc::invalid,
844835
"Reduction variable must be a scalar.",
@@ -1561,7 +1552,7 @@ template <typename KernelName, class Reduction>
15611552
std::enable_if_t<!Reduction::is_usm>
15621553
reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
15631554
auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH);
1564-
Redu.associateWithHandler(CGH);
1555+
associateWithHandler(CGH, &Redu.getUserRedVar(), access::target::device);
15651556
CGH.copy(InAcc, Redu.getUserRedVar());
15661557
}
15671558

@@ -2089,26 +2080,16 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
20892080
CGH, KernelFunc, Range, Redu, Out);
20902081
}
20912082

2092-
inline void associateReduAccsWithHandlerHelper(handler &) {}
2093-
2094-
template <typename ReductionT>
2095-
void associateReduAccsWithHandlerHelper(handler &CGH, ReductionT &Redu) {
2096-
Redu.associateWithHandler(CGH);
2097-
}
2098-
2099-
template <typename ReductionT, typename... RestT,
2100-
enable_if_t<(sizeof...(RestT) > 0), int> Z = 0>
2101-
void associateReduAccsWithHandlerHelper(handler &CGH, ReductionT &Redu,
2102-
RestT &...Rest) {
2103-
Redu.associateWithHandler(CGH);
2104-
associateReduAccsWithHandlerHelper(CGH, Rest...);
2105-
}
2106-
21072083
template <typename... Reductions, size_t... Is>
21082084
void associateReduAccsWithHandler(handler &CGH,
21092085
std::tuple<Reductions...> &ReduTuple,
21102086
std::index_sequence<Is...>) {
2111-
associateReduAccsWithHandlerHelper(CGH, std::get<Is>(ReduTuple)...);
2087+
auto ProcessOne = [&CGH](auto Redu) {
2088+
if constexpr (decltype(Redu)::is_acc) {
2089+
associateWithHandler(CGH, &Redu.getUserRedVar(), access::target::device);
2090+
}
2091+
};
2092+
(ProcessOne(std::get<Is>(ReduTuple)), ...);
21122093
}
21132094

21142095
/// All scalar reductions are processed together; there is one loop of log2(N)
@@ -2378,7 +2359,8 @@ void reduSaveFinalResultToUserMemHelper(
23782359
event CopyEvent = withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
23792360
auto InAcc = Redu.getReadAccToPreviousPartialReds(CopyHandler);
23802361
auto OutAcc = Redu.getUserRedVar();
2381-
Redu.associateWithHandler(CopyHandler);
2362+
associateWithHandler(CopyHandler, &Redu.getUserRedVar(),
2363+
access::target::device);
23822364
if (!Events.empty())
23832365
CopyHandler.depends_on(Events.back());
23842366
CopyHandler.copy(InAcc, OutAcc);

0 commit comments

Comments
 (0)