cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
identity.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_OPERATOR_IDENTITY_H
4#define pRC_TENSOR_TRAIN_OPERATOR_IDENTITY_H
5
10
11namespace pRC
12{
13 template<class T>
14 struct Identity<T, If<TensorTrain::IsOperatorView<T>>>
15 : Identity<ResultOf<Eval, T>>
16 {
17 };
18
19 template<class T>
20 struct Identity<T, If<TensorTrain::IsOperator<T>>>
21 {
22 constexpr auto operator()()
23 {
24 auto const f = []<Index C>()
25 {
26 using Core = typename T::template Cores<C>;
27 constexpr auto CRL = Core::size(0);
28 constexpr auto CM = Core::size(1);
29 constexpr auto CN = Core::size(2);
30 constexpr auto CRR = Core::size(3);
31
34 typename Core::template ChangeSizes<CM, CN>>()),
35 zero<typename Core::template ChangeSizes<CRL - 1, CM, CN,
36 CRR - 1>>()));
37 };
38
39 using F = RemoveConstReference<decltype(f)>;
40 using M = typename T::M;
41 using N = typename T::N;
42 using Ranks = typename T::Ranks;
43
44 return TensorTrain::OperatorViews::Enumerate<typename T::Type, M, N,
45 Ranks, F>(f);
46 }
47
48 template<class X, If<IsConstructible<typename T::Type, X>> = 0>
49 constexpr auto operator()(X &&value)
50 {
51 auto const f =
52 [value = typename T::Type(forward<X>(value))]<Index C>()
53 {
54 using Core = typename T::template Cores<C>;
55 constexpr auto CRL = Core::size(0);
56 constexpr auto CM = Core::size(1);
57 constexpr auto CN = Core::size(2);
58 constexpr auto CRR = Core::size(3);
59
60 if constexpr(C == 0)
61 {
64 typename Core::template ChangeSizes<CM, CN>>(
65 value)),
66 zero<typename Core::template ChangeSizes<CRL - 1, CM,
67 CN, CRR - 1>>()));
68 }
69 else
70 {
73 typename Core::template ChangeSizes<CM, CN>>()),
74 zero<typename Core::template ChangeSizes<CRL - 1, CM,
75 CN, CRR - 1>>()));
76 }
77 };
78
79 using F = RemoveConstReference<decltype(f)>;
80 using M = typename T::M;
81 using N = typename T::N;
82 using Ranks = typename T::Ranks;
83
84 return TensorTrain::OperatorViews::Enumerate<typename T::Type, M, N,
85 Ranks, F>(f);
86 }
87 };
88}
89#endif // pRC_TENSOR_TRAIN_OPERATOR_IDENTITY_H
Definition cholesky.hpp:18
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
static constexpr auto identity()
Definition identity.hpp:12
constexpr auto operator()()
Definition identity.hpp:22
constexpr auto operator()(X &&value)
Definition identity.hpp:49
Definition type_traits.hpp:262