Coverage Report

Created: 2024-05-04 12:45

/proc/self/cwd/external/gemmlowp/internal/kernel.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
// kernel.h: general definitions for kernels.
16
17
#ifndef GEMMLOWP_INTERNAL_KERNEL_H_
18
#define GEMMLOWP_INTERNAL_KERNEL_H_
19
20
#include "../public/bit_depth.h"
21
#include "common.h"
22
23
namespace gemmlowp {
24
25
// Explanation of general gemmlowp terminology
26
// ===========================================
27
//
28
// We use the following abbreviations:
29
// LHS = "left-hand side"
30
// RHS = "right-hand side"
31
// Sometimes when referring to either LHS or RHS, we just say a "Side".
32
//
33
// In a matrix product of a MxK matrix times a KxN matrix,
34
// we call K the 'depth'. Note that M is the number of rows
35
// of the result (and of the LHS), and N is the number of columns
36
// of the result (and of the RHS).
37
//
38
// In each of the LHS and RHS matrices, we call 'width' the
39
// other dimension, besides the depth. So in the LHS, 'width'
40
// is the number of rows, while in the RHS, 'width' is the number
41
// of columns.
42
//
43
//  So in the LHS MxK matrix, the depth is K and the width in M.
44
// And in the RHS KxN matrix, the depth is K and the width in N.
45
//
46
// This is illustrated in this picture:
47
//
48
//                             RHS width
49
//                        <----------------->
50
//                        +-----------------+ ^
51
//                        |       RHS       | | Depth
52
//                        +-----------------+ v
53
//                 ^ +--+ +-----------------+
54
//                 | |L | |                 |
55
//       LHS width | |H | |      Result     |
56
//                 | |S | |                 |
57
//                 v +--+ +-----------------+
58
//                   <-->
59
//                   Depth
60
61
// Explanation of gemmlowp kernel formats and "cells"
62
// ==================================================
63
//
64
// Kernels operate on small LHS and RHS blocks that fit in registers.
65
// These blocks are stored contiguously in memory, but not always
66
// in a traditional column-major or row-major order; instead,
67
// they consist of a number of sub-blocks, which we call "cells",
68
// that are stored in column-major or row-major order. However,
69
// what really matters to us is not so much rows vs columns, but
70
// rather width vs depth. So we refer to "width-major" and "depth-major"
71
// storage orders. In the LHS, width-major means row-major,
72
// while in the RHS, width-major means column-major.
73
// There is also a third possibility, "diagonal order",
74
// which is unused at the moment.
75
//
76
// We aim to treat both sides, LHS and RHS, on an equal footing,
77
// so we call them both 'sides'. A KernelFormat thus is just a pair
78
// of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
79
// contains a CellFormat and a number of cells; cells are only ever
80
// stacked in the width dimension, which means stacked vertically in the
81
// LHS and stacked horizondally in the RHS.
82
//
83
// Example
84
// =======
85
//
86
// Let's work out the data layout expected by a kernel having the
87
// following format (the struct names here are defined below in this file):
88
//
89
// KernelFormat<
90
//   KernelSideFormat<CellFormat<3, 4>, 3>,
91
//   KernelSideFormat<CellFormat<5, 4>, 2>
92
// >
93
//
94
// The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
95
// 3 cells, each cell having dimensions (width=3, depth=4), laid out in
96
// DepthMajor order (the default value, see CellFormat). In the LHS,
97
// DepthMajor means column-major, so the LHS cells are of size 3x4 in
98
// column-major order, so the LHS layout is:
99
//
100
// 0  3  6  9
101
// 1  4  7  10
102
// 2  5  8  11
103
// 12 15 18 21
104
// 13 16 19 22
105
// 14 17 20 23
106
// 24 27 30 33
107
// 25 28 31 34
108
// 26 29 32 35
109
//
110
// The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
111
// 2 cells each having dimensions (width=5, depth=4), laid out in
112
// DepthMajor order (the default value, see CellFormat). In the RHS,
113
// DepthMajor means row-major, so the RHS cells are of size 4x5 in
114
// row-major order, so the RHS layout is:
115
//
116
// 0  1  2  3  4  20 21 22 23 24
117
// 5  6  7  8  9  25 26 27 28 29
118
// 10 11 12 13 14 30 31 32 33 34
119
// 15 16 17 18 19 35 36 37 38 39
120
121
// CellOrder enumerates the possible storage orders (=layouts) for
122
// a cell (see explanation above).
123
enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
124
125
// CellFormat describes how data is laid
126
// out in a cell. That is, a CellOrder together with actual dimensions.
127
template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
128
struct CellFormat {
129
  static constexpr int kWidth = tWidth;
130
  static constexpr int kDepth = tDepth;
131
  static constexpr CellOrder kOrder = tOrder;
132
133
  static constexpr int kSize = kWidth * kDepth;
134
};
135
136
// KernelSideFormat describes how data is laid out in a kernel side
137
// (i.e. LHS or RHS). That is, a CellFormat together with a number of
138
// cells. These cells are always stacked in the Width dimension.
139
// For example, in the LHS case, the Width dimension is the rows dimension,
140
// se we're saying that in the LHS, cells are stacked vertically.
141
// We never stack cells in the Depth dimension.
142
template <typename tCellFormat, int tCells>
143
struct KernelSideFormat {
144
  typedef tCellFormat Cell;
145
  static constexpr int kCells = tCells;
146
  static constexpr int kWidth = kCells * Cell::kWidth;
147
  static constexpr int kDepth = Cell::kDepth;
148
  typedef std::uint8_t Scalar;       // The scalar type of the Format.
149
  typedef std::uint8_t InputScalar;  // The scalar type of the original input.
150
};
151
152
// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
153
// packs converts it to int8.
154
template <typename tCellFormat, int tCells>
155
struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
156
  typedef std::int8_t Scalar;
157
  typedef std::uint8_t InputScalar;
158
};
159
160
// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
161
// pack conversion.
162
template <typename tCellFormat, int tCells>
163
struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
164
  typedef std::int8_t Scalar;
