/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_XREPEAT #define XTENSOR_XREPEAT #include #include #include "xaccessible.hpp" #include "xexpression.hpp" #include "xiterable.hpp" namespace xt { template class xrepeat; template class xrepeat_stepper; /********************* * xrepeat extension * *********************/ namespace extension { template struct xrepeat_base_impl; template struct xrepeat_base_impl { using type = xtensor_empty_base; }; template struct xrepeat_base : xrepeat_base_impl, CT, X> { }; template using xrepeat_base_t = typename xrepeat_base::type; } /*********** * xrepeat * ***********/ template struct xcontainer_inner_types> { using xexpression_type = std::decay_t; using reference = typename xexpression_type::const_reference; using const_reference = typename xexpression_type::const_reference; using size_type = typename xexpression_type::size_type; using temporary_type = typename xexpression_type::temporary_type; static constexpr bool is_const = std::is_const>::value; using extract_storage_type = xtl::mpl::eval_if_t< has_data_interface, detail::expr_storage_type, make_invalid_type<>>; using storage_type = std::conditional_t; }; template struct xiterable_inner_types> { using xexpression_type = std::decay_t; using repeats_type = std::decay_t; using inner_shape_type = typename xexpression_type::inner_shape_type; using const_stepper = xrepeat_stepper; using stepper = const_stepper; }; /** * @class xrepeat * @brief Expression with repeated values along an axis. * * The xrepeat class implements the repetition of the elements of * an \ref xexpression along a given axis. xrepeat is not meant * to be used directly, but only with the \ref repeat helper * functions. * * @sa repeat */ template class xrepeat : public xconst_iterable>, public xconst_accessible>, public xsharable_expression>, public extension::xrepeat_base_t { public: using self_type = xrepeat; using xexpression_type = std::decay_t; using accessible_base = xconst_accessible; using extension_base = extension::xrepeat_base_t; using expression_tag = typename extension_base::expression_tag; using value_type = typename xexpression_type::value_type; using shape_type = typename xexpression_type::shape_type; using repeats_type = xtl::const_closure_type_t; using container_type = xcontainer_inner_types>; using reference = typename container_type::reference; using const_reference = typename container_type::const_reference; using size_type = typename container_type::size_type; using temporary_type = typename container_type::temporary_type; static constexpr layout_type static_layout = xexpression_type::static_layout; static constexpr bool contiguous_layout = false; using bool_load_type = typename xexpression_type::bool_load_type; using pointer = typename xexpression_type::pointer; using const_pointer = typename xexpression_type::const_pointer; using difference_type = typename xexpression_type::difference_type; using iterable_type = xiterable>; using stepper = typename iterable_type::stepper; using const_stepper = typename iterable_type::const_stepper; template explicit xrepeat(CTA&& e, R&& repeats, size_type axis); using accessible_base::size; const shape_type& shape() const noexcept; layout_type layout() const noexcept; bool is_contiguous() const noexcept; using accessible_base::shape; template const_reference operator()(Args... args) const; template const_reference unchecked(Args... args) const; template const_reference element(It first, It last) const; const xexpression_type& expression() const noexcept; template bool broadcast_shape(S& shape, bool reuse_cache = false) const; template bool has_linear_assign(const S& strides) const noexcept; const_stepper stepper_begin() const; const_stepper stepper_begin(const shape_type& s) const; const_stepper stepper_end(layout_type l) const; const_stepper stepper_end(const shape_type& s, layout_type l) const; private: CT m_e; size_type m_repeating_axis; repeats_type m_repeats; shape_type m_shape; const_reference access() const; template const_reference access(Arg arg, Args... args) const; template const_reference access_impl(stepper&& s, Arg arg, Args... args) const; template const_reference access_impl(stepper&& s) const; }; /******************* * xrepeat_stepper * *******************/ template class xrepeat_stepper { public: using repeats_type = R; using storage_type = typename S::storage_type; using subiterator_type = typename S::subiterator_type; using subiterator_traits = typename S::subiterator_traits; using value_type = typename subiterator_traits::value_type; using reference = typename subiterator_traits::reference; using pointer = typename subiterator_traits::pointer; using difference_type = typename subiterator_traits::difference_type; using size_type = typename storage_type::size_type; using shape_type = typename storage_type::shape_type; using simd_value_type = xt_simd::simd_type; template using simd_return_type = xt_simd::simd_return_type; xrepeat_stepper(S&& s, const shape_type& shape, const repeats_type& repeats, size_type axis); reference operator*() const; void step(size_type dim, size_type n = 1); void step_back(size_type dim, size_type n = 1); void reset(size_type dim); void reset_back(size_type dim); void to_begin(); void to_end(layout_type l); template simd_return_type step_simd(); void step_leading(); template void store_simd(const V& vec); private: S m_substepper; const shape_type& m_shape; std::ptrdiff_t m_repeating_steps; std::vector m_positions; size_type m_subposition; size_type m_repeating_axis; const repeats_type& m_repeats; void make_step(size_type dim, size_type n); void make_step_back(size_type dim, size_type n); std::vector get_next_positions(size_type dim, size_type steps_to_go) const; std::vector get_next_positions_back(size_type dim, size_type steps_to_go) const; }; /************************** * xrepeat implementation * **************************/ /** * Constructs an xrepeat expression repeating the element of the specified * \ref xexpression. * * @param e the input expression * @param repeats The number of repetitions for each elements * @param axis The axis along which to repeat the value */ template template xrepeat::xrepeat(CTA&& e, R&& repeats, size_type axis) : m_e(std::forward(e)) , m_repeating_axis(axis) , m_repeats(std::forward(repeats)) , m_shape(e.shape()) { using shape_value_type = typename shape_type::value_type; m_shape[axis] = static_cast( std::accumulate(m_repeats.begin(), m_repeats.end(), shape_value_type(0)) ); } /** * @name Size and shape */ //@{ /** * Returns the shape of the expression. */ template inline auto xrepeat::shape() const noexcept -> const shape_type& { return m_shape; } /** * Returns the layout_type of the expression. */ template inline auto xrepeat::layout() const noexcept -> layout_type { return m_e.layout(); } template inline bool xrepeat::is_contiguous() const noexcept { return false; } //@} /** * @name Data */ //@{ /** * Returns a constant reference to the element at the specified position in the expression. * @param args a list of indices specifying the position in the function. Indices * must be unsigned integers, the number of indices should be equal or greater than * the number of dimensions of the expression. */ template template inline auto xrepeat::operator()(Args... args) const -> const_reference { return access(args...); } /** * Returns a constant reference to the element at the specified position in the expression. * @param args a list of indices specifying the position in the expression. Indices * must be unsigned integers, the number of indices must be equal to the number of * dimensions of the expression, else the behavior is undefined. * * @warning This method is meant for performance, for expressions with a dynamic * number of dimensions (i.e. not known at compile time). Since it may have * undefined behavior (see parameters), operator() should be preferred whenever * it is possible. * @warning This method is NOT compatible with broadcasting, meaning the following * code has undefined behavior: * @code{.cpp} * xt::xarray a = {{0, 1}, {2, 3}}; * xt::xarray b = {0, 1}; * auto fd = a + b; * double res = fd.uncheked(0, 1); * @endcode */ template template inline auto xrepeat::unchecked(Args... args) const -> const_reference { return this->operator()(args...); } /** * Returns a constant reference to the element at the specified position in the view. * @param first iterator starting the sequence of indices * @param last iterator ending the sequence of indices * The number of indices in the sequence should be equal to or greater than the the number * of dimensions of the view.. */ template template inline auto xrepeat::element(It first, It last) const -> const_reference { auto s = stepper_begin(m_e.shape()); std::size_t dimension = 0; auto iter = first; while (iter != last) { s.step(dimension, *iter); ++dimension; ++iter; } return access_impl<0>(std::forward(s)); } /** * Returns a constant reference to the underlying expression of the broadcast expression. */ template inline auto xrepeat::expression() const noexcept -> const xexpression_type& { return m_e; } //@} /** * @name Broadcasting */ //@{ /** * Broadcast the shape of the function to the specified parameter. * @param shape the result shape * @param reuse_cache parameter for internal optimization * @return a boolean indicating whether the broadcasting is trivial */ template template inline bool xrepeat::broadcast_shape(S& shape, bool) const { return xt::broadcast_shape(m_shape, shape); } /** * Checks whether the xbroadcast can be linearly assigned to an expression * with the specified strides. * @return a boolean indicating whether a linear assign is possible */ template template inline bool xrepeat::has_linear_assign(const S&) const noexcept { return false; } //@} template inline auto xrepeat::access() const -> const_reference { return access_impl<0>(stepper_begin(m_e.shape())); } template template inline auto xrepeat::access(Arg arg, Args... args) const -> const_reference { constexpr size_t number_of_arguments = 1 + sizeof...(Args); if (number_of_arguments > this->dimension()) { return access(args...); } return access_impl<0>(stepper_begin(m_e.shape()), arg, args...); } template inline auto xrepeat::stepper_begin() const -> const_stepper { return stepper_begin(m_e.shape()); } template inline auto xrepeat::stepper_begin(const shape_type& s) const -> const_stepper { return const_stepper(m_e.stepper_begin(s), m_shape, m_repeats, m_repeating_axis); } template inline auto xrepeat::stepper_end(layout_type l) const -> const_stepper { return stepper_end(m_e.shape(), l); } template inline auto xrepeat::stepper_end(const shape_type& s, layout_type l) const -> const_stepper { auto st = const_stepper(m_e.stepper_begin(s), m_shape, m_repeats, m_repeating_axis); st.to_end(l); return st; } template template inline auto xrepeat::access_impl(stepper&& s, Arg arg, Args... args) const -> const_reference { s.step(I, static_cast(arg)); return access_impl(std::forward(s), args...); } template template inline auto xrepeat::access_impl(stepper&& s) const -> const_reference { return *s; } /********************************** * xrepeat_stepper implementation * **********************************/ template xrepeat_stepper::xrepeat_stepper(S&& s, const shape_type& shape, const repeats_type& repeats, size_type axis) : m_substepper(std::forward(s)) , m_shape(shape) , m_repeating_steps(0) , m_positions(shape.size()) , m_subposition(0) , m_repeating_axis(axis) , m_repeats(repeats) { } template inline auto xrepeat_stepper::operator*() const -> reference { return m_substepper.operator*(); } template inline void xrepeat_stepper::step(size_type dim, size_type steps_to_go) { if (m_positions[dim] + steps_to_go >= m_shape[dim]) { const auto next_positions = get_next_positions(dim, steps_to_go); if (next_positions[dim] > m_positions[dim]) { make_step(dim, next_positions[dim] - m_positions[dim]); } else { make_step_back(dim, m_positions[dim] - next_positions[dim]); } for (size_type d = 0; d < dim; ++d) { make_step(d, next_positions[d] - m_positions[d]); } } else { make_step(dim, steps_to_go); } } template inline void xrepeat_stepper::step_back(size_type dim, size_type steps_to_go) { if (m_positions[dim] < steps_to_go) { const auto next_positions = get_next_positions_back(dim, steps_to_go); if (next_positions[dim] < m_positions[dim]) { make_step_back(dim, m_positions[dim] - next_positions[dim]); } else { make_step(dim, next_positions[dim] - m_positions[dim]); } for (size_type d = 0; d < dim; ++d) { make_step_back(d, m_positions[d] - next_positions[d]); } } else { make_step_back(dim, steps_to_go); } } template inline void xrepeat_stepper::reset(size_type dim) { m_substepper.reset(dim); m_positions[dim] = 0; if (dim == m_repeating_axis) { m_subposition = 0; m_repeating_steps = 0; } } template inline void xrepeat_stepper::reset_back(size_type dim) { m_substepper.reset_back(dim); m_positions[dim] = m_shape[dim] - 1; if (dim == m_repeating_axis) { m_subposition = m_repeats.size() - 1; m_repeating_steps = static_cast(m_repeats.back()) - 1; } } template inline void xrepeat_stepper::to_begin() { m_substepper.to_begin(); std::fill(m_positions.begin(), m_positions.end(), 0); m_subposition = 0; m_repeating_steps = 0; } template inline void xrepeat_stepper::to_end(layout_type l) { m_substepper.to_end(l); std::transform( m_shape.begin(), m_shape.end(), m_positions.begin(), [](auto value) { return value - 1; } ); if (layout_type::row_major == l) { ++m_positions.front(); } else { ++m_positions.back(); } m_subposition = m_repeats.size(); m_repeating_steps = 0; } template inline void xrepeat_stepper::step_leading() { step(m_shape.size() - 1, 1); } template inline void xrepeat_stepper::make_step(size_type dim, size_type steps_to_go) { if (steps_to_go > 0) { if (dim == m_repeating_axis) { size_type subposition = m_subposition; m_repeating_steps += static_cast(steps_to_go); while (m_repeating_steps >= static_cast(m_repeats[subposition])) { m_repeating_steps -= static_cast(m_repeats[subposition]); ++subposition; } m_substepper.step(dim, subposition - m_subposition); m_subposition = subposition; } else { m_substepper.step(dim, steps_to_go); } m_positions[dim] += steps_to_go; } } template inline void xrepeat_stepper::make_step_back(size_type dim, size_type steps_to_go) { if (steps_to_go > 0) { if (dim == m_repeating_axis) { size_type subposition = m_subposition; m_repeating_steps -= static_cast(steps_to_go); while (m_repeating_steps < 0) { --subposition; m_repeating_steps += static_cast(m_repeats[subposition]); } m_substepper.step_back(dim, m_subposition - subposition); m_subposition = subposition; } else { m_substepper.step_back(dim, steps_to_go); } m_positions[dim] -= steps_to_go; } } template inline auto xrepeat_stepper::get_next_positions(size_type dim, size_type steps_to_go) const -> std::vector { size_type next_position_for_dim = m_positions[dim] + steps_to_go; if (dim > 0) { size_type steps_in_previous_dim = 0; while (next_position_for_dim >= m_shape[dim]) { next_position_for_dim -= m_shape[dim]; ++steps_in_previous_dim; } if (steps_in_previous_dim > 0) { auto next_positions = get_next_positions(dim - 1, steps_in_previous_dim); next_positions[dim] = next_position_for_dim; return next_positions; } } std::vector next_positions = m_positions; next_positions[dim] = next_position_for_dim; return next_positions; } template inline auto xrepeat_stepper::get_next_positions_back(size_type dim, size_type steps_to_go) const -> std::vector { auto next_position_for_dim = static_cast(m_positions[dim] - steps_to_go); if (dim > 0) { size_type steps_in_previous_dim = 0; while (next_position_for_dim < 0) { next_position_for_dim += static_cast(m_shape[dim]); ++steps_in_previous_dim; } if (steps_in_previous_dim > 0) { auto next_positions = get_next_positions_back(dim - 1, steps_in_previous_dim); next_positions[dim] = static_cast(next_position_for_dim); return next_positions; } } std::vector next_positions = m_positions; next_positions[dim] = static_cast(next_position_for_dim); return next_positions; } } #endif