/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_XSHAPE_HPP #define XTENSOR_XSHAPE_HPP #include #include #include #include #include #include #include #include #include "xlayout.hpp" #include "xstorage.hpp" #include "xtensor_forward.hpp" namespace xt { template using dynamic_shape = svector; template using static_shape = std::array; template class fixed_shape; using xindex = dynamic_shape; template bool same_shape(const S1& s1, const S2& s2) noexcept; template struct initializer_dimension; template constexpr R shape(T t); template xt::static_shape shape(const T (&aList)[N]); template struct static_dimension; template struct select_layout; template struct promote_shape; template struct promote_strides; template struct index_from_shape; } namespace xtl { namespace detail { template struct sequence_builder; template struct sequence_builder> { using sequence_type = xt::fixed_shape; using value_type = typename sequence_type::value_type; inline static sequence_type make(std::size_t /*size*/) { return sequence_type{}; } inline static sequence_type make(std::size_t /*size*/, value_type /*v*/) { return sequence_type{}; } }; } } namespace xt { /** * @defgroup xt_xshape Support functions to get/check a shape array. */ /************** * same_shape * **************/ /** * Check if two objects have the same shape. * * @ingroup xt_xshape * @param s1 an array * @param s2 an array * @return bool */ template inline bool same_shape(const S1& s1, const S2& s2) noexcept { return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin()); } /************* * has_shape * *************/ /** * Check if an object has a certain shape. * * @ingroup xt_xshape * @param a an array * @param shape the shape to test * @return bool */ template inline bool has_shape(const E& e, std::initializer_list shape) noexcept { return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin()); } /** * Check if an object has a certain shape. * * @ingroup has_shape * @param a an array * @param shape the shape to test * @return bool */ template ::value>> inline bool has_shape(const E& e, const S& shape) { return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin()); } /************************* * initializer_dimension * *************************/ namespace detail { template struct initializer_depth_impl { static constexpr std::size_t value = 0; }; template struct initializer_depth_impl> { static constexpr std::size_t value = 1 + initializer_depth_impl::value; }; } template struct initializer_dimension { static constexpr std::size_t value = detail::initializer_depth_impl::value; }; /********************* * initializer_shape * *********************/ namespace detail { template struct initializer_shape_impl { template static constexpr std::size_t value(T t) { return t.size() == 0 ? 0 : initializer_shape_impl::value(*t.begin()); } }; template <> struct initializer_shape_impl<0> { template static constexpr std::size_t value(T t) { return t.size(); } }; template constexpr R initializer_shape(U t, std::index_sequence) { using size_type = typename R::value_type; return {size_type(initializer_shape_impl::value(t))...}; } } template constexpr R shape(T t) { return detail::initializer_shape( t, std::make_index_sequence::value>() ); } /** @brief Generate an xt::static_shape of the given size. */ template xt::static_shape shape(const T (&list)[N]) { xt::static_shape shape; std::copy(std::begin(list), std::end(list), std::begin(shape)); return shape; } /******************** * static_dimension * ********************/ namespace detail { template struct static_dimension_impl { static constexpr std::ptrdiff_t value = -1; }; template struct static_dimension_impl::value)>> { static constexpr std::ptrdiff_t value = static_cast(std::tuple_size::value); }; } template struct static_dimension { static constexpr std::ptrdiff_t value = detail::static_dimension_impl::value; }; /** * Compute a layout based on a layout and a shape type. * * The main functionality of this function is that it reduces vectors to * ``xt::layout_type::any`` so that assigning a row major 1D container to another * row_major container becomes free. * * @ingroup xt_xshape */ template struct select_layout { static constexpr std::ptrdiff_t static_dimension = xt::static_dimension::value; static constexpr bool is_any = static_dimension != -1 && static_dimension <= 1 && L != layout_type::dynamic; static constexpr layout_type value = is_any ? layout_type::any : L; }; /************************************* * promote_shape and promote_strides * *************************************/ namespace detail { template constexpr std::common_type_t imax(const T1& a, const T2& b) { return a > b ? a : b; } // Variadic meta-function returning the maximal size of std::arrays. template struct max_array_size; template <> struct max_array_size<> { static constexpr std::size_t value = 0; }; template struct max_array_size : std::integral_constant::value, max_array_size::value)> { }; // Broadcasting for fixed shapes template struct at { static constexpr std::size_t arr[sizeof...(X)] = {X...}; static constexpr std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0; }; template struct broadcast_fixed_shape; template struct broadcast_fixed_shape_impl; template struct broadcast_fixed_shape_cmp_impl; template struct broadcast_fixed_shape_cmp_impl, fixed_shape> { // We line the shapes up from the last index // IX may underflow, thus being a very large number static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I)); // Out of bounds access gives value 0 static constexpr std::size_t I_v = at::value; static constexpr std::size_t J_v = at::value; // we're statically checking if the broadcast shapes are either one on either of them or equal static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match."); static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v; static constexpr bool value = (I_v == J_v); }; template struct broadcast_fixed_shape_impl, fixed_shape, fixed_shape> { static_assert(sizeof...(J) >= sizeof...(I), "broadcast shapes do not match."); using type = xt::fixed_shape< broadcast_fixed_shape_cmp_impl, fixed_shape>::ordinate...>; static constexpr bool value = xtl::conjunction< broadcast_fixed_shape_cmp_impl, fixed_shape>...>::value; }; /* broadcast_fixed_shape, fixed_shape> * Just like a call to broadcast_shape(cont S1& input, S2& output), * except that the result shape is alised as type, and the returned * bool is the member value. Asserts on an illegal broadcast, including * the case where pack I is strictly longer than pack J. */ template struct broadcast_fixed_shape, fixed_shape> : broadcast_fixed_shape_impl, fixed_shape, fixed_shape> { }; // Simple is_array and only_array meta-functions template struct is_array { static constexpr bool value = false; }; template struct is_array> { static constexpr bool value = true; }; template struct is_fixed : std::false_type { }; template struct is_fixed> : std::true_type { }; template struct is_scalar_shape { static constexpr bool value = false; }; template struct is_scalar_shape> { static constexpr bool value = true; }; template using only_array = xtl::conjunction, is_fixed>...>; // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or // scalar template using only_fixed = std::integral_constant< bool, xtl::disjunction...>::value && xtl::conjunction, is_scalar_shape>...>::value>; template using all_fixed = xtl::conjunction...>; // The promote_index meta-function returns std::vector in the // general case and an array of the promoted value type and maximal size if all // arguments are of type std::array template struct promote_array { using type = std:: array::type, max_array_size::value>; }; template <> struct promote_array<> { using type = std::array; }; template struct filter_scalar { using type = S; }; template struct filter_scalar> { using type = fixed_shape<1>; }; template using filter_scalar_t = typename filter_scalar::type; template struct promote_fixed : promote_fixed...> { }; template struct promote_fixed> { using type = fixed_shape; static constexpr bool value = true; }; template struct promote_fixed, fixed_shape, S...> { private: using intermediate = std::conditional_t< (sizeof...(I) > sizeof...(J)), broadcast_fixed_shape, fixed_shape>, broadcast_fixed_shape, fixed_shape>>; using result = promote_fixed; public: using type = typename result::type; static constexpr bool value = xtl::conjunction::value; }; template struct select_promote_index; template struct select_promote_index : promote_fixed { }; template <> struct select_promote_index { // todo correct? used in xvectorize using type = dynamic_shape; }; template struct select_promote_index : promote_array { }; template struct select_promote_index { using type = dynamic_shape::type>; }; template struct promote_index : select_promote_index::value, only_array::value, S...> { }; template struct index_from_shape_impl { using type = T; }; template struct index_from_shape_impl> { using type = std::array; }; } template struct promote_shape { using type = typename detail::promote_index::type; }; /** * @ingroup xt_xshape */ template using promote_shape_t = typename promote_shape::type; template struct promote_strides { using type = typename detail::promote_index::type; }; /** * @ingroup xt_xshape */ template using promote_strides_t = typename promote_strides::type; template struct index_from_shape { using type = typename detail::index_from_shape_impl::type; }; /** * @ingroup xt_xshape */ template using index_from_shape_t = typename index_from_shape::type; /********************** * filter_fixed_shape * **********************/ namespace detail { template struct filter_fixed_shape_impl { using type = S; }; template struct filter_fixed_shape_impl> { using type = std::array; }; } template struct filter_fixed_shape : detail::filter_fixed_shape_impl { }; /** * @ingroup xt_xshape */ template using filter_fixed_shape_t = typename filter_fixed_shape::type; } #endif