cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
mul.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
4#define pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
5
8
9namespace pRC
10{
11 template<class XA, class XB,
12 TensorTrain::IsOperatorish RA = RemoveReference<XA>,
13 TensorTrain::IsTensorish RB = RemoveReference<XB>>
14 requires IsSame<typename RA::N, typename RB::N>
15 static inline constexpr auto operator*(XA &&a, XB &&b)
16 {
17 return loop(
18 []<class XLA, class XLB>(XLA &&a, XLB &&b)
19 {
20 using CA = RemoveReference<XLA>;
21 using CB = RemoveReference<XLB>;
22
23 return reshape<CA::size(0) * CB::size(0), CA::size(1),
24 CA::size(3) * CB::size(2)>(permute<0, 3, 1, 2, 4>(
25 eval(contract<2, 1>(forward<XLA>(a), forward<XLB>(b)))));
26 },
27 forward<XA>(a), forward<XB>(b));
28 }
29
30 template<class XA, class XB,
31 TensorTrain::IsTensorish RA = RemoveReference<XA>,
32 TensorTrain::IsOperatorish RB = RemoveReference<XB>>
33 requires IsSame<typename RA::N, typename RB::M>
34 static inline constexpr auto operator*(XA &&a, XB &&b)
35 {
36 return loop(
37 []<class XLA, class XLB>(XLA &&a, XLB &&b)
38 {
39 using CA = RemoveReference<XLA>;
40 using CB = RemoveReference<XLB>;
41
42 return reshape<CA::size(0) * CB::size(0), CB::size(2),
43 CA::size(2) * CB::size(3)>(permute<0, 2, 3, 1, 4>(
44 eval(contract<1, 1>(forward<XLA>(a), forward<XLB>(b)))));
45 },
46 forward<XA>(a), forward<XB>(b));
47 }
48
49 template<class XA, class XB,
50 TensorTrain::IsOperatorish RA = RemoveReference<XA>,
51 TensorTrain::IsOperatorish RB = RemoveReference<XB>>
52 requires IsSame<typename RA::N, typename RB::M>
53 static inline constexpr auto operator*(XA &&a, XB &&b)
54 {
55 return loop(
56 []<class XLA, class XLB>(XLA &&a, XLB &&b)
57 {
58 using CA = RemoveReference<XLA>;
59 using CB = RemoveReference<XLB>;
60
61 return reshape<CA::size(0) * CB::size(0), CA::size(1),
62 CB::size(2), CA::size(3) * CB::size(3)>(
64 contract<2, 1>(forward<XLA>(a), forward<XLB>(b)))));
65 },
66 forward<XA>(a), forward<XB>(b));
67 }
68}
69#endif // pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
Definition cholesky.hpp:10
static constexpr auto contract(X &&a)
Definition contract.hpp:20
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto reshape(X &&a)
Definition reshape.hpp:14
static constexpr auto operator*(JacobiRotation< TA > const &a, JacobiRotation< TB > const &b)
Definition jacobi_rotation.hpp:298
static constexpr auto loop(F &&f, Xs &&...args)
Definition loop.hpp:20
static constexpr auto permute(Sequence< T, Is... > const)
Definition sequence.hpp:487
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12