Skip to content

Commit

Permalink
Make SIMD compile on EMScripten
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-cukic committed Feb 9, 2023
1 parent 6935c46 commit 8015c0b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 17 deletions.
17 changes: 0 additions & 17 deletions bench/bm_case1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,23 +416,6 @@ inline const boost::ut::suite _constexpr_bm = [] {
};
}

{
auto gen_mult_block_float = [] {
return merge<"out", "in">(converting_multiply<float, double>(2.0f), merge<"out", "in">(converting_multiply<double, float>(0.5f), add<float, -1>()));
};
auto merged_node = merge<"out", "in">(
merge<"out", "in">(test::source<float>(N_SAMPLES), gen_mult_block_float()), test::sink<float>());
"constexpr src->mult(2.0)->mult(0.5)->add(-1)->sink"_benchmark.repeat<N_ITER>(N_SAMPLES) = [&merged_node]() {
test::n_samples_produced = 0LU;
test::n_samples_consumed = 0LU;
for (std::size_t i = 0; i < N_SAMPLES; i++) {
merged_node.process_one();
}
expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough samples";
expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough samples";
};
}

{
auto gen_mult_block_float = [] {
return merge<"out", "in">(multiply<float>(2.0f), merge<"out", "in">(multiply<float>(0.5f),
Expand Down
1 change: 1 addition & 0 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class node : protected std::tuple<Arguments...> {
auto &&out_range = out.request_write(n);
// if SIMD makes sense (i.e. input and output ranges are contiguous and all types are
// vectorizable)

if constexpr ((std::ranges::contiguous_range<decltype(out_range)> && ... && std::ranges::contiguous_range<Ins>) &&
detail::vectorizable<return_type> && detail::node_can_process_simd<Derived>
&& input_port_types ::template transform<stdx::native_simd>::template all_of<std::is_constructible>) {
Expand Down
20 changes: 20 additions & 0 deletions include/node_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <port_traits.hpp> // localinclude
#include <utils.hpp> // localinclude

#include <vir/simd.h>

namespace fair::graph::traits::node {

namespace detail {
Expand Down Expand Up @@ -141,6 +143,7 @@ using get_port_member_descriptor =
typename meta::to_typelist<refl::descriptor::member_list<Node>>
::template filter<detail::member_descriptor_has_type<PortType>::template matches>::template at<0>;

#ifndef __EMSCRIPTEN__
template<typename Node>
concept can_process_simd =
traits::node::input_ports<Node>::size() > 0 &&
Expand All @@ -153,6 +156,23 @@ concept can_process_simd =
};
};

#else

template<typename Node>
concept can_process_simd =
traits::node::input_ports<Node>::size() > 0 &&
traits::node::input_ports<Node>::template all_same<> &&
traits::node::output_ports<Node>::size() > 0 &&
requires (Node& node,
typename traits::node::input_port_types<Node>::template transform<vir::stdx::native_simd>::template apply<std::tuple>& input_simds) {
{
[]<std::size_t... Is>(Node &node, auto const &input, std::index_sequence<Is...>) -> decltype(node.process_one(std::get<Is>(input)...)) { return {}; }
(node, input_simds, std::make_index_sequence<traits::node::input_ports<Node>::size()>())
};
};

#endif

} // namespace node

#endif // include guard
12 changes: 12 additions & 0 deletions include/typelist.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,18 @@ struct typelist {
template<template<typename...> typename Pred>
constexpr static bool none_of = (!Pred<Ts>::value && ...);

using safe_head = std::remove_pointer_t<decltype([] {
if constexpr (sizeof...(Ts) > 0) {
return static_cast<this_t::at<0>*>(nullptr);
} else {
return static_cast<void*>(nullptr);
}
}())>;

template<typename Matcher = typename this_t::safe_head>
constexpr static bool all_same =
((std::is_same_v<Matcher, Ts> && ...));

template<template<typename...> typename Predicate>
using filter = concat<std::conditional_t<Predicate<Ts>::value, typelist<Ts>, typelist<>>...>;

Expand Down

0 comments on commit 8015c0b

Please sign in to comment.