cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
enumerate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
5
10
11namespace pRC
12{
13 template<class F, class... Xs,
15 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
19 declval<Xs>().template core<0>()...))>> = 0>
20 static inline constexpr auto enumerate(F &&f, Xs &&...args)
21 {
23 return expand(makeSeries<Index, Dimension{}>(),
24 [&f, &args...](auto const... seq)
25 {
26 auto core = [&f, &args...]<Index N>()
27 {
28 return forward<F>(f).template operator()<N>(
29 forward<Xs>(args).template core<N>()...);
30 };
31
32 using T = Common<typename decltype(core.template
33 operator()<seq>())::Type...>;
34
35 if constexpr(Common<
36 typename decltype(core.template operator()<
37 seq>())::Dimension...>() ==
38 3)
39 {
40 using Ranks =
42 Sizes<decltype(core.template operator()<
43 seq>())::size(0)...>())),
44 decltype(chip<Dimension() - 1>(Sizes<
45 decltype(core.template operator()<seq>())::size(
46 2)...>()))>;
47
48 using N = pRC::Sizes<
49 decltype(core.template operator()<seq>())::size(1)...>;
50
51 return TensorTrain::TensorViews::Enumerate<T, N, Ranks, F,
53 forward<F>(f), view(forward<Xs>(args))...);
54 }
55 if constexpr(Common<
56 typename decltype(core.template operator()<
57 seq>())::Dimension...>() ==
58 4)
59 {
60 using Ranks =
62 Sizes<decltype(core.template operator()<
63 seq>())::size(0)...>())),
64 decltype(chip<Dimension() - 1>(Sizes<
65 decltype(core.template operator()<seq>())::size(
66 3)...>()))>;
67
68 using M = pRC::Sizes<
69 decltype(core.template operator()<seq>())::size(1)...>;
70
71 using N = pRC::Sizes<
72 decltype(core.template operator()<seq>())::size(2)...>;
73
74 return TensorTrain::OperatorViews::Enumerate<T, M, N, Ranks,
75 F,
77 forward<F>(f), view(forward<Xs>(args))...);
78 }
79 });
80 }
81
82 template<class F, class... Xs,
84 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
88 declval<Xs>().template core<0>()...))>> = 0>
89 static inline constexpr auto enumerate(Xs &&...args)
90 {
91 return enumerate(F(), forward<Xs>(args)...);
92 }
93
94 template<class F, class... Xs,
96 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
98 If<IsInvocable<Enumerate<F>, F, Xs &...>> = 0>
99 static inline constexpr auto enumerate(F &&f, Xs &&...args)
100 {
101 return eval(enumerate(forward<F>(f), args...));
102 }
103
104 template<class F, class... Xs,
106 TensorTrain::IsOperatorish<RemoveReference<Xs>>>...>> = 0,
109 static inline constexpr auto enumerate(Xs &&...args)
110 {
111 return eval(enumerate<F>(args...));
112 }
113}
114#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ENUMERATE_H
Definition sequence.hpp:56
Definition enumerate.hpp:19
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr X view(X &&a)
Definition view.hpp:12
typename CommonTypes< Ts... >::Type Common
Definition common.hpp:55
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr auto enumerate(F &&f, Xs &&...args)
Definition enumerate.hpp:20
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:551
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344