/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_STRIDED_VIEW_BASE_HPP #define XTENSOR_STRIDED_VIEW_BASE_HPP #include #include #include #include "xaccessible.hpp" #include "xslice.hpp" #include "xstrides.hpp" #include "xtensor_config.hpp" #include "xtensor_forward.hpp" #include "xutils.hpp" namespace xt { namespace detail { template class flat_expression_adaptor { public: using xexpression_type = std::decay_t; using shape_type = typename xexpression_type::shape_type; using inner_strides_type = get_strides_t; using index_type = inner_strides_type; using size_type = typename xexpression_type::size_type; using value_type = typename xexpression_type::value_type; using const_reference = typename xexpression_type::const_reference; using reference = std::conditional_t< std::is_const>::value, typename xexpression_type::const_reference, typename xexpression_type::reference>; using iterator = decltype(std::declval>().template begin()); using const_iterator = decltype(std::declval>().template cbegin()); using reverse_iterator = decltype(std::declval>().template rbegin()); using const_reverse_iterator = decltype(std::declval>().template crbegin()); explicit flat_expression_adaptor(CT* e); template flat_expression_adaptor(CT* e, FST&& strides); void update_pointer(CT* ptr) const; size_type size() const; reference operator[](size_type idx); const_reference operator[](size_type idx) const; iterator begin(); iterator end(); const_iterator begin() const; const_iterator end() const; const_iterator cbegin() const; const_iterator cend() const; private: static index_type& get_index(); mutable CT* m_e; inner_strides_type m_strides; size_type m_size; }; template struct is_flat_expression_adaptor : std::false_type { }; template struct is_flat_expression_adaptor> : std::true_type { }; template struct provides_data_interface : xtl::conjunction>, xtl::negation>> { }; } template class xstrided_view_base : public xaccessible { public: using base_type = xaccessible; using inner_types = xcontainer_inner_types; using xexpression_type = typename inner_types::xexpression_type; using undecay_expression = typename inner_types::undecay_expression; static constexpr bool is_const = std::is_const>::value; using value_type = typename xexpression_type::value_type; using reference = typename inner_types::reference; using const_reference = typename inner_types::const_reference; using pointer = std:: conditional_t; using const_pointer = typename xexpression_type::const_pointer; using size_type = typename inner_types::size_type; using difference_type = typename xexpression_type::difference_type; using storage_getter = typename inner_types::storage_getter; using inner_storage_type = typename inner_types::inner_storage_type; using storage_type = std::remove_reference_t; using shape_type = typename inner_types::shape_type; using strides_type = get_strides_t; using backstrides_type = strides_type; using inner_shape_type = shape_type; using inner_strides_type = strides_type; using inner_backstrides_type = backstrides_type; using undecay_shape = typename inner_types::undecay_shape; using simd_value_type = xt_simd::simd_type; using bool_load_type = typename xexpression_type::bool_load_type; static constexpr layout_type static_layout = inner_types::layout; static constexpr bool contiguous_layout = static_layout != layout_type::dynamic && xexpression_type::contiguous_layout; template xstrided_view_base(CTA&& e, SA&& shape, strides_type&& strides, size_type offset, layout_type layout) noexcept; xstrided_view_base(xstrided_view_base&& rhs); xstrided_view_base(const xstrided_view_base& rhs); const inner_shape_type& shape() const noexcept; const inner_strides_type& strides() const noexcept; const inner_backstrides_type& backstrides() const noexcept; layout_type layout() const noexcept; bool is_contiguous() const noexcept; using base_type::shape; reference operator()(); const_reference operator()() const; template reference operator()(Args... args); template const_reference operator()(Args... args) const; template reference unchecked(Args... args); template const_reference unchecked(Args... args) const; template reference element(It first, It last); template const_reference element(It first, It last) const; storage_type& storage() noexcept; const storage_type& storage() const noexcept; template std::enable_if_t::value, pointer> data() noexcept; template std::enable_if_t::value, const_pointer> data() const noexcept; size_type data_offset() const noexcept; xexpression_type& expression() noexcept; const xexpression_type& expression() const noexcept; template bool broadcast_shape(O& shape, bool reuse_cache = false) const; template bool has_linear_assign(const O& strides) const noexcept; protected: using offset_type = typename strides_type::value_type; template offset_type compute_index(Args... args) const; template offset_type compute_unchecked_index(Args... args) const; template offset_type compute_element_index(It first, It last) const; void set_offset(size_type offset); private: undecay_expression m_e; inner_storage_type m_storage; inner_shape_type m_shape; inner_strides_type m_strides; inner_backstrides_type m_backstrides; size_type m_offset; layout_type m_layout; }; /*************************** * flat_expression_adaptor * ***************************/ namespace detail { template struct inner_storage_getter { using type = decltype(std::declval().storage()); using reference = std::add_lvalue_reference_t; template using rebind_t = inner_storage_getter; static decltype(auto) get_flat_storage(reference e) { return e.storage(); } static auto get_offset(reference e) { return e.data_offset(); } static decltype(auto) get_strides(reference e) { return e.strides(); } }; template struct flat_adaptor_getter { using type = flat_expression_adaptor, L>; using reference = std::add_lvalue_reference_t; template using rebind_t = flat_adaptor_getter; static type get_flat_storage(reference e) { // moved to addressof because ampersand on xview returns a closure pointer return type(std::addressof(e)); } static auto get_offset(reference) { return typename std::decay_t::size_type(0); } static auto get_strides(reference e) { dynamic_shape strides; strides.resize(e.shape().size()); compute_strides(e.shape(), L, strides); return strides; } }; template using flat_storage_getter = std::conditional_t< has_data_interface>::value, inner_storage_getter, flat_adaptor_getter>; template inline auto get_offset(E& e) { return flat_storage_getter::get_offset(e); } template inline decltype(auto) get_strides(E& e) { return flat_storage_getter::get_strides(e); } } /************************************* * xstrided_view_base implementation * *************************************/ /** * @name Constructor */ //@{ /** * Constructs an xstrided_view_base * * @param e the underlying xexpression for this view * @param shape the shape of the view * @param strides the strides of the view * @param offset the offset of the first element in the underlying container * @param layout the layout of the view */ template template inline xstrided_view_base::xstrided_view_base( CTA&& e, SA&& shape, strides_type&& strides, size_type offset, layout_type layout ) noexcept : m_e(std::forward(e)) , // m_storage(detail::get_flat_storage(m_e)), m_storage(storage_getter::get_flat_storage(m_e)) , m_shape(std::forward(shape)) , m_strides(std::move(strides)) , m_offset(offset) , m_layout(layout) { m_backstrides = xtl::make_sequence(m_shape.size(), 0); adapt_strides(m_shape, m_strides, m_backstrides); } namespace detail { template auto& copy_move_storage(T& expr, const S& /*storage*/) { return expr.storage(); } template auto copy_move_storage(T& expr, const detail::flat_expression_adaptor& storage) { detail::flat_expression_adaptor new_storage = storage; // copy storage new_storage.update_pointer(std::addressof(expr)); return new_storage; } } template inline xstrided_view_base::xstrided_view_base(xstrided_view_base&& rhs) : base_type(std::move(rhs)) , m_e(std::forward(rhs.m_e)) , m_storage(detail::copy_move_storage(m_e, rhs.m_storage)) , m_shape(std::move(rhs.m_shape)) , m_strides(std::move(rhs.m_strides)) , m_backstrides(std::move(rhs.m_backstrides)) , m_offset(std::move(rhs.m_offset)) , m_layout(std::move(rhs.m_layout)) { } template inline xstrided_view_base::xstrided_view_base(const xstrided_view_base& rhs) : base_type(rhs) , m_e(rhs.m_e) , m_storage(detail::copy_move_storage(m_e, rhs.m_storage)) , m_shape(rhs.m_shape) , m_strides(rhs.m_strides) , m_backstrides(rhs.m_backstrides) , m_offset(rhs.m_offset) , m_layout(rhs.m_layout) { } //@} /** * @name Size and shape */ //@{ /** * Returns the shape of the xtrided_view_base. */ template inline auto xstrided_view_base::shape() const noexcept -> const inner_shape_type& { return m_shape; } /** * Returns the strides of the xtrided_view_base. */ template inline auto xstrided_view_base::strides() const noexcept -> const inner_strides_type& { return m_strides; } /** * Returns the backstrides of the xtrided_view_base. */ template inline auto xstrided_view_base::backstrides() const noexcept -> const inner_backstrides_type& { return m_backstrides; } /** * Returns the layout of the xtrided_view_base. */ template inline auto xstrided_view_base::layout() const noexcept -> layout_type { return m_layout; } template inline bool xstrided_view_base::is_contiguous() const noexcept { return m_layout != layout_type::dynamic && m_e.is_contiguous(); } //@} /** * @name Data */ //@{ template inline auto xstrided_view_base::operator()() -> reference { return m_storage[static_cast(m_offset)]; } template inline auto xstrided_view_base::operator()() const -> const_reference { return m_storage[static_cast(m_offset)]; } /** * Returns a reference to the element at the specified position in the view. * @param args a list of indices specifying the position in the view. Indices * must be unsigned integers, the number of indices should be equal or greater than * the number of dimensions of the view. */ template template inline auto xstrided_view_base::operator()(Args... args) -> reference { XTENSOR_TRY(check_index(shape(), args...)); XTENSOR_CHECK_DIMENSION(shape(), args...); offset_type index = compute_index(args...); return m_storage[static_cast(index)]; } /** * Returns a constant reference to the element at the specified position in the view. * @param args a list of indices specifying the position in the view. Indices * must be unsigned integers, the number of indices should be equal or greater than * the number of dimensions of the view. */ template template inline auto xstrided_view_base::operator()(Args... args) const -> const_reference { XTENSOR_TRY(check_index(shape(), args...)); XTENSOR_CHECK_DIMENSION(shape(), args...); offset_type index = compute_index(args...); return m_storage[static_cast(index)]; } /** * Returns a reference to the element at the specified position in the view. * @param args a list of indices specifying the position in the view. Indices * must be unsigned integers, the number of indices must be equal to the number of * dimensions of the view, else the behavior is undefined. * * @warning This method is meant for performance, for expressions with a dynamic * number of dimensions (i.e. not known at compile time). Since it may have * undefined behavior (see parameters), operator() should be preferred whenever * it is possible. * @warning This method is NOT compatible with broadcasting, meaning the following * code has undefined behavior: * @code{.cpp} * xt::xarray a = {{0, 1}, {2, 3}}; * xt::xarray b = {0, 1}; * auto fd = a + b; * double res = fd.uncheked(0, 1); * @endcode */ template template inline auto xstrided_view_base::unchecked(Args... args) -> reference { offset_type index = compute_unchecked_index(args...); return m_storage[static_cast(index)]; } /** * Returns a constant reference to the element at the specified position in the view. * @param args a list of indices specifying the position in the view. Indices * must be unsigned integers, the number of indices must be equal to the number of * dimensions of the view, else the behavior is undefined. * * @warning This method is meant for performance, for expressions with a dynamic * number of dimensions (i.e. not known at compile time). Since it may have * undefined behavior (see parameters), operator() should be preferred whenever * it is possible. * @warning This method is NOT compatible with broadcasting, meaning the following * code has undefined behavior: * @code{.cpp} * xt::xarray a = {{0, 1}, {2, 3}}; * xt::xarray b = {0, 1}; * auto fd = a + b; * double res = fd.uncheked(0, 1); * @endcode */ template template inline auto xstrided_view_base::unchecked(Args... args) const -> const_reference { offset_type index = compute_unchecked_index(args...); return m_storage[static_cast(index)]; } /** * Returns a reference to the element at the specified position in the view. * @param first iterator starting the sequence of indices * @param last iterator ending the sequence of indices * The number of indices in the sequence should be equal to or greater than the the number * of dimensions of the view.. */ template template inline auto xstrided_view_base::element(It first, It last) -> reference { XTENSOR_TRY(check_element_index(shape(), first, last)); return m_storage[static_cast(compute_element_index(first, last))]; } /** * Returns a constant reference to the element at the specified position in the view. * @param first iterator starting the sequence of indices * @param last iterator ending the sequence of indices * The number of indices in the sequence should be equal to or greater than the the number * of dimensions of the view.. */ template template inline auto xstrided_view_base::element(It first, It last) const -> const_reference { XTENSOR_TRY(check_element_index(shape(), first, last)); return m_storage[static_cast(compute_element_index(first, last))]; } /** * Returns a reference to the buffer containing the elements of the view. */ template inline auto xstrided_view_base::storage() noexcept -> storage_type& { return m_storage; } /** * Returns a constant reference to the buffer containing the elements of the view. */ template inline auto xstrided_view_base::storage() const noexcept -> const storage_type& { return m_storage; } /** * Returns a pointer to the underlying array serving as element storage. * The first element of the view is at data() + data_offset(). */ template template inline auto xstrided_view_base::data() noexcept -> std::enable_if_t::value, pointer> { return m_e.data(); } /** * Returns a constant pointer to the underlying array serving as element storage. * The first element of the view is at data() + data_offset(). */ template template inline auto xstrided_view_base::data() const noexcept -> std::enable_if_t::value, const_pointer> { return m_e.data(); } /** * Returns the offset to the first element in the view. */ template inline auto xstrided_view_base::data_offset() const noexcept -> size_type { return m_offset; } /** * Returns a reference to the underlying expression of the view. */ template inline auto xstrided_view_base::expression() noexcept -> xexpression_type& { return m_e; } /** * Returns a constant reference to the underlying expression of the view. */ template inline auto xstrided_view_base::expression() const noexcept -> const xexpression_type& { return m_e; } //@} /** * @name Broadcasting */ //@{ /** * Broadcast the shape of the view to the specified parameter. * @param shape the result shape * @param reuse_cache parameter for internal optimization * @return a boolean indicating whether the broadcasting is trivial */ template template inline bool xstrided_view_base::broadcast_shape(O& shape, bool) const { return xt::broadcast_shape(m_shape, shape); } /** * Checks whether the xstrided_view_base can be linearly assigned to an expression * with the specified strides. * @return a boolean indicating whether a linear assign is possible */ template template inline bool xstrided_view_base::has_linear_assign(const O& str) const noexcept { return has_data_interface::value && str.size() == strides().size() && std::equal(str.cbegin(), str.cend(), strides().begin()); } //@} template template inline auto xstrided_view_base::compute_index(Args... args) const -> offset_type { return static_cast(m_offset) + xt::data_offset(strides(), static_cast(args)...); } template template inline auto xstrided_view_base::compute_unchecked_index(Args... args) const -> offset_type { return static_cast(m_offset) + xt::unchecked_data_offset(strides(), static_cast(args)...); } template template inline auto xstrided_view_base::compute_element_index(It first, It last) const -> offset_type { return static_cast(m_offset) + xt::element_offset(strides(), first, last); } template void xstrided_view_base::set_offset(size_type offset) { m_offset = offset; } /****************************************** * flat_expression_adaptor implementation * ******************************************/ namespace detail { template inline flat_expression_adaptor::flat_expression_adaptor(CT* e) : m_e(e) { resize_container(get_index(), m_e->dimension()); resize_container(m_strides, m_e->dimension()); m_size = compute_size(m_e->shape()); compute_strides(m_e->shape(), L, m_strides); } template template inline flat_expression_adaptor::flat_expression_adaptor(CT* e, FST&& strides) : m_e(e) , m_strides(xtl::forward_sequence(strides)) { resize_container(get_index(), m_e->dimension()); m_size = m_e->size(); } template inline void flat_expression_adaptor::update_pointer(CT* ptr) const { m_e = ptr; } template inline auto flat_expression_adaptor::size() const -> size_type { return m_size; } template inline auto flat_expression_adaptor::operator[](size_type idx) -> reference { auto i = static_cast(idx); get_index() = detail::unravel_noexcept(i, m_strides, L); return m_e->element(get_index().cbegin(), get_index().cend()); } template inline auto flat_expression_adaptor::operator[](size_type idx) const -> const_reference { auto i = static_cast(idx); get_index() = detail::unravel_noexcept(i, m_strides, L); return m_e->element(get_index().cbegin(), get_index().cend()); } template inline auto flat_expression_adaptor::begin() -> iterator { return m_e->template begin(); } template inline auto flat_expression_adaptor::end() -> iterator { return m_e->template end(); } template inline auto flat_expression_adaptor::begin() const -> const_iterator { return m_e->template cbegin(); } template inline auto flat_expression_adaptor::end() const -> const_iterator { return m_e->template cend(); } template inline auto flat_expression_adaptor::cbegin() const -> const_iterator { return m_e->template cbegin(); } template inline auto flat_expression_adaptor::cend() const -> const_iterator { return m_e->template cend(); } template inline auto flat_expression_adaptor::get_index() -> index_type& { thread_local static index_type index; return index; } } /********************************** * Builder helpers implementation * **********************************/ namespace detail { template struct slice_getter_impl { const S& m_shape; mutable std::size_t idx; using array_type = std::array; explicit slice_getter_impl(const S& shape) : m_shape(shape) , idx(0) { } template array_type operator()(const T& /*t*/) const { return array_type{{0, 0, 0}}; } template array_type operator()(const xrange_adaptor& range) const { auto sl = range.get(static_cast(m_shape[idx])); return array_type({sl(0), sl.size(), sl.step_size()}); } template array_type operator()(const xrange& range) const { return array_type({range(T(0)), range.size(), T(1)}); } template array_type operator()(const xstepped_range& range) const { return array_type({range(T(0)), range.size(), range.step_size(T(0))}); } }; template struct strided_view_args : adj_strides_policy { using base_type = adj_strides_policy; template void fill_args(const S& shape, ST&& old_strides, std::size_t base_offset, layout_type layout, const V& slices) { // Compute dimension std::size_t dimension = shape.size(), n_newaxis = 0, n_add_all = 0; std::ptrdiff_t dimension_check = static_cast(shape.size()); bool has_ellipsis = false; for (const auto& el : slices) { if (xtl::get_if(&el) != nullptr) { ++dimension; ++n_newaxis; } else if (xtl::get_if(&el) != nullptr) { --dimension; --dimension_check; } else if (xtl::get_if(&el) != nullptr) { if (has_ellipsis == true) { XTENSOR_THROW(std::runtime_error, "Ellipsis can only appear once."); } has_ellipsis = true; } else { --dimension_check; } } if (dimension_check < 0) { XTENSOR_THROW(std::runtime_error, "Too many slices for view."); } if (has_ellipsis) { // replace ellipsis with N * xt::all // remove -1 because of the ellipsis slize itself n_add_all = shape.size() - (slices.size() - 1 - n_newaxis); } // Compute strided view new_offset = base_offset; new_shape.resize(dimension); new_strides.resize(dimension); base_type::resize(dimension); auto old_shape = shape; using old_strides_value_type = typename std::decay_t::value_type; std::ptrdiff_t axis_skip = 0; std::size_t idx = 0, i = 0, i_ax = 0; auto slice_getter = detail::slice_getter_impl(shape); for (; i < slices.size(); ++i) { i_ax = static_cast(static_cast(i) - axis_skip); auto ptr = xtl::get_if(&slices[i]); if (ptr != nullptr) { auto slice0 = static_cast(*ptr); new_offset += static_cast(slice0 * old_strides[i_ax]); } else if (xtl::get_if(&slices[i]) != nullptr) { new_shape[idx] = 1; base_type::set_fake_slice(idx); ++axis_skip, ++idx; } else if (xtl::get_if(&slices[i]) != nullptr) { for (std::size_t j = 0; j < n_add_all; ++j) { new_shape[idx] = old_shape[i_ax]; new_strides[idx] = old_strides[i_ax]; base_type::set_fake_slice(idx); ++idx, ++i_ax; } axis_skip = axis_skip - static_cast(n_add_all) + 1; } else if (xtl::get_if(&slices[i]) != nullptr) { new_shape[idx] = old_shape[i_ax]; new_strides[idx] = old_strides[i_ax]; base_type::set_fake_slice(idx); ++idx; } else if (base_type::fill_args(slices, i, idx, old_shape[i_ax], old_strides[i_ax], new_shape, new_strides)) { ++idx; } else { slice_getter.idx = i_ax; auto info = xtl::visit(slice_getter, slices[i]); new_offset += static_cast(info[0] * old_strides[i_ax]); new_shape[idx] = static_cast(info[1]); new_strides[idx] = info[2] * old_strides[i_ax]; base_type::set_fake_slice(idx); ++idx; } } i_ax = static_cast(static_cast(i) - axis_skip); for (; i_ax < old_shape.size(); ++i_ax, ++idx) { new_shape[idx] = old_shape[i_ax]; new_strides[idx] = old_strides[i_ax]; base_type::set_fake_slice(idx); } new_layout = do_strides_match(new_shape, new_strides, layout, true) ? layout : layout_type::dynamic; } using shape_type = dynamic_shape; shape_type new_shape; using strides_type = dynamic_shape; strides_type new_strides; std::size_t new_offset; layout_type new_layout; }; } } #endif