3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
13 template<IsSizes Ranks,
class X,
class R = RemoveReference<X>,
14 IsFloat VT = Value<R>>
15 requires(TensorTrain::IsTensorish<R> ||
16 TensorTrain::IsOperatorish<R>) &&
17 IsFloat<Value<R>> && (Ranks::Dimension == R::Dimension - 1)
18 static inline constexpr auto round(X &&a,
22 [&a, &tolerance](
auto const... seq)
28 Index C = 0,
class XA,
class XB,
class... Xs>(
29 auto const &self, XA &&a, XB &&b, Xs &&...cores)
35 constexpr auto RAL = RA::size(0);
36 constexpr auto RAR = RA::size(RA::Dimension - 1);
40 if constexpr(C == R::Dimension - 1)
44 [&self, &tolerance, &totalNorm](
45 auto const cores,
auto const... seq)
52 forward<Xs>(cores)..., forward<XA>(a)));
60 auto const newLeft =
expand(
62 [&q = q](
auto const... seq)
64 return reshape<RAL, RA::size(seq)...,
70 using RN =
decltype(rNorm);
72 auto const isCloseToZero = rNorm <=
77 isCloseToZero ?
eval(
zero<
decltype(r)>())
87 totalNorm *=
exp(
log(rNorm) /
91 return self.template operator()<
D, C + 1>(
92 newRight, forward<Xs>(cores)..., newLeft);
100 using RF =
decltype(firstNorm);
102 auto const isCloseToZero = firstNorm <=
106 auto const newFirst = isCloseToZero
108 :
eval(forward<XB>(b) / firstNorm);
116 totalNorm *=
exp(
log(firstNorm) /
122 [&totalNorm](
auto const cores,
126 get<seq>(cores) * totalNorm...);
129 forward<Xs>(cores)..., forward<XA>(a)));
133 constexpr auto cutoff =
min(Ranks::size(C - 1),
134 RAL, RA::size() / RAL);
136 auto const [u, s, v] =
141 auto const newRight =
expand(
143 [&v = v](
auto const... seq)
145 return reshape<cutoff, RA::size(seq)...,
154 return self.template operator()<
D, C - 1>(
155 newLeft, forward<Xs>(cores)..., newRight);
158 })(forward<X>(a).template core<seq>()...);
162 template<
class X,
class R = RemoveReference<X>, IsFloat VT = Value<R>>
163 requires(TensorTrain::IsTensorish<R> ||
164 TensorTrain::IsOperatorish<R>) &&
166 static inline constexpr auto round(X &&a,
172 template<Size C,
class X,
class R = RemoveReference<X>,
173 IsFloat VT = Value<R>>
174 requires(TensorTrain::IsTensorish<R> ||
175 TensorTrain::IsOperatorish<R>) &&
177 static inline constexpr auto round(X &&a,
182 forward<X>(a), tolerance);
186 IsFloat VT = Value<R>>
187 requires(TensorTrain::IsTensorish<R> ||
188 TensorTrain::IsOperatorish<R>) &&
189 IsFloat<Value<R>> && (
sizeof...(Cs) > 1) &&
190 (
sizeof...(Cs) == R::Dimension - 1)
191 static inline constexpr auto round(X &&a,
194 return round<
Sizes<Cs...>>(forward<X>(a), tolerance);
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition sequence.hpp:29
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
Definition cholesky.hpp:10
static constexpr auto contract(X &&a)
Definition contract.hpp:20
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto svd(X &&input)
Definition svd.hpp:15
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto reverse(Direction const D)
Definition direction.hpp:24
static constexpr auto qr(X &&input)
Definition qr.hpp:13
RecursiveLambda(X &&) -> RecursiveLambda< RemoveReference< X > >
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:17
static constexpr auto reshape(X &&a)
Definition reshape.hpp:14
static constexpr auto forwardAsTuple(Xs &&...args)
Definition basics.hpp:65
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:312
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto makeRange()
Definition sequence.hpp:421
Direction
Definition direction.hpp:9
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto log(T const &a)
Definition log.hpp:11
static constexpr auto round(T const &a)
Definition round.hpp:11
static constexpr auto exp(T const &a)
Definition exp.hpp:11
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12