Coverage Report

Created: 2024-02-28 06:46

/src/tesseract/src/arch/intsimdmatrixavx2.cpp
Line
Count
Source (jump to first uncovered line)
1
///////////////////////////////////////////////////////////////////////
2
// File:        intsimdmatrixavx2.cpp
3
// Description: matrix-vector product for 8-bit data on avx2.
4
// Author:      Ray Smith
5
//
6
// (C) Copyright 2017, Google Inc.
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
// you may not use this file except in compliance with the License.
9
// You may obtain a copy of the License at
10
// http://www.apache.org/licenses/LICENSE-2.0
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
///////////////////////////////////////////////////////////////////////
17
18
#include "intsimdmatrix.h"
19
20
#if !defined(__AVX2__)
21
#  if defined(__i686__) || defined(__x86_64__)
22
#    error Implementation only for AVX2 capable architectures
23
#  endif
24
#else
25
#  include <immintrin.h>
26
#  include <algorithm>
27
#  include <cstdint>
28
#  include <vector>
29
30
#  if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \
31
      defined(_WIN32) && !defined(_WIN64)
32
// Optimize for size (/Os) instead of using the default optimization for some
33
// versions of the 32 bit Visual Studio compiler which generate buggy code.
34
#    pragma optimize("", off)
35
#    pragma optimize("s", on)
36
#  endif
37
38
namespace tesseract {
39
40
// Number of outputs held in each register. 8 x 32 bit ints.
41
constexpr int kNumOutputsPerRegister = 8;
42
// Maximum number of registers that we will use.
43
constexpr int kMaxOutputRegisters = 8;
44
// Number of inputs in the inputs register.
45
constexpr int kNumInputsPerRegister = 32;
46
// Number of inputs in each weight group.
47
constexpr int kNumInputsPerGroup = 4;
48
// Number of groups of inputs to be broadcast.
49
constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
50
51
// Functions to compute part of a matrix.vector multiplication. The weights
52
// are in a very specific order (see above) in w, which is multiplied by
53
// u of length num_in, to produce output v after scaling the integer results
54
// by the corresponding member of scales.
55
// The amount of w and scales consumed is fixed and not available to the
56
// caller. The number of outputs written to v will be at most num_out.
57
58
// Computes one set of 4x8 products of inputs and weights, adding to result.
59
// Horizontally adds 4 adjacent results, making 8x32-bit results.
60
// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
61
// Note that wi must previously have been re-organized with blocks of 4x8
62
// weights in contiguous memory.
63
// ones is a register of 16x16-bit values all equal to 1.
64
// Note: wi is incremented by the amount of data read.
65
// weights and reps are scratch registers.
66
// This function must be inlined with references in order for the compiler to
67
// correctly use the registers declared in the caller.
68
static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
69
230G
                                 __m256i &weights, __m256i &reps, __m256i &result) {
70
  // Load a 4x8 block of weights.
71
230G
  weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
72
230G
  wi += kNumInputsPerRegister;
73
  // Normalize the signs on rep_input, weights, so weights is always +ve.
74
230G
  reps = _mm256_sign_epi8(rep_input, weights);
75
230G
  weights = _mm256_sign_epi8(weights, weights);
76
  // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
77
  // with adjacent pairs added.
78
230G
  weights = _mm256_maddubs_epi16(weights, reps);
79
  // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
80
  // with  adjacent pairs added. What we really want is a horizontal add of
81
  // 16+16=32 bit result, but there is no such instruction, so multiply by
82
  // 16-bit ones instead. It is probably faster than all the sign-extending,
83
  // permuting and adding that would otherwise be required.
84
230G
  weights = _mm256_madd_epi16(weights, ones);
85
230G
  result = _mm256_add_epi32(result, weights);
86
230G
}
87
88
// Load 64 bits into the bottom of a 128bit register.
89
// We don't actually care what the top 64bits are, but this ends
90
// up with them being zero.
91
0
static inline __m128i load64_to_128(const int8_t *wi_) {
92
0
  const auto *wi = reinterpret_cast<const int64_t *>(wi_);
93
0
  return _mm_set_epi64x(0, wi[0]);
94
0
}
95
96
#if defined(FAST_FLOAT)
97
98
static inline void ExtractResults8(__m256i result, const int8_t *wi,
99
0
                                   const float *scales, float *v) {
100
0
  __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
101
0
  __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
102
0
  __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
103
0
  __m256 scale01234567 = _mm256_loadu_ps(scales);
104
0
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
105
0
  result = _mm256_add_epi32(result, w256);     // result += bias * 127
106
0
  __m256 res01234567 = _mm256_cvtepi32_ps(result);
107
0
  result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
108
0
  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
109
0
  _mm256_storeu_ps(v, res01234567);
110
0
}
111
112
static inline void ExtractResults16(__m256i result0, __m256i result1,
113
                                    const int8_t *&wi, const float *&scales,
114
2.10G
                                    float *&v) {
115
2.10G
  __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
116
  // 8x8bit vals in bottom of 128bit reg
117
2.10G
  const __m256i bias_scale =
118
2.10G
      _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
119
2.10G
  __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
120
2.10G
  __m256 scale01234567 = _mm256_loadu_ps(scales);
121
2.10G
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
122
2.10G
  result0 = _mm256_add_epi32(result0, w256);   // result += bias * 127
123
2.10G
  __m256 res01234567 = _mm256_cvtepi32_ps(result0);
124
2.10G
  result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
125
2.10G
  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
126
2.10G
  _mm256_storeu_ps(v, res01234567);
127
2.10G
  w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
128
2.10G
  w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
129
2.10G
  scale01234567 = _mm256_loadu_ps(scales + 8);
130
2.10G
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
131
2.10G
  result1 = _mm256_add_epi32(result1, w256);   // result += bias * 127
132
2.10G
  res01234567 = _mm256_cvtepi32_ps(result1);
133
2.10G
  result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
134
2.10G
  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
135
2.10G
  _mm256_storeu_ps(v + 8, res01234567);
136
2.10G
  wi += 16;
137
2.10G
  scales += 16;
138
2.10G
  v += 16;
139
2.10G
}
140
141
// Computes part of matrix.vector v = Wu. Computes N=64 results.
142
// The weights *must* be arranged so that consecutive reads from wi
143
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
144
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
145
// bias weights, before continuing with any more weights.
146
// u must be padded out with zeros to
147
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
148
static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
149
385M
                                     int num_in, float *v) {
150
  // Register containing 16-bit ones for horizontal add with 16->32 bit
151
  // conversion.
152
385M
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
153
385M
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
154
  // Initialize all the results to 0.
155
385M
  __m256i result0 = _mm256_setzero_si256();
156
385M
  __m256i result1 = _mm256_setzero_si256();
157
385M
  __m256i result2 = _mm256_setzero_si256();
158
385M
  __m256i result3 = _mm256_setzero_si256();
159
385M
  __m256i result4 = _mm256_setzero_si256();
160
385M
  __m256i result5 = _mm256_setzero_si256();
161
385M
  __m256i result6 = _mm256_setzero_si256();
162
385M
  __m256i result7 = _mm256_setzero_si256();
163
  // Iterate over the input (u), one registerful at a time.
164
3.90G
  for (int j = 0; j < num_in;) {
165
3.51G
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
166
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
167
    // kNumInputGroups times.
168
30.8G
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
169
      // Replicate the low 32 bits (4 inputs) 8 times.
170
27.3G
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
171
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
172
27.3G
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
173
27.3G
      __m256i weights, reps;
174
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
175
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
176
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
177
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
178
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
179
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
180
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
181
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
182
27.3G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
183
27.3G
    }
184
3.51G
  }
