cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
enumerate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
5
9
10namespace pRC
11{
12 template<class F, class... Xs>
13 requires((TensorTrain::IsTensorish<RemoveReference<Xs>> ||
14 TensorTrain::IsOperatorish<RemoveReference<Xs>>) &&
15 ...) &&
16 (sizeof...(Xs) > 0) && (isSame(RemoveReference<Xs>::Dimension...)) &&
17 requires {
18 declval<F>().template operator()<0>(
19 declval<Xs>().template core<0>()...);
20 }
21 static inline constexpr auto enumerate(F &&f, Xs &&...args)
22 {
23 if constexpr((IsInvocable<View, Xs> && ...))
24 {
25 constexpr auto Dimension =
28 [&f, &args...](auto const... seq)
29 {
30 auto core = [&f, &args...]<Index N>()
31 {
32 return forward<F>(f).template operator()<N>(
33 forward<Xs>(args).template core<N>()...);
34 };
35
36 using T =
37 Common<typename decltype(core.template
38 operator()<seq>())::Type...>;
39
40 if constexpr(Common<Constant<Size,
41 decltype(core.template operator()<
42 seq>())::Dimension>...>() == 3)
43 {
44 using Ranks =
46 Sizes<decltype(core.template operator()<
47 seq>())::size(0)...>())),
48 decltype(chip<Dimension - 1>(
49 Sizes<decltype(core.template operator()<
50 seq>())::size(2)...>()))>;
51
52 using N =
53 pRC::Sizes<decltype(core.template
54 operator()<seq>())::size(1)...>;
55
57 F,
58 RemoveReference<decltype(view(
59 forward<Xs>(args)))>...>(forward<F>(f),
60 view(forward<Xs>(args))...);
61 }
62 if constexpr(Common<Constant<Size,
63 decltype(core.template operator()<
64 seq>())::Dimension>...>() == 4)
65 {
66 using Ranks =
68 Sizes<decltype(core.template operator()<
69 seq>())::size(0)...>())),
70 decltype(chip<Dimension - 1>(
71 Sizes<decltype(core.template operator()<
72 seq>())::size(3)...>()))>;
73
74 using M =
75 Sizes<decltype(core.template
76 operator()<seq>())::size(1)...>;
77
78 using N =
79 Sizes<decltype(core.template
80 operator()<seq>())::size(2)...>;
81
83 Ranks, F,
84 RemoveReference<decltype(view(
85 forward<Xs>(args)))>...>(forward<F>(f),
86 view(forward<Xs>(args))...);
87 }
88 });
89 }
90 else
91 {
92 return eval(enumerate(forward<F>(f), args...));
93 }
94 }
95
96 template<class F, class... Xs>
97 requires((TensorTrain::IsTensorish<RemoveReference<Xs>> ||
98 TensorTrain::IsOperatorish<RemoveReference<Xs>>) &&
99 ...) &&
100 requires { enumerate(declval<F>(), declval<Xs>()...); }
101 static inline constexpr auto enumerate(Xs &&...args)
102 {
103 return enumerate(F(), forward<Xs>(args)...);
104 }
105}
106#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
Definition value.hpp:12
Definition sequence.hpp:29
Definition enumerate.hpp:20
Definition concepts.hpp:31
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition cholesky.hpp:10
std::size_t Size
Definition basics.hpp:31
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr decltype(auto) view(X &&a)
Definition view.hpp:13
std::common_type_t< Ts... > Common
Definition basics.hpp:53
static constexpr auto isSame(X &&arg, Xs &&...args)
Definition is_same.hpp:9
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
std::integral_constant< T, V > Constant
Definition basics.hpp:38
static constexpr auto enumerate(F &&f, Xs &&...args)
Definition enumerate.hpp:21
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12