@@ -483,27 +483,25 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
483483 }
484484};
485485
486- template <typename _Size, std:: uint16_t _ElemsPerItem, std:: uint16_t _WGSize, typename _KernelName>
486+ template <typename _Size, typename _KernelName>
487487struct __parallel_copy_if_static_single_group_submitter ;
488488
489- template <typename _Size, std::uint16_t _ElemsPerItem, std::uint16_t _WGSize, typename ... _ScanKernelName>
490- struct __parallel_copy_if_static_single_group_submitter <_Size, _ElemsPerItem, _WGSize,
491- __internal::__optional_kernel_name<_ScanKernelName...>>
489+ template <typename _Size, typename ... _ScanKernelName>
490+ struct __parallel_copy_if_static_single_group_submitter <_Size, __internal::__optional_kernel_name<_ScanKernelName...>>
492491{
493492 template <typename _InRng, typename _OutRng, typename _UnaryOp, typename _Assign>
494493 __future<sycl::event, __result_and_scratch_storage<_Size>>
495494 operator ()(sycl::queue& __q, _InRng&& __in_rng, _OutRng&& __out_rng, std::size_t __n, _UnaryOp __unary_op,
496- _Assign __assign)
495+ _Assign __assign, std:: uint16_t __n_uniform, std:: uint16_t __wg_size )
497496 {
498- using _ValueType = :: std::uint16_t ;
497+ using _ValueType = std::uint16_t ;
499498
500- // This type is used as a workaround for when an internal tuple is assigned to :: std::tuple, such as
499+ // This type is used as a workaround for when an internal tuple is assigned to std::tuple, such as
501500 // with zip_iterator
502501 using __tuple_type =
503502 typename ::oneapi::dpl::__internal::__get_tuple_type<std::decay_t <decltype (__in_rng[0 ])>,
504503 std::decay_t <decltype (__out_rng[0 ])>>::__type;
505504
506- constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize;
507505 using __result_and_scratch_storage_t = __result_and_scratch_storage<_Size>;
508506 __result_and_scratch_storage_t __result{__q, 0 };
509507
@@ -513,37 +511,37 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
513511 // Local memory is split into two parts. The first half stores the result of applying the
514512 // predicate on each element of the input range. The second half stores the index of the output
515513 // range to copy elements of the input range.
516- auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1 >{__elems_per_wg * 2 } , __hdl);
514+ auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1 >( std::size_t (__n_uniform) * 2 ) , __hdl);
517515 auto __res_acc =
518516 __result.template __get_result_acc <sycl::access_mode::write>(__hdl, __dpl_sycl::__no_init{});
519517
520518 __hdl.parallel_for <_ScanKernelName...>(
521- sycl::nd_range<1 >(_WGSize, _WGSize ), [=](sycl::nd_item<1 > __self_item) {
519+ sycl::nd_range<1 >(__wg_size, __wg_size ), [=](sycl::nd_item<1 > __self_item) {
522520 auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr (__res_acc);
523521 const auto & __group = __self_item.get_group ();
524522 // This kernel is only launched for sizes less than 2^16
525- const :: std::uint16_t __item_id = __self_item.get_local_linear_id ();
523+ const std::uint16_t __item_id = __self_item.get_local_linear_id ();
526524 auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr (__lacc);
527- for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize )
525+ for (std::uint16_t __idx = __item_id; __idx < __n; __idx += __wg_size )
528526 {
529527 __lacc[__idx] = __unary_op (__in_rng[__idx]);
530528 }
531529
532530 __scan_work_group<_ValueType, /* _Inclusive */ false >(
533- __group, __lacc_ptr, __lacc_ptr + __elems_per_wg , __lacc_ptr + __elems_per_wg ,
531+ __group, __lacc_ptr, __lacc_ptr + __n , __lacc_ptr + __n_uniform ,
534532 sycl::plus<_ValueType>{});
535533
536- for (:: std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize )
534+ for (std::uint16_t __idx = __item_id; __idx < __n; __idx += __wg_size )
537535 {
538536 if (__lacc[__idx])
539537 __assign (static_cast <__tuple_type>(__in_rng[__idx]),
540- __out_rng[__lacc[__idx + __elems_per_wg ]]);
538+ __out_rng[__lacc[__idx + __n_uniform ]]);
541539 }
542540
543541 if (__item_id == 0 )
544542 {
545543 // Add predicate of last element to account for the scan's exclusivity
546- *__res_ptr = __lacc[__elems_per_wg + __n - 1 ] + __lacc[__n - 1 ];
544+ *__res_ptr = __lacc[__n_uniform + __n - 1 ] + __lacc[__n - 1 ];
547545 }
548546 });
549547 });
@@ -726,30 +724,6 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut
726724 unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init});
727725}
728726
729- template <typename _CustomName, typename _SizeType>
730- struct __invoke_single_group_copy_if
731- {
732- // Specialization for devices that have a max work-group size of at least 1024
733- static constexpr ::std::uint16_t __targeted_wg_size = 1024 ;
734-
735- template <std::uint16_t _Size, typename _InRng, typename _OutRng, typename _Pred,
736- typename _Assign = oneapi::dpl::__internal::__pstl_assign>
737- auto
738- operator ()(sycl::queue& __q, std::size_t __n, _InRng&& __in_rng, _OutRng&& __out_rng, _Pred __pred,
739- _Assign __assign)
740- {
741- constexpr ::std::uint16_t __wg_size = ::std::min (_Size, __targeted_wg_size);
742- constexpr ::std::uint16_t __num_elems_per_item = ::oneapi::dpl::__internal::__dpl_ceiling_div (_Size, __wg_size);
743-
744- using _KernelName = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
745- __scan_copy_single_wg_kernel<std::integral_constant<std::uint16_t , __wg_size>,
746- std::integral_constant<std::uint16_t , __num_elems_per_item>, _CustomName>>;
747- return __par_backend_hetero::__parallel_copy_if_static_single_group_submitter<
748- _SizeType, __num_elems_per_item, __wg_size, _KernelName>()
749- (__q, std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __pred, __assign);
750- }
751- };
752-
753727template <typename _CustomName, typename _InRng, typename _OutRng, typename _Size, typename _GenMask, typename _WriteOp,
754728 typename _IsUniquePattern>
755729__future<sycl::event, __result_and_scratch_storage<_Size>>
@@ -920,8 +894,6 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
920894{
921895 using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
922896
923- using _SingleGroupInvoker = __invoke_single_group_copy_if<_CustomName, _Size>;
924-
925897 // Next power of 2 greater than or equal to __n
926898 auto __n_uniform = ::oneapi::dpl::__internal::__dpl_bit_ceil (static_cast <std::make_unsigned_t <_Size>>(__n));
927899
@@ -934,18 +906,18 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
934906 // The kernel stores n integers for the predicate and another n integers for the offsets
935907 const auto __req_slm_size = sizeof (std::uint16_t ) * __n_uniform * 2 ;
936908
937- constexpr std::uint16_t __single_group_upper_limit = 2048 ;
909+ // constexpr std::uint16_t __single_group_upper_limit = 2048;
910+ constexpr std::uint16_t __max_elem_per_item = 2 ;
938911
939912 std::size_t __max_wg_size = oneapi::dpl::__internal::__max_work_group_size (__q_local);
940913
941- if (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size &&
942- __max_wg_size >= _SingleGroupInvoker::__targeted_wg_size)
914+ if (__n <= __max_wg_size * __max_elem_per_item && __max_slm_size >= __req_slm_size)
943915 {
944- using _SizeBreakpoints = std::integer_sequence<std:: uint16_t , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 >;
945-
946- return __par_backend_hetero::__static_monotonic_dispatcher<_SizeBreakpoints>:: __dispatch (
947- _SingleGroupInvoker{}, __n, __q_local, __n, std::forward<_InRng>(__in_rng),
948- std::forward<_OutRng>(__out_rng ), __pred, __assign );
916+ using _KernelName = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
917+ __scan_copy_single_wg_kernel<_CustomName>>;
918+ return __par_backend_hetero::__parallel_copy_if_static_single_group_submitter<_Size, _KernelName>() (
919+ __q_local, std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __pred, __assign ,
920+ static_cast < std::uint16_t >(__n_uniform ), static_cast <std:: uint16_t >( std::min (__n_uniform, __max_wg_size)) );
949921 }
950922 else if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz (__q_local))
951923 {
0 commit comments