185
385M
  ExtractResults16(result0, result1, wi, scales, v);
186
385M
  ExtractResults16(result2, result3, wi, scales, v);
187
385M
  ExtractResults16(result4, result5, wi, scales, v);
188
385M
  ExtractResults16(result6, result7, wi, scales, v);
189
385M
}
190
191
// Computes part of matrix.vector v = Wu. Computes N=32 results.
192
// For details see PartialMatrixDotVector64 with N=32.
193
static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
194
38.9M
                                     int num_in, float *v) {
195
  // Register containing 16-bit ones for horizontal add with 16->32 bit
196
  // conversion.
197
38.9M
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
198
38.9M
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
199
  // Initialize all the results to 0.
200
38.9M
  __m256i result0 = _mm256_setzero_si256();
201
38.9M
  __m256i result1 = _mm256_setzero_si256();
202
38.9M
  __m256i result2 = _mm256_setzero_si256();
203
38.9M
  __m256i result3 = _mm256_setzero_si256();
204
  // Iterate over the input (u), one registerful at a time.
205
298M
  for (int j = 0; j < num_in;) {
206
259M
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
207
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
208
    // kNumInputGroups times.
209
2.33G
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
210
      // Replicate the low 32 bits (4 inputs) 8 times.
211
2.07G
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
212
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
213
2.07G
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
214
2.07G
      __m256i weights, reps;
215
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
216
2.07G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
217
2.07G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
218
2.07G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
219
2.07G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
220
2.07G
    }
221
259M
  }
