Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit d336fdd

Browse files
committed
Fix regression in transform_inclusive_scan.
We were deducing a reference type when we needed a value type for the transform iterator instantition. Fixes #1332 and adds a regression test.
1 parent 84219dc commit d336fdd

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

testing/transform_scan.cu

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,55 @@ struct TestTransformScanToDiscardIterator
347347
};
348348
VariableUnitTest<TestTransformScanToDiscardIterator, IntegralTypes> TestTransformScanToDiscardIteratorInstance;
349349

350+
// Regression test for https://github.com/NVIDIA/thrust/issues/1332
351+
// The issue was the internal transform_input_iterator_t created by the
352+
// transform_inclusive_scan implementation was instantiated using a reference
353+
// type for the value_type.
354+
template <typename T>
355+
void TestValueCategoryDeduction()
356+
{
357+
thrust::device_vector<T> vec;
358+
359+
T a_h[10] = {5, 0, 5, 8, 6, 7, 5, 3, 0, 9};
360+
vec.assign((T*)a_h, a_h + 10);
361+
362+
363+
thrust::transform_inclusive_scan(thrust::device,
364+
vec.cbegin(),
365+
vec.cend(),
366+
vec.begin(),
367+
thrust::identity<>{},
368+
thrust::maximum<>{});
369+
370+
ASSERT_EQUAL(T{5}, vec[0]);
371+
ASSERT_EQUAL(T{5}, vec[1]);
372+
ASSERT_EQUAL(T{5}, vec[2]);
373+
ASSERT_EQUAL(T{8}, vec[3]);
374+
ASSERT_EQUAL(T{8}, vec[4]);
375+
ASSERT_EQUAL(T{8}, vec[5]);
376+
ASSERT_EQUAL(T{8}, vec[6]);
377+
ASSERT_EQUAL(T{8}, vec[7]);
378+
ASSERT_EQUAL(T{8}, vec[8]);
379+
ASSERT_EQUAL(T{9}, vec[9]);
380+
381+
vec.assign((T*)a_h, a_h + 10);
382+
thrust::transform_exclusive_scan(thrust::device,
383+
vec.cbegin(),
384+
vec.cend(),
385+
vec.begin(),
386+
thrust::identity<>{},
387+
T{},
388+
thrust::maximum<>{});
389+
390+
ASSERT_EQUAL(T{0}, vec[0]);
391+
ASSERT_EQUAL(T{5}, vec[1]);
392+
ASSERT_EQUAL(T{5}, vec[2]);
393+
ASSERT_EQUAL(T{5}, vec[3]);
394+
ASSERT_EQUAL(T{8}, vec[4]);
395+
ASSERT_EQUAL(T{8}, vec[5]);
396+
ASSERT_EQUAL(T{8}, vec[6]);
397+
ASSERT_EQUAL(T{8}, vec[7]);
398+
ASSERT_EQUAL(T{8}, vec[8]);
399+
ASSERT_EQUAL(T{8}, vec[9]);
400+
}
401+
DECLARE_GENERIC_UNITTEST(TestValueCategoryDeduction);

thrust/system/cuda/detail/transform_scan.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,19 @@ transform_inclusive_scan(execution_policy<Derived> &policy,
5050
TransformOp transform_op,
5151
ScanOp scan_op)
5252
{
53-
// Use the input iterator's value type per https://wg21.link/P0571
53+
// Use the transformed input iterator's value type per https://wg21.link/P0571
5454
using input_type = typename thrust::iterator_value<InputIt>::type;
5555
#if THRUST_CPP_DIALECT < 2017
5656
using result_type = typename std::result_of<TransformOp(input_type)>::type;
5757
#else
5858
using result_type = std::invoke_result_t<TransformOp, input_type>;
5959
#endif
6060

61+
using value_type = typename std::remove_reference<result_type>::type;
62+
6163
typedef typename iterator_traits<InputIt>::difference_type size_type;
6264
size_type num_items = static_cast<size_type>(thrust::distance(first, last));
63-
typedef transform_input_iterator_t<result_type,
65+
typedef transform_input_iterator_t<value_type,
6466
InputIt,
6567
TransformOp>
6668
transformed_iterator_t;
@@ -88,7 +90,7 @@ transform_exclusive_scan(execution_policy<Derived> &policy,
8890
ScanOp scan_op)
8991
{
9092
// Use the initial value type per https://wg21.link/P0571
91-
using result_type = InitialValueType;
93+
using result_type = typename std::remove_reference<InitialValueType>::type;
9294

9395
typedef typename iterator_traits<InputIt>::difference_type size_type;
9496
size_type num_items = static_cast<size_type>(thrust::distance(first, last));

thrust/system/detail/generic/transform_scan.inl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ __host__ __device__
5151
// Use the input iterator's value type per https://wg21.link/P0571
5252
using InputType = typename thrust::iterator_value<InputIterator>::type;
5353
#if THRUST_CPP_DIALECT < 2017
54-
using ValueType = typename std::result_of<UnaryFunction(InputType)>::type;
54+
using ResultType = typename std::result_of<UnaryFunction(InputType)>::type;
5555
#else
56-
using ValueType = std::invoke_result_t<UnaryFunction, InputType>;
56+
using ResultType = std::invoke_result_t<UnaryFunction, InputType>;
5757
#endif
58+
using ValueType = typename std::remove_reference<ResultType>::type;
5859

5960
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
6061
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);
@@ -79,7 +80,7 @@ __host__ __device__
7980
AssociativeOperator binary_op)
8081
{
8182
// Use the initial value type per https://wg21.link/P0571
82-
using ValueType = InitialValueType;
83+
using ValueType = typename std::remove_reference<InitialValueType>::type;
8384

8485
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
8586
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);

0 commit comments

Comments
 (0)