cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
unit.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_OPERATOR_UNIT_H
4#define pRC_TENSOR_TRAIN_OPERATOR_UNIT_H
5
10
11namespace pRC
12{
13 template<class T>
14 struct Unit<T, If<TensorTrain::IsOperatorView<T>>> : Unit<ResultOf<Eval, T>>
15 {
16 };
17
18 template<class T>
19 struct Unit<T, If<TensorTrain::IsOperator<T>>>
20 {
21 constexpr auto operator()()
22 {
23 auto const f = []<Index C>()
24 {
25 using Core = typename T::template Cores<C>;
26 constexpr auto CRL = Core::size(0);
27 constexpr auto CM = Core::size(1);
28 constexpr auto CN = Core::size(2);
29 constexpr auto CRR = Core::size(3);
30
32 unit<typename Core::template ChangeSizes<1, CM, CN, 1>>(),
33 zero<typename Core::template ChangeSizes<CRL - 1, CM, CN,
34 CRR - 1>>()));
35 };
36
37 using F = RemoveConstReference<decltype(f)>;
38 using M = typename T::M;
39 using N = typename T::N;
40 using Ranks = typename T::Ranks;
41
42 return TensorTrain::OperatorViews::Enumerate<typename T::Type, M, N,
43 Ranks, F>(f);
44 }
45
46 template<class X, If<IsConstructible<typename T::Type, X>> = 0>
47 constexpr auto operator()(X &&value)
48 {
49 auto const f =
50 [value = typename T::Type(forward<X>(value))]<Index C>()
51 {
52 using Core = typename T::template Cores<C>;
53 constexpr auto CRL = Core::size(0);
54 constexpr auto CM = Core::size(1);
55 constexpr auto CN = Core::size(2);
56 constexpr auto CRR = Core::size(3);
57
58 if constexpr(C == 0)
59 {
61 unit<typename Core::template ChangeSizes<1, CM, CN, 1>>(
62 value),
63 zero<typename Core::template ChangeSizes<CRL - 1, CM,
64 CN, CRR - 1>>()));
65 }
66 else
67 {
69 unit<typename Core::template ChangeSizes<1, CM, CN,
70 1>>(),
71 zero<typename Core::template ChangeSizes<CRL - 1, CM,
72 CN, CRR - 1>>()));
73 }
74 };
75
76 using F = RemoveConstReference<decltype(f)>;
77 using M = typename T::M;
78 using N = typename T::N;
79 using Ranks = typename T::Ranks;
80
81 return TensorTrain::OperatorViews::Enumerate<typename T::Type, M, N,
82 Ranks, F>(f);
83 }
84 };
85}
86#endif // pRC_TENSOR_TRAIN_OPERATOR_UNIT_H
Definition cholesky.hpp:18
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
static constexpr auto unit()
Definition unit.hpp:12
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
constexpr auto operator()()
Definition unit.hpp:21
constexpr auto operator()(X &&value)
Definition unit.hpp:47
Definition type_traits.hpp:265