Coverage Report

Created: 2026-02-14 07:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/Simd/src/Simd/SimdSynetMergedConvolution8i.h
Line
Count
Source
1
/*
2
* Simd Library (http://ermig1979.github.io/Simd).
3
*
4
* Copyright (c) 2011-2024 Yermalayeu Ihar.
5
*
6
* Permission is hereby granted, free of charge, to any person obtaining a copy
7
* of this software and associated documentation files (the "Software"), to deal
8
* in the Software without restriction, including without limitation the rights
9
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
* copies of the Software, and to permit persons to whom the Software is
11
* furnished to do so, subject to the following conditions:
12
*
13
* The above copyright notice and this permission notice shall be included in
14
* all copies or substantial portions of the Software.
15
*
16
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
* SOFTWARE.
23
*/
24
#ifndef __SimdSynetMergedConvolution8i_h__
25
#define __SimdSynetMergedConvolution8i_h__
26
27
#include "Simd/SimdArray.h"
28
#include "Simd/SimdPerformance.h"
29
#include "Simd/SimdRuntime.h"
30
#include "Simd/SimdSynetConvolution8i.h"
31
32
#ifdef _N
33
#undef _N
34
#endif
35
36
namespace Simd
37
{
38
    struct MergConvParam8i
39
    {
40
        size_t count;
41
        ConvParam conv[3];
42
43
        MergConvParam8i(size_t batch, const SimdConvolutionParameters * convs, size_t count, SimdSynetCompatibilityType compatibility)
44
0
        {
45
0
            assert(count <= 3);
46
0
            this->count = count;
47
0
            for (size_t i = 0; i < count; ++i)
48
0
                this->conv[i] = ConvParam(batch, convs + i, compatibility);
49
0
        }
50
51
        bool Valid()
52
0
        {
53
0
            if (count < 2 || count > 3)
54
0
                return false;
55
0
            for (size_t i = 0; i < count; ++i)
56
0
            {
57
0
                ConvParam& c = conv[i];
58
0
                if (!c.Valid(SimdTensorData32f, SimdTensorData8u))
59
0
                    return false;
60
0
                if (c.srcF != SimdTensorFormatNhwc)
61
0
                    return false;
62
0
                if (c.kernelY != c.kernelX || !(c.kernelY == 1 || c.kernelY == 3 || c.kernelY == 5 || c.kernelY == 7))
63
0
                    return false;
64
0
                if (c.strideY != c.strideX || !(c.strideY == 1 || c.strideY == 2 || c.strideY == 3))
65
0
                    return false;
66
0
                if (c.dilationY != 1 || c.dilationX != 1)
67
0
                    return false;
68
69
0
                if (c.dstH == (c.srcH + c.padY + c.padH - (c.dilationY * (c.kernelY - 1) + 1) - 1) / c.strideY + 1)
70
0
                    c.padH--;
71
0
                if (c.dstW == (c.srcW + c.padX + c.padW - (c.dilationY * (c.kernelX - 1) + 1) - 1) / c.strideX + 1)
72
0
                    c.padW--;
73
0
                if (c.IsDepthwise() && i != count - 1)
74
0
                    c.dstT = SimdTensorData8u;
75
0
            }
76
0
            if (count == 3)
77
0
            {
78
0
                if (conv[0].group != 1 || (conv[0].kernelY != 1 && conv[0].kernelY != 3))
79
0
                    return false;
80
0
                if (conv[1].group != conv[1].srcC || conv[1].group != conv[1].dstC || (conv[1].kernelY != 3 && conv[1].kernelY != 5 && conv[1].kernelY != 7))
81
0
                    return false;
82
0
                if (conv[2].group != 1 || conv[2].kernelY != 1 || conv[2].strideY != 1)
83
0
                    return false;
84
0
            }
85
0
            else
86
0
            {
87
0
                if (conv[0].group == 1)
88
0
                {
89
0
                    if (conv[0].kernelY != 1 && conv[0].kernelY != 3)
90
0
                        return false;
91
0
                    if (conv[1].group != conv[1].srcC || conv[1].group != conv[1].dstC || (conv[1].kernelY != 3 && conv[1].kernelY != 5 && conv[1].kernelY != 7))
92
0
                        return false;
93
0
                }
94
0
                else
95
0
                {
96
0
                    if (conv[0].group != conv[0].srcC || conv[0].group != conv[0].dstC || (conv[0].kernelY != 3 && conv[0].kernelY != 5 && conv[0].kernelY != 7))
97
0
                        return false;
98
0
                    if (conv[1].group != 1 || conv[1].kernelY != 1 || conv[1].strideY != 1)
99
0
                        return false;
100
0
                }
101
0
            }
102
0
            return true;
103
0
        }
104
105
#ifdef SIMD_PERFORMANCE_STATISTIC
106
        String Info() const
107
        {
108
            std::stringstream ss;
109
            ss << count << ":" << conv[0].batch << "x" << conv[0].srcC << "x" << conv[0].srcH << "x" << conv[0].srcW;
110
            for (size_t i = 0; i < count; ++i)
111
                ss << "-" << (conv[i].group != 1 ? String("") : ToStr(conv[i].dstC) + "x") << conv[i].kernelY << "x" << conv[i].strideY;
112
            ss << "-" << (conv[0].srcT == SimdTensorData32f ? "f" : "u") << (conv[count - 1].dstT == SimdTensorData32f ? "f" : "u");
113
            return ss.str();
114
        }
115
116
        int64_t Flop() const
117
        {
118
            int64_t flop = 0;
119
            for (size_t i = 0; i < count; ++i)
120
                flop += conv[i].Flop();
121
            return flop;
122
        }
123
#endif
124
    };
125
126
    class SynetMergedConvolution8i : public Deletable
127
    {
128
    public:
129
        virtual const MergConvParam8i & Param() const = 0;
130
131
        virtual size_t ExternalBufferSize() const = 0;
132
133
        virtual size_t InternalBufferSize() const = 0;
134
135
        virtual void SetParams(const float * const * weight, SimdBool * internal, const float * const * bias, const float * const * params, const float* const* stats) = 0;
136
137
        virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst) = 0;
138
139
#if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG))
140
        virtual Base::PerformanceMeasurer* Perf(const char *func) = 0;
141
#endif
142
143
        virtual const char* Info() const = 0;
144
    };
