/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_STRIDES_HPP #define XTENSOR_STRIDES_HPP #include #include #include #include #include #include "xexception.hpp" #include "xshape.hpp" #include "xtensor_config.hpp" #include "xtensor_forward.hpp" namespace xt { template std::size_t compute_size(const shape_type& shape) noexcept; /** * @defgroup xt_xstrides Support functions swich between array indices and flat indices */ /*************** * data offset * ***************/ template offset_type data_offset(const S& strides) noexcept; /** * @brief Return the flat index for an array index. * * Given ``m`` arguments, and dimension ``n``of the array (``n == strides.size()``). * * - If ``m == n``, the index is * ``strides[0] * index[0] + ... + strides[n - 1] * index[n - 1]``. * * - If ``m < n`` and the last argument is ``xt::missing`` the indices are zero-padded at * the end to match the dimension of the array. The index is then * ``strides[0] * index[0] + ... + strides[m - 1] * index[m - 1]``. * * - If ``m < n`` (and the last argument is not ``xt::missing``), the index is * ``strides[n - m - 1] * index[0] + ... + strides[n - 1] * index[m - 1]``. * * - If ``m > n``, then the first ``m - n`` arguments are ignored. The index is then * ``strides[0] * index[m - n] + ... + strides[n - 1] * index[m - 1]``. * * @ingroup xt_xstrides * @param strides Strides of the array. * @param args Array index. * @return The flat index. */ template offset_type data_offset(const S& strides, Arg arg, Args... args) noexcept; template offset_type unchecked_data_offset(const S& strides, Args... args) noexcept; template offset_type element_offset(const S& strides, It first, It last) noexcept; /******************* * strides builder * *******************/ /** * @brief Compute the strides given the shape and the layout of an array. * * @ingroup xt_xstrides * @param shape Shape of the array. * @param l Layout type, see xt::layout_type(). * @param strides (output) Strides of the array. * @return The size: the product of the shape. */ template std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides); template std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides, backstrides_type& backstrides); template void adapt_strides(const shape_type& shape, strides_type& strides) noexcept; template void adapt_strides(const shape_type& shape, strides_type& strides, backstrides_type& backstrides) noexcept; /***************** * unravel_index * *****************/ template S unravel_from_strides(typename S::value_type index, const S& strides, layout_type l = layout_type::row_major); template get_strides_t unravel_index(typename S::value_type index, const S& shape, layout_type l = layout_type::row_major); template std::vector> unravel_indices(const T& indices, const S& shape, layout_type l = layout_type::row_major); /*********************** * broadcast functions * ***********************/ template S uninitialized_shape(size_type size); template bool broadcast_shape(const S1& input, S2& output); template bool broadcastable(const S1& s1, S2& s2); /************************* * check strides overlap * *************************/ template struct check_strides_overlap; /********************************** * check bounds, without throwing * **********************************/ /** * @brief Check if the index is within the bounds of the array. * * @param shape Shape of the array. * @param args Array index. * @return true If the index is within the bounds of the array. * @return false Otherwise. */ template bool in_bounds(const S& shape, Args&... args); /******************************** * apply periodicity to indices * *******************************/ /** * @brief Normalise an index of a periodic array. * For example if the shape is ``(3, 4)`` and the index is ``(3, -4)`` the result is ``(0, 0)``. * * @ingroup xt_xstrides * @param shape Shape of the array. * @param args (input/output) Array index. */ template void normalize_periodic(const S& shape, Args&... args); /******************************************** * utility functions for strided containers * ********************************************/ template It strided_data_end(const C& c, It begin, layout_type l, size_type offset) { using difference_type = typename std::iterator_traits::difference_type; if (c.dimension() == 0) { ++begin; } else { for (std::size_t i = 0; i != c.dimension(); ++i) { begin += c.strides()[i] * difference_type(c.shape()[i] - 1); } if (l == layout_type::row_major) { begin += c.strides().back(); } else { if (offset == 0) { begin += c.strides().front(); } } } return begin; } /*********** * strides * ***********/ namespace detail { template inline return_type compute_stride_impl(layout_type layout, const S& shape, T axis, D default_stride) { if (layout == layout_type::row_major) { return std::accumulate( shape.cbegin() + axis + 1, shape.cend(), static_cast(1), std::multiplies() ); } if (layout == layout_type::column_major) { return std::accumulate( shape.cbegin(), shape.cbegin() + axis, static_cast(1), std::multiplies() ); } return default_stride; } } /** * @brief Choose stride type * @ingroup xt_xstrides */ enum class stride_type { internal = 0, ///< As used internally (with `stride(axis) == 0` if `shape(axis) == 1`) normal = 1, ///< Normal stride corresponding to storage. bytes = 2, ///< Normal stride in bytes. }; /** * @brief Get strides of an object. * * @ingroup xt_xstrides * @param a an array * @return array */ template inline auto strides(const E& e, stride_type type = stride_type::normal) noexcept { using strides_type = typename E::strides_type; using return_type = typename strides_type::value_type; strides_type ret = e.strides(); auto shape = e.shape(); if (type == stride_type::internal) { return ret; } for (std::size_t i = 0; i < ret.size(); ++i) { if (shape[i] == 1) { ret[i] = detail::compute_stride_impl(e.layout(), shape, i, ret[i]); } } if (type == stride_type::bytes) { return_type f = static_cast(sizeof(typename E::value_type)); std::for_each( ret.begin(), ret.end(), [f](auto& c) { c *= f; } ); } return ret; } /** * @brief Get stride of an object along an axis. * * @ingroup xt_xstrides * @param a an array * @return integer */ template inline auto strides(const E& e, std::size_t axis, stride_type type = stride_type::normal) noexcept { using strides_type = typename E::strides_type; using return_type = typename strides_type::value_type; return_type ret = e.strides()[axis]; if (type == stride_type::internal) { return ret; } if (ret == 0) { if (e.shape(axis) == 1) { ret = detail::compute_stride_impl(e.layout(), e.shape(), axis, ret); } } if (type == stride_type::bytes) { return_type f = static_cast(sizeof(typename E::value_type)); ret *= f; } return ret; } /****************** * Implementation * ******************/ namespace detail { template inline std::size_t compute_size_impl(const shape_type& shape, std::true_type /* is signed */) { using size_type = std::decay_t; return static_cast(std::abs( std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies()) )); } template inline std::size_t compute_size_impl(const shape_type& shape, std::false_type /* is not signed */) { using size_type = std::decay_t; return static_cast( std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies()) ); } } template inline std::size_t compute_size(const shape_type& shape) noexcept { return detail::compute_size_impl( shape, xtl::is_signed::value_type>>() ); } namespace detail { template inline auto raw_data_offset(const S&) noexcept { using strides_value_type = std::decay_t()[0])>; return strides_value_type(0); } template inline auto raw_data_offset(const S&, missing_type) noexcept { using strides_value_type = std::decay_t()[0])>; return strides_value_type(0); } template inline auto raw_data_offset(const S& strides, Arg arg, Args... args) noexcept { return static_cast(arg) * strides[dim] + raw_data_offset(strides, args...); } template struct layout_data_offset { template inline static auto run(const S& strides, Arg arg, Args... args) noexcept { return raw_data_offset(strides, arg, args...); } }; template struct layout_data_offset { using self_type = layout_data_offset; template inline static auto run(const S& strides, Arg arg) noexcept { if (std::ptrdiff_t(dim) + 1 == static_dim) { return arg; } else { return arg * strides[dim]; } } template inline static auto run(const S& strides, Arg arg, Args... args) noexcept { return arg * strides[dim] + self_type::template run(strides, args...); } }; template struct layout_data_offset { using self_type = layout_data_offset; template inline static auto run(const S& strides, Arg arg) noexcept { if (dim == 0) { return arg; } else { return arg * strides[dim]; } } template inline static auto run(const S& strides, Arg arg, Args... args) noexcept { if (dim == 0) { return arg + self_type::template run(strides, args...); } else { return arg * strides[dim] + self_type::template run(strides, args...); } } }; } template inline offset_type data_offset(const S&) noexcept { return offset_type(0); } template inline offset_type data_offset(const S& strides, Arg arg, Args... args) noexcept { constexpr std::size_t nargs = sizeof...(Args) + 1; if (nargs == strides.size()) { // Correct number of arguments: iterate return static_cast(detail::raw_data_offset<0>(strides, arg, args...)); } else if (nargs > strides.size()) { // Too many arguments: drop the first return data_offset(strides, args...); } else if (detail::last_type_is_missing) { // Too few arguments & last argument xt::missing: postfix index with zeros return static_cast(detail::raw_data_offset<0>(strides, arg, args...)); } else { // Too few arguments: right to left scalar product auto view = strides.cend() - nargs; return static_cast(detail::raw_data_offset<0>(view, arg, args...)); } } template inline offset_type unchecked_data_offset(const S& strides, Args... args) noexcept { return static_cast( detail::layout_data_offset::value>::template run<0>(strides.cbegin(), args...) ); } template inline offset_type element_offset(const S& strides, It first, It last) noexcept { using difference_type = typename std::iterator_traits::difference_type; auto size = static_cast( (std::min)(static_cast(std::distance(first, last)), strides.size()) ); return std::inner_product(last - size, last, strides.cend() - size, offset_type(0)); } namespace detail { template inline void adapt_strides( const shape_type& shape, strides_type& strides, bs_ptr backstrides, typename strides_type::size_type i ) noexcept { if (shape[i] == 1) { strides[i] = 0; } (*backstrides)[i] = strides[i] * std::ptrdiff_t(shape[i] - 1); } template inline void adapt_strides( const shape_type& shape, strides_type& strides, std::nullptr_t, typename strides_type::size_type i ) noexcept { if (shape[i] == 1) { strides[i] = 0; } } template inline std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides, bs_ptr bs) { using strides_value_type = typename std::decay_t::value_type; strides_value_type data_size = 1; #if defined(_MSC_VER) && (1931 <= _MSC_VER) // Workaround MSVC compiler optimization bug, xtensor#2568 if (0 == shape.size()) { return static_cast(data_size); } #endif if (L == layout_type::row_major || l == layout_type::row_major) { for (std::size_t i = shape.size(); i != 0; --i) { strides[i - 1] = data_size; data_size = strides[i - 1] * static_cast(shape[i - 1]); adapt_strides(shape, strides, bs, i - 1); } } else { for (std::size_t i = 0; i < shape.size(); ++i) { strides[i] = data_size; data_size = strides[i] * static_cast(shape[i]); adapt_strides(shape, strides, bs, i); } } return static_cast(data_size); } } template inline std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides) { return detail::compute_strides(shape, l, strides, nullptr); } template inline std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides, backstrides_type& backstrides) { return detail::compute_strides(shape, l, strides, &backstrides); } template inline bool stride_match_condition(const T1& stride, const T2& shape, const T1& data_size, bool zero_strides) { return (shape == T2(1) && stride == T1(0) && zero_strides) || (stride == data_size); } // zero_strides should be true when strides are set to 0 if the corresponding dimensions are 1 template inline bool do_strides_match(const shape_type& shape, const strides_type& strides, layout_type l, bool zero_strides) { using value_type = typename strides_type::value_type; value_type data_size = 1; if (l == layout_type::row_major) { for (std::size_t i = strides.size(); i != 0; --i) { if (!stride_match_condition(strides[i - 1], shape[i - 1], data_size, zero_strides)) { return false; } data_size *= static_cast(shape[i - 1]); } return true; } else if (l == layout_type::column_major) { for (std::size_t i = 0; i < strides.size(); ++i) { if (!stride_match_condition(strides[i], shape[i], data_size, zero_strides)) { return false; } data_size *= static_cast(shape[i]); } return true; } else { return false; } } template inline void adapt_strides(const shape_type& shape, strides_type& strides) noexcept { for (typename shape_type::size_type i = 0; i < shape.size(); ++i) { detail::adapt_strides(shape, strides, nullptr, i); } } template inline void adapt_strides(const shape_type& shape, strides_type& strides, backstrides_type& backstrides) noexcept { for (typename shape_type::size_type i = 0; i < shape.size(); ++i) { detail::adapt_strides(shape, strides, &backstrides, i); } } namespace detail { template inline S unravel_noexcept(typename S::value_type idx, const S& strides, layout_type l) noexcept { using value_type = typename S::value_type; using size_type = typename S::size_type; S result = xtl::make_sequence(strides.size(), 0); if (l == layout_type::row_major) { for (size_type i = 0; i < strides.size(); ++i) { value_type str = strides[i]; value_type quot = str != 0 ? idx / str : 0; idx = str != 0 ? idx % str : idx; result[i] = quot; } } else { for (size_type i = strides.size(); i != 0; --i) { value_type str = strides[i - 1]; value_type quot = str != 0 ? idx / str : 0; idx = str != 0 ? idx % str : idx; result[i - 1] = quot; } } return result; } } template inline S unravel_from_strides(typename S::value_type index, const S& strides, layout_type l) { if (l != layout_type::row_major && l != layout_type::column_major) { XTENSOR_THROW(std::runtime_error, "unravel_index: dynamic layout not supported"); } return detail::unravel_noexcept(index, strides, l); } template inline get_value_type_t ravel_from_strides(const T& index, const S& strides) { return element_offset>(strides, index.begin(), index.end()); } template inline get_strides_t unravel_index(typename S::value_type index, const S& shape, layout_type l) { using strides_type = get_strides_t; using strides_value_type = typename strides_type::value_type; strides_type strides = xtl::make_sequence(shape.size(), 0); compute_strides(shape, l, strides); return unravel_from_strides(static_cast(index), strides, l); } template inline std::vector> unravel_indices(const T& idx, const S& shape, layout_type l) { using strides_type = get_strides_t; using strides_value_type = typename strides_type::value_type; strides_type strides = xtl::make_sequence(shape.size(), 0); compute_strides(shape, l, strides); std::vector> out(idx.size()); auto out_iter = out.begin(); auto idx_iter = idx.begin(); for (; out_iter != out.end(); ++out_iter, ++idx_iter) { *out_iter = unravel_from_strides(static_cast(*idx_iter), strides, l); } return out; } template inline get_value_type_t ravel_index(const T& index, const S& shape, layout_type l) { using strides_type = get_strides_t; strides_type strides = xtl::make_sequence(shape.size(), 0); compute_strides(shape, l, strides); return ravel_from_strides(index, strides); } template inline S uninitialized_shape(stype size) { using value_type = typename S::value_type; using size_type = typename S::size_type; return xtl::make_sequence(static_cast(size), std::numeric_limits::max()); } template inline bool broadcast_shape(const S1& input, S2& output) { bool trivial_broadcast = (input.size() == output.size()); // Indices are faster than reverse iterators using value_type = typename S2::value_type; auto output_index = output.size(); auto input_index = input.size(); if (output_index < input_index) { throw_broadcast_error(output, input); } for (; input_index != 0; --input_index, --output_index) { // First case: output = (MAX, MAX, ...., MAX) // output is a new shape that has not been through // the broadcast process yet; broadcast is trivial if (output[output_index - 1] == std::numeric_limits::max()) { output[output_index - 1] = static_cast(input[input_index - 1]); } // Second case: output has been initialized to 1. Broadcast is trivial // only if input is 1 to. else if (output[output_index - 1] == 1) { output[output_index - 1] = static_cast(input[input_index - 1]); trivial_broadcast = trivial_broadcast && (input[input_index - 1] == 1); } // Third case: output has been initialized to something different from 1. // if input is 1, then the broadcast is not trivial else if (input[input_index - 1] == 1) { trivial_broadcast = false; } // Last case: input and output must have the same value, else // shape are not compatible and an exception is thrown else if (static_cast(input[input_index - 1]) != output[output_index - 1]) { throw_broadcast_error(output, input); } } return trivial_broadcast; } template inline bool broadcastable(const S1& src_shape, const S2& dst_shape) { auto src_iter = src_shape.crbegin(); auto dst_iter = dst_shape.crbegin(); bool res = dst_shape.size() >= src_shape.size(); for (; src_iter != src_shape.crend() && res; ++src_iter, ++dst_iter) { res = (static_cast(*src_iter) == static_cast(*dst_iter)) || (*src_iter == 1); } return res; } template <> struct check_strides_overlap { template static std::size_t get(const S1& s1, const S2& s2) { using value_type = typename S1::value_type; // Indices are faster than reverse iterators auto s1_index = s1.size(); auto s2_index = s2.size(); for (; s2_index != 0; --s1_index, --s2_index) { if (static_cast(s1[s1_index - 1]) != static_cast(s2[s2_index - 1])) { break; } } return s1_index; } }; template <> struct check_strides_overlap { template static std::size_t get(const S1& s1, const S2& s2) { // Indices are faster than reverse iterators using size_type = typename S1::size_type; using value_type = typename S1::value_type; size_type index = 0; // This check is necessary as column major "broadcasting" is still // performed in a row major fashion if (s1.size() != s2.size()) { return 0; } auto size = s2.size(); for (; index < size; ++index) { if (static_cast(s1[index]) != static_cast(s2[index])) { break; } } return index; } }; namespace detail { template inline bool check_in_bounds_impl(const S&) { return true; } template inline bool check_in_bounds_impl(const S&, missing_type) { return true; } template inline bool check_in_bounds_impl(const S& shape, T& arg, Args&... args) { if (sizeof...(Args) + 1 > shape.size()) { return check_in_bounds_impl(shape, args...); } else { return arg >= T(0) && arg < static_cast(shape[dim]) && check_in_bounds_impl(shape, args...); } } } template inline bool check_in_bounds(const S& shape, Args&... args) { return detail::check_in_bounds_impl(shape, args...); } namespace detail { template inline void normalize_periodic_impl(const S&) { } template inline void normalize_periodic_impl(const S&, missing_type) { } template inline void normalize_periodic_impl(const S& shape, T& arg, Args&... args) { if (sizeof...(Args) + 1 > shape.size()) { normalize_periodic_impl(shape, args...); } else { T n = static_cast(shape[dim]); arg = (n + (arg % n)) % n; normalize_periodic_impl(shape, args...); } } } template inline void normalize_periodic(const S& shape, Args&... args) { check_dimension(shape, args...); detail::normalize_periodic_impl(shape, args...); } } #endif