cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
diagonal.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_OPERATOR_VIEWS_DIAGONAL_H
4#define pRC_CORE_TENSOR_OPERATOR_VIEWS_DIAGONAL_H
5
9
11{
12 template<class T, class N, class V>
13 class Diagonal : public View<T, N, Diagonal<T, N, V>>
14 {
15 static_assert(IsTensorView<V>());
16
17 private:
19
20 public:
21 template<class X, If<IsConstructible<V, X>> = 0>
23 : mA(forward<X>(a))
24 {
25 }
26
27 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
28 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
29 constexpr decltype(auto) operator()(Is const... indices)
30 {
31 return expand(
32 makeSeries<Index, typename Base::Dimension() / 2>(),
33 [this](auto const &indices, auto const... seq) -> T
34 {
35 if(((indices[seq] ==
36 indices[typename Base::Dimension() / 2 + seq]) &&
37 ...))
38 {
39 return mA(indices[seq]..., indices[seq]...);
40 }
41
42 return zero();
43 },
44 Indices<typename Base::Dimension{}>(indices...));
45 }
46
47 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
48 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
49 constexpr decltype(auto) operator()(Is const... indices) const
50 {
51 return expand(
52 makeSeries<Index, typename Base::Dimension() / 2>(),
53 [this](auto const &indices, auto const... seq) -> T
54 {
55 if(((indices[seq] ==
56 indices[typename Base::Dimension() / 2 + seq]) &&
57 ...))
58 {
59 return mA(indices[seq]..., indices[seq]...);
60 }
61
62 return zero();
63 },
64 Indices<typename Base::Dimension{}>(indices...));
65 }
66
67 constexpr decltype(auto) operator()(
68 typename Base::Subscripts const &subscripts)
69 {
70 return this->call(subscripts);
71 }
72
73 constexpr decltype(auto) operator()(
74 typename Base::Subscripts const &subscripts) const
75 {
76 return this->call(subscripts);
77 }
78
79 private:
80 V mA;
81 };
82}
83#endif // pRC_CORE_TENSOR_OPERATOR_VIEWS_DIAGONAL_H
Definition indices.hpp:15
Definition diagonal.hpp:14
constexpr decltype(auto) operator()(Is const ... indices)
Definition diagonal.hpp:29
Diagonal(X &&a)
Definition diagonal.hpp:22
constexpr decltype(auto) operator()(Is const ... indices) const
Definition diagonal.hpp:49
Definition type_traits.hpp:32
Definition diagonal.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344