Skip to content

Commit 510c335

Browse files
committed
1. Adds new masked API runtime/compile time masks (store_masked and load_masked)
2. General use case optimization 3. New tests 4. x86 kernels
1 parent e392354 commit 510c335

File tree

13 files changed

+3774
-8
lines changed

13 files changed

+3774
-8
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <algorithm>
1616
#include <complex>
1717
#include <stdexcept>
18-
1918
#include "../../types/xsimd_batch_constant.hpp"
2019
#include "./xsimd_common_details.hpp"
2120

@@ -348,6 +347,140 @@ namespace xsimd
348347
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349348
}
350349

350+
template <class A, class T_in, class T_out>
351+
XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, typename batch<T_out, A>::batch_bool_type const& mask, convert<T_out>, requires_arch<common>) noexcept
352+
{
353+
constexpr std::size_t size = batch<T_out, A>::size;
354+
if (mask.none())
355+
return batch<T_out, A>(0);
356+
if (mask.all())
357+
return batch<T_out, A>::load(mem, unaligned_mode {});
358+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
359+
mask.for_each_set_index([&](std::size_t idx) noexcept { buffer[idx] = static_cast<T_out>(mem[idx]); });
360+
return batch<T_out, A>::load(buffer.data(), unaligned_mode {});
361+
}
362+
363+
template <class A, class T_in, class T_out>
364+
XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, typename batch<T_in, A>::batch_bool_type const& mask, requires_arch<common>) noexcept
365+
{
366+
if (mask.none())
367+
return;
368+
if (mask.all())
369+
{
370+
src.store(mem, unaligned_mode {});
371+
return;
372+
}
373+
mask.for_each_set_index([&](std::size_t idx) noexcept { mem[idx] = static_cast<T_out>(src.get(idx)); });
374+
}
375+
376+
// COMPILE-TIME (single version each, XSIMD_IF_CONSTEXPR)
377+
template <class A, class T_in, class T_out, bool... Values>
378+
XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, requires_arch<common>) noexcept
379+
{
380+
constexpr std::size_t size = batch<T_out, A>::size;
381+
constexpr std::size_t n = mask.countr_one();
382+
constexpr std::size_t l = mask.countl_one();
383+
384+
// All zeros / all ones fast paths
385+
XSIMD_IF_CONSTEXPR(mask.none())
386+
{
387+
return batch<T_out, A>(0);
388+
}
389+
else XSIMD_IF_CONSTEXPR(mask.all())
390+
{
391+
return batch<T_out, A>::load(mem, unaligned_mode {});
392+
}
393+
// Prefix-ones (n contiguous ones from LSB)
394+
else XSIMD_IF_CONSTEXPR(n > 0)
395+
{
396+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
397+
for (std::size_t i = 0; i < n; ++i)
398+
buffer[i] = static_cast<T_out>(mem[i]);
399+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
400+
}
401+
// Suffix-ones (l contiguous ones from MSB)
402+
else XSIMD_IF_CONSTEXPR(l > 0)
403+
{
404+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
405+
const std::size_t start = size - l;
406+
for (std::size_t i = 0; i < l; ++i)
407+
buffer[start + i] = static_cast<T_out>(mem[start + i]);
408+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
409+
}
410+
else XSIMD_IF_CONSTEXPR(mask.popcount() > 0)
411+
{
412+
constexpr std::size_t first = mask.first_one_index();
413+
constexpr std::size_t last = mask.last_one_index();
414+
constexpr std::size_t span = last >= first ? (last - first + 1) : 0;
415+
XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span)
416+
{
417+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
418+
for (std::size_t i = 0; i < span; ++i)
419+
buffer[first + i] = static_cast<T_out>(mem[first + i]);
420+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
421+
}
422+
else
423+
{
424+
return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {});
425+
}
426+
}
427+
else
428+
{
429+
// Fallback to runtime path for non prefix/suffix masks
430+
return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {});
431+
}
432+
}
433+
434+
template <class A, class T_in, class T_out, bool... Values>
435+
XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...> mask, requires_arch<common>) noexcept
436+
{
437+
constexpr std::size_t size = batch<T_in, A>::size;
438+
constexpr std::size_t n = mask.countr_one();
439+
constexpr std::size_t l = mask.countl_one();
440+
441+
// All zeros / all ones fast paths
442+
XSIMD_IF_CONSTEXPR(mask.none())
443+
{
444+
return;
445+
}
446+
else XSIMD_IF_CONSTEXPR(mask.all())
447+
{
448+
src.store(mem, unaligned_mode {});
449+
}
450+
// Prefix-ones
451+
else XSIMD_IF_CONSTEXPR(n > 0)
452+
{
453+
for (std::size_t i = 0; i < n; ++i)
454+
mem[i] = static_cast<T_out>(src.get(i));
455+
}
456+
// Suffix-ones
457+
else XSIMD_IF_CONSTEXPR(l > 0)
458+
{
459+
const std::size_t start = size - l;
460+
for (std::size_t i = 0; i < l; ++i)
461+
mem[start + i] = static_cast<T_out>(src.get(start + i));
462+
}
463+
else XSIMD_IF_CONSTEXPR(mask.popcount() > 0)
464+
{
465+
constexpr std::size_t first = mask.first_one_index();
466+
constexpr std::size_t last = mask.last_one_index();
467+
constexpr std::size_t span = last >= first ? (last - first + 1) : 0;
468+
XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span)
469+
{
470+
for (std::size_t i = 0; i < span; ++i)
471+
mem[first + i] = static_cast<T_out>(src.get(first + i));
472+
}
473+
else
474+
{
475+
store_masked<A>(mem, src, mask.as_batch_bool(), common {});
476+
}
477+
}
478+
else
479+
{
480+
store_masked<A>(mem, src, mask.as_batch_bool(), common {});
481+
}
482+
}
483+
351484
// rotate_right
352485
template <size_t N, class A, class T>
353486
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

0 commit comments

Comments
 (0)