Skip to content

Commit e452faf

Browse files
committed
Make the single WG implementation without static dispatch
1 parent a4ca202 commit e452faf

File tree

1 file changed

+22
-50
lines changed

1 file changed

+22
-50
lines changed

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -483,27 +483,25 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
483483
}
484484
};
485485

486-
template <typename _Size, std::uint16_t _ElemsPerItem, std::uint16_t _WGSize, typename _KernelName>
486+
template <typename _Size, typename _KernelName>
487487
struct __parallel_copy_if_static_single_group_submitter;
488488

489-
template <typename _Size, std::uint16_t _ElemsPerItem, std::uint16_t _WGSize, typename... _ScanKernelName>
490-
struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _WGSize,
491-
__internal::__optional_kernel_name<_ScanKernelName...>>
489+
template <typename _Size, typename... _ScanKernelName>
490+
struct __parallel_copy_if_static_single_group_submitter<_Size, __internal::__optional_kernel_name<_ScanKernelName...>>
492491
{
493492
template <typename _InRng, typename _OutRng, typename _UnaryOp, typename _Assign>
494493
__future<sycl::event, __result_and_scratch_storage<_Size>>
495494
operator()(sycl::queue& __q, _InRng&& __in_rng, _OutRng&& __out_rng, std::size_t __n, _UnaryOp __unary_op,
496-
_Assign __assign)
495+
_Assign __assign, std::uint16_t __n_uniform, std::uint16_t __wg_size)
497496
{
498-
using _ValueType = ::std::uint16_t;
497+
using _ValueType = std::uint16_t;
499498

500-
// This type is used as a workaround for when an internal tuple is assigned to ::std::tuple, such as
499+
// This type is used as a workaround for when an internal tuple is assigned to std::tuple, such as
501500
// with zip_iterator
502501
using __tuple_type =
503502
typename ::oneapi::dpl::__internal::__get_tuple_type<std::decay_t<decltype(__in_rng[0])>,
504503
std::decay_t<decltype(__out_rng[0])>>::__type;
505504

506-
constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize;
507505
using __result_and_scratch_storage_t = __result_and_scratch_storage<_Size>;
508506
__result_and_scratch_storage_t __result{__q, 0};
509507

@@ -513,37 +511,37 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
513511
// Local memory is split into two parts. The first half stores the result of applying the
514512
// predicate on each element of the input range. The second half stores the index of the output
515513
// range to copy elements of the input range.
516-
auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg * 2}, __hdl);
514+
auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>(std::size_t(__n_uniform) * 2), __hdl);
517515
auto __res_acc =
518516
__result.template __get_result_acc<sycl::access_mode::write>(__hdl, __dpl_sycl::__no_init{});
519517

520518
__hdl.parallel_for<_ScanKernelName...>(
521-
sycl::nd_range<1>(_WGSize, _WGSize), [=](sycl::nd_item<1> __self_item) {
519+
sycl::nd_range<1>(__wg_size, __wg_size), [=](sycl::nd_item<1> __self_item) {
522520
auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__res_acc);
523521
const auto& __group = __self_item.get_group();
524522
// This kernel is only launched for sizes less than 2^16
525-
const ::std::uint16_t __item_id = __self_item.get_local_linear_id();
523+
const std::uint16_t __item_id = __self_item.get_local_linear_id();
526524
auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc);
527-
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
525+
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += __wg_size)
528526
{
529527
__lacc[__idx] = __unary_op(__in_rng[__idx]);
530528
}
531529

532530
__scan_work_group<_ValueType, /* _Inclusive */ false>(
533-
__group, __lacc_ptr, __lacc_ptr + __elems_per_wg, __lacc_ptr + __elems_per_wg,
531+
__group, __lacc_ptr, __lacc_ptr + __n, __lacc_ptr + __n_uniform,
534532
sycl::plus<_ValueType>{});
535533

536-
for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
534+
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += __wg_size)
537535
{
538536
if (__lacc[__idx])
539537
__assign(static_cast<__tuple_type>(__in_rng[__idx]),
540-
__out_rng[__lacc[__idx + __elems_per_wg]]);
538+
__out_rng[__lacc[__idx + __n_uniform]]);
541539
}
542540

