pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
mul.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
4#define pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
5
9
10namespace pRC
11{
12 template<class XA, class XB, class RA = RemoveReference<XA>,
13 class RB = RemoveReference<XB>, If<TensorTrain::IsOperatorish<RA>> = 0,
14 If<IsInvocable<View, XA>> = 0, If<TensorTrain::IsTensorish<RB>> = 0,
15 If<IsInvocable<View, XB>> = 0,
16 If<IsSame<typename RA::N, typename RB::N>> = 0>
17 static inline constexpr auto operator*(XA &&a, XB &&b)
18 {
19 return loop(
20 []<class XLA, class XLB>(XLA &&a, XLB &&b)
21 {
24
25 return reshape<CA::size(0) * CB::size(0), CA::size(1),
26 CA::size(3) * CB::size(2)>(permute<0, 3, 1, 2, 4>(
28 },
30 }
31
32 template<class XA, class XB, class RA = RemoveReference<XA>,
33 class RB = RemoveReference<XB>, If<TensorTrain::IsTensorish<RA>> = 0,
34 If<IsInvocable<View, XA>> = 0, If<TensorTrain::IsOperatorish<RB>> = 0,
35 If<IsInvocable<View, XB>> = 0,
36 If<IsSame<typename RA::N, typename RB::M>> = 0>
37 static inline constexpr auto operator*(XA &&a, XB &&b)
38 {
39 return loop(
40 []<class XLA, class XLB>(XLA &&a, XLB &&b)
41 {
44
45 return reshape<CA::size(0) * CB::size(0), CB::size(2),
46 CA::size(2) * CB::size(3)>(permute<0, 2, 3, 1, 4>(
48 },
50 }
51
52 template<class XA, class XB, class RA = RemoveReference<XA>,
53 class RB = RemoveReference<XB>, If<TensorTrain::IsOperatorish<RA>> = 0,
54 If<IsInvocable<View, XA>> = 0, If<TensorTrain::IsOperatorish<RB>> = 0,
55 If<IsInvocable<View, XB>> = 0,
56 If<IsSame<typename RA::N, typename RB::M>> = 0>
57 static inline constexpr auto operator*(XA &&a, XB &&b)
58 {
59 return loop(
60 []<class XLA, class XLB>(XLA &&a, XLB &&b)
61 {
64
65 return reshape<CA::size(0) * CB::size(0), CA::size(1),
66 CB::size(2), CA::size(3) * CB::size(3)>(
69 },
71 }
72
73 template<class XA, class XB, class RA = RemoveReference<XA>,
74 class RB = RemoveReference<XB>,
75 If<Any<TensorTrain::IsOperatorish<RA>, TensorTrain::IsTensorish<RA>>> =
76 0,
77 If<Any<TensorTrain::IsOperatorish<RB>, TensorTrain::IsTensorish<RB>>> =
78 0,
79 If<Not<All<IsInvocable<View, XA>, IsInvocable<View, XB>>>> = 0,
80 If<IsInvocable<Mul, XA &, XB &>> = 0>
81 static inline constexpr auto operator*(XA &&a, XB &&b)
82 {
83 return eval(a * b);
84 }
85}
86#endif // pRC_TENSOR_TRAIN_OPERATOR_FUNCTIONS_MUL_H
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
std::remove_reference_t< T > RemoveReference
Definition type_traits.hpp:56
static constexpr auto loop(F &&f, Xs &&...args)
Applies a function element-wise to Tensors.
Definition loop.hpp:31
static constexpr auto operator*(JacobiRotation< TA > const &a, JacobiRotation< TB > const &b)
Definition jacobi_rotation.hpp:311
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr auto reshape(X &&a)
Reshapes a Tensor.
Definition reshape.hpp:29