|
15 | 15 | #include <algorithm> |
16 | 16 | #include <complex> |
17 | 17 | #include <stdexcept> |
18 | | - |
19 | 18 | #include "../../types/xsimd_batch_constant.hpp" |
20 | 19 | #include "./xsimd_common_details.hpp" |
21 | 20 |
|
@@ -348,6 +347,140 @@ namespace xsimd |
348 | 347 | return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {}); |
349 | 348 | } |
350 | 349 |
|
| 350 | + template <class A, class T_in, class T_out> |
| 351 | + XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, typename batch<T_out, A>::batch_bool_type const& mask, convert<T_out>, requires_arch<common>) noexcept |
| 352 | + { |
| 353 | + constexpr std::size_t size = batch<T_out, A>::size; |
| 354 | + if (mask.none()) |
| 355 | + return batch<T_out, A>(0); |
| 356 | + if (mask.all()) |
| 357 | + return batch<T_out, A>::load(mem, unaligned_mode {}); |
| 358 | + alignas(A::alignment()) std::array<T_out, size> buffer { 0 }; |
| 359 | + mask.for_each_set_index([&](std::size_t idx) noexcept { buffer[idx] = static_cast<T_out>(mem[idx]); }); |
| 360 | + return batch<T_out, A>::load(buffer.data(), unaligned_mode {}); |
| 361 | + } |
| 362 | + |
| 363 | + template <class A, class T_in, class T_out> |
| 364 | + XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, typename batch<T_in, A>::batch_bool_type const& mask, requires_arch<common>) noexcept |
| 365 | + { |
| 366 | + if (mask.none()) |
| 367 | + return; |
| 368 | + if (mask.all()) |
| 369 | + { |
| 370 | + src.store(mem, unaligned_mode {}); |
| 371 | + return; |
| 372 | + } |
| 373 | + mask.for_each_set_index([&](std::size_t idx) noexcept { mem[idx] = static_cast<T_out>(src.get(idx)); }); |
| 374 | + } |
| 375 | + |
| 376 | + // COMPILE-TIME (single version each, XSIMD_IF_CONSTEXPR) |
| 377 | + template <class A, class T_in, class T_out, bool... Values> |
| 378 | + XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, requires_arch<common>) noexcept |
| 379 | + { |
| 380 | + constexpr std::size_t size = batch<T_out, A>::size; |
| 381 | + constexpr std::size_t n = mask.countr_one(); |
| 382 | + constexpr std::size_t l = mask.countl_one(); |
| 383 | + |
| 384 | + // All zeros / all ones fast paths |
| 385 | + XSIMD_IF_CONSTEXPR(mask.none()) |
| 386 | + { |
| 387 | + return batch<T_out, A>(0); |
| 388 | + } |
| 389 | + else XSIMD_IF_CONSTEXPR(mask.all()) |
| 390 | + { |
| 391 | + return batch<T_out, A>::load(mem, unaligned_mode {}); |
| 392 | + } |
| 393 | + // Prefix-ones (n contiguous ones from LSB) |
| 394 | + else XSIMD_IF_CONSTEXPR(n > 0) |
| 395 | + { |
| 396 | + alignas(A::alignment()) std::array<T_out, size> buffer { 0 }; |
| 397 | + for (std::size_t i = 0; i < n; ++i) |
| 398 | + buffer[i] = static_cast<T_out>(mem[i]); |
| 399 | + return batch<T_out, A>::load(buffer.data(), aligned_mode {}); |
| 400 | + } |
| 401 | + // Suffix-ones (l contiguous ones from MSB) |
| 402 | + else XSIMD_IF_CONSTEXPR(l > 0) |
| 403 | + { |
| 404 | + alignas(A::alignment()) std::array<T_out, size> buffer { 0 }; |
| 405 | + const std::size_t start = size - l; |
| 406 | + for (std::size_t i = 0; i < l; ++i) |
| 407 | + buffer[start + i] = static_cast<T_out>(mem[start + i]); |
| 408 | + return batch<T_out, A>::load(buffer.data(), aligned_mode {}); |
| 409 | + } |
| 410 | + else XSIMD_IF_CONSTEXPR(mask.popcount() > 0) |
| 411 | + { |
| 412 | + constexpr std::size_t first = mask.first_one_index(); |
| 413 | + constexpr std::size_t last = mask.last_one_index(); |
| 414 | + constexpr std::size_t span = last >= first ? (last - first + 1) : 0; |
| 415 | + XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span) |
| 416 | + { |
| 417 | + alignas(A::alignment()) std::array<T_out, size> buffer { 0 }; |
| 418 | + for (std::size_t i = 0; i < span; ++i) |
| 419 | + buffer[first + i] = static_cast<T_out>(mem[first + i]); |
| 420 | + return batch<T_out, A>::load(buffer.data(), aligned_mode {}); |
| 421 | + } |
| 422 | + else |
| 423 | + { |
| 424 | + return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {}); |
| 425 | + } |
| 426 | + } |
| 427 | + else |
| 428 | + { |
| 429 | + // Fallback to runtime path for non prefix/suffix masks |
| 430 | + return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {}); |
| 431 | + } |
| 432 | + } |
| 433 | + |
| 434 | + template <class A, class T_in, class T_out, bool... Values> |
| 435 | + XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...> mask, requires_arch<common>) noexcept |
| 436 | + { |
| 437 | + constexpr std::size_t size = batch<T_in, A>::size; |
| 438 | + constexpr std::size_t n = mask.countr_one(); |
| 439 | + constexpr std::size_t l = mask.countl_one(); |
| 440 | + |
| 441 | + // All zeros / all ones fast paths |
| 442 | + XSIMD_IF_CONSTEXPR(mask.none()) |
| 443 | + { |
| 444 | + return; |
| 445 | + } |
| 446 | + else XSIMD_IF_CONSTEXPR(mask.all()) |
| 447 | + { |
| 448 | + src.store(mem, unaligned_mode {}); |
| 449 | + } |
| 450 | + // Prefix-ones |
| 451 | + else XSIMD_IF_CONSTEXPR(n > 0) |
| 452 | + { |
| 453 | + for (std::size_t i = 0; i < n; ++i) |
| 454 | + mem[i] = static_cast<T_out>(src.get(i)); |
| 455 | + } |
| 456 | + // Suffix-ones |
| 457 | + else XSIMD_IF_CONSTEXPR(l > 0) |
| 458 | + { |
| 459 | + const std::size_t start = size - l; |
| 460 | + for (std::size_t i = 0; i < l; ++i) |
| 461 | + mem[start + i] = static_cast<T_out>(src.get(start + i)); |
| 462 | + } |
| 463 | + else XSIMD_IF_CONSTEXPR(mask.popcount() > 0) |
| 464 | + { |
| 465 | + constexpr std::size_t first = mask.first_one_index(); |
| 466 | + constexpr std::size_t last = mask.last_one_index(); |
| 467 | + constexpr std::size_t span = last >= first ? (last - first + 1) : 0; |
| 468 | + XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span) |
| 469 | + { |
| 470 | + for (std::size_t i = 0; i < span; ++i) |
| 471 | + mem[first + i] = static_cast<T_out>(src.get(first + i)); |
| 472 | + } |
| 473 | + else |
| 474 | + { |
| 475 | + store_masked<A>(mem, src, mask.as_batch_bool(), common {}); |
| 476 | + } |
| 477 | + } |
| 478 | + else |
| 479 | + { |
| 480 | + store_masked<A>(mem, src, mask.as_batch_bool(), common {}); |
| 481 | + } |
| 482 | + } |
| 483 | + |
351 | 484 | // rotate_right |
352 | 485 | template <size_t N, class A, class T> |
353 | 486 | XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept |
|
0 commit comments