cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
tensor_product.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_VIEWS_TENSOR_PRODUCT_H
4#define pRC_CORE_TENSOR_VIEWS_TENSOR_PRODUCT_H
5
8
9namespace pRC::TensorViews
10{
11 template<class T, class N, class VA, class VB>
12 class TensorProduct : public View<T, N, TensorProduct<T, N, VA, VB>>
13 {
14 static_assert(IsTensorView<VA>());
15 static_assert(IsTensorView<VB>());
16
17 private:
19
20 public:
21 template<class XA, class XB, If<IsSame<VA, RemoveReference<XA>>> = 0,
22 If<IsSame<VB, RemoveReference<XB>>> = 0>
24 : mA(forward<XA>(a))
25 , mB(forward<XB>(b))
26 {
27 }
28
29 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
30 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
31 constexpr decltype(auto) operator()(Is const... indices)
32 {
33 return expand(
34 makeSeries<Index, typename VA::Dimension{}>(),
35 [this](auto const &indices,
36 auto const... seq) -> decltype(auto)
37 {
38 return mA(indices[seq]...);
39 },
40 Indices<sizeof...(Is)>(indices...)) *
41 expand(
42 makeRange<Index, typename VA::Dimension{},
43 typename Base::Dimension{}>(),
44 [this](auto const &indices,
45 auto const... seq) -> decltype(auto)
46 {
47 return mB(indices[seq]...);
48 },
49 Indices<sizeof...(Is)>(indices...));
50 }
51
52 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
53 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
54 constexpr decltype(auto) operator()(Is const... indices) const
55 {
56 return expand(
57 makeSeries<Index, typename VA::Dimension{}>(),
58 [this](auto const &indices,
59 auto const... seq) -> decltype(auto)
60 {
61 return mA(indices[seq]...);
62 },
63 Indices<sizeof...(Is)>(indices...)) *
64 expand(
65 makeRange<Index, typename VA::Dimension{},
66 typename Base::Dimension{}>(),
67 [this](auto const &indices,
68 auto const... seq) -> decltype(auto)
69 {
70 return mB(indices[seq]...);
71 },
72 Indices<sizeof...(Is)>(indices...));
73 }
74
75 constexpr decltype(auto) operator()(
76 typename Base::Subscripts const &subscripts)
77 {
78 return this->call(subscripts);
79 }
80
81 constexpr decltype(auto) operator()(
82 typename Base::Subscripts const &subscripts) const
83 {
84 return this->call(subscripts);
85 }
86
87 private:
88 VA mA;
89 VB mB;
90 };
91}
92#endif // pRC_CORE_TENSOR_VIEWS_TENSOR_PRODUCT_H
Definition indices.hpp:15
Definition tensor_product.hpp:13
constexpr decltype(auto) operator()(Is const ... indices)
Definition tensor_product.hpp:31
TensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:23
constexpr decltype(auto) operator()(Is const ... indices) const
Definition tensor_product.hpp:54
Definition type_traits.hpp:32
Definition diagonal.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto makeRange()
Definition sequence.hpp:379
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344