165
  typedef std::int8_t InputScalar;
166
};
167
168
// KernelFormat describes fully the input data layout that a kernel expects.
169
// It consists of two KernelSideFormat's, one for LHS and one for RHS.
170
template <typename tLhs, typename tRhs>
171
struct KernelFormat {
172
  typedef tLhs Lhs;
173
  typedef tRhs Rhs;
174
175
  static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
176
  static constexpr int kDepth = Lhs::Cell::kDepth;
177
  static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells;
178
  static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells;
179
};
180
181
0
inline const char* CellOrderName(CellOrder o) {
182
0
  switch (o) {
183
0
    case CellOrder::DepthMajor:
184
0
      return "DepthMajor";
185
0
    case CellOrder::WidthMajor:
186
0
      return "WidthMajor";
187
0
    case CellOrder::Diagonal:
188
0
      return "Diagonal";
189
0
    default:
190
0
      assert(false);
191
0
      return nullptr;
192
0
  }
193
0
}
194
195
// Returns the offset into a cell, at which a given coefficient is stored.
196
template <typename CellFormat>
197
0
inline int OffsetIntoCell(int w, int d) {
198
0
  const int size = CellFormat::kWidth;
199
0
  switch (CellFormat::kOrder) {
200
0
    case CellOrder::DepthMajor:
201
0
      return w + d * CellFormat::kWidth;
202
0
    case CellOrder::WidthMajor:
203
0
      return d + w * CellFormat::kDepth;
204
0
    case CellOrder::Diagonal:
205
0
      assert(CellFormat::kWidth == CellFormat::kDepth);
206
0
      return ((size + w - d) * size + d) % (size * size);
207
0
    default:
208
0
      assert(false);
209
0
      return 0;
210
0
  }
211
0
}
212
213
// KernelBase is the virtual base class below all kernels.
214
// The idea is that we don't need to templatize all our code on the exact
215
// kernel type; we only need to templatize on kernel format. Kernels
216
// sharing the same format can thus share the same packing/unpacking code.
217
struct KernelBase {
218
  virtual const char* Name() const = 0;
219
220
  // This is the kernel implementation. We use the word 'run' consistently
221
  // throughout gemmlowp to mean an inner loop, the implementation of which
222
  // is to be provided by a separate optimized function.
223
  virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
224
                   std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
225
                   const std::uint8_t* rhs_ptr, std::size_t start_depth,
226
                   std::size_t run_depth) const = 0;
227
228
0
  virtual ~KernelBase() {}
229
};
230
231
template <typename InputKernelScalarType, typename KernelScalarType>
232
struct ZeroPointInputValue {};
233
234
template <>
235
struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
236
  static constexpr std::uint8_t kValue = 0;
237
};
238
239
template <>
240
struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
241
  static constexpr std::uint8_t kValue = 128;
242
};
243
244
template <>
245
struct ZeroPointInputValue<std::int8_t, std::int8_t> {
246
  static constexpr std::uint8_t kValue = 0;
247
};
248
249
}  // namespace gemmlowp
250
251
#endif  // GEMMLOWP_INTERNAL_KERNEL_H_