cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
truncate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
4#define pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
5
11
12namespace pRC
13{
14 template<Size C, Position P, class X, class R = RemoveReference<X>,
15 If<IsTensorish<R>> = 0, class V = typename R::Value, class VT = V,
16 If<All<IsFloat<V>, IsFloat<VT>>> = 0,
17 If<IsSatisfied<(P == Position::Left || P == Position::Right)>> = 0>
18 static inline constexpr auto truncate(X &&a,
19 VT const &tolerance = NumericLimits<VT>::epsilon())
20 {
21 using DF = RemoveReference<decltype(folding<!P>(forward<X>(a)))>;
22
23 if constexpr(C < min(DF::size(0), DF::size(1)))
24 {
25 if constexpr(P == Position::Right)
26 {
27 return expand(makeSeries<Index, typename R::Dimension() - 1>(),
28 [&](auto const... seq)
29 {
30 auto const [u, s, v] =
31 svd<C>(folding<!P>(forward<X>(a)), tolerance);
32 Tensor const lambda = norm(s);
33
34 return tuple(lambda,
35 eval(reshape<R::size(seq)...,
36 RemoveReference<decltype(u)>::size(1)>(u)),
37 eval(fromDiagonal(s / lambda) * adjoint(v)));
38 });
39 }
40 else
41 {
42 return expand(makeRange<Index, 1, typename R::Dimension{}>(),
43 [&](auto const... seq)
44 {
45 auto const [u, s, v] =
46 svd<C>(folding<!P>(forward<X>(a)), tolerance);
47 Tensor const lambda = norm(s);
48
49 return tuple(lambda, eval(u * fromDiagonal(s / lambda)),
50 eval(reshape<RemoveReference<decltype(v)>::size(1),
51 R::size(seq)...>(adjoint(v))));
52 });
53 }
54 }
55 else
56 {
57 return orthogonalize<!P>(forward<X>(a));
58 }
59 }
60}
61#endif // pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
Definition tensor.hpp:28
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto truncate(X &&a, VT const &tolerance=NumericLimits< VT >::epsilon())
Definition truncate.hpp:18
std::remove_reference_t< T > RemoveReference
Definition type_traits.hpp:56
static constexpr auto makeRange()
Definition sequence.hpp:379
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:325
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:21
static constexpr auto reshape(X &&a)
Definition reshape.hpp:17
static constexpr auto norm(Complex< T > const &a)
Definition norm.hpp:11
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344
Definition limits.hpp:13