145
146
    namespace Base
147
    {
148
        class SynetMergedConvolution8i : public Simd::SynetMergedConvolution8i
149
        {
150
        public:
151
            SynetMergedConvolution8i(const MergConvParam8i& p);
152
153
0
            virtual String Desc() const { return Ext(); }
154
0
            virtual String Ext() const { return "Base"; }
155
0
            virtual const MergConvParam8i& Param() const { return _param; }
156
            virtual size_t ExternalBufferSize() const;
157
            virtual size_t InternalBufferSize() const;
158
            virtual void SetParams(const float * const * weight, SimdBool * internal, const float * const * bias, const float * const * params, const float* const* stats);
159
            virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst);
160
161
#if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG))
162
            virtual Base::PerformanceMeasurer* Perf(const char* func);
163
#endif
164
165
            virtual const char* Info() const
166
0
            {
167
0
                _info = Desc();
168
0
                return _info.c_str();
169
0
            }
170
171
            struct AlgParam
172
            {
173
                size_t miC, maC, yStep[3], yStart[3], bufH[3], dp[2], dw[3], size;
174
                int32_t zero, upper;
175
            };
176
177
            typedef void(*Convert8uTo32fPtr)(const uint8_t* src, size_t maC, size_t yBeg, size_t yEnd, size_t width, size_t channels, 
178
                const float* scale, const float* shift, float* dst, size_t bufH, SimdSynetCompatibilityType compatibility);
179
180
            typedef void(*Convert32fTo8uPtr)(const float* src, size_t yBeg, size_t yEnd, size_t width, size_t channels, 
181
                const float* scale, const float* shift, uint8_t* dst, size_t bufH, SimdSynetCompatibilityType compatibility);
182
183
            typedef void(*InputConvolutionPtr)(const uint8_t* src, const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd,
184
                const int8_t* weight, const float* norm, const float* bias, const float* params, float* dst);
185
186
            typedef void(*DepthwiseConvolutionPtr)(const float* src, const ConvParam& p, const AlgParam & a, size_t maC, size_t yBeg, size_t yEnd,
187
                const float* weight, const float* bias, const float* params, const float* scale, const float* shift, uint8_t * dst);
188
189
            typedef void(*OutputConvolutionPtr)(const uint8_t* src, const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd,
190
                const int8_t* weight, const float* norm, const float* bias, const float* params, const float* scale, const float* shift, int32_t* buf, uint8_t* dst, int first);
191
192
        protected:
193
            uint8_t* GetBuffer(uint8_t* buffer);
194
            void Quantize(const float* weight, const float* bias, size_t i, size_t q);
195
            void ReorderInputWeight(const ConvParam& p, Array8i & weight);
196
            void ReorderDepthwiseWeight(const ConvParam& p, Array32f & weight);
197
            void ReorderOutputWeight(const ConvParam& p, Array8i& weight);
198
            void DirectConvolution8i(const uint8_t* src, size_t i, size_t q, uint8_t* buf, int32_t* sum, float* dst);
199
200
            MergConvParam8i _param;
201
            bool _s8u, _d8u, _dw0, _1x1;
202
            size_t _sizeS, _sizeD, _sizeI[2], _sizeB[5];
203
            CvtParam _cvt[3];
204
            Array8u _buffer;
205
            Array8i _weight8i[2];
206
            Array32f _weight32f, _norm[2], _bias[3], _params[3];
207
            AlgParam _alg;
208
            Convert8uTo32fPtr _cvt8uTo32f;
209
            Convert32fTo8uPtr _cvt32fTo8u;
210
            InputConvolutionPtr _input;
211
            DepthwiseConvolutionPtr _depthwise;
212
            OutputConvolutionPtr _output[2];
213
214
        private:
215
#if defined(SIMD_PERFORMANCE_STATISTIC) && (defined(NDEBUG) || defined(SIMD_PERF_STAT_IN_DEBUG))
216
            Base::PerformanceMeasurer * _perf;
217
#endif        
218
            mutable String _info;
219
        };
