/src/tesseract/src/arch/intsimdmatrixavx2.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /////////////////////////////////////////////////////////////////////// |
2 | | // File: intsimdmatrixavx2.cpp |
3 | | // Description: matrix-vector product for 8-bit data on avx2. |
4 | | // Author: Ray Smith |
5 | | // |
6 | | // (C) Copyright 2017, Google Inc. |
7 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
8 | | // you may not use this file except in compliance with the License. |
9 | | // You may obtain a copy of the License at |
10 | | // http://www.apache.org/licenses/LICENSE-2.0 |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | /////////////////////////////////////////////////////////////////////// |
17 | | |
18 | | #include "intsimdmatrix.h" |
19 | | |
20 | | #if !defined(__AVX2__) |
21 | | # if defined(__i686__) || defined(__x86_64__) |
22 | | # error Implementation only for AVX2 capable architectures |
23 | | # endif |
24 | | #else |
25 | | # include <immintrin.h> |
26 | | # include <algorithm> |
27 | | # include <cstdint> |
28 | | # include <vector> |
29 | | |
30 | | # if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \ |
31 | | defined(_WIN32) && !defined(_WIN64) |
32 | | // Optimize for size (/Os) instead of using the default optimization for some |
33 | | // versions of the 32 bit Visual Studio compiler which generate buggy code. |
34 | | # pragma optimize("", off) |
35 | | # pragma optimize("s", on) |
36 | | # endif |
37 | | |
38 | | namespace tesseract { |
39 | | |
40 | | // Number of outputs held in each register. 8 x 32 bit ints. |
41 | | constexpr int kNumOutputsPerRegister = 8; |
42 | | // Maximum number of registers that we will use. |
43 | | constexpr int kMaxOutputRegisters = 8; |
44 | | // Number of inputs in the inputs register. |
45 | | constexpr int kNumInputsPerRegister = 32; |
46 | | // Number of inputs in each weight group. |
47 | | constexpr int kNumInputsPerGroup = 4; |
48 | | // Number of groups of inputs to be broadcast. |
49 | | constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup; |
50 | | |
51 | | // Functions to compute part of a matrix.vector multiplication. The weights |
52 | | // are in a very specific order (see above) in w, which is multiplied by |
53 | | // u of length num_in, to produce output v after scaling the integer results |
54 | | // by the corresponding member of scales. |
55 | | // The amount of w and scales consumed is fixed and not available to the |
56 | | // caller. The number of outputs written to v will be at most num_out. |
57 | | |
58 | | // Computes one set of 4x8 products of inputs and weights, adding to result. |
59 | | // Horizontally adds 4 adjacent results, making 8x32-bit results. |
60 | | // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers. |
61 | | // Note that wi must previously have been re-organized with blocks of 4x8 |
62 | | // weights in contiguous memory. |
63 | | // ones is a register of 16x16-bit values all equal to 1. |
64 | | // Note: wi is incremented by the amount of data read. |
65 | | // weights and reps are scratch registers. |
66 | | // This function must be inlined with references in order for the compiler to |
67 | | // correctly use the registers declared in the caller. |
68 | | static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi, |
69 | 230G | __m256i &weights, __m256i &reps, __m256i &result) { |
70 | | // Load a 4x8 block of weights. |
71 | 230G | weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi)); |
72 | 230G | wi += kNumInputsPerRegister; |
73 | | // Normalize the signs on rep_input, weights, so weights is always +ve. |
74 | 230G | reps = _mm256_sign_epi8(rep_input, weights); |
75 | 230G | weights = _mm256_sign_epi8(weights, weights); |
76 | | // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, |
77 | | // with adjacent pairs added. |
78 | 230G | weights = _mm256_maddubs_epi16(weights, reps); |
79 | | // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, |
80 | | // with adjacent pairs added. What we really want is a horizontal add of |
81 | | // 16+16=32 bit result, but there is no such instruction, so multiply by |
82 | | // 16-bit ones instead. It is probably faster than all the sign-extending, |
83 | | // permuting and adding that would otherwise be required. |
84 | 230G | weights = _mm256_madd_epi16(weights, ones); |
85 | 230G | result = _mm256_add_epi32(result, weights); |
86 | 230G | } |
87 | | |
88 | | // Load 64 bits into the bottom of a 128bit register. |
89 | | // We don't actually care what the top 64bits are, but this ends |
90 | | // up with them being zero. |
91 | 0 | static inline __m128i load64_to_128(const int8_t *wi_) { |
92 | 0 | const auto *wi = reinterpret_cast<const int64_t *>(wi_); |
93 | 0 | return _mm_set_epi64x(0, wi[0]); |
94 | 0 | } |
95 | | |
96 | | #if defined(FAST_FLOAT) |
97 | | |
98 | | static inline void ExtractResults8(__m256i result, const int8_t *wi, |
99 | 0 | const float *scales, float *v) { |
100 | 0 | __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg |
101 | 0 | __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg |
102 | 0 | __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); |
103 | 0 | __m256 scale01234567 = _mm256_loadu_ps(scales); |
104 | 0 | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
105 | 0 | result = _mm256_add_epi32(result, w256); // result += bias * 127 |
106 | 0 | __m256 res01234567 = _mm256_cvtepi32_ps(result); |
107 | 0 | result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); |
108 | 0 | res01234567 = _mm256_mul_ps(res01234567, scale01234567); |
109 | 0 | _mm256_storeu_ps(v, res01234567); |
110 | 0 | } |
111 | | |
112 | | static inline void ExtractResults16(__m256i result0, __m256i result1, |
113 | | const int8_t *&wi, const float *&scales, |
114 | 2.10G | float *&v) { |
115 | 2.10G | __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi)); |
116 | | // 8x8bit vals in bottom of 128bit reg |
117 | 2.10G | const __m256i bias_scale = |
118 | 2.10G | _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); |
119 | 2.10G | __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg |
120 | 2.10G | __m256 scale01234567 = _mm256_loadu_ps(scales); |
121 | 2.10G | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
122 | 2.10G | result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 |
123 | 2.10G | __m256 res01234567 = _mm256_cvtepi32_ps(result0); |
124 | 2.10G | result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); |
125 | 2.10G | res01234567 = _mm256_mul_ps(res01234567, scale01234567); |
126 | 2.10G | _mm256_storeu_ps(v, res01234567); |
127 | 2.10G | w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); |
128 | 2.10G | w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg |
129 | 2.10G | scale01234567 = _mm256_loadu_ps(scales + 8); |
130 | 2.10G | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
131 | 2.10G | result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 |
132 | 2.10G | res01234567 = _mm256_cvtepi32_ps(result1); |
133 | 2.10G | result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); |
134 | 2.10G | res01234567 = _mm256_mul_ps(res01234567, scale01234567); |
135 | 2.10G | _mm256_storeu_ps(v + 8, res01234567); |
136 | 2.10G | wi += 16; |
137 | 2.10G | scales += 16; |
138 | 2.10G | v += 16; |
139 | 2.10G | } |
140 | | |
141 | | // Computes part of matrix.vector v = Wu. Computes N=64 results. |
142 | | // The weights *must* be arranged so that consecutive reads from wi |
143 | | // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of |
144 | | // (kNumInputsPerGroup inputs))). After that there must be N consecutive |
145 | | // bias weights, before continuing with any more weights. |
146 | | // u must be padded out with zeros to |
147 | | // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. |
148 | | static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, |
149 | 385M | int num_in, float *v) { |
150 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
151 | | // conversion. |
152 | 385M | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
153 | 385M | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
154 | | // Initialize all the results to 0. |
155 | 385M | __m256i result0 = _mm256_setzero_si256(); |
156 | 385M | __m256i result1 = _mm256_setzero_si256(); |
157 | 385M | __m256i result2 = _mm256_setzero_si256(); |
158 | 385M | __m256i result3 = _mm256_setzero_si256(); |
159 | 385M | __m256i result4 = _mm256_setzero_si256(); |
160 | 385M | __m256i result5 = _mm256_setzero_si256(); |
161 | 385M | __m256i result6 = _mm256_setzero_si256(); |
162 | 385M | __m256i result7 = _mm256_setzero_si256(); |
163 | | // Iterate over the input (u), one registerful at a time. |
164 | 3.90G | for (int j = 0; j < num_in;) { |
165 | 3.51G | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
166 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
167 | | // kNumInputGroups times. |
168 | 30.8G | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
169 | | // Replicate the low 32 bits (4 inputs) 8 times. |
170 | 27.3G | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
171 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
172 | 27.3G | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
173 | 27.3G | __m256i weights, reps; |
174 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
175 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
176 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
177 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result2); |
178 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result3); |
179 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result4); |
180 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result5); |
181 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result6); |
182 | 27.3G | MultiplyGroup(rep_input, ones, wi, weights, reps, result7); |
183 | 27.3G | } |
184 | 3.51G | } |
185 | 385M | ExtractResults16(result0, result1, wi, scales, v); |
186 | 385M | ExtractResults16(result2, result3, wi, scales, v); |
187 | 385M | ExtractResults16(result4, result5, wi, scales, v); |
188 | 385M | ExtractResults16(result6, result7, wi, scales, v); |
189 | 385M | } |
190 | | |
191 | | // Computes part of matrix.vector v = Wu. Computes N=32 results. |
192 | | // For details see PartialMatrixDotVector64 with N=32. |
193 | | static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, |
194 | 38.9M | int num_in, float *v) { |
195 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
196 | | // conversion. |
197 | 38.9M | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
198 | 38.9M | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
199 | | // Initialize all the results to 0. |
200 | 38.9M | __m256i result0 = _mm256_setzero_si256(); |
201 | 38.9M | __m256i result1 = _mm256_setzero_si256(); |
202 | 38.9M | __m256i result2 = _mm256_setzero_si256(); |
203 | 38.9M | __m256i result3 = _mm256_setzero_si256(); |
204 | | // Iterate over the input (u), one registerful at a time. |
205 | 298M | for (int j = 0; j < num_in;) { |
206 | 259M | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
207 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
208 | | // kNumInputGroups times. |
209 | 2.33G | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
210 | | // Replicate the low 32 bits (4 inputs) 8 times. |
211 | 2.07G | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
212 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
213 | 2.07G | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
214 | 2.07G | __m256i weights, reps; |
215 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
216 | 2.07G | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
217 | 2.07G | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
218 | 2.07G | MultiplyGroup(rep_input, ones, wi, weights, reps, result2); |
219 | 2.07G | MultiplyGroup(rep_input, ones, wi, weights, reps, result3); |
220 | 2.07G | } |
221 | 259M | } |
222 | 38.9M | ExtractResults16(result0, result1, wi, scales, v); |
223 | 38.9M | ExtractResults16(result2, result3, wi, scales, v); |
224 | 38.9M | } |
225 | | |
226 | | // Computes part of matrix.vector v = Wu. Computes N=16 results. |
227 | | // For details see PartialMatrixDotVector64 with N=16. |
228 | | static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, |
229 | 484M | int num_in, float *v) { |
230 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
231 | | // conversion. |
232 | 484M | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
233 | 484M | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
234 | | // Initialize all the results to 0. |
235 | 484M | __m256i result0 = _mm256_setzero_si256(); |
236 | 484M | __m256i result1 = _mm256_setzero_si256(); |
237 | | // Iterate over the input (u), one registerful at a time. |
238 | 1.03G | for (int j = 0; j < num_in;) { |
239 | 549M | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
240 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
241 | | // kNumInputGroups times. |
242 | 2.54G | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
243 | | // Replicate the low 32 bits (4 inputs) 8 times. |
244 | 1.99G | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
245 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
246 | 1.99G | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
247 | 1.99G | __m256i weights, reps; |
248 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
249 | 1.99G | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
250 | 1.99G | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
251 | 1.99G | } |
252 | 549M | } |
253 | 484M | ExtractResults16(result0, result1, wi, scales, v); |
254 | 484M | } |
255 | | |
256 | | // Computes part of matrix.vector v = Wu. Computes N=8 results. |
257 | | // For details see PartialMatrixDotVector64 with N=8. |
258 | | static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, |
259 | 0 | int num_in, float *v) { |
260 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
261 | | // conversion. |
262 | 0 | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
263 | 0 | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
264 | | // Initialize all the results to 0. |
265 | 0 | __m256i result0 = _mm256_setzero_si256(); |
266 | | // Iterate over the input (u), one registerful at a time. |
267 | 0 | for (int j = 0; j < num_in;) { |
268 | 0 | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
269 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
270 | | // kNumInputGroups times. |
271 | 0 | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
272 | | // Replicate the low 32 bits (4 inputs) 8 times. |
273 | 0 | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
274 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
275 | 0 | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
276 | 0 | __m256i weights, reps; |
277 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
278 | 0 | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
279 | 0 | } |
280 | 0 | } |
281 | 0 | ExtractResults8(result0, wi, scales, v); |
282 | 0 | } |
283 | | |
284 | | static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, |
285 | 743M | const int8_t *u, float *v) { |
286 | 743M | const int num_out = dim1; |
287 | 743M | const int num_in = dim2 - 1; |
288 | | // Each call to a partial_func_ produces group_size outputs, except the |
289 | | // last one, which can produce less. |
290 | 743M | const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); |
291 | 743M | const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); |
292 | 743M | int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; |
293 | 743M | int output = 0; |
294 | | |
295 | 743M | int w_step = (rounded_num_in + 1) * group_size; |
296 | | |
297 | | // Run with this group size, until it would produce too much output, then |
298 | | // switch to a smaller size. |
299 | 1.12G | for (; output + group_size <= rounded_num_out; output += group_size) { |
300 | 385M | PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); |
301 | 385M | wi += w_step; |
302 | 385M | scales += group_size; |
303 | 385M | v += group_size; |
304 | 385M | } |
305 | 743M | group_size /= 2; |
306 | 743M | w_step /= 2; |
307 | | |
308 | 743M | if (output + group_size <= rounded_num_out) { |
309 | 38.9M | PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); |
310 | 38.9M | wi += w_step; |
311 | 38.9M | scales += group_size; |
312 | 38.9M | v += group_size; |
313 | 38.9M | output += group_size; |
314 | 38.9M | } |
315 | 743M | group_size /= 2; |
316 | 743M | w_step /= 2; |
317 | | |
318 | 743M | if (output + group_size <= rounded_num_out) { |
319 | 484M | PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); |
320 | 484M | wi += w_step; |
321 | 484M | scales += group_size; |
322 | 484M | v += group_size; |
323 | 484M | output += group_size; |
324 | 484M | } |
325 | 743M | group_size /= 2; |
326 | 743M | w_step /= 2; |
327 | | |
328 | 743M | if (output + group_size <= rounded_num_out) { |
329 | 0 | PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); |
330 | 0 | } |
331 | 743M | } |
332 | | #else |
333 | | static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, |
334 | | double *v) { |
335 | | __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg |
336 | | __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg |
337 | | __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); |
338 | | __m256d scale0123 = _mm256_loadu_pd(scales); |
339 | | __m256d scale4567 = _mm256_loadu_pd(scales + 4); |
340 | | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
341 | | result = _mm256_add_epi32(result, w256); // result += bias * 127 |
342 | | __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); |
343 | | result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); |
344 | | __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); |
345 | | res0123 = _mm256_mul_pd(res0123, scale0123); |
346 | | res4567 = _mm256_mul_pd(res4567, scale4567); |
347 | | _mm256_storeu_pd(v, res0123); |
348 | | _mm256_storeu_pd(v + 4, res4567); |
349 | | } |
350 | | |
351 | | static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, |
352 | | const double *&scales, double *&v) { |
353 | | __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi)); |
354 | | // 8x8bit vals in bottom of 128bit reg |
355 | | const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); |
356 | | __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg |
357 | | __m256d scale0123 = _mm256_loadu_pd(scales); |
358 | | __m256d scale4567 = _mm256_loadu_pd(scales + 4); |
359 | | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
360 | | result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 |
361 | | __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); |
362 | | result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); |
363 | | __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); |
364 | | res0123 = _mm256_mul_pd(res0123, scale0123); |
365 | | res4567 = _mm256_mul_pd(res4567, scale4567); |
366 | | _mm256_storeu_pd(v, res0123); |
367 | | _mm256_storeu_pd(v + 4, res4567); |
368 | | w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); |
369 | | w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg |
370 | | scale0123 = _mm256_loadu_pd(scales + 8); |
371 | | scale4567 = _mm256_loadu_pd(scales + 12); |
372 | | w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> |
373 | | result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 |
374 | | res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); |
375 | | result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); |
376 | | res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); |
377 | | res0123 = _mm256_mul_pd(res0123, scale0123); |
378 | | res4567 = _mm256_mul_pd(res4567, scale4567); |
379 | | _mm256_storeu_pd(v + 8, res0123); |
380 | | _mm256_storeu_pd(v + 12, res4567); |
381 | | wi += 16; |
382 | | scales += 16; |
383 | | v += 16; |
384 | | } |
385 | | |
386 | | // Computes part of matrix.vector v = Wu. Computes N=64 results. |
387 | | // The weights *must* be arranged so that consecutive reads from wi |
388 | | // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of |
389 | | // (kNumInputsPerGroup inputs))). After that there must be N consecutive |
390 | | // bias weights, before continuing with any more weights. |
391 | | // u must be padded out with zeros to |
392 | | // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. |
393 | | static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u, |
394 | | int num_in, double *v) { |
395 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
396 | | // conversion. |
397 | | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
398 | | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
399 | | // Initialize all the results to 0. |
400 | | __m256i result0 = _mm256_setzero_si256(); |
401 | | __m256i result1 = _mm256_setzero_si256(); |
402 | | __m256i result2 = _mm256_setzero_si256(); |
403 | | __m256i result3 = _mm256_setzero_si256(); |
404 | | __m256i result4 = _mm256_setzero_si256(); |
405 | | __m256i result5 = _mm256_setzero_si256(); |
406 | | __m256i result6 = _mm256_setzero_si256(); |
407 | | __m256i result7 = _mm256_setzero_si256(); |
408 | | // Iterate over the input (u), one registerful at a time. |
409 | | for (int j = 0; j < num_in;) { |
410 | | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
411 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
412 | | // kNumInputGroups times. |
413 | | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
414 | | // Replicate the low 32 bits (4 inputs) 8 times. |
415 | | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
416 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
417 | | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
418 | | __m256i weights, reps; |
419 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
420 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
421 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
422 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result2); |
423 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result3); |
424 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result4); |
425 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result5); |
426 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result6); |
427 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result7); |
428 | | } |
429 | | } |
430 | | ExtractResults16(result0, result1, wi, scales, v); |
431 | | ExtractResults16(result2, result3, wi, scales, v); |
432 | | ExtractResults16(result4, result5, wi, scales, v); |
433 | | ExtractResults16(result6, result7, wi, scales, v); |
434 | | } |
435 | | |
436 | | // Computes part of matrix.vector v = Wu. Computes N=32 results. |
437 | | // For details see PartialMatrixDotVector64 with N=32. |
438 | | static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u, |
439 | | int num_in, double *v) { |
440 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
441 | | // conversion. |
442 | | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
443 | | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
444 | | // Initialize all the results to 0. |
445 | | __m256i result0 = _mm256_setzero_si256(); |
446 | | __m256i result1 = _mm256_setzero_si256(); |
447 | | __m256i result2 = _mm256_setzero_si256(); |
448 | | __m256i result3 = _mm256_setzero_si256(); |
449 | | // Iterate over the input (u), one registerful at a time. |
450 | | for (int j = 0; j < num_in;) { |
451 | | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
452 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
453 | | // kNumInputGroups times. |
454 | | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
455 | | // Replicate the low 32 bits (4 inputs) 8 times. |
456 | | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
457 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
458 | | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
459 | | __m256i weights, reps; |
460 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
461 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
462 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
463 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result2); |
464 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result3); |
465 | | } |
466 | | } |
467 | | ExtractResults16(result0, result1, wi, scales, v); |
468 | | ExtractResults16(result2, result3, wi, scales, v); |
469 | | } |
470 | | |
471 | | // Computes part of matrix.vector v = Wu. Computes N=16 results. |
472 | | // For details see PartialMatrixDotVector64 with N=16. |
473 | | static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u, |
474 | | int num_in, double *v) { |
475 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
476 | | // conversion. |
477 | | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
478 | | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
479 | | // Initialize all the results to 0. |
480 | | __m256i result0 = _mm256_setzero_si256(); |
481 | | __m256i result1 = _mm256_setzero_si256(); |
482 | | // Iterate over the input (u), one registerful at a time. |
483 | | for (int j = 0; j < num_in;) { |
484 | | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
485 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
486 | | // kNumInputGroups times. |
487 | | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
488 | | // Replicate the low 32 bits (4 inputs) 8 times. |
489 | | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
490 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
491 | | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
492 | | __m256i weights, reps; |
493 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
494 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
495 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result1); |
496 | | } |
497 | | } |
498 | | ExtractResults16(result0, result1, wi, scales, v); |
499 | | } |
500 | | |
501 | | // Computes part of matrix.vector v = Wu. Computes N=8 results. |
502 | | // For details see PartialMatrixDotVector64 with N=8. |
503 | | static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u, |
504 | | int num_in, double *v) { |
505 | | // Register containing 16-bit ones for horizontal add with 16->32 bit |
506 | | // conversion. |
507 | | __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); |
508 | | __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); |
509 | | // Initialize all the results to 0. |
510 | | __m256i result0 = _mm256_setzero_si256(); |
511 | | // Iterate over the input (u), one registerful at a time. |
512 | | for (int j = 0; j < num_in;) { |
513 | | __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); |
514 | | // Inputs are processed in groups of kNumInputsPerGroup, replicated |
515 | | // kNumInputGroups times. |
516 | | for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { |
517 | | // Replicate the low 32 bits (4 inputs) 8 times. |
518 | | __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); |
519 | | // Rotate the inputs in groups of 4, so the next 4 inputs are ready. |
520 | | inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); |
521 | | __m256i weights, reps; |
522 | | // Mul-add, with horizontal add of the 4 inputs to each of the results. |
523 | | MultiplyGroup(rep_input, ones, wi, weights, reps, result0); |
524 | | } |
525 | | } |
526 | | ExtractResults8(result0, wi, scales, v); |
527 | | } |
528 | | |
529 | | static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales, |
530 | | const int8_t *u, double *v) { |
531 | | const int num_out = dim1; |
532 | | const int num_in = dim2 - 1; |
533 | | // Each call to a partial_func_ produces group_size outputs, except the |
534 | | // last one, which can produce less. |
535 | | const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); |
536 | | const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); |
537 | | int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; |
538 | | int output = 0; |
539 | | |
540 | | int w_step = (rounded_num_in + 1) * group_size; |
541 | | |
542 | | // Run with this group size, until it would produce too much output, then |
543 | | // switch to a smaller size. |
544 | | for (; output + group_size <= rounded_num_out; output += group_size) { |
545 | | PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); |
546 | | wi += w_step; |
547 | | scales += group_size; |
548 | | v += group_size; |
549 | | } |
550 | | group_size /= 2; |
551 | | w_step /= 2; |
552 | | |
553 | | if (output + group_size <= rounded_num_out) { |
554 | | PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); |
555 | | wi += w_step; |
556 | | scales += group_size; |
557 | | v += group_size; |
558 | | output += group_size; |
559 | | } |
560 | | group_size /= 2; |
561 | | w_step /= 2; |
562 | | |
563 | | if (output + group_size <= rounded_num_out) { |
564 | | PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); |
565 | | wi += w_step; |
566 | | scales += group_size; |
567 | | v += group_size; |
568 | | output += group_size; |
569 | | } |
570 | | group_size /= 2; |
571 | | w_step /= 2; |
572 | | |
573 | | if (output + group_size <= rounded_num_out) { |
574 | | PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); |
575 | | } |
576 | | } |
577 | | #endif |
578 | | |
579 | | const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = { |
580 | | // Function. |
581 | | matrixDotVector, |
582 | | // Number of 32 bit outputs held in each register. |
583 | | kNumOutputsPerRegister, |
584 | | // Maximum number of registers that we will use to hold outputs. |
585 | | kMaxOutputRegisters, |
586 | | // Number of 8 bit inputs in the inputs register. |
587 | | kNumInputsPerRegister, |
588 | | // Number of inputs in each weight group. |
589 | | kNumInputsPerGroup |
590 | | }; |
591 | | |
592 | | } // namespace tesseract. |
593 | | |
594 | | #endif |