222
38.9M
  ExtractResults16(result0, result1, wi, scales, v);
223
38.9M
  ExtractResults16(result2, result3, wi, scales, v);
224
38.9M
}
225
226
// Computes part of matrix.vector v = Wu. Computes N=16 results.
227
// For details see PartialMatrixDotVector64 with N=16.
228
static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
229
484M
                                     int num_in, float *v) {
230
  // Register containing 16-bit ones for horizontal add with 16->32 bit
231
  // conversion.
232
484M
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
233
484M
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
234
  // Initialize all the results to 0.
235
484M
  __m256i result0 = _mm256_setzero_si256();
236
484M
  __m256i result1 = _mm256_setzero_si256();
237
  // Iterate over the input (u), one registerful at a time.
238
1.03G
  for (int j = 0; j < num_in;) {
239
549M
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
240
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
241
    // kNumInputGroups times.
242
2.54G
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
243
      // Replicate the low 32 bits (4 inputs) 8 times.
244
1.99G
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
245
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
246
1.99G
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
247
1.99G
      __m256i weights, reps;
248
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
249
1.99G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
250
1.99G
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
251
1.99G
    }
252
549M
  }
253
484M
  ExtractResults16(result0, result1, wi, scales, v);
254
484M
}
255
256
// Computes part of matrix.vector v = Wu. Computes N=8 results.
257
// For details see PartialMatrixDotVector64 with N=8.
258
static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
259
0
                                           int num_in, float *v) {
260
  // Register containing 16-bit ones for horizontal add with 16->32 bit
261
  // conversion.
262
0
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
263
0
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
264
  // Initialize all the results to 0.
265
0
  __m256i result0 = _mm256_setzero_si256();
266
  // Iterate over the input (u), one registerful at a time.
267
0
  for (int j = 0; j < num_in;) {
268
0
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
269
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
270
    // kNumInputGroups times.
271
0
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
272
      // Replicate the low 32 bits (4 inputs) 8 times.
273
0
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
274
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
275
0
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
276
0
      __m256i weights, reps;
277
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
278
0
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
279
0
    }
280
0
  }
281
0
  ExtractResults8(result0, wi, scales, v);
