/*************************************************************************** * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * * Copyright (c) QuantStack * * * * Distributed under the terms of the BSD 3-Clause License. * * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ #ifndef XTENSOR_FUNCTION_HPP #define XTENSOR_FUNCTION_HPP #include #include #include #include #include #include #include #include #include #include "xaccessible.hpp" #include "xexpression_traits.hpp" #include "xiterable.hpp" #include "xiterator.hpp" #include "xlayout.hpp" #include "xscalar.hpp" #include "xshape.hpp" #include "xstrides.hpp" #include "xtensor_simd.hpp" #include "xutils.hpp" namespace xt { namespace detail { template using conjunction_c = xtl::conjunction...>; /************************ * xfunction_cache_impl * ************************/ template struct xfunction_cache_impl { S shape; bool is_trivial; bool is_initialized; xfunction_cache_impl() : shape(xtl::make_sequence(0, std::size_t(0))) , is_trivial(false) , is_initialized(false) { } }; template struct xfunction_cache_impl, is_shape_trivial> { XTENSOR_CONSTEXPR_ENHANCED_STATIC fixed_shape shape = fixed_shape(); XTENSOR_CONSTEXPR_ENHANCED_STATIC bool is_trivial = is_shape_trivial::value; XTENSOR_CONSTEXPR_ENHANCED_STATIC bool is_initialized = true; }; #ifdef XTENSOR_HAS_CONSTEXPR_ENHANCED // Out of line definitions to prevent linker errors prior to C++17 template constexpr fixed_shape xfunction_cache_impl, is_shape_trivial>::shape; template constexpr bool xfunction_cache_impl, is_shape_trivial>::is_trivial; template constexpr bool xfunction_cache_impl, is_shape_trivial>::is_initialized; #endif template struct xfunction_bool_load_type { using type = xtl::promote_type_t::bool_load_type...>; }; template struct xfunction_bool_load_type { using type = typename std::decay_t::bool_load_type; }; template using xfunction_bool_load_type_t = typename xfunction_bool_load_type::type; } /************************ * xfunction extensions * ************************/ namespace extension { template struct xfunction_base_impl; template struct xfunction_base_impl { using type = xtensor_empty_base; }; template struct xfunction_base : xfunction_base_impl, F, CT...> { }; template using xfunction_base_t = typename xfunction_base::type; } template struct xfunction_cache : detail::xfunction_cache_impl { }; template class xfunction_iterator; template class xfunction_stepper; template class xfunction; template struct xiterable_inner_types> { using inner_shape_type = promote_shape_t::shape_type...>; using const_stepper = xfunction_stepper; using stepper = const_stepper; }; template struct xcontainer_inner_types> { // Added indirection for MSVC 2017 bug with the operator value_type() using func_return_type = typename meta_identity< decltype(std::declval()(std::declval>>()...))>::type; using value_type = std::decay_t; using reference = func_return_type; using const_reference = reference; using size_type = common_size_type_t...>; }; template struct has_simd_interface, T> : xtl::conjunction< has_simd_type, has_simd_apply>, has_simd_interface, T>...> { }; /************************************* * overlapping_memory_checker_traits * *************************************/ template struct overlapping_memory_checker_traits< E, std::enable_if_t::value && is_specialization_of::value>> { template = 0> static bool check_tuple(const std::tuple&, const memory_range&) { return false; } template = 0> static bool check_tuple(const std::tuple& t, const memory_range& dst_range) { using ChildE = std::decay_t(t))>; return overlapping_memory_checker_traits::check_overlap(std::get(t), dst_range) || check_tuple(t, dst_range); } static bool check_overlap(const E& expr, const memory_range& dst_range) { if (expr.size() == 0) { return false; } else { return check_tuple(expr.arguments(), dst_range); } } }; /************* * xfunction * *************/ /** * @class xfunction * @brief Multidimensional function operating on * xtensor expressions. * * The xfunction class implements a multidimensional function * operating on xtensor expressions. * * @tparam F the function type * @tparam CT the closure types for arguments of the function */ template class xfunction : private xconst_iterable>, public xsharable_expression>, private xconst_accessible>, public extension::xfunction_base_t { public: using self_type = xfunction; using accessible_base = xconst_accessible; using extension_base = extension::xfunction_base_t; using expression_tag = typename extension_base::expression_tag; using only_scalar = all_xscalar; using functor_type = typename std::remove_reference::type; using tuple_type = std::tuple; using inner_types = xcontainer_inner_types; using value_type = typename inner_types::value_type; using reference = typename inner_types::reference; using const_reference = typename inner_types::const_reference; using pointer = value_type*; using const_pointer = const value_type*; using size_type = typename inner_types::size_type; using difference_type = common_difference_type_t...>; using simd_value_type = xt_simd::simd_type; // xtl::promote_type_t::bool_load_type...>; using bool_load_type = detail::xfunction_bool_load_type_t; template using simd_return_type = xt_simd::simd_return_type; using iterable_base = xconst_iterable>; using inner_shape_type = typename iterable_base::inner_shape_type; using shape_type = inner_shape_type; using stepper = typename iterable_base::stepper; using const_stepper = typename iterable_base::const_stepper; static constexpr layout_type static_layout = compute_layout(std::decay_t::static_layout...); static constexpr bool contiguous_layout = static_layout != layout_type::dynamic; template using layout_iterator = typename iterable_base::template layout_iterator; template using const_layout_iterator = typename iterable_base::template const_layout_iterator; template using reverse_layout_iterator = typename iterable_base::template reverse_layout_iterator; template using const_reverse_layout_iterator = typename iterable_base::template const_reverse_layout_iterator; template using broadcast_iterator = typename iterable_base::template broadcast_iterator; template using const_broadcast_iterator = typename iterable_base::template const_broadcast_iterator; template using reverse_broadcast_iterator = typename iterable_base::template reverse_broadcast_iterator; template using const_reverse_broadcast_iterator = typename iterable_base::template const_reverse_broadcast_iterator; using const_linear_iterator = xfunction_iterator; using linear_iterator = const_linear_iterator; using const_reverse_linear_iterator = std::reverse_iterator; using reverse_linear_iterator = std::reverse_iterator; using iterator = typename iterable_base::iterator; using const_iterator = typename iterable_base::const_iterator; using reverse_iterator = typename iterable_base::reverse_iterator; using const_reverse_iterator = typename iterable_base::const_reverse_iterator; template , self_type>::value>> xfunction(Func&& f, CTA&&... e) noexcept; template xfunction(xfunction xf) noexcept; ~xfunction() = default; xfunction(const xfunction&) = default; xfunction& operator=(const xfunction&) = default; xfunction(xfunction&&) = default; xfunction& operator=(xfunction&&) = default; using accessible_base::size; size_type dimension() const noexcept; const inner_shape_type& shape() const; layout_type layout() const noexcept; bool is_contiguous() const noexcept; using accessible_base::shape; template const_reference operator()(Args... args) const; template const_reference unchecked(Args... args) const; using accessible_base::at; using accessible_base::operator[]; using accessible_base::back; using accessible_base::front; using accessible_base::in_bounds; using accessible_base::periodic; template const_reference element(It first, It last) const; template bool broadcast_shape(S& shape, bool reuse_cache = false) const; template bool has_linear_assign(const S& strides) const noexcept; using iterable_base::begin; using iterable_base::cbegin; using iterable_base::cend; using iterable_base::crbegin; using iterable_base::crend; using iterable_base::end; using iterable_base::rbegin; using iterable_base::rend; const_linear_iterator linear_begin() const noexcept; const_linear_iterator linear_end() const noexcept; const_linear_iterator linear_cbegin() const noexcept; const_linear_iterator linear_cend() const noexcept; const_reverse_linear_iterator linear_rbegin() const noexcept; const_reverse_linear_iterator linear_rend() const noexcept; const_reverse_linear_iterator linear_crbegin() const noexcept; const_reverse_linear_iterator linear_crend() const noexcept; template const_stepper stepper_begin(const S& shape) const noexcept; template const_stepper stepper_end(const S& shape, layout_type l) const noexcept; const_reference data_element(size_type i) const; const_reference flat(size_type i) const; template ::type> operator value_type() const; template ::size> simd_return_type load_simd(size_type i) const; const tuple_type& arguments() const noexcept; const functor_type& functor() const noexcept; private: template layout_type layout_impl(std::index_sequence) const noexcept; template const_reference access_impl(std::index_sequence, Args... args) const; template const_reference unchecked_impl(std::index_sequence, Args... args) const; template const_reference element_access_impl(std::index_sequence, It first, It last) const; template const_reference data_element_impl(std::index_sequence, size_type i) const; template auto load_simd_impl(std::index_sequence, size_type i) const; template const_stepper build_stepper(Func&& f, std::index_sequence) const noexcept; template auto build_iterator(Func&& f, std::index_sequence) const noexcept; size_type compute_dimension() const noexcept; void compute_cached_shape() const; tuple_type m_e; functor_type m_f; mutable xfunction_cache::shape_type...>> m_cache; friend class xfunction_iterator; friend class xfunction_stepper; friend class xconst_iterable; friend class xconst_accessible; }; /********************** * xfunction_iterator * **********************/ template class xfunction_iterator : public xtl::xrandom_access_iterator_base< xfunction_iterator, typename xfunction::value_type, typename xfunction::difference_type, typename xfunction::pointer, typename xfunction::reference> { public: using self_type = xfunction_iterator; using functor_type = typename std::remove_reference::type; using xfunction_type = xfunction; using value_type = typename xfunction_type::value_type; using reference = typename xfunction_type::value_type; using pointer = typename xfunction_type::const_pointer; using difference_type = typename xfunction_type::difference_type; using iterator_category = std::random_access_iterator_tag; template xfunction_iterator(const xfunction_type* func, It&&... it) noexcept; self_type& operator++(); self_type& operator--(); self_type& operator+=(difference_type n); self_type& operator-=(difference_type n); difference_type operator-(const self_type& rhs) const; reference operator*() const; bool equal(const self_type& rhs) const; bool less_than(const self_type& rhs) const; private: using data_type = std::tuple>()))...>; template reference deref_impl(std::index_sequence) const; template difference_type tuple_max_diff(std::index_sequence, const data_type& lhs, const data_type& rhs) const; const xfunction_type* p_f; data_type m_it; }; template bool operator==(const xfunction_iterator& it1, const xfunction_iterator& it2); template bool operator<(const xfunction_iterator& it1, const xfunction_iterator& it2); /********************* * xfunction_stepper * *********************/ template class xfunction_stepper { public: using self_type = xfunction_stepper; using functor_type = typename std::remove_reference::type; using xfunction_type = xfunction; using value_type = typename xfunction_type::value_type; using reference = typename xfunction_type::reference; using pointer = typename xfunction_type::const_pointer; using size_type = typename xfunction_type::size_type; using difference_type = typename xfunction_type::difference_type; using shape_type = typename xfunction_type::shape_type; template using simd_return_type = xt_simd::simd_return_type; template xfunction_stepper(const xfunction_type* func, St&&... st) noexcept; void step(size_type dim); void step_back(size_type dim); void step(size_type dim, size_type n); void step_back(size_type dim, size_type n); void reset(size_type dim); void reset_back(size_type dim); void to_begin(); void to_end(layout_type l); reference operator*() const; template simd_return_type step_simd(); void step_leading(); private: template reference deref_impl(std::index_sequence) const; template simd_return_type step_simd_impl(std::index_sequence); const xfunction_type* p_f; std::tuple::const_stepper...> m_st; }; /********************************* * xfunction implementation * *********************************/ /** * @name Constructor */ //@{ /** * Constructs an xfunction applying the specified function to the given * arguments. * @param f the function to apply * @param e the \ref xexpression arguments */ template template inline xfunction::xfunction(Func&& f, CTA&&... e) noexcept : m_e(std::forward(e)...) , m_f(std::forward(f)) { } /** * Constructs an xfunction applying the specified function given by another * xfunction with its arguments. * @param xf the xfunction to apply */ template template inline xfunction::xfunction(xfunction xf) noexcept : m_e(xf.arguments()) , m_f(xf.functor()) { } //@} /** * @name Size and shape */ //@{ /** * Returns the number of dimensions of the function. */ template inline auto xfunction::dimension() const noexcept -> size_type { size_type dimension = m_cache.is_initialized ? m_cache.shape.size() : compute_dimension(); return dimension; } template inline void xfunction::compute_cached_shape() const { static_assert(!detail::is_fixed::value, "Calling compute_cached_shape on fixed!"); m_cache.shape = uninitialized_shape>(compute_dimension()); m_cache.is_trivial = broadcast_shape(m_cache.shape, false); m_cache.is_initialized = true; } /** * Returns the shape of the xfunction. */ template inline auto xfunction::shape() const -> const inner_shape_type& { xtl::mpl::static_if::value>( [&](auto self) { if (!m_cache.is_initialized) { self(this)->compute_cached_shape(); } }, [](auto /*self*/) {} ); return m_cache.shape; } /** * Returns the layout_type of the xfunction. */ template inline layout_type xfunction::layout() const noexcept { return layout_impl(std::make_index_sequence()); } template inline bool xfunction::is_contiguous() const noexcept { return layout() != layout_type::dynamic && accumulate( [](bool r, const auto& exp) { return r && exp.is_contiguous(); }, true, m_e ); } //@} /** * @name Data */ /** * Returns a constant reference to the element at the specified position in the function. * @param args a list of indices specifying the position in the function. Indices * must be unsigned integers, the number of indices should be equal or greater than * the number of dimensions of the function. */ template template inline auto xfunction::operator()(Args... args) const -> const_reference { // The static cast prevents the compiler from instantiating the template methods with signed integers, // leading to warning about signed/unsigned conversions in the deeper layers of the access methods return access_impl(std::make_index_sequence(), static_cast(args)...); } /** * @name Data */ /** * Returns a constant reference to the element at the specified position of the underlying * contiguous storage of the function. * @param index index to underlying flat storage. */ template inline auto xfunction::flat(size_type index) const -> const_reference { return data_element_impl(std::make_index_sequence(), index); } /** * Returns a constant reference to the element at the specified position in the expression. * @param args a list of indices specifying the position in the expression. Indices * must be unsigned integers, the number of indices must be equal to the number of * dimensions of the expression, else the behavior is undefined. * * @warning This method is meant for performance, for expressions with a dynamic * number of dimensions (i.e. not known at compile time). Since it may have * undefined behavior (see parameters), operator() should be preferred whenever * it is possible. * @warning This method is NOT compatible with broadcasting, meaning the following * code has undefined behavior: * @code{.cpp} * xt::xarray a = {{0, 1}, {2, 3}}; * xt::xarray b = {0, 1}; * auto fd = a + b; * double res = fd.unchecked(0, 1); * @endcode */ template template inline auto xfunction::unchecked(Args... args) const -> const_reference { // The static cast prevents the compiler from instantiating the template methods with signed integers, // leading to warning about signed/unsigned conversions in the deeper layers of the access methods return unchecked_impl(std::make_index_sequence(), static_cast(args)...); } /** * Returns a constant reference to the element at the specified position in the function. * @param first iterator starting the sequence of indices * @param last iterator ending the sequence of indices * The number of indices in the sequence should be equal to or greater * than the number of dimensions of the container. */ template template inline auto xfunction::element(It first, It last) const -> const_reference { return element_access_impl(std::make_index_sequence(), first, last); } //@} /** * @name Broadcasting */ //@{ /** * Broadcast the shape of the function to the specified parameter. * @param shape the result shape * @param reuse_cache boolean for reusing a previously computed shape * @return a boolean indicating whether the broadcasting is trivial */ template template inline bool xfunction::broadcast_shape(S& shape, bool reuse_cache) const { if (m_cache.is_initialized && reuse_cache) { std::copy(m_cache.shape.cbegin(), m_cache.shape.cend(), shape.begin()); return m_cache.is_trivial; } else { // e.broadcast_shape must be evaluated even if b is false auto func = [&shape](bool b, auto&& e) { return e.broadcast_shape(shape) && b; }; return accumulate(func, true, m_e); } } /** * Checks whether the xfunction can be linearly assigned to an expression * with the specified strides. * @return a boolean indicating whether a linear assign is possible */ template template inline bool xfunction::has_linear_assign(const S& strides) const noexcept { auto func = [&strides](bool b, auto&& e) { return b && e.has_linear_assign(strides); }; return accumulate(func, true, m_e); } //@} template inline auto xfunction::linear_begin() const noexcept -> const_linear_iterator { return linear_cbegin(); } template inline auto xfunction::linear_end() const noexcept -> const_linear_iterator { return linear_cend(); } template inline auto xfunction::linear_cbegin() const noexcept -> const_linear_iterator { auto f = [](const auto& e) noexcept { return xt::linear_begin(e); }; return build_iterator(f, std::make_index_sequence()); } template inline auto xfunction::linear_cend() const noexcept -> const_linear_iterator { auto f = [](const auto& e) noexcept { return xt::linear_end(e); }; return build_iterator(f, std::make_index_sequence()); } template inline auto xfunction::linear_rbegin() const noexcept -> const_reverse_linear_iterator { return linear_crbegin(); } template inline auto xfunction::linear_rend() const noexcept -> const_reverse_linear_iterator { return linear_crend(); } template inline auto xfunction::linear_crbegin() const noexcept -> const_reverse_linear_iterator { return const_reverse_linear_iterator(linear_cend()); } template inline auto xfunction::linear_crend() const noexcept -> const_reverse_linear_iterator { return const_reverse_linear_iterator(linear_cbegin()); } template template inline auto xfunction::stepper_begin(const S& shape) const noexcept -> const_stepper { auto f = [&shape](const auto& e) noexcept { return e.stepper_begin(shape); }; return build_stepper(f, std::make_index_sequence()); } template template inline auto xfunction::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper { auto f = [&shape, l](const auto& e) noexcept { return e.stepper_end(shape, l); }; return build_stepper(f, std::make_index_sequence()); } template inline auto xfunction::data_element(size_type i) const -> const_reference { return data_element_impl(std::make_index_sequence(), i); } template template inline xfunction::operator value_type() const { return operator()(); } template template inline auto xfunction::load_simd(size_type i) const -> simd_return_type { return load_simd_impl(std::make_index_sequence(), i); } template inline auto xfunction::arguments() const noexcept -> const tuple_type& { return m_e; } template inline auto xfunction::functor() const noexcept -> const functor_type& { return m_f; } template template inline layout_type xfunction::layout_impl(std::index_sequence) const noexcept { return compute_layout(std::get(m_e).layout()...); } template template inline auto xfunction::access_impl(std::index_sequence, Args... args) const -> const_reference { XTENSOR_TRY(check_index(shape(), args...)); XTENSOR_CHECK_DIMENSION(shape(), args...); return m_f(std::get(m_e)(args...)...); } template template inline auto xfunction::unchecked_impl(std::index_sequence, Args... args) const -> const_reference { return m_f(std::get(m_e).unchecked(args...)...); } template template inline auto xfunction::element_access_impl(std::index_sequence, It first, It last) const -> const_reference { XTENSOR_TRY(check_element_index(shape(), first, last)); return m_f((std::get(m_e).element(first, last))...); } template template inline auto xfunction::data_element_impl(std::index_sequence, size_type i) const -> const_reference { return m_f((std::get(m_e).data_element(i))...); } template template inline auto xfunction::load_simd_impl(std::index_sequence, size_type i) const { return m_f.simd_apply((std::get(m_e).template load_simd(i))...); } template template inline auto xfunction::build_stepper(Func&& f, std::index_sequence) const noexcept -> const_stepper { return const_stepper(this, f(std::get(m_e))...); } template template inline auto xfunction::build_iterator(Func&& f, std::index_sequence) const noexcept { return const_linear_iterator(this, f(std::get(m_e))...); } template inline auto xfunction::compute_dimension() const noexcept -> size_type { auto func = [](size_type d, auto&& e) noexcept { return (std::max)(d, e.dimension()); }; return accumulate(func, size_type(0), m_e); } /************************************* * xfunction_iterator implementation * *************************************/ template template inline xfunction_iterator::xfunction_iterator(const xfunction_type* func, It&&... it) noexcept : p_f(func) , m_it(std::forward(it)...) { } template inline auto xfunction_iterator::operator++() -> self_type& { auto f = [](auto& it) { ++it; }; for_each(f, m_it); return *this; } template inline auto xfunction_iterator::operator--() -> self_type& { auto f = [](auto& it) { return --it; }; for_each(f, m_it); return *this; } template inline auto xfunction_iterator::operator+=(difference_type n) -> self_type& { auto f = [n](auto& it) { it += n; }; for_each(f, m_it); return *this; } template inline auto xfunction_iterator::operator-=(difference_type n) -> self_type& { auto f = [n](auto& it) { it -= n; }; for_each(f, m_it); return *this; } template inline auto xfunction_iterator::operator-(const self_type& rhs) const -> difference_type { return tuple_max_diff(std::make_index_sequence(), m_it, rhs.m_it); } template inline auto xfunction_iterator::operator*() const -> reference { return deref_impl(std::make_index_sequence()); } template inline bool xfunction_iterator::equal(const self_type& rhs) const { // Optimization: no need to compare each subiterator since they all // are incremented decremented together. constexpr std::size_t temp = xtl::mpl::find_if::value; constexpr std::size_t index = (temp == std::tuple_size::value) ? 0 : temp; return std::get(m_it) == std::get(rhs.m_it); } template inline bool xfunction_iterator::less_than(const self_type& rhs) const { // Optimization: no need to compare each subiterator since they all // are incremented decremented together. constexpr std::size_t temp = xtl::mpl::find_if::value; constexpr std::size_t index = (temp == std::tuple_size::value) ? 0 : temp; return std::get(m_it) < std::get(rhs.m_it); } template template inline auto xfunction_iterator::deref_impl(std::index_sequence) const -> reference { return (p_f->m_f)(*std::get(m_it)...); } template template inline auto xfunction_iterator::tuple_max_diff( std::index_sequence, const data_type& lhs, const data_type& rhs ) const -> difference_type { auto diff = std::make_tuple((std::get(lhs) - std::get(rhs))...); auto func = [](difference_type n, auto&& v) { return (std::max)(n, v); }; return accumulate(func, difference_type(0), diff); } template inline bool operator==(const xfunction_iterator& it1, const xfunction_iterator& it2) { return it1.equal(it2); } template inline bool operator<(const xfunction_iterator& it1, const xfunction_iterator& it2) { return it1.less_than(it2); } /************************************ * xfunction_stepper implementation * ************************************/ template template inline xfunction_stepper::xfunction_stepper(const xfunction_type* func, St&&... st) noexcept : p_f(func) , m_st(std::forward(st)...) { } template inline void xfunction_stepper::step(size_type dim) { auto f = [dim](auto& st) { st.step(dim); }; for_each(f, m_st); } template inline void xfunction_stepper::step_back(size_type dim) { auto f = [dim](auto& st) { st.step_back(dim); }; for_each(f, m_st); } template inline void xfunction_stepper::step(size_type dim, size_type n) { auto f = [dim, n](auto& st) { st.step(dim, n); }; for_each(f, m_st); } template inline void xfunction_stepper::step_back(size_type dim, size_type n) { auto f = [dim, n](auto& st) { st.step_back(dim, n); }; for_each(f, m_st); } template inline void xfunction_stepper::reset(size_type dim) { auto f = [dim](auto& st) { st.reset(dim); }; for_each(f, m_st); } template inline void xfunction_stepper::reset_back(size_type dim) { auto f = [dim](auto& st) { st.reset_back(dim); }; for_each(f, m_st); } template inline void xfunction_stepper::to_begin() { auto f = [](auto& st) { st.to_begin(); }; for_each(f, m_st); } template inline void xfunction_stepper::to_end(layout_type l) { auto f = [l](auto& st) { st.to_end(l); }; for_each(f, m_st); } template inline auto xfunction_stepper::operator*() const -> reference { return deref_impl(std::make_index_sequence()); } template template inline auto xfunction_stepper::deref_impl(std::index_sequence) const -> reference { return (p_f->m_f)(*std::get(m_st)...); } template template inline auto xfunction_stepper::step_simd_impl(std::index_sequence) -> simd_return_type { return (p_f->m_f.simd_apply)(std::get(m_st).template step_simd()...); } template template inline auto xfunction_stepper::step_simd() -> simd_return_type { return step_simd_impl(std::make_index_sequence()); } template inline void xfunction_stepper::step_leading() { auto step_leading_lambda = [](auto&& st) { st.step_leading(); }; for_each(step_leading_lambda, m_st); } } #endif