Skip to content

Commit

Permalink
Forwarded linear iterator for linearly iterable expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Drew Hubley committed Jul 12, 2024
1 parent 8c0a484 commit b4f7e3d
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 11 deletions.
46 changes: 41 additions & 5 deletions include/xtensor/xstrided_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,36 @@ namespace xt

using inner_storage_type = typename base_type::inner_storage_type;
using storage_type = typename base_type::storage_type;
using linear_iterator = typename storage_type::iterator;
using const_linear_iterator = typename storage_type::const_iterator;
using reverse_linear_iterator = std::reverse_iterator<linear_iterator>;
using const_reverse_linear_iterator = std::reverse_iterator<const_linear_iterator>;

template <class C, class = void_t<>>
struct get_linear_iterator : std::false_type
{
using iterator = typename C::iterator;
};

template <typename C>
struct get_linear_iterator<C, void_t<decltype(std::declval<C>().linear_begin())>> : std::true_type
{
using iterator = typename C::linear_iterator;
};

template <class C, class = void_t<>>
struct get_const_linear_iterator : std::false_type
{
using iterator = typename C::const_iterator;
};

template <typename C>
struct get_const_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
{
using iterator = typename C::const_linear_iterator;
};

using linear_iterator = typename get_linear_iterator<storage_type>::iterator;
using const_linear_iterator = typename get_const_linear_iterator<storage_type>::iterator;
using reverse_linear_iterator = std::reverse_iterator<typename get_linear_iterator<storage_type>::iterator>;
using const_reverse_linear_iterator = std::reverse_iterator<
typename get_const_linear_iterator<storage_type>::iterator>;

using iterable_base = select_iterable_base_t<L, xexpression_type::static_layout, self_type>;
using inner_shape_type = typename base_type::inner_shape_type;
Expand Down Expand Up @@ -222,6 +248,7 @@ namespace xt
const_linear_iterator linear_begin() const;
const_linear_iterator linear_end() const;
const_linear_iterator linear_cbegin() const;

const_linear_iterator linear_cend() const;

reverse_linear_iterator linear_rbegin();
Expand Down Expand Up @@ -511,7 +538,16 @@ namespace xt
template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::linear_cbegin() const -> const_linear_iterator
{
return this->storage().cbegin() + static_cast<std::ptrdiff_t>(data_offset());
return xtl::mpl::static_if<get_const_linear_iterator<storage_type>::value>(
[&](auto self)
{
return self(this->storage()).linear_cbegin() + static_cast<std::ptrdiff_t>(data_offset());
},
[&](auto self)
{
return self(this->storage()).cbegin() + static_cast<std::ptrdiff_t>(data_offset());
}
);
}

template <class CT, class S, layout_type L, class FST>
Expand Down
132 changes: 126 additions & 6 deletions include/xtensor/xstrided_view_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace xt
using reverse_iterator = decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
using const_reverse_iterator = decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());


explicit flat_expression_adaptor(CT* e);

template <class FST>
Expand Down Expand Up @@ -75,6 +76,52 @@ namespace xt
size_type m_size;
};

template <class CT, layout_type L>
class linear_flat_expression_adaptor : public flat_expression_adaptor<CT, L>
{
public:

using xexpression_type = std::decay_t<CT>;
using shape_type = typename xexpression_type::shape_type;
using inner_strides_type = get_strides_t<shape_type>;
using index_type = inner_strides_type;
using size_type = typename xexpression_type::size_type;
using value_type = typename xexpression_type::value_type;
using const_reference = typename xexpression_type::const_reference;
using reference = std::conditional_t<
std::is_const<std::remove_reference_t<CT>>::value,
typename xexpression_type::const_reference,
typename xexpression_type::reference>;


using linear_iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_begin());
using const_linear_iterator = decltype(std::declval<std::decay_t<CT>>().linear_cbegin());
using reverse_linear_iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_rbegin()
);
using const_reverse_linear_iterator = decltype(std::declval<std::decay_t<CT>>().linear_crbegin());


explicit linear_flat_expression_adaptor(CT* e);

