Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a plugin system for defining nodes #83

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ cmake_minimum_required(VERSION 3.19)
project(graph-prototype CXX)
set(CMAKE_CXX_STANDARD 20)

# Mainly for FMT
set(CMAKE_POSITION_INDEPENDENT_CODE TRUE)

add_library(graph-prototype-options INTERFACE)
include(cmake/CompilerWarnings.cmake)
set_project_warnings(graph-prototype-options)
Expand Down
250 changes: 192 additions & 58 deletions include/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,21 @@ class dynamic_port {
}

public:
using value_type = void; // a sterile port
using value_type = void; // a sterile port

constexpr dynamic_port() = delete;
constexpr dynamic_port() = delete;

template<Port T>
constexpr dynamic_port(const T &arg) = delete;
dynamic_port(const dynamic_port &arg) = delete;
dynamic_port &
operator=(const dynamic_port &arg)
= delete;

dynamic_port(dynamic_port &&arg) = default;
dynamic_port &
operator=(dynamic_port &&arg)
= default;

// TODO: Make owning versus non-owning API more explicit
template<Port T>
explicit constexpr dynamic_port(T &arg) noexcept : _accessor{ std::make_unique<wrapper<T, false>>(arg) } {}

Expand Down Expand Up @@ -335,76 +343,148 @@ class dynamic_node {

#endif

class graph {
class node_model {
protected:
using dynamic_ports = std::vector<fair::graph::dynamic_port>;
bool _dynamic_ports_loaded = false;
dynamic_ports _dynamic_input_ports;
dynamic_ports _dynamic_output_ports;

node_model(){};

public:
class node_model {
public:
virtual ~node_model() = default;
node_model(const node_model &) = delete;
node_model &
operator=(const node_model &)
= delete;
node_model(node_model &&other) = delete;
node_model &
operator=(node_model &&other)
= delete;

fair::graph::dynamic_port &
dynamic_input_port(std::size_t index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noexcept, constexpr, [[nodiscard]]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert should be cautionary isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in having dynamic_* marked as constexpr it is not meant to be evaluated at compile-time. [[nodiscard]] is fine, noexcept - it is not, find_node can throw.

assert(_dynamic_ports_loaded);
return _dynamic_input_ports[index];
}

virtual std::string_view
name() const
= 0;
fair::graph::dynamic_port &
dynamic_output_port(std::size_t index) {
assert(_dynamic_ports_loaded);
return _dynamic_output_ports[index];
}

virtual work_return_t
work() = 0;
auto
dynamic_input_ports_size() const {
assert(_dynamic_ports_loaded);
return _dynamic_input_ports.size();
}

virtual void *
raw() = 0;
};
auto
dynamic_output_ports_size() const {
assert(_dynamic_ports_loaded);
return _dynamic_output_ports.size();
}

std::vector<std::function<connection_result_t()>> _connection_definitions;
std::vector<std::unique_ptr<node_model>> _nodes;
private:
virtual ~node_model() = default;

template<typename T>
class node_wrapper final : public node_model {
private:
static_assert(std::is_same_v<T, std::remove_reference_t<T>>);
T _node;
virtual std::string_view
name() const
= 0;

public:
node_wrapper(const node_wrapper &other) = delete;
virtual work_return_t
work() = 0;

node_wrapper &
operator=(const node_wrapper &other)
= delete;
virtual void *
raw() = 0;
};

node_wrapper(node_wrapper &&other) : _node(std::exchange(other._node, nullptr)) {}
template<typename T>
class node_wrapper : public node_model {
private:
static_assert(std::is_same_v<T, std::remove_reference_t<T>>);
T _node;

[[nodiscard]] constexpr const auto &
node_ref() const noexcept {
if constexpr (requires { *_node; }) {
return *_node;
} else {
return _node;
}
}

node_wrapper &
operator=(node_wrapper &&other) {
auto tmp = std::move(other);
std::swap(_node, tmp._node);
return *this;
[[nodiscard]] constexpr auto &
node_ref() noexcept {
if constexpr (requires { *_node; }) {
return *_node;
} else {
return _node;
}
}

~node_wrapper() override = default;
void
init_dynamic_ports() {
using Node = std::remove_cvref_t<decltype(node_ref())>;

node_wrapper() {}
constexpr std::size_t input_port_count = fair::graph::traits::node::template input_port_types<Node>::size;
[this]<std::size_t... Is>(std::index_sequence<Is...>) { (this->_dynamic_input_ports.emplace_back(fair::graph::input_port<Is>(&node_ref())), ...); }
(std::make_index_sequence<input_port_count>());

template<typename Arg>
requires(!std::is_same_v<std::remove_cvref_t<Arg>, T>)
node_wrapper(Arg &&arg) : _node(std::forward<Arg>(arg)) {}
constexpr std::size_t output_port_count = fair::graph::traits::node::template output_port_types<Node>::size;
[this]<std::size_t... Is>(std::index_sequence<Is...>) { (this->_dynamic_output_ports.push_back(fair::graph::dynamic_port(fair::graph::output_port<Is>(&node_ref()))), ...); }
(std::make_index_sequence<output_port_count>());

template<typename... Args>
requires(sizeof...(Args) > 1)
node_wrapper(Args &&...args) : _node{ std::forward<Args>(args)... } {}
static_assert(input_port_count + output_port_count > 0);
_dynamic_ports_loaded = true;
}

constexpr work_return_t
work() override {
return _node.work();
}
public:
node_wrapper(const node_wrapper &other) = delete;
node_wrapper(node_wrapper &&other) = delete;
node_wrapper &
operator=(const node_wrapper &other)
= delete;
node_wrapper &
operator=(node_wrapper &&other)
= delete;

~node_wrapper() override = default;

node_wrapper() { init_dynamic_ports(); }

template<typename Arg>
requires(!std::is_same_v<std::remove_cvref_t<Arg>, T>)
node_wrapper(Arg &&arg) : _node(std::forward<Arg>(arg)) {
init_dynamic_ports();
}

std::string_view
name() const override {
return _node.name();
}
template<typename... Args>
requires(sizeof...(Args) > 1)
node_wrapper(Args &&...args) : _node{ std::forward<Args>(args)... } {
init_dynamic_ports();
}

void *
raw() override {
return std::addressof(_node);
}
};
constexpr work_return_t
work() override {
return node_ref().work();
}

std::string_view
name() const override {
return node_ref().name();
}

void *
raw() override {
return std::addressof(node_ref());
}
};

class graph {
private:
std::vector<std::function<connection_result_t()>> _connection_definitions;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N.B. OK ... as a reminder for us: we should beef the connection definitions up to the edge<T> definition including also priority, preferred buffer sizes, etc.

std::vector<std::unique_ptr<node_model>> _nodes;

class edge {
public:
Expand Down Expand Up @@ -462,7 +542,34 @@ class graph {
}
};

std::vector<edge> _edges;
std::vector<edge> _edges;

template<typename Node>
std::unique_ptr<node_model> &
find_node(Node &what) {
auto it = [&, this] {
if constexpr (std::is_same_v<Node, node_model>) {
return std::find_if(_nodes.begin(), _nodes.end(), [&](const auto &node) { return node.get() == &what; });
} else {
return std::find_if(_nodes.begin(), _nodes.end(), [&](const auto &node) { return node->raw() == &what; });
}
}();

if (it == _nodes.end()) throw fmt::format("No such node in this graph");
return *it;
}

template<typename Node>
[[nodiscard]] dynamic_port &
dynamic_output_port(Node &node, std::size_t index) {
return find_node(node)->dynamic_output_port(index);
}

template<typename Node>
[[nodiscard]] dynamic_port &
dynamic_input_port(Node &node, std::size_t index) {
return find_node(node)->dynamic_input_port(index);
}

template<std::size_t src_port_index, std::size_t dst_port_index, typename Source, typename SourcePort, typename Destination, typename DestinationPort>
[[nodiscard]] connection_result_t
Expand Down Expand Up @@ -575,6 +682,12 @@ class graph {
return _edges.size();
}

node_model &
add_node(std::unique_ptr<node_model> node) {
auto &new_node_ref = _nodes.emplace_back(std::move(node));
return *new_node_ref.get();
}

template<typename Node, typename... Args>
auto &
make_node(Args &&...args) { // TODO for review: do we still need this factory method or allow only pmt-map-type constructors (see below)
Expand Down Expand Up @@ -622,10 +735,31 @@ class graph {
return graph::source_connector<Source, Port>(*this, source, std::invoke(member_ptr, source));
}

[[nodiscard]] const std::vector<edge>&
[[nodiscard]] const std::vector<edge> &
get_edges() const {
return _edges;
}

template<typename Source, typename Sink>
connection_result_t
dynamic_connect(Source &source, std::size_t source_index, Sink &sink, std::size_t sink_index) {
return dynamic_output_port(source, source_index).connect(dynamic_input_port(sink, sink_index));
}

const std::vector<std::function<connection_result_t()>> &
connection_definitions() {
return _connection_definitions;
}

void
clear_connection_definitions() {
_connection_definitions.clear();
}

auto &
nodes() {
return _nodes;
}
};

// TODO: add nicer enum formatter
Expand Down
10 changes: 10 additions & 0 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,16 @@ static_assert(traits::node::can_process_simd<decltype(merge_by_index<0, 0>(copy(
} // namespace test
#endif

namespace detail {
template<template<typename> typename NodeTemplate, typename... AllowedTypes>
struct register_node {
template<typename RegisterInstance>
register_node(RegisterInstance *plugin_instance, std::string node_type) {
plugin_instance->template add_node_type<NodeTemplate, AllowedTypes...>(node_type);
}
};
} // namespace detail

} // namespace fair::graph

#endif // include guard
Loading