/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_ITERATOR_HPP #define XTENSOR_ITERATOR_HPP #include #include #include #include #include #include #include #include #include #include #include "xexception.hpp" #include "xlayout.hpp" #include "xshape.hpp" #include "xutils.hpp" namespace xt { /*********************** * iterator meta utils * ***********************/ template class xscalar; template class xscalar_stepper; namespace detail { template struct get_stepper_iterator_impl { using type = typename C::container_iterator; }; template struct get_stepper_iterator_impl { using type = typename C::const_container_iterator; }; template struct get_stepper_iterator_impl> { using type = typename xscalar::dummy_iterator; }; template struct get_stepper_iterator_impl> { using type = typename xscalar::const_dummy_iterator; }; } template using get_stepper_iterator = typename detail::get_stepper_iterator_impl::type; /******************************** * xindex_type_t implementation * ********************************/ namespace detail { template struct index_type_impl { using type = dynamic_shape; }; template struct index_type_impl> { using type = std::array; }; template struct index_type_impl> { using type = std::array; }; } template using xindex_type_t = typename detail::index_type_impl::type; /************ * xstepper * ************/ template class xstepper { public: using storage_type = C; using subiterator_type = get_stepper_iterator; using subiterator_traits = std::iterator_traits; using value_type = typename subiterator_traits::value_type; using reference = typename subiterator_traits::reference; using pointer = typename subiterator_traits::pointer; using difference_type = typename subiterator_traits::difference_type; using size_type = typename storage_type::size_type; using shape_type = typename storage_type::shape_type; using simd_value_type = xt_simd::simd_type; template using simd_return_type = xt_simd::simd_return_type; xstepper() = default; xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept; reference operator*() const; void step(size_type dim, size_type n = 1); void step_back(size_type dim, size_type n = 1); void reset(size_type dim); void reset_back(size_type dim); void to_begin(); void to_end(layout_type l); template simd_return_type step_simd(); void step_leading(); template void store_simd(const R& vec); private: storage_type* p_c; subiterator_type m_it; size_type m_offset; }; template struct stepper_tools { // For performance reasons, increment_stepper and decrement_stepper are // specialized for the case where n=1, which underlies operator++ and // operator-- on xiterators. template static void increment_stepper(S& stepper, IT& index, const ST& shape); template static void decrement_stepper(S& stepper, IT& index, const ST& shape); template static void increment_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n); template static void decrement_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n); }; /******************** * xindexed_stepper * ********************/ template class xindexed_stepper { public: using self_type = xindexed_stepper; using xexpression_type = std::conditional_t; using value_type = typename xexpression_type::value_type; using reference = std:: conditional_t; using pointer = std:: conditional_t; using size_type = typename xexpression_type::size_type; using difference_type = typename xexpression_type::difference_type; using shape_type = typename xexpression_type::shape_type; using index_type = xindex_type_t; xindexed_stepper() = default; xindexed_stepper(xexpression_type* e, size_type offset, bool end = false) noexcept; reference operator*() const; void step(size_type dim, size_type n = 1); void step_back(size_type dim, size_type n = 1); void reset(size_type dim); void reset_back(size_type dim); void to_begin(); void to_end(layout_type l); private: xexpression_type* p_e; index_type m_index; size_type m_offset; }; template struct is_indexed_stepper { static const bool value = false; }; template struct is_indexed_stepper> { static const bool value = true; }; template struct enable_indexed_stepper : std::enable_if::value, R> { }; template using enable_indexed_stepper_t = typename enable_indexed_stepper::type; template struct disable_indexed_stepper : std::enable_if::value, R> { }; template using disable_indexed_stepper_t = typename disable_indexed_stepper::type; /************* * xiterator * *************/ namespace detail { template class shape_storage { public: using shape_type = S; using param_type = const S&; shape_storage() = default; shape_storage(param_type shape); const S& shape() const; private: S m_shape; }; template class shape_storage { public: using shape_type = S; using param_type = const S*; shape_storage(param_type shape = 0); const S& shape() const; private: const S* p_shape; }; template struct LAYOUT_FORBIDEN_FOR_XITERATOR; } template class xiterator : public xtl::xrandom_access_iterator_base< xiterator, typename St::value_type, typename St::difference_type, typename St::pointer, typename St::reference>, private detail::shape_storage { public: using self_type = xiterator; using stepper_type = St; using value_type = typename stepper_type::value_type; using reference = typename stepper_type::reference; using pointer = typename stepper_type::pointer; using difference_type = typename stepper_type::difference_type; using size_type = typename stepper_type::size_type; using iterator_category = std::random_access_iterator_tag; using private_base = detail::shape_storage; using shape_type = typename private_base::shape_type; using shape_param_type = typename private_base::param_type; using index_type = xindex_type_t; xiterator() = default; // end_index means either reverse_iterator && !end or !reverse_iterator && end xiterator(St st, shape_param_type shape, bool end_index); self_type& operator++(); self_type& operator--(); self_type& operator+=(difference_type n); self_type& operator-=(difference_type n); difference_type operator-(const self_type& rhs) const; reference operator*() const; pointer operator->() const; bool equal(const xiterator& rhs) const; bool less_than(const xiterator& rhs) const; private: stepper_type m_st; index_type m_index; difference_type m_linear_index; using checking_type = typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR::type; }; template bool operator==(const xiterator& lhs, const xiterator& rhs); template bool operator<(const xiterator& lhs, const xiterator& rhs); template struct is_contiguous_container> : std::false_type { }; /********************* * xbounded_iterator * *********************/ template class xbounded_iterator : public xtl::xrandom_access_iterator_base< xbounded_iterator, typename std::iterator_traits::value_type, typename std::iterator_traits::difference_type, typename std::iterator_traits::pointer, typename std::iterator_traits::reference> { public: using self_type = xbounded_iterator; using subiterator_type = It; using bound_iterator_type = BIt; using value_type = typename std::iterator_traits::value_type; using reference = typename std::iterator_traits::reference; using pointer = typename std::iterator_traits::pointer; using difference_type = typename std::iterator_traits::difference_type; using iterator_category = std::random_access_iterator_tag; xbounded_iterator() = default; xbounded_iterator(It it, BIt bound_it); self_type& operator++(); self_type& operator--(); self_type& operator+=(difference_type n); self_type& operator-=(difference_type n); difference_type operator-(const self_type& rhs) const; value_type operator*() const; bool equal(const self_type& rhs) const; bool less_than(const self_type& rhs) const; private: subiterator_type m_it; bound_iterator_type m_bound_it; }; template bool operator==(const xbounded_iterator& lhs, const xbounded_iterator& rhs); template bool operator<(const xbounded_iterator& lhs, const xbounded_iterator& rhs); /***************************** * linear_begin / linear_end * *****************************/ namespace detail { template > struct has_linear_iterator : std::false_type { }; template struct has_linear_iterator().linear_cbegin())>> : std::true_type { }; } template XTENSOR_CONSTEXPR_RETURN auto linear_begin(C& c) noexcept { return xtl::mpl::static_if::value>( [&](auto self) { return self(c).linear_begin(); }, /*else*/ [&](auto self) { return self(c).begin(); } ); } template XTENSOR_CONSTEXPR_RETURN auto linear_end(C& c) noexcept { return xtl::mpl::static_if::value>( [&](auto self) { return self(c).linear_end(); }, /*else*/ [&](auto self) { return self(c).end(); } ); } template XTENSOR_CONSTEXPR_RETURN auto linear_begin(const C& c) noexcept { return xtl::mpl::static_if::value>( [&](auto self) { return self(c).linear_cbegin(); }, /*else*/ [&](auto self) { return self(c).cbegin(); } ); } template XTENSOR_CONSTEXPR_RETURN auto linear_end(const C& c) noexcept { return xtl::mpl::static_if::value>( [&](auto self) { return self(c).linear_cend(); }, /*else*/ [&](auto self) { return self(c).cend(); } ); } /*************************** * xstepper implementation * ***************************/ template inline xstepper::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept : p_c(c) , m_it(it) , m_offset(offset) { } template inline auto xstepper::operator*() const -> reference { return *m_it; } template inline void xstepper::step(size_type dim, size_type n) { if (dim >= m_offset) { using strides_value_type = typename std::decay_tstrides())>::value_type; m_it += difference_type(static_cast(n) * p_c->strides()[dim - m_offset]); } } template inline void xstepper::step_back(size_type dim, size_type n) { if (dim >= m_offset) { using strides_value_type = typename std::decay_tstrides())>::value_type; m_it -= difference_type(static_cast(n) * p_c->strides()[dim - m_offset]); } } template inline void xstepper::reset(size_type dim) { if (dim >= m_offset) { m_it -= difference_type(p_c->backstrides()[dim - m_offset]); } } template inline void xstepper::reset_back(size_type dim) { if (dim >= m_offset) { m_it += difference_type(p_c->backstrides()[dim - m_offset]); } } template inline void xstepper::to_begin() { m_it = p_c->data_xbegin(); } template inline void xstepper::to_end(layout_type l) { m_it = p_c->data_xend(l, m_offset); } namespace detail { template struct step_simd_invoker { template static R apply(const It& it) { R reg; return reg.load_unaligned(&(*it)); // return reg; } }; template struct step_simd_invoker, S, L>> { template static R apply(const xiterator, S, L>& it) { return R(*it); } }; } template template inline auto xstepper::step_simd() -> simd_return_type { using simd_type = simd_return_type; simd_type reg = detail::step_simd_invoker::template apply(m_it); m_it += xt_simd::revert_simd_traits::size; return reg; } template template inline void xstepper::store_simd(const R& vec) { vec.store_unaligned(&(*m_it)); m_it += xt_simd::revert_simd_traits::size; ; } template void xstepper::step_leading() { ++m_it; } template <> template void stepper_tools::increment_stepper(S& stepper, IT& index, const ST& shape) { using size_type = typename S::size_type; const size_type size = index.size(); size_type i = size; while (i != 0) { --i; if (index[i] != shape[i] - 1) { ++index[i]; stepper.step(i); return; } else { index[i] = 0; if (i != 0) { stepper.reset(i); } } } if (i == 0) { if (size != size_type(0)) { std::transform( shape.cbegin(), shape.cend() - 1, index.begin(), [](const auto& v) { return v - 1; } ); index[size - 1] = shape[size - 1]; } stepper.to_end(layout_type::row_major); } } template <> template void stepper_tools::increment_stepper( S& stepper, IT& index, const ST& shape, typename S::size_type n ) { using size_type = typename S::size_type; const size_type size = index.size(); const size_type leading_i = size - 1; size_type i = size; while (i != 0 && n != 0) { --i; size_type inc = (i == leading_i) ? n : 1; if (xtl::cmp_less(index[i] + inc, shape[i])) { index[i] += inc; stepper.step(i, inc); n -= inc; if (i != leading_i || index.size() == 1) { i = index.size(); } } else { if (i == leading_i) { size_type off = shape[i] - index[i] - 1; stepper.step(i, off); n -= off; } index[i] = 0; if (i != 0) { stepper.reset(i); } } } if (i == 0 && n != 0) { if (size != size_type(0)) { std::transform( shape.cbegin(), shape.cend() - 1, index.begin(), [](const auto& v) { return v - 1; } ); index[leading_i] = shape[leading_i]; } stepper.to_end(layout_type::row_major); } } template <> template void stepper_tools::decrement_stepper(S& stepper, IT& index, const ST& shape) { using size_type = typename S::size_type; size_type i = index.size(); while (i != 0) { --i; if (index[i] != 0) { --index[i]; stepper.step_back(i); return; } else { index[i] = shape[i] - 1; if (i != 0) { stepper.reset_back(i); } } } if (i == 0) { stepper.to_begin(); } } template <> template void stepper_tools::decrement_stepper( S& stepper, IT& index, const ST& shape, typename S::size_type n ) { using size_type = typename S::size_type; size_type i = index.size(); size_type leading_i = index.size() - 1; while (i != 0 && n != 0) { --i; size_type inc = (i == leading_i) ? n : 1; if (xtl::cmp_greater_equal(index[i], inc)) { index[i] -= inc; stepper.step_back(i, inc); n -= inc; if (i != leading_i || index.size() == 1) { i = index.size(); } } else { if (i == leading_i) { size_type off = index[i]; stepper.step_back(i, off); n -= off; } index[i] = shape[i] - 1; if (i != 0) { stepper.reset_back(i); } } } if (i == 0 && n != 0) { stepper.to_begin(); } } template <> template void stepper_tools::increment_stepper(S& stepper, IT& index, const ST& shape) { using size_type = typename S::size_type; const size_type size = index.size(); size_type i = 0; while (i != size) { if (index[i] != shape[i] - 1) { ++index[i]; stepper.step(i); return; } else { index[i] = 0; if (i != size - 1) { stepper.reset(i); } } ++i; } if (i == size) { if (size != size_type(0)) { std::transform( shape.cbegin() + 1, shape.cend(), index.begin() + 1, [](const auto& v) { return v - 1; } ); index[0] = shape[0]; } stepper.to_end(layout_type::column_major); } } template <> template void stepper_tools::increment_stepper( S& stepper, IT& index, const ST& shape, typename S::size_type n ) { using size_type = typename S::size_type; const size_type size = index.size(); const size_type leading_i = 0; size_type i = 0; while (i != size && n != 0) { size_type inc = (i == leading_i) ? n : 1; if (index[i] + inc < shape[i]) { index[i] += inc; stepper.step(i, inc); n -= inc; if (i != leading_i || size == 1) { i = 0; continue; } } else { if (i == leading_i) { size_type off = shape[i] - index[i] - 1; stepper.step(i, off); n -= off; } index[i] = 0; if (i != size - 1) { stepper.reset(i); } } ++i; } if (i == size && n != 0) { if (size != size_type(0)) { std::transform( shape.cbegin() + 1, shape.cend(), index.begin() + 1, [](const auto& v) { return v - 1; } ); index[leading_i] = shape[leading_i]; } stepper.to_end(layout_type::column_major); } } template <> template void stepper_tools::decrement_stepper(S& stepper, IT& index, const ST& shape) { using size_type = typename S::size_type; size_type size = index.size(); size_type i = 0; while (i != size) { if (index[i] != 0) { --index[i]; stepper.step_back(i); return; } else { index[i] = shape[i] - 1; if (i != size - 1) { stepper.reset_back(i); } } ++i; } if (i == size) { stepper.to_begin(); } } template <> template void stepper_tools::decrement_stepper( S& stepper, IT& index, const ST& shape, typename S::size_type n ) { using size_type = typename S::size_type; size_type size = index.size(); size_type i = 0; size_type leading_i = 0; while (i != size && n != 0) { size_type inc = (i == leading_i) ? n : 1; if (index[i] >= inc) { index[i] -= inc; stepper.step_back(i, inc); n -= inc; if (i != leading_i || index.size() == 1) { i = 0; continue; } } else { if (i == leading_i) { size_type off = index[i]; stepper.step_back(i, off); n -= off; } index[i] = shape[i] - 1; if (i != size - 1) { stepper.reset_back(i); } } ++i; } if (i == size && n != 0) { stepper.to_begin(); } } /*********************************** * xindexed_stepper implementation * ***********************************/ template inline xindexed_stepper::xindexed_stepper(xexpression_type* e, size_type offset, bool end) noexcept : p_e(e) , m_index(xtl::make_sequence(e->shape().size(), size_type(0))) , m_offset(offset) { if (end) { // Note: the layout here doesn't matter (unused) but using default traversal looks more "correct". to_end(XTENSOR_DEFAULT_TRAVERSAL); } } template inline auto xindexed_stepper::operator*() const -> reference { return p_e->element(m_index.cbegin(), m_index.cend()); } template inline void xindexed_stepper::step(size_type dim, size_type n) { if (dim >= m_offset) { m_index[dim - m_offset] += static_cast(n); } } template inline void xindexed_stepper::step_back(size_type dim, size_type n) { if (dim >= m_offset) { m_index[dim - m_offset] -= static_cast(n); } } template inline void xindexed_stepper::reset(size_type dim) { if (dim >= m_offset) { m_index[dim - m_offset] = 0; } } template inline void xindexed_stepper::reset_back(size_type dim) { if (dim >= m_offset) { m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1; } } template inline void xindexed_stepper::to_begin() { std::fill(m_index.begin(), m_index.end(), size_type(0)); } template inline void xindexed_stepper::to_end(layout_type l) { const auto& shape = p_e->shape(); std::transform( shape.cbegin(), shape.cend(), m_index.begin(), [](const auto& v) { return v - 1; } ); size_type l_dim = (l == layout_type::row_major) ? shape.size() - 1 : 0; m_index[l_dim] = shape[l_dim]; } /**************************** * xiterator implementation * ****************************/ namespace detail { template inline shape_storage::shape_storage(param_type shape) : m_shape(shape) { } template inline const S& shape_storage::shape() const { return m_shape; } template inline shape_storage::shape_storage(param_type shape) : p_shape(shape) { } template inline const S& shape_storage::shape() const { return *p_shape; } template <> struct LAYOUT_FORBIDEN_FOR_XITERATOR { using type = int; }; template <> struct LAYOUT_FORBIDEN_FOR_XITERATOR { using type = int; }; } template inline xiterator::xiterator(St st, shape_param_type shape, bool end_index) : private_base(shape) , m_st(st) , m_index( end_index ? xtl::forward_sequence(this->shape()) : xtl::make_sequence(this->shape().size(), size_type(0)) ) , m_linear_index(0) { // end_index means either reverse_iterator && !end or !reverse_iterator && end if (end_index) { if (m_index.size() != size_type(0)) { auto iter_begin = (L == layout_type::row_major) ? m_index.begin() : m_index.begin() + 1; auto iter_end = (L == layout_type::row_major) ? m_index.end() - 1 : m_index.end(); std::transform( iter_begin, iter_end, iter_begin, [](const auto& v) { return v - 1; } ); } m_linear_index = difference_type(std::accumulate( this->shape().cbegin(), this->shape().cend(), size_type(1), std::multiplies() )); } } template inline auto xiterator::operator++() -> self_type& { stepper_tools::increment_stepper(m_st, m_index, this->shape()); ++m_linear_index; return *this; } template inline auto xiterator::operator--() -> self_type& { stepper_tools::decrement_stepper(m_st, m_index, this->shape()); --m_linear_index; return *this; } template inline auto xiterator::operator+=(difference_type n) -> self_type& { if (n >= 0) { stepper_tools::increment_stepper(m_st, m_index, this->shape(), static_cast(n)); } else { stepper_tools::decrement_stepper(m_st, m_index, this->shape(), static_cast(-n)); } m_linear_index += n; return *this; } template inline auto xiterator::operator-=(difference_type n) -> self_type& { if (n >= 0) { stepper_tools::decrement_stepper(m_st, m_index, this->shape(), static_cast(n)); } else { stepper_tools::increment_stepper(m_st, m_index, this->shape(), static_cast(-n)); } m_linear_index -= n; return *this; } template inline auto xiterator::operator-(const self_type& rhs) const -> difference_type { return m_linear_index - rhs.m_linear_index; } template inline auto xiterator::operator*() const -> reference { return *m_st; } template inline auto xiterator::operator->() const -> pointer { return &(*m_st); } template inline bool xiterator::equal(const xiterator& rhs) const { XTENSOR_ASSERT(this->shape() == rhs.shape()); return m_linear_index == rhs.m_linear_index; } template inline bool xiterator::less_than(const xiterator& rhs) const { XTENSOR_ASSERT(this->shape() == rhs.shape()); return m_linear_index < rhs.m_linear_index; } template inline bool operator==(const xiterator& lhs, const xiterator& rhs) { return lhs.equal(rhs); } template bool operator<(const xiterator& lhs, const xiterator& rhs) { return lhs.less_than(rhs); } /************************************ * xbounded_iterator implementation * ************************************/ template xbounded_iterator::xbounded_iterator(It it, BIt bound_it) : m_it(it) , m_bound_it(bound_it) { } template inline auto xbounded_iterator::operator++() -> self_type& { ++m_it; ++m_bound_it; return *this; } template inline auto xbounded_iterator::operator--() -> self_type& { --m_it; --m_bound_it; return *this; } template inline auto xbounded_iterator::operator+=(difference_type n) -> self_type& { m_it += n; m_bound_it += n; return *this; } template inline auto xbounded_iterator::operator-=(difference_type n) -> self_type& { m_it -= n; m_bound_it -= n; return *this; } template inline auto xbounded_iterator::operator-(const self_type& rhs) const -> difference_type { return m_it - rhs.m_it; } template inline auto xbounded_iterator::operator*() const -> value_type { using type = decltype(*m_bound_it); return (static_cast(*m_it) < *m_bound_it) ? *m_it : static_cast((*m_bound_it) - 1); } template inline bool xbounded_iterator::equal(const self_type& rhs) const { return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it; } template inline bool xbounded_iterator::less_than(const self_type& rhs) const { return m_it < rhs.m_it; } template inline bool operator==(const xbounded_iterator& lhs, const xbounded_iterator& rhs) { return lhs.equal(rhs); } template inline bool operator<(const xbounded_iterator& lhs, const xbounded_iterator& rhs) { return lhs.less_than(rhs); } } #endif