282
0
}
283
284
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
285
743M
                            const int8_t *u, float *v) {
286
743M
  const int num_out = dim1;
287
743M
  const int num_in = dim2 - 1;
288
  // Each call to a partial_func_ produces group_size outputs, except the
289
  // last one, which can produce less.
290
743M
  const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
291
743M
  const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
292
743M
  int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
293
743M
  int output = 0;
294
295
743M
  int w_step = (rounded_num_in + 1) * group_size;
296
297
  // Run with this group size, until it would produce too much output, then
298
  // switch to a smaller size.
299
1.12G
  for (; output + group_size <= rounded_num_out; output += group_size) {
300
385M
    PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
301
385M
    wi += w_step;
302
385M
    scales += group_size;
303
385M
    v += group_size;
304
385M
  }
305
743M
  group_size /= 2;
306
743M
  w_step /= 2;
307
308
743M
  if (output + group_size <= rounded_num_out) {
309
38.9M
    PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
310
38.9M
    wi += w_step;
311
38.9M
    scales += group_size;
312
38.9M
    v += group_size;
313
38.9M
    output += group_size;
314
38.9M
  }
315
743M
  group_size /= 2;
316
743M
  w_step /= 2;
317
318
743M
  if (output + group_size <= rounded_num_out) {
319
484M
    PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
320
484M
    wi += w_step;
321
484M
    scales += group_size;
322
484M
    v += group_size;
323
484M
    output += group_size;
324
484M
  }
325
743M
  group_size /= 2;
326
743M
  w_step /= 2;
327
328
743M
  if (output + group_size <= rounded_num_out) {
329
0
    PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
330
0
  }
331
743M
}
332
#else
333
static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
334
                                   double *v) {
335
  __m128i w128 = load64_to_128(wi);          // 8x8bit vals in bottom of 128bit reg
336
  __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
337
  __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
338
  __m256d scale0123 = _mm256_loadu_pd(scales);
339
  __m256d scale4567 = _mm256_loadu_pd(scales + 4);
340
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
341
  result = _mm256_add_epi32(result, w256);     // result += bias * 127
342
  __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
343
  result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
344
  __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
345
  res0123 = _mm256_mul_pd(res0123, scale0123);
346
  res4567 = _mm256_mul_pd(res4567, scale4567);
347
  _mm256_storeu_pd(v, res0123);
348
  _mm256_storeu_pd(v + 4, res4567);
349
}
350
351
static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
352
                                    const double *&scales, double *&v) {
353
  __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
354
  // 8x8bit vals in bottom of 128bit reg
355
  const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
356
  __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
357
  __m256d scale0123 = _mm256_loadu_pd(scales);
358
  __m256d scale4567 = _mm256_loadu_pd(scales + 4);
359
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
360
  result0 = _mm256_add_epi32(result0, w256);   // result += bias * 127
361
  __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
362
  result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
363
  __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
364
  res0123 = _mm256_mul_pd(res0123, scale0123);
365
  res4567 = _mm256_mul_pd(res4567, scale4567);
366
  _mm256_storeu_pd(v, res0123);
367
  _mm256_storeu_pd(v + 4, res4567);
368
  w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
369
  w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
370
  scale0123 = _mm256_loadu_pd(scales + 8);
371
  scale4567 = _mm256_loadu_pd(scales + 12);
372
  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
373
  result1 = _mm256_add_epi32(result1, w256);   // result += bias * 127
374
  res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
375
  result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
376
  res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
377
  res0123 = _mm256_mul_pd(res0123, scale0123);
378
  res4567 = _mm256_mul_pd(res4567, scale4567);
379
  _mm256_storeu_pd(v + 8, res0123);
380
  _mm256_storeu_pd(v + 12, res4567);
381
  wi += 16;
382
  scales += 16;
383
  v += 16;
384
}
385
386
// Computes part of matrix.vector v = Wu. Computes N=64 results.
387
// The weights *must* be arranged so that consecutive reads from wi
388
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
389
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
390
// bias weights, before continuing with any more weights.
391
// u must be padded out with zeros to
392
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
393
static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
394
                                     int num_in, double *v) {
395
  // Register containing 16-bit ones for horizontal add with 16->32 bit
396
  // conversion.
397
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
398
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
399
  // Initialize all the results to 0.
400
  __m256i result0 = _mm256_setzero_si256();
401
  __m256i result1 = _mm256_setzero_si256();
402
  __m256i result2 = _mm256_setzero_si256();
403
  __m256i result3 = _mm256_setzero_si256();
404
  __m256i result4 = _mm256_setzero_si256();
405
  __m256i result5 = _mm256_setzero_si256();
406
  __m256i result6 = _mm256_setzero_si256();
407
  __m256i result7 = _mm256_setzero_si256();
408
  // Iterate over the input (u), one registerful at a time.
409
  for (int j = 0; j < num_in;) {
410
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
411
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
412
    // kNumInputGroups times.
413
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
414
      // Replicate the low 32 bits (4 inputs) 8 times.
415
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
416
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
417
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
418
      __m256i weights, reps;
419
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
420
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
421
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
422
      MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
423
      MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
424
      MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
425
      MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
426
      MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
427
      MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
428
    }
429
  }
430
  ExtractResults16(result0, result1, wi, scales, v);
431
  ExtractResults16(result2, result3, wi, scales, v);
432
  ExtractResults16(result4, result5, wi, scales, v);
433
  ExtractResults16(result6, result7, wi, scales, v);
434
}
435
436
// Computes part of matrix.vector v = Wu. Computes N=32 results.
437
// For details see PartialMatrixDotVector64 with N=32.
438
static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
439
                                     int num_in, double *v) {
440
  // Register containing 16-bit ones for horizontal add with 16->32 bit
441
  // conversion.
442
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
443
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
444
  // Initialize all the results to 0.
445
  __m256i result0 = _mm256_setzero_si256();
446
  __m256i result1 = _mm256_setzero_si256();
447
  __m256i result2 = _mm256_setzero_si256();
448
  __m256i result3 = _mm256_setzero_si256();
449
  // Iterate over the input (u), one registerful at a time.
450
  for (int j = 0; j < num_in;) {
451
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
452
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
453
    // kNumInputGroups times.
454
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
455
      // Replicate the low 32 bits (4 inputs) 8 times.
456
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
457
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
458
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
459
      __m256i weights, reps;
460
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
461
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
462
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
463
      MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
464
      MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
465
    }
466
  }
