cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
loop.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
5
10
11namespace pRC
12{
13 template<class F, class... Xs,
15 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
19 static inline constexpr auto loop(F &&f, Xs &&...args)
20 {
22 return expand(makeSeries<Index, Dimension{}>(),
23 [&f, &args...](auto const... seq)
24 {
25 auto core = [&f, &args...]<Index N>()
26 {
27 return forward<F>(f)(
28 forward<Xs>(args).template core<N>()...);
29 };
30
31 using T = Common<typename decltype(core.template
32 operator()<seq>())::Type...>;
33
34 if constexpr(Common<
35 typename decltype(core.template operator()<
36 seq>())::Dimension...>() ==
37 3)
38 {
39 using Ranks =
41 Sizes<decltype(core.template operator()<
42 seq>())::size(0)...>())),
43 decltype(chip<Dimension() - 1>(Sizes<
44 decltype(core.template operator()<seq>())::size(
45 2)...>()))>;
46
47 using N = pRC::Sizes<
48 decltype(core.template operator()<seq>())::size(1)...>;
49
50 return TensorTrain::TensorViews::Loop<T, N, Ranks, F,
52 forward<F>(f), view(forward<Xs>(args))...);
53 }
54 if constexpr(Common<
55 typename decltype(core.template operator()<
56 seq>())::Dimension...>() ==
57 4)
58 {
59 using Ranks =
61 Sizes<decltype(core.template operator()<
62 seq>())::size(0)...>())),
63 decltype(chip<Dimension() - 1>(Sizes<
64 decltype(core.template operator()<seq>())::size(
65 3)...>()))>;
66
67 using M = pRC::Sizes<
68 decltype(core.template operator()<seq>())::size(1)...>;
69
70 using N = pRC::Sizes<
71 decltype(core.template operator()<seq>())::size(2)...>;
72
73 return TensorTrain::OperatorViews::Loop<T, M, N, Ranks, F,
75 forward<F>(f), view(forward<Xs>(args))...);
76 }
77 });
78 }
79
80 template<class F, class... Xs,
82 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
86 static inline constexpr auto loop(Xs &&...args)
87 {
88 return loop(F(), forward<Xs>(args)...);
89 }
90
91 template<class F, class... Xs,
93 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
95 If<IsInvocable<Loop<F>, F, Xs &...>> = 0>
96 static inline constexpr auto loop(F &&f, Xs &&...args)
97 {
98 return eval(loop(forward<F>(f), args...));
99 }
100
101 template<class F, class... Xs,
103 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
105 If<IsInvocable<Loop<F>, Xs &...>> = 0>
106 static inline constexpr auto loop(Xs &&...args)
107 {
108 return eval(loop<F>(args...));
109 }
110}
111#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
Definition sequence.hpp:56
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr X view(X &&a)
Definition view.hpp:12
typename CommonTypes< Ts... >::Type Common
Definition common.hpp:55
static constexpr auto loop(F &&f, Xs &&...args)
Definition loop.hpp:22
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:551
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344