220
221
        class SynetMergedConvolution8iCdc : public SynetMergedConvolution8i
222
        {
223
        public:
224
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
225
226
            virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst);
227
228
            static bool Preferable(const MergConvParam8i& p);
229
230
        protected:
231
            void SetSize(size_t F);
232
        };
233
234
        class SynetMergedConvolution8iCd : public SynetMergedConvolution8i
235
        {
236
        public:
237
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
238
239
            virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst);
240
241
            static bool Preferable(const MergConvParam8i& p);
242
243
        protected:
244
            void SetSize(size_t F);
245
        };
246
247
        class SynetMergedConvolution8iDc : public SynetMergedConvolution8i
248
        {
249
        public:
250
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
251
252
            virtual void Forward(const uint8_t* src, uint8_t* buf, uint8_t* dst);
253
254
            static bool Preferable(const MergConvParam8i& p);
255
256
        protected:
257
            void SetSize(size_t F);
258
        };
259
260
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters * convs, size_t count, SimdSynetCompatibilityType compatibility);
261
    }
262
263
#ifdef SIMD_SSE41_ENABLE    
264
    namespace Sse41
265
    {
266
        void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input);
267
268
        void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise);
269
270
        void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output);
271
272
        class SynetMergedConvolution8iCdc : public Base::SynetMergedConvolution8iCdc
273
        {
274
        public:
275
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
276
277
0
            virtual String Ext() const { return "Sse41"; }
278
        };
279
280
        class SynetMergedConvolution8iCd : public Base::SynetMergedConvolution8iCd
281
        {
282
        public:
283
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
284
285
0
            virtual String Ext() const { return "Sse41"; }
286
        };
287
288
        class SynetMergedConvolution8iDc : public Base::SynetMergedConvolution8iDc
289
        {
290
        public:
291
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
292
293
0
            virtual String Ext() const { return "Sse41"; }
294
        };
295
296
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility);
297
    }
298
#endif
299
300
#ifdef SIMD_AVX2_ENABLE    
301
    namespace Avx2
302
    {
303
        void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input);
304
305
        void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise);
306
307
        void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output);
308
309
        class SynetMergedConvolution8iCdc : public Sse41::SynetMergedConvolution8iCdc
310
        {
311
        public:
312
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
313
314
0
            virtual String Ext() const { return "Avx2"; }
315
        };
316
317
        class SynetMergedConvolution8iCd : public Sse41::SynetMergedConvolution8iCd
318
        {
319
        public:
320
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
321
322
0
            virtual String Ext() const { return "Avx2"; }
323
        };
324
325
        class SynetMergedConvolution8iDc : public Sse41::SynetMergedConvolution8iDc
326
        {
327
        public:
328
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
329
330
0
            virtual String Ext() const { return "Avx2"; }
331
        };
332
333
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility);
334
    }
335
#endif
336
337
#ifdef SIMD_AVX512BW_ENABLE    
338
    namespace Avx512bw