467
  ExtractResults16(result0, result1, wi, scales, v);
468
  ExtractResults16(result2, result3, wi, scales, v);
469
}
470
471
// Computes part of matrix.vector v = Wu. Computes N=16 results.
472
// For details see PartialMatrixDotVector64 with N=16.
473
static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
474
                                     int num_in, double *v) {
475
  // Register containing 16-bit ones for horizontal add with 16->32 bit
476
  // conversion.
477
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
478
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
479
  // Initialize all the results to 0.
480
  __m256i result0 = _mm256_setzero_si256();
481
  __m256i result1 = _mm256_setzero_si256();
482
  // Iterate over the input (u), one registerful at a time.
483
  for (int j = 0; j < num_in;) {
484
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
485
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
486
    // kNumInputGroups times.
487
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
488
      // Replicate the low 32 bits (4 inputs) 8 times.
489
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
490
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
491
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
492
      __m256i weights, reps;
493
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
494
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
495
      MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
496
    }
497
  }
498
  ExtractResults16(result0, result1, wi, scales, v);
499
}
500
501
// Computes part of matrix.vector v = Wu. Computes N=8 results.
502
// For details see PartialMatrixDotVector64 with N=8.
503
static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
504
                                           int num_in, double *v) {
505
  // Register containing 16-bit ones for horizontal add with 16->32 bit
506
  // conversion.
507
  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
508
  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
509
  // Initialize all the results to 0.
510
  __m256i result0 = _mm256_setzero_si256();
511
  // Iterate over the input (u), one registerful at a time.
512
  for (int j = 0; j < num_in;) {
513
    __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
514
    // Inputs are processed in groups of kNumInputsPerGroup, replicated
515
    // kNumInputGroups times.
516
    for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
517
      // Replicate the low 32 bits (4 inputs) 8 times.
518
      __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
519
      // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
520
      inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
521
      __m256i weights, reps;
522
      // Mul-add, with horizontal add of the 4 inputs to each of the results.
523
      MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
524
    }
525
  }
526
  ExtractResults8(result0, wi, scales, v);
527
}
528
529
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
530
                            const int8_t *u, double *v) {
531
  const int num_out = dim1;
532
  const int num_in = dim2 - 1;
533
  // Each call to a partial_func_ produces group_size outputs, except the
534
  // last one, which can produce less.
535
  const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
536
  const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
537
  int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
538
  int output = 0;
539
540
  int w_step = (rounded_num_in + 1) * group_size;
541
542
  // Run with this group size, until it would produce too much output, then
543
  // switch to a smaller size.
544
  for (; output + group_size <= rounded_num_out; output += group_size) {
545
    PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
546
    wi += w_step;
547
    scales += group_size;
548
    v += group_size;
549
  }
550
  group_size /= 2;
551
  w_step /= 2;
552
553
  if (output + group_size <= rounded_num_out) {
554
    PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
555
    wi += w_step;
556
    scales += group_size;
557
    v += group_size;
558
    output += group_size;
559
  }
560
  group_size /= 2;
561
  w_step /= 2;
562
563
  if (output + group_size <= rounded_num_out) {
564
    PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
565
    wi += w_step;
566
    scales += group_size;
567
    v += group_size;
568
    output += group_size;
569
  }
570
  group_size /= 2;
571
  w_step /= 2;
572
573
  if (output + group_size <= rounded_num_out) {
574
    PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
575
  }
576
}
577
#endif
578
579
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
580
    // Function.
581
    matrixDotVector,
582
    // Number of 32 bit outputs held in each register.
583
    kNumOutputsPerRegister,
584
    // Maximum number of registers that we will use to hold outputs.
585
    kMaxOutputRegisters,
586
    // Number of 8 bit inputs in the inputs register.
587
    kNumInputsPerRegister,
588
    // Number of inputs in each weight group.
589
    kNumInputsPerGroup
590
};
591
592
} // namespace tesseract.
593
594
#endif