543541
if (__item_id == 0)
544542
{
545543
// Add predicate of last element to account for the scan's exclusivity
546-
*__res_ptr = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1];
544+
*__res_ptr = __lacc[__n_uniform + __n - 1] + __lacc[__n - 1];
547545
}
548546
});
549547
});
@@ -726,30 +724,6 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut
726724
unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init});
727725
}
728726

729-
template <typename _CustomName, typename _SizeType>
730-
struct __invoke_single_group_copy_if
731-
{
732-
// Specialization for devices that have a max work-group size of at least 1024
733-
static constexpr ::std::uint16_t __targeted_wg_size = 1024;
734-
735-
template <std::uint16_t _Size, typename _InRng, typename _OutRng, typename _Pred,
736-
typename _Assign = oneapi::dpl::__internal::__pstl_assign>
737-
auto
738-
operator()(sycl::queue& __q, std::size_t __n, _InRng&& __in_rng, _OutRng&& __out_rng, _Pred __pred,
739-
_Assign __assign)
740-
{
741-
constexpr ::std::uint16_t __wg_size = ::std::min(_Size, __targeted_wg_size);
742-
constexpr ::std::uint16_t __num_elems_per_item = ::oneapi::dpl::__internal::__dpl_ceiling_div(_Size, __wg_size);
743-
744-
using _KernelName = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
745-
__scan_copy_single_wg_kernel<std::integral_constant<std::uint16_t, __wg_size>,
746-
std::integral_constant<std::uint16_t, __num_elems_per_item>, _CustomName>>;
747-
return __par_backend_hetero::__parallel_copy_if_static_single_group_submitter<
748-
_SizeType, __num_elems_per_item, __wg_size, _KernelName>()
749-
(__q, std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __pred, __assign);
750-
}
751-
};
752-
753727
template <typename _CustomName, typename _InRng, typename _OutRng, typename _Size, typename _GenMask, typename _WriteOp,
754728
typename _IsUniquePattern>
755729
__future<sycl::event, __result_and_scratch_storage<_Size>>
@@ -920,8 +894,6 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
920894
{
921895
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
922896

923-
using _SingleGroupInvoker = __invoke_single_group_copy_if<_CustomName, _Size>;
924-
925897
// Next power of 2 greater than or equal to __n
926898
auto __n_uniform = ::oneapi::dpl::__internal::__dpl_bit_ceil(static_cast<std::make_unsigned_t<_Size>>(__n));
927899

@@ -934,18 +906,18 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
934906
// The kernel stores n integers for the predicate and another n integers for the offsets
935907
const auto __req_slm_size = sizeof(std::uint16_t) * __n_uniform * 2;
936908

937-
constexpr std::uint16_t __single_group_upper_limit = 2048;
909+
// constexpr std::uint16_t __single_group_upper_limit = 2048;
910+
constexpr std::uint16_t __max_elem_per_item = 2;
938911

939912
std::size_t __max_wg_size = oneapi::dpl::__internal::__max_work_group_size(__q_local);
940913

941-
if (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size &&
942-
__max_wg_size >= _SingleGroupInvoker::__targeted_wg_size)
914+
if (__n <= __max_wg_size * __max_elem_per_item && __max_slm_size >= __req_slm_size)
943915
{
944-
using _SizeBreakpoints = std::integer_sequence<std::uint16_t, 16, 32, 64, 128, 256, 512, 1024, 2048>;
945-
946-
return __par_backend_hetero::__static_monotonic_dispatcher<_SizeBreakpoints>::__dispatch(
947-
_SingleGroupInvoker{}, __n, __q_local, __n, std::forward<_InRng>(__in_rng),
948-
std::forward<_OutRng>(__out_rng), __pred, __assign);
916+
using _KernelName = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
917+
__scan_copy_single_wg_kernel<_CustomName>>;
918+
return __par_backend_hetero::__parallel_copy_if_static_single_group_submitter<_Size, _KernelName>()(
919+
__q_local, std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __pred, __assign,
920+
static_cast<std::uint16_t>(__n_uniform), static_cast<std::uint16_t>(std::min(__n_uniform, __max_wg_size)));
949921
}
950922
else if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local))
951923
{

0 commit comments

Comments
 (0)