diff --git a/include/xsimd/arch/xsimd_avx.hpp b/include/xsimd/arch/xsimd_avx.hpp index 27f14b505..4af728e07 100644 --- a/include/xsimd/arch/xsimd_avx.hpp +++ b/include/xsimd/arch/xsimd_avx.hpp @@ -1613,7 +1613,7 @@ namespace xsimd } return split; } - constexpr auto lane_mask = mask % make_batch_constant(); + constexpr auto lane_mask = mask % std::integral_constant(); XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask)) { __m256 broadcast = _mm256_permute2f128_ps(self, self, 0x00); // [low | low] @@ -1632,7 +1632,7 @@ namespace xsimd __m256 swapped = _mm256_permute2f128_ps(self, self, 0x01); // [high | low] // normalize mask taking modulo 4 - constexpr auto half_mask = mask % make_batch_constant(); + constexpr auto half_mask = mask % std::integral_constant(); // permute within each lane __m256 r0 = _mm256_permutevar_ps(self, half_mask.as_batch()); @@ -1640,7 +1640,7 @@ namespace xsimd // select lane by the mask index divided by 4 constexpr auto lane = batch_constant {}; - constexpr int lane_idx = ((mask / make_batch_constant()) != lane).mask(); + constexpr int lane_idx = ((mask / std::integral_constant()) != lane).mask(); return _mm256_blend_ps(r0, r1, lane_idx); } @@ -1681,7 +1681,7 @@ namespace xsimd // select lane by the mask index divided by 2 constexpr auto lane = batch_constant {}; - constexpr int lane_idx = ((mask / make_batch_constant()) != lane).mask(); + constexpr int lane_idx = ((mask / std::integral_constant()) != lane).mask(); // blend the two permutes return _mm256_blend_pd(r0, r1, lane_idx); diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index 699f40ec3..bf6d9e7de 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -1332,7 +1332,7 @@ namespace xsimd return self; } - constexpr auto lane_mask = mask % make_batch_constant(); + constexpr auto lane_mask = mask % std::integral_constant(); XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask)) { @@ -1409,7 +1409,7 @@ namespace xsimd } XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask)) { - constexpr auto lane_mask = mask % make_batch_constant(); + constexpr auto lane_mask = mask % std::integral_constant(); // Cheaper intrinsics when not crossing lanes // Contrary to the uint64_t version, the limits of 8 bits for the immediate constant // cannot make different permutations across lanes diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index 7959c494f..cccba8144 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -2074,7 +2074,7 @@ namespace xsimd __m128i hi = _mm_unpackhi_epi64(hil, hih); // mask to choose the right lane - constexpr auto blend_mask = mask < make_batch_constant(); + constexpr auto blend_mask = mask < std::integral_constant(); // blend the two permutes return select(blend_mask, batch(lo), batch(hi)); diff --git a/include/xsimd/types/xsimd_batch_constant.hpp b/include/xsimd/types/xsimd_batch_constant.hpp index 10fff0df0..4da2e3da6 100644 --- a/include/xsimd/types/xsimd_batch_constant.hpp +++ b/include/xsimd/types/xsimd_batch_constant.hpp @@ -316,11 +316,16 @@ namespace xsimd } public: -#define MAKE_BINARY_OP(OP, NAME) \ - template \ - constexpr auto operator OP(batch_constant other) const \ - { \ - return apply>(*this, other); \ +#define MAKE_BINARY_OP(OP, NAME) \ + template \ + constexpr auto operator OP(batch_constant other) const \ + { \ + return apply>(*this, other); \ + } \ + template \ + constexpr batch_constant operator OP(std::integral_constant) const \ + { \ + return {}; \ } MAKE_BINARY_OP(+, std::plus) @@ -350,11 +355,16 @@ namespace xsimd return apply_bool...>, std::tuple...>>(std::make_index_sequence()); } -#define MAKE_BINARY_BOOL_OP(OP, NAME) \ - template \ - constexpr auto operator OP(batch_constant other) const \ - { \ - return apply_bool>(*this, other); \ +#define MAKE_BINARY_BOOL_OP(OP, NAME) \ + template \ + constexpr auto operator OP(batch_constant other) const \ + { \ + return apply_bool>(*this, other); \ + } \ + template \ + constexpr batch_bool_constant operator OP(std::integral_constant) const \ + { \ + return {}; \ } MAKE_BINARY_BOOL_OP(==, std::equal_to) @@ -483,6 +493,29 @@ namespace xsimd #endif + namespace generator + { + template + struct iota + { + static constexpr T get(size_t index, size_t) + { + return static_cast(index); + } + }; + } + /** + * @brief Build a @c batch_constant as an enumerated range + * + * @tparam T type of the data held in the batch. + * @tparam A Architecture that will be used when converting to a regular batch. + */ + template + XSIMD_INLINE constexpr auto make_iota_batch_constant() noexcept + { + return make_batch_constant, A>(); + } + } // namespace xsimd #endif diff --git a/test/test_batch_constant.cpp b/test/test_batch_constant.cpp index 5b0e23028..2b99fc05b 100644 --- a/test/test_batch_constant.cpp +++ b/test/test_batch_constant.cpp @@ -82,6 +82,10 @@ struct constant_batch_test constexpr auto b = xsimd::make_batch_constant(); INFO("batch(value_type)"); CHECK_BATCH_EQ((batch_type)b, expected); + + constexpr auto b_p = xsimd::make_iota_batch_constant(); + INFO("batch(value_type)"); + CHECK_BATCH_EQ((batch_type)b_p, expected); } template @@ -106,43 +110,64 @@ struct constant_batch_test { constexpr auto n12 = xsimd::make_batch_constant, arch_type>(); constexpr auto n3 = xsimd::make_batch_constant, arch_type>(); + constexpr std::integral_constant c3; constexpr auto n12_add_n3 = n12 + n3; constexpr auto n15 = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "n12 + n3 == n15"); + constexpr auto n12_add_c3 = n12 + c3; + static_assert(std::is_same::value, "n12 + c3 == n15"); constexpr auto n12_sub_n3 = n12 - n3; constexpr auto n9 = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "n12 - n3 == n9"); + constexpr auto n12_sub_c3 = n12 - c3; + static_assert(std::is_same::value, "n12 - c3 == n9"); constexpr auto n12_mul_n3 = n12 * n3; constexpr auto n36 = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "n12 * n3 == n36"); + constexpr auto n12_mul_c3 = n12 * c3; + static_assert(std::is_same::value, "n12 - c3 == n36"); constexpr auto n12_div_n3 = n12 / n3; constexpr auto n4 = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "n12 / n3 == n4"); + constexpr auto n12_div_c3 = n12 / c3; + static_assert(std::is_same::value, "n12 / c3 == n4"); constexpr auto n12_mod_n3 = n12 % n3; constexpr auto n0 = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "n12 % n3 == n0"); + constexpr auto n12_mod_c3 = n12 % c3; + static_assert(std::is_same::value, "n12 % c3 == n0"); constexpr auto n12_land_n3 = n12 & n3; static_assert(std::is_same::value, "n12 & n3 == n0"); + constexpr auto n12_land_c3 = n12 & c3; + static_assert(std::is_same::value, "n12 & c3 == n0"); constexpr auto n12_lor_n3 = n12 | n3; static_assert(std::is_same::value, "n12 | n3 == n15"); + constexpr auto n12_lor_c3 = n12 | c3; + static_assert(std::is_same::value, "n12 | c3 == n15"); constexpr auto n12_lxor_n3 = n12 ^ n3; static_assert(std::is_same::value, "n12 ^ n3 == n15"); + constexpr auto n12_lxor_c3 = n12 ^ c3; + static_assert(std::is_same::value, "n12 ^ c3 == n15"); constexpr auto n96 = xsimd::make_batch_constant, arch_type>(); constexpr auto n12_lshift_n3 = n12 << n3; static_assert(std::is_same::value, "n12 << n3 == n96"); + constexpr auto n12_lshift_c3 = n12 << c3; + static_assert(std::is_same::value, "n12 << c3 == n96"); constexpr auto n1 = xsimd::make_batch_constant, arch_type>(); constexpr auto n12_rshift_n3 = n12 >> n3; static_assert(std::is_same::value, "n12 >> n3 == n1"); + constexpr auto n12_rshift_c3 = n12 >> c3; + static_assert(std::is_same::value, "n12 >> c3 == n1"); constexpr auto n12_uadd = +n12; static_assert(std::is_same::value, "+n12 == n12"); @@ -163,21 +188,27 @@ struct constant_batch_test static_assert(std::is_same::value, "n12 == n12"); static_assert(std::is_same::value, "n12 == n3"); + static_assert(std::is_same::value, "n12 == c3"); static_assert(std::is_same::value, "n12 != n12"); static_assert(std::is_same::value, "n12 != n3"); + static_assert(std::is_same::value, "n12 != c3"); static_assert(std::is_same::value, "n12 < n12"); static_assert(std::is_same::value, "n12 < n3"); + static_assert(std::is_same::value, "n12 < c3"); static_assert(std::is_same n12), false_batch_type>::value, "n12 > n12"); static_assert(std::is_same n3), true_batch_type>::value, "n12 > n3"); + static_assert(std::is_same c3), true_batch_type>::value, "n12 > c3"); static_assert(std::is_same::value, "n12 <= n12"); static_assert(std::is_same::value, "n12 <= n3"); + static_assert(std::is_same::value, "n12 <= c3"); static_assert(std::is_same= n12), true_batch_type>::value, "n12 >= n12"); static_assert(std::is_same= n3), true_batch_type>::value, "n12 >= n3"); + static_assert(std::is_same= c3), true_batch_type>::value, "n12 >= c3"); } };