template <class FST>
linear_flat_expression_adaptor(CT* e, FST&& strides);

linear_iterator linear_begin();
linear_iterator linear_end();
const_linear_iterator linear_begin() const;
const_linear_iterator linear_end() const;
const_linear_iterator linear_cbegin() const;
const_linear_iterator linear_cend() const;

private:

static index_type& get_index();

mutable CT* m_e;
inner_strides_type m_strides;
size_type m_size;
};

template <class T>
struct is_flat_expression_adaptor : std::false_type
{
Expand All @@ -85,9 +132,21 @@ namespace xt
{
};

template <class T>
struct is_linear_flat_expression_adaptor : std::false_type
{
};

template <class CT, layout_type L>
struct is_linear_flat_expression_adaptor<linear_flat_expression_adaptor<CT, L>> : std::true_type
{
};

template <class E, class ST>
struct provides_data_interface
: xtl::conjunction<has_data_interface<std::decay_t<E>>, xtl::negation<is_flat_expression_adaptor<ST>>>
struct provides_data_interface : xtl::conjunction<
has_data_interface<std::decay_t<E>>,
xtl::negation<is_flat_expression_adaptor<ST>>,
xtl::negation<is_linear_flat_expression_adaptor<ST>>>
{
};
}
Expand Down Expand Up @@ -246,7 +305,10 @@ namespace xt
template <class CT, layout_type L>
struct flat_adaptor_getter
{
using type = flat_expression_adaptor<std::remove_reference_t<CT>, L>;
using type = std::conditional_t<
detail::has_linear_iterator<std::remove_reference_t<CT>>::value && (std::remove_reference_t<CT>::static_layout == L),
linear_flat_expression_adaptor<std::remove_reference_t<CT>, L>,
flat_expression_adaptor<std::remove_reference_t<CT>, L>>;
using reference = std::add_lvalue_reference_t<CT>;

template <class E>
Expand Down Expand Up @@ -318,9 +380,7 @@ namespace xt
layout_type layout
) noexcept
: m_e(std::forward<CTA>(e))
,
// m_storage(detail::get_flat_storage<undecay_expression>(m_e)),
m_storage(storage_getter::get_flat_storage(m_e))
, m_storage(storage_getter::get_flat_storage(m_e))
, m_shape(std::forward<SA>(shape))
, m_strides(std::move(strides))
, m_offset(offset)
Expand All @@ -345,6 +405,14 @@ namespace xt
new_storage.update_pointer(std::addressof(expr));
return new_storage;
}

template <class T, class E, layout_type L>
auto copy_move_storage(T& expr, const detail::linear_flat_expression_adaptor<E, L>& storage)
{
detail::linear_flat_expression_adaptor<E, L> new_storage = storage; // copy storage
new_storage.update_pointer(std::addressof(expr));
return new_storage;
}
}

template <class D>
Expand Down Expand Up @@ -783,6 +851,58 @@ namespace xt
thread_local static index_type index;
return index;
}

template <class CT, layout_type L>
inline linear_flat_expression_adaptor<CT, L>::linear_flat_expression_adaptor(CT* e)
: flat_expression_adaptor<CT, L>(e)
, m_e(e)
{
}

template <class CT, layout_type L>
template <class FST>
inline linear_flat_expression_adaptor<CT, L>::linear_flat_expression_adaptor(CT* e, FST&& strides)
: flat_expression_adaptor<CT, L>(e, strides)
, m_e(e)
, m_strides(xtl::forward_sequence<inner_strides_type, FST>(strides))
{
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_begin() -> linear_iterator
{
return m_e->linear_begin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_end() -> linear_iterator
{
return m_e->linear_end();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_begin() const -> const_linear_iterator
{
return m_e->linear_cbegin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_end() const -> const_linear_iterator
{
return m_e->linear_cend();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_cbegin() const -> const_linear_iterator
{
return m_e->linear_cbegin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::linear_cend() const -> const_linear_iterator
{
return m_e->linear_cend();
}
}

/**********************************
Expand Down

0 comments on commit b4f7e3d

Please sign in to comment.