cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
loop.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
5
9
10namespace pRC
11{
12 template<class F, class... Xs>
13 requires((TensorTrain::IsTensorish<RemoveReference<Xs>> ||
14 TensorTrain::IsOperatorish<RemoveReference<Xs>>) &&
15 ...) &&
16 (sizeof...(Xs) > 0) && (isSame(RemoveReference<Xs>::Dimension...)) &&
17 requires { declval<F>()(declval<Xs>().template core<0>()...); }
18 static inline constexpr auto loop(F &&f, Xs &&...args)
19 {
20 if constexpr((IsInvocable<View, Xs> && ...))
21 {
22 constexpr auto Dimension =
25 [&f, &args...](auto const... seq)
26 {
27 auto core = [&f, &args...]<Index N>()
28 {
29 return forward<F>(f)(
30 forward<Xs>(args).template core<N>()...);
31 };
32
33 using T =
34 Common<typename decltype(core.template
35 operator()<seq>())::Type...>;
36
37 if constexpr(Common<Constant<Size,
38 decltype(core.template operator()<
39 seq>())::Dimension>...>() == 3)
40 {
41 using Ranks =
43 Sizes<decltype(core.template operator()<
44 seq>())::size(0)...>())),
45 decltype(chip<Dimension - 1>(
46 Sizes<decltype(core.template operator()<
47 seq>())::size(2)...>()))>;
48
49 using N =
50 pRC::Sizes<decltype(core.template
51 operator()<seq>())::size(1)...>;
52
53 return TensorTrain::TensorViews::Loop<T, N, Ranks, F,
54 RemoveReference<decltype(view(
55 forward<Xs>(args)))>...>(forward<F>(f),
56 view(forward<Xs>(args))...);
57 }
58 if constexpr(Common<Constant<Size,
59 decltype(core.template operator()<
60 seq>())::Dimension>...>() == 4)
61 {
62 using Ranks =
64 Sizes<decltype(core.template operator()<
65 seq>())::size(0)...>())),
66 decltype(chip<Dimension - 1>(
67 Sizes<decltype(core.template operator()<
68 seq>())::size(3)...>()))>;
69
70 using M =
71 Sizes<decltype(core.template
72 operator()<seq>())::size(1)...>;
73
74 using N =
75 Sizes<decltype(core.template
76 operator()<seq>())::size(2)...>;
77
78 return TensorTrain::OperatorViews::Loop<T, M, N, Ranks,
79 F,
80 RemoveReference<decltype(view(
81 forward<Xs>(args)))>...>(forward<F>(f),
82 view(forward<Xs>(args))...);
83 }
84 });
85 }
86 else
87 {
88 return eval(loop(forward<F>(f), args...));
89 }
90 }
91
92 template<class F, class... Xs>
93 requires((TensorTrain::IsTensorish<RemoveReference<Xs>> ||
94 TensorTrain::IsOperatorish<RemoveReference<Xs>>) &&
95 ...) &&
96 requires { loop(declval<F>(), declval<Xs>()...); }
97 static inline constexpr auto loop(Xs &&...args)
98 {
99 return loop(F(), forward<Xs>(args)...);
100 }
101}
102#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_LOOP_H
Definition value.hpp:12
Definition sequence.hpp:29
Definition concepts.hpp:31
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition cholesky.hpp:10
static constexpr auto loop(F &&f, Xs &&...args)
Definition loop.hpp:18
std::size_t Size
Definition basics.hpp:31
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr decltype(auto) view(X &&a)
Definition view.hpp:13
std::common_type_t< Ts... > Common
Definition basics.hpp:53
static constexpr auto isSame(X &&arg, Xs &&...args)
Definition is_same.hpp:9
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto loop(F &&f, Xs &&...args)
Definition loop.hpp:20
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
std::integral_constant< T, V > Constant
Definition basics.hpp:38
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12