@@ -42,39 +42,50 @@ class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
4242// Specialized blocking for quantized implementations.
4343// Used by TensorContractionThreadPool, inputs must have dimensions that are
4444// multiples of 32.
45- template <int KcFactor, typename Index>
46- struct ComputeGemmByColBlockingSizes <QInt8, QUInt8, KcFactor, Index> {
47- void operator ()(Index& k, Index& m, Index& n, Index num_threads)
45+ template <typename Index,
46+ typename LeftTensor,
47+ typename left_nocontract_t , typename left_contract_t ,
48+ bool left_inner_dim_contiguous, bool left_inner_dim_reordered, int LeftAlignment,
49+ typename RightTensor,
50+ typename right_nocontract_t , typename right_contract_t ,
51+ bool right_inner_dim_contiguous, bool right_inner_dim_reordered, int RightAlignment, int ShardingType>
52+ class TensorContractionBlocking <TensorContractionInputMapper<QInt8, Index, Lhs, LeftTensor, left_nocontract_t , left_contract_t , 32 , left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor, right_nocontract_t , right_contract_t , 32 , right_inner_dim_contiguous, right_inner_dim_reordered, RightAlignment>, Index, ShardingType> {
53+ public:
54+
55+ typedef QInt8 LhsScalar;
56+ typedef QUInt8 RhsScalar;
57+
58+ TensorContractionBlocking (Index k, Index m, Index n, Index num_threads = 1 ) :
59+ kc_ (k), mc_(m), nc_(n)
4860 {
4961 eigen_assert (m % 32 == 0 );
50- eigen_assert (n % 32 == 0 );
5162 eigen_assert (k % 32 == 0 );
5263 if (!k || !m || !n) {
5364 return ;
5465 }
55- n = (((n / num_threads) + 31 ) / 32 ) * 32 ;
56- }
57- };
5866
59- // Specialized blocking for quantized implementations.
60- // Used by TensorContractionThreadPool, inputs must have dimensions that are
61- // multiples of 32.
62- template <int KcFactor, typename Index>
63- struct ComputeGemmByRowBlockingSizes <QInt8, QUInt8, KcFactor, Index> {
64- void operator ()(Index& k, Index& m, Index& n, Index num_threads)
65- {
66- eigen_assert (m % 32 == 0 );
67- eigen_assert (n % 32 == 0 || n == 1 );
68- eigen_assert (k % 32 == 0 );
69- if (!k || !m || !n) {
70- return ;
67+ if (ShardingType == ShardByCol) {
68+ eigen_assert (n % 32 == 0 );
69+ nc_ = (((n / num_threads) + 31 ) / 32 ) * 32 ;
7170 }
72- // Special case to avoid breaking the unimplemented matrix-vector case
73- if (n == 1 ) {
74- n = 32 ;
71+ else {
72+ eigen_assert (n % 32 == 0 || n == 1 );
73+ // Special case to avoid breaking the unimplemented matrix-vector case
74+ if (n == 1 ) {
75+ nc_ = 32 ;
76+ }
77+ mc_ = (((m / num_threads) + 31 ) / 32 ) * 32 ;
7578 }
76- m = (((m / num_threads) + 31 ) / 32 ) * 32 ;
7779 }
80+
81+ EIGEN_ALWAYS_INLINE Index kc () const { return kc_; }
82+ EIGEN_ALWAYS_INLINE Index mc () const { return mc_; }
83+ EIGEN_ALWAYS_INLINE Index nc () const { return nc_; }
84+
85+ private:
86+ Index kc_;
87+ Index mc_;
88+ Index nc_;
7889};
7990
8091// Specialized blocking for quantized implementations.
0 commit comments