339
    {
340
        void Convert8uTo32f(const uint8_t* src, size_t maC, size_t yBeg, size_t yEnd, size_t width, size_t channels,
341
            const float* scale, const float* shift, float* dst, size_t bufH, SimdSynetCompatibilityType compatibility);
342
343
        void Convert32fTo8u(const float* src, size_t yBeg, size_t yEnd, size_t width, size_t channels,
344
            const float* scale, const float* shift, uint8_t* dst, size_t bufH, SimdSynetCompatibilityType compatibility);
345
346
        void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input);
347
348
        void SetDepthwise(const ConvParam& p, Base::SynetMergedConvolution8i::DepthwiseConvolutionPtr& depthwise);
349
350
        void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output);
351
352
        class SynetMergedConvolution8iCdc : public Avx2::SynetMergedConvolution8iCdc
353
        {
354
        public:
355
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
356
357
0
            virtual String Ext() const { return "Avx512bw"; }
358
        };
359
360
        class SynetMergedConvolution8iCd : public Avx2::SynetMergedConvolution8iCd
361
        {
362
        public:
363
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
364
365
0
            virtual String Ext() const { return "Avx512bw"; }
366
        };
367
368
        class SynetMergedConvolution8iDc : public Avx2::SynetMergedConvolution8iDc
369
        {
370
        public:
371
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
372
373
0
            virtual String Ext() const { return "Avx512bw"; }
374
        };
375
376
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility);
377
    }
378
#endif
379
380
#ifdef SIMD_AVX512VNNI_ENABLE    
381
    namespace Avx512vnni
382
    {
383
        void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input);
384
385
        void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output);
386
387
        class SynetMergedConvolution8iCdc : public Avx512bw::SynetMergedConvolution8iCdc
388
        {
389
        public:
390
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
391
392
            virtual String Ext() const { return "Avx512vnni"; }
393
        };
394
395
        class SynetMergedConvolution8iCd : public Avx512bw::SynetMergedConvolution8iCd
396
        {
397
        public:
398
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
399
400
            virtual String Ext() const { return "Avx512vnni"; }
401
        };
402
403
        class SynetMergedConvolution8iDc : public Avx512bw::SynetMergedConvolution8iDc
404
        {
405
        public:
406
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
407
408
            virtual String Ext() const { return "Avx512vnni"; }
409
        };
410
411
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility);
412
    }
413
#endif
414
415
#if defined(SIMD_AMXBF16_ENABLE) || (defined(SIMD_AVX512BW_ENABLE) && defined(SIMD_AMX_EMULATE))
416
    namespace AmxBf16
417
    {
418
        void SetInput(const ConvParam& p, Base::SynetMergedConvolution8i::InputConvolutionPtr& input);
419
420
        void SetOutput(const ConvParam& p, Base::SynetMergedConvolution8i::OutputConvolutionPtr* output);
421
422
#if defined(SIMD_AMX_EMULATE)
423
        class SynetMergedConvolution8iCdc : public Avx512bw::SynetMergedConvolution8iCdc
424
#else
425
        class SynetMergedConvolution8iCdc : public Avx512vnni::SynetMergedConvolution8iCdc
426
#endif
427
        {
428
        public:
429
            SynetMergedConvolution8iCdc(const MergConvParam8i& p);
430
431
            virtual String Ext() const { return "AmxBf16"; }
432
        };
433
434
#if defined(SIMD_AMX_EMULATE)
435
        class SynetMergedConvolution8iCd : public Avx512bw::SynetMergedConvolution8iCd
436
#else
437
        class SynetMergedConvolution8iCd : public Avx512vnni::SynetMergedConvolution8iCd
438
#endif        
439
        {
440
        public:
441
            SynetMergedConvolution8iCd(const MergConvParam8i& p);
442
443
            virtual String Ext() const { return "AmxBf16"; }
444
        };
445
446
#if defined(SIMD_AMX_EMULATE)
447
        class SynetMergedConvolution8iDc : public Avx512bw::SynetMergedConvolution8iDc
448
#else
449
        class SynetMergedConvolution8iDc : public Avx512vnni::SynetMergedConvolution8iDc
450
#endif
451
        {
452
        public:
453
            SynetMergedConvolution8iDc(const MergConvParam8i& p);
454
455
            virtual String Ext() const { return "AmxBf16"; }
456
        };
457
458
        void* SynetMergedConvolution8iInit(size_t batch, const SimdConvolutionParameters* convs, size_t count, SimdSynetCompatibilityType compatibility);
459
    }
460
#endif
461
}
462
#endif//__SimdSynetMergedConvolution8i_h__