/src/Simd/src/Simd/SimdSynetMergedConvolution8i.h
Line | Count | Source |
1 | | /* |
2 | | * Simd Library (http://ermig1979.github.io/Simd). |
3 | | * |
4 | | * Copyright (c) 2011-2024 Yermalayeu Ihar. |
5 | | * |
6 | | * Permission is hereby granted, free of charge, to any person obtaining a copy |
7 | | * of this software and associated documentation files (the "Software"), to deal |
8 | | * in the Software without restriction, including without limitation the rights |
9 | | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
10 | | * copies of the Software, and to permit persons to whom the Software is |
11 | | * furnished to do so, subject to the following conditions: |
12 | | * |
13 | | * The above copyright notice and this permission notice shall be included in |
14 | | * all copies or substantial portions of the Software. |
15 | | * |
16 | | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
17 | | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
18 | | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
19 | | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
20 | | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
21 | | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
22 | | * SOFTWARE. |
23 | | */ |
24 | | #ifndef __SimdSynetMergedConvolution8i_h__ |
25 | | #define __SimdSynetMergedConvolution8i_h__ |
26 | | |
27 | | #include "Simd/SimdArray.h" |
28 | | #include "Simd/SimdPerformance.h" |
29 | | #include "Simd/SimdRuntime.h" |
30 | | #include "Simd/SimdSynetConvolution8i.h" |
31 | | |
32 | | #ifdef _N |
33 | | #undef _N |
34 | | #endif |
35 | | |
36 | | namespace Simd |
37 | | { |
38 | | struct MergConvParam8i |
39 | | { |
40 | | size_t count; |
41 | | ConvParam conv[3]; |
42 | | |
43 | | MergConvParam8i(size_t batch, const SimdConvolutionParameters * convs, size_t count, SimdSynetCompatibilityType compatibility) |
44 | 0 | { |
45 | 0 | assert(count <= 3); |
46 | 0 | this->count = count; |
47 | 0 | for (size_t i = 0; i < count; ++i) |
48 | 0 | this->conv[i] = ConvParam(batch, convs + i, compatibility); |
49 | 0 | } |
50 | | |
51 | | bool Valid() |
52 | 0 | { |
53 | 0 | if (count < 2 || count > 3) |
54 | 0 | return false; |
55 | 0 | for (size_t i = 0; i < count; ++i) |
56 | 0 | { |
57 | 0 | ConvParam& c = conv[i]; |
58 | 0 | if (!c.Valid(SimdTensorData32f, SimdTensorData8u)) |
59 | 0 | return false; |
60 | 0 | if (c.srcF != SimdTensorFormatNhwc) |
61 | 0 | return false; |
62 | 0 | if (c.kernelY != c.kernelX || !(c.kernelY == 1 || c.kernelY == 3 || c.kernelY == 5 || c.kernelY == 7)) |
63 | 0 | return false; |
64 | 0 | if (c.strideY != c.strideX || !(c.strideY == 1 || c.strideY == 2 || c.strideY == 3)) |
65 | 0 | return false; |
66 | 0 | if (c.dilationY != 1 || c.dilationX != 1) |
67 | 0 | return false; |
68 | | |
69 | 0 | if (c.dstH == (c.srcH + c.padY + c.padH - (c.dilationY * (c.kernelY - 1) + 1) - 1) / c.strideY + 1) |
70 | 0 | c.padH--; |
71 | 0 | if (c.dstW == (c.srcW + c.padX + c.padW - (c.dilationY * (c.kernelX - 1) + 1) - 1) / c.strideX + 1) |
72 | 0 | c.padW--; |
73 | 0 | if (c.IsDepthwise() && i != count - 1) |
74 | 0 | c.dstT = SimdTensorData8u; |
75 | 0 | } |
76 | 0 | if (count == 3) |
77 | 0 | { |
78 | 0 | if (conv[0].group != 1 || (conv[0].kernelY != 1 && conv[0].kernelY != 3)) |
79 | 0 | return false; |
80 | 0 | if (conv[1].group != conv[1].srcC || conv[1].group != conv[1].dstC || (conv[1].kernelY != 3 && conv[1].kernelY != 5 && conv[1].kernelY != 7)) |
81 | 0 | return false; |
82 | 0 | if (conv[2].group != 1 || conv[2].kernelY != 1 || conv[2].strideY != 1) |
83 | 0 | return false; |
84 | 0 | } |
85 | 0 | else |
86 | 0 | { |
87 | 0 | if (conv[0].group == 1) |
88 | 0 | { |
89 | 0 | if (conv[0].kernelY != 1 && conv[0].kernelY != 3) |
90 | 0 | return false; |
91 | 0 | if (conv[1].group != conv[1].srcC || conv[1].group != conv[1].dstC || (conv[1].kernelY != 3 && conv[1].kernelY != 5 && conv[1].kernelY != 7)) |
92 | 0 | return false; |
93 | 0 | } |
94 | 0 | else |
95 | 0 | { |
96 | 0 | if (conv[0].group != conv[0].srcC || conv[0].group != conv[0].dstC || (conv[0].kernelY != 3 && conv[0].kernelY != 5 && conv[0].kernelY != 7)) |
97 | 0 | return false; |
98 | 0 | if (conv[1].group != 1 || conv[1].kernelY != 1 || conv[1].strideY != 1) |
99 | 0 | return false; |
100 | 0 | } |
101 | 0 | } |
102 | 0 | return true; |
103 | 0 | } |
104 | | |
105 | | #ifdef SIMD_PERFORMANCE_STATISTIC |
106 | | String Info() const |
107 | | { |
108 | | std::stringstream ss; |
109 | | ss << count << ":" << conv[0].batch << "x" << conv[0].srcC << "x" << conv[0].srcH << "x" << conv[0].srcW; |
110 | | for (size_t i = 0; i < count; ++i) |
111 | | ss << "-" << (conv[i].group != 1 ? String("") : ToStr(conv[i].dstC) + "x") << conv[i].kernelY << "x" << conv[i].strideY; |
112 | | ss << "-" << (conv[0].srcT == SimdTensorData32f ? "f" : "u") << (conv[count - 1].dstT == SimdTensorData32f ? "f" : "u"); |
113 | | return ss.str(); |
114 | | } |
115 | | |
116 | | int64_t Flop() const |
117 | | { |
118 | | int64_t flop = 0; |
119 | | for (size_t i = 0; i < count; ++i) |
120 | | flop += conv[i].Flop(); |
121 | | return flop; |
122 | | } |
123 | | #endif |
124 | | }; |
125 | | |
126 | | class SynetMergedConvolution8i : public Deletable |
127 | | { |
128 | | public: |
129 | | virtual const MergConvParam8i & Param() const = 0; |
130 | | |
131 | | virtual size_t ExternalBufferSize() const = 0; |
132 | | |
133 | | virtual size_t InternalBufferSize() const = 0; |
134 | | |
135 | | virtual void SetParams(const float * const * weight, SimdBool * internal, const float * const * bias, const float * const * params, const float* const* stats) = 0; |
136 | | |
137 | | virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst) = 0; |
138 | | |
139 | | #if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG)) |
140 | | virtual Base::PerformanceMeasurer* Perf(const char *func) = 0; |
141 | | #endif |
142 | | |
143 | | virtual const char* Info() const = 0; |
144 | | }; |
145 | | |
146 | | namespace Base |
147 | | { |
148 | | class SynetMergedConvolution8i : public Simd::SynetMergedConvolution8i |
149 | | { |
150 | | public: |
151 | | SynetMergedConvolution8i(const MergConvParam8i& p); |
152 | | |
153 | 0 | virtual String Desc() const { return Ext(); } |
154 | 0 | virtual String Ext() const { return "Base"; } |
155 | 0 | virtual const MergConvParam8i& Param() const { return _param; } |
156 | | virtual size_t ExternalBufferSize() const; |
157 | | virtual size_t InternalBufferSize() const; |
158 | | virtual void SetParams(const float * const * weight, SimdBool * internal, const float * const * bias, const float * const * params, const float* const* stats); |
159 | | virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst); |
160 | | |
161 | | #if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG)) |
162 | | virtual Base::PerformanceMeasurer* Perf(const char* func); |
163 | | #endif |
164 | | |
165 | | virtual const char* Info() const |
166 | 0 | { |
167 | 0 | _info = Desc(); |
168 | 0 | return _info.c_str(); |
169 | 0 | } |
170 | | |
171 | | struct AlgParam |
172 | | { |
173 | | size_t miC, maC, yStep[3], yStart[3], bufH[3], dp[2], dw[3], size; |
174 | | int32_t zero, upper; |
175 | | }; |
176 | | |
177 | | typedef void(*Convert8uTo32fPtr)(const uint8_t* src, size_t maC, size_t yBeg, size_t yEnd, size_t width, size_t channels, |
178 | | const float* scale, const float* shift, float* dst, size_t bufH, SimdSynetCompatibilityType compatibility); |
179 | | |
180 | | typedef void(*Convert32fTo8uPtr)(const float* src, size_t yBeg, size_t yEnd, size_t width, size_t channels, |
181 | | const float* scale, const float* shift, uint8_t* dst, size_t bufH, SimdSynetCompatibilityType compatibility); |
182 | | |
183 | | typedef void(*InputConvolutionPtr)(const uint8_t* src, const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, |
184 | | const int8_t* weight, const float* norm, const float* bias, const float* params, float* dst); |
185 | | |
186 | | typedef void(*DepthwiseConvolutionPtr)(const float* src, const ConvParam& p, const AlgParam & a, size_t maC, size_t yBeg, size_t yEnd, |
187 | | const float* weight, const float* bias, const float* params, const float* scale, const float* shift, uint8_t * dst); |
188 | | |
189 | | typedef void(*OutputConvolutionPtr)(const uint8_t* src, const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, |
190 | | const int8_t* weight, const float* norm, const float* bias, const float* params, const float* scale, const float* shift, int32_t* buf, uint8_t* dst, int first); |
191 | | |
192 | | protected: |
193 | | uint8_t* GetBuffer(uint8_t* buffer); |
194 | | void Quantize(const float* weight, const float* bias, size_t i, size_t q); |
195 | | void ReorderInputWeight(const ConvParam& p, Array8i & weight); |
196 | | void ReorderDepthwiseWeight(const ConvParam& p, Array32f & weight); |
197 | | void ReorderOutputWeight(const ConvParam& p, Array8i& weight); |
198 | | void DirectConvolution8i(const uint8_t* src, size_t i, size_t q, uint8_t* buf, int32_t* sum, float* dst); |
199 | | |
200 | | MergConvParam8i _param; |
201 | | bool _s8u, _d8u, _dw0, _1x1; |
202 | | size_t _sizeS, _sizeD, _sizeI[2], _sizeB[5]; |
203 | | CvtParam _cvt[3]; |
204 | | Array8u _buffer; |
205 | | Array8i _weight8i[2]; |
206 | | Array32f _weight32f, _norm[2], _bias[3], _params[3]; |
207 | | AlgParam _alg; |
208 | | Convert8uTo32fPtr _cvt8uTo32f; |
209 | | Convert32fTo8uPtr _cvt32fTo8u; |
210 | | InputConvolutionPtr _input; |
211 | | DepthwiseConvolutionPtr _depthwise; |
212 | | OutputConvolutionPtr _output[2]; |
213 | | |
214 | | private: |
215 | | #if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG)) |
216 | | Base::PerformanceMeasurer * _perf; |
217 | | #endif |
218 | | mutable String _info; |
219 | | }; |
220 | | |
221 | | class SynetMergedConvolution8iCdc : public SynetMergedConvolution8i |
222 | | { |
223 | | public: |
224 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
225 | | |
226 | | virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst); |
227 | | |
228 | | static bool Preferable(const MergConvParam8i& p); |
229 | | |
230 | | protected: |
231 | | void SetSize(size_t F); |
232 | | }; |
233 | | |
234 | | class SynetMergedConvolution8iCd : public SynetMergedConvolution8i |
235 | | { |
236 | | public: |
237 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
238 | | |
239 | | virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst); |
240 | | |
241 | | static bool Preferable(const MergConvParam8i& p); |
242 | | |
243 | | protected: |
244 | | void SetSize(size_t F); |
245 | | }; |
246 | | |
247 | | class SynetMergedConvolution8iDc : public SynetMergedConvolution8i |
248 | | { |
249 | | public: |
250 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
251 | | |
252 | | virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst); |
253 | | |
254 | | static bool Preferable(const MergConvParam8i& p); |
255 | | |
256 | | protected: |
257 | | void SetSize(size_t F); |
258 | | }; |
259 | | |
260 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters * convs, size_t count, SimdSynetCompatibilityType compatibility); |
261 | | } |
262 | | |
263 | | #ifdef SIMD_SSE41_ENABLE |
264 | | namespace Sse41 |
265 | | { |
266 | | void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input); |
267 | | |
268 | | void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise); |
269 | | |
270 | | void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output); |
271 | | |
272 | | class SynetMergedConvolution8iCdc : public Base::SynetMergedConvolution8iCdc |
273 | | { |
274 | | public: |
275 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
276 | | |
277 | 0 | virtual String Ext() const { return "Sse41"; } |
278 | | }; |
279 | | |
280 | | class SynetMergedConvolution8iCd : public Base::SynetMergedConvolution8iCd |
281 | | { |
282 | | public: |
283 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
284 | | |
285 | 0 | virtual String Ext() const { return "Sse41"; } |
286 | | }; |
287 | | |
288 | | class SynetMergedConvolution8iDc : public Base::SynetMergedConvolution8iDc |
289 | | { |
290 | | public: |
291 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
292 | | |
293 | 0 | virtual String Ext() const { return "Sse41"; } |
294 | | }; |
295 | | |
296 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility); |
297 | | } |
298 | | #endif |
299 | | |
300 | | #ifdef SIMD_AVX2_ENABLE |
301 | | namespace Avx2 |
302 | | { |
303 | | void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input); |
304 | | |
305 | | void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise); |
306 | | |
307 | | void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output); |
308 | | |
309 | | class SynetMergedConvolution8iCdc : public Sse41::SynetMergedConvolution8iCdc |
310 | | { |
311 | | public: |
312 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
313 | | |
314 | 0 | virtual String Ext() const { return "Avx2"; } |
315 | | }; |
316 | | |
317 | | class SynetMergedConvolution8iCd : public Sse41::SynetMergedConvolution8iCd |
318 | | { |
319 | | public: |
320 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
321 | | |
322 | 0 | virtual String Ext() const { return "Avx2"; } |
323 | | }; |
324 | | |
325 | | class SynetMergedConvolution8iDc : public Sse41::SynetMergedConvolution8iDc |
326 | | { |
327 | | public: |
328 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
329 | | |
330 | 0 | virtual String Ext() const { return "Avx2"; } |
331 | | }; |
332 | | |
333 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility); |
334 | | } |
335 | | #endif |
336 | | |
337 | | #ifdef SIMD_AVX512BW_ENABLE |
338 | | namespace Avx512bw |
339 | | { |
340 | | void Convert8uTo32f(const uint8_t* src, size_t maC, size_t yBeg, size_t yEnd, size_t width, size_t channels, |
341 | | const float* scale, const float* shift, float* dst, size_t bufH, SimdSynetCompatibilityType compatibility); |
342 | | |
343 | | void Convert32fTo8u(const float* src, size_t yBeg, size_t yEnd, size_t width, size_t channels, |
344 | | const float* scale, const float* shift, uint8_t* dst, size_t bufH, SimdSynetCompatibilityType compatibility); |
345 | | |
346 | | void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input); |
347 | | |
348 | | void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise); |
349 | | |
350 | | void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output); |
351 | | |
352 | | class SynetMergedConvolution8iCdc : public Avx2::SynetMergedConvolution8iCdc |
353 | | { |
354 | | public: |
355 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
356 | | |
357 | 0 | virtual String Ext() const { return "Avx512bw"; } |
358 | | }; |
359 | | |
360 | | class SynetMergedConvolution8iCd : public Avx2::SynetMergedConvolution8iCd |
361 | | { |
362 | | public: |
363 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
364 | | |
365 | 0 | virtual String Ext() const { return "Avx512bw"; } |
366 | | }; |
367 | | |
368 | | class SynetMergedConvolution8iDc : public Avx2::SynetMergedConvolution8iDc |
369 | | { |
370 | | public: |
371 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
372 | | |
373 | 0 | virtual String Ext() const { return "Avx512bw"; } |
374 | | }; |
375 | | |
376 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility); |
377 | | } |
378 | | #endif |
379 | | |
380 | | #ifdef SIMD_AVX512VNNI_ENABLE |
381 | | namespace Avx512vnni |
382 | | { |
383 | | void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input); |
384 | | |
385 | | void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output); |
386 | | |
387 | | class SynetMergedConvolution8iCdc : public Avx512bw::SynetMergedConvolution8iCdc |
388 | | { |
389 | | public: |
390 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
391 | | |
392 | | virtual String Ext() const { return "Avx512vnni"; } |
393 | | }; |
394 | | |
395 | | class SynetMergedConvolution8iCd : public Avx512bw::SynetMergedConvolution8iCd |
396 | | { |
397 | | public: |
398 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
399 | | |
400 | | virtual String Ext() const { return "Avx512vnni"; } |
401 | | }; |
402 | | |
403 | | class SynetMergedConvolution8iDc : public Avx512bw::SynetMergedConvolution8iDc |
404 | | { |
405 | | public: |
406 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
407 | | |
408 | | virtual String Ext() const { return "Avx512vnni"; } |
409 | | }; |
410 | | |
411 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility); |
412 | | } |
413 | | #endif |
414 | | |
415 | | #if defined(SIMD_AMXBF16_ENABLE) || (defined(SIMD_AVX512BW_ENABLE) && defined(SIMD_AMX_EMULATE)) |
416 | | namespace AmxBf16 |
417 | | { |
418 | | void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input); |
419 | | |
420 | | void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output); |
421 | | |
422 | | #if defined(SIMD_AMX_EMULATE) |
423 | | class SynetMergedConvolution8iCdc : public Avx512bw::SynetMergedConvolution8iCdc |
424 | | #else |
425 | | class SynetMergedConvolution8iCdc : public Avx512vnni::SynetMergedConvolution8iCdc |
426 | | #endif |
427 | | { |
428 | | public: |
429 | | SynetMergedConvolution8iCdc(const MergConvParam8i& p); |
430 | | |
431 | | virtual String Ext() const { return "AmxBf16"; } |
432 | | }; |
433 | | |
434 | | #if defined(SIMD_AMX_EMULATE) |
435 | | class SynetMergedConvolution8iCd : public Avx512bw::SynetMergedConvolution8iCd |
436 | | #else |
437 | | class SynetMergedConvolution8iCd : public Avx512vnni::SynetMergedConvolution8iCd |
438 | | #endif |
439 | | { |
440 | | public: |
441 | | SynetMergedConvolution8iCd(const MergConvParam8i& p); |
442 | | |
443 | | virtual String Ext() const { return "AmxBf16"; } |
444 | | }; |
445 | | |
446 | | #if defined(SIMD_AMX_EMULATE) |
447 | | class SynetMergedConvolution8iDc : public Avx512bw::SynetMergedConvolution8iDc |
448 | | #else |
449 | | class SynetMergedConvolution8iDc : public Avx512vnni::SynetMergedConvolution8iDc |
450 | | #endif |
451 | | { |
452 | | public: |
453 | | SynetMergedConvolution8iDc(const MergConvParam8i& p); |
454 | | |
455 | | virtual String Ext() const { return "AmxBf16"; } |
456 | | }; |
457 | | |
458 | | void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility); |
459 | | } |
460 | | #endif |
461 | | } |
462 | | #endif//__SimdSynetMergedConvolution8i_h__ |