cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
truncate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
4#define pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
5
9
10namespace pRC
11{
12 template<Size C, Position P, class X, IsTensorish R = RemoveReference<X>,
13 IsFloat VT = Value<R>>
14 requires(C < reduce<Min>(RemoveReference<decltype(folding<!P>(
15 declval<X>()))>::Sizes())) &&
16 (P == Position::Right)
17 static inline constexpr auto truncate(X &&a,
18 VT const &tolerance = NumericLimits<VT>::epsilon())
19 {
21 [&](auto const... seq)
22 {
23 auto const [u, s, v] =
24 svd<C>(folding<!P>(forward<X>(a)), tolerance);
25
26 return Tuple(eval(reshape<R::size(seq)...,
27 RemoveReference<decltype(u)>::size(1)>(u)),
28 eval(fromDiagonal(s) * adjoint(v)));
29 });
30 }
31
32 template<Size C, Position P, class X, IsTensorish R = RemoveReference<X>,
33 IsFloat VT = Value<R>>
34 requires(C < reduce<Min>(RemoveReference<decltype(folding<!P>(
35 declval<X>()))>::Sizes())) &&
36 (P == Position::Left)
37 static inline constexpr auto truncate(X &&a,
38 VT const &tolerance = NumericLimits<VT>::epsilon())
39 {
41 [&](auto const... seq)
42 {
43 auto const [u, s, v] =
44 svd<C>(folding<!P>(forward<X>(a)), tolerance);
45
46 return Tuple(eval(u * fromDiagonal(s)),
47 eval(reshape<RemoveReference<decltype(v)>::size(1),
48 R::size(seq)...>(adjoint(v))));
49 });
50 }
51
52 template<Size C, Position P, class X, IsTensorish R = RemoveReference<X>,
53 IsFloat VT = Value<R>>
54 static inline constexpr auto truncate(X &&a,
55 VT const &tolerance = NumericLimits<VT>::epsilon())
56 {
57 return orthogonalize<!P>(forward<X>(a));
58 }
59}
60#endif // pRC_TENSOR_TRAIN_ALGORITHMS_TRUNCATE_H
Definition cholesky.hpp:10
static constexpr auto svd(X &&input)
Definition svd.hpp:15
std::tuple< Ts... > Tuple
Definition basics.hpp:23
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto truncate(X &&a, VT const &tolerance=NumericLimits< VT >::epsilon())
Definition truncate.hpp:17
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:17
static constexpr auto reshape(X &&a)
Definition reshape.hpp:14
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:312
Sequence< Size, Ns... > Sizes
Definition sequence.hpp:100
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto makeRange()
Definition sequence.hpp:421
static constexpr auto orthogonalize(X &&a)
Definition orthogonalize.hpp:13
static constexpr auto folding(X &&a)
Definition folding.hpp:15
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
Definition limits.hpp:13