Coverage Report

Created: 2026-01-15 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/opencv/modules/dnn/src/layers/recurrent_layers.cpp
Line
Count
Source
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
//  By downloading, copying, installing or using the software you agree to this license.
6
//  If you do not agree to this license, do not download, install,
7
//  copy or use the software.
8
//
9
//
10
//                           License Agreement
11
//                For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14
// Copyright (C) 2017, Intel Corporation, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
//   * Redistribution's of source code must retain the above copyright notice,
21
//     this list of conditions and the following disclaimer.
22
//
23
//   * Redistribution's in binary form must reproduce the above copyright notice,
24
//     this list of conditions and the following disclaimer in the documentation
25
//     and/or other materials provided with the distribution.
26
//
27
//   * The name of the copyright holders may not be used to endorse or promote products
28
//     derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "../precomp.hpp"
44
#include <iostream>
45
#include <cmath>
46
#include <opencv2/dnn/shape_utils.hpp>
47
48
#ifdef HAVE_CUDA
49
#include "../cuda4dnn/primitives/recurrent_cells.hpp"
50
using namespace cv::dnn::cuda4dnn;
51
#endif
52
53
#include "layers_common.hpp"
54
55
namespace cv
56
{
57
namespace dnn
58
{
59
60
template<typename Dtype>
61
static void tanh(const Mat &src, Mat &dst)
62
0
{
63
0
    MatConstIterator_<Dtype> itSrc = src.begin<Dtype>();
64
0
    MatIterator_<Dtype> itDst = dst.begin<Dtype>();
65
66
0
    for (; itSrc != src.end<Dtype>(); itSrc++, itDst++)
67
0
        *itDst = std::tanh(*itSrc);
68
0
}
Unexecuted instantiation: recurrent_layers.cpp:void cv::dnn::tanh<float>(cv::Mat const&, cv::Mat&)
Unexecuted instantiation: recurrent_layers.cpp:void cv::dnn::tanh<double>(cv::Mat const&, cv::Mat&)
69
70
//TODO: make utils method
71
static void tanh(const Mat &src, Mat &dst)
72
0
{
73
0
    dst.create(src.dims, (const int*)src.size, src.type());
74
75
0
    if (src.type() == CV_32F)
76
0
        tanh<float>(src, dst);
77
0
    else if (src.type() == CV_64F)
78
0
        tanh<double>(src, dst);
79
0
    else
80
0
        CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types");
81
0
}
82
83
static void sigmoid(const Mat &src, Mat &dst)
84
0
{
85
0
    cv::exp(-src, dst);
86
0
    cv::pow(1 + dst, -1, dst);
87
0
}
88
89
typedef void (*ActivationFunction)(const Mat &src, Mat &dst);
90
0
static ActivationFunction get_activation_function(const String& activation) {
91
    // most used activations for PyTorch and TF : Tanh, Sigmoid
92
    // if you need to support more optional activations use std::map instead
93
0
    if (activation == "Tanh")
94
0
    {
95
0
        return tanh;
96
0
    }
97
0
    else if (activation == "Sigmoid")
98
0
    {
99
0
        return sigmoid;
100
0
    }
101
0
    else
102
0
    {
103
0
        CV_Error(Error::StsNotImplemented,
104
0
                 cv::format("Activation function [%s] for layer LSTM  is not supported", activation.c_str()));
105
0
    }
106
0
}
107
108
class LSTMLayerImpl CV_FINAL : public LSTMLayer
109
{
110
    int numTimeStamps, numSamples, numHidden;
111
    bool allocated;
112
113
    MatShape outTailShape;  //shape of single output sample
114
    MatShape outTsShape;    //shape of N output samples
115
116
    enum layout_t : int {
117
        SEQ_BATCH_HID = 0,
118
        BATCH_SEQ_HID = 1
119
    };
120
121
    bool useTimestampDim;
122
    bool produceCellOutput;
123
    float forgetBias, cellClip;
124
    bool useCellClip, usePeephole;
125
    bool reverse;   // If true, go in negative direction along the time axis
126
    bool bidirectional;  // If true, produces both forward and reversed directions along time axis
127
    layout_t layout;  // If layout == BATCH_SEQ_HID, uses batch_size x seq_length x num_hidden for input and output
128
                      // else uses seq_length x batch_size x num_hidden
129
130
    ActivationFunction f_activation;
131
    ActivationFunction g_activation;
132
    ActivationFunction h_activation;
133
    bool isDefaultActivations{true};
134
135
#if CV_TRY_AVX
136
    bool useAVX;
137
#endif
138
#if CV_TRY_AVX2
139
    bool useAVX2;
140
#endif
141
#if CV_TRY_SVE
142
    bool useSVE;
143
#endif
144
#if CV_TRY_NEON
145
    bool useNEON;
146
#endif
147
148
    // CUDA needs input blobs to be rearranged in a specific way, but some transformations
149
    // in ONNXImporter are destructive, so we keep a copy.
150
    std::vector<Mat> originalBlobs;
151
152
public:
153
154
    LSTMLayerImpl(const LayerParams& params)
155
0
        : numTimeStamps(0), numSamples(0)
156
#if CV_TRY_AVX
157
0
          , useAVX(checkHardwareSupport(CPU_AVX))
158
#endif
159
#if CV_TRY_AVX2
160
0
          , useAVX2(checkHardwareSupport(CPU_AVX2))
161
#endif
162
#if CV_TRY_SVE
163
          , useSVE(checkHardwareSupport(CPU_SVE))
164
#endif
165
#if CV_TRY_NEON
166
          , useNEON(checkHardwareSupport(CPU_NEON))
167
#endif
168
0
    {
169
0
        setParamsFrom(params);
170
171
0
        if (params.get<bool>("is_onnx", false))
172
0
        {
173
            // collect copies of onnx blobs
174
0
            originalBlobs.insert(originalBlobs.begin(), blobs.begin(), blobs.begin() + 3);
175
0
            blobs.erase(blobs.begin(), blobs.begin() + 3);
176
0
        }
177
178
0
        bidirectional = params.get<bool>("bidirectional", false);
179
0
        if (!blobs.empty())
180
0
        {
181
0
            CV_Assert(blobs.size() >= 3);
182
183
0
            blobs[2] = blobs[2].reshape(1, 1);
184
185
0
            const Mat& Wh = blobs[0];
186
0
            const Mat& Wx = blobs[1];
187
0
            const Mat& bias = blobs[2];
188
0
            const Mat& hInternal = blobs[3];
189
0
            const Mat& cInternal = blobs[4];
190
0
            CV_CheckEQ(Wh.dims, 2, "");
191
0
            CV_CheckEQ(Wx.dims, 2, "");
192
0
            CV_CheckEQ(Wh.rows, Wx.rows, "");
193
0
            CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
194
0
            CV_CheckEQ(Wh.rows, (int)bias.total(), "");
195
            // Only perform these checks if hInternal and cInternal are not empty matrices
196
            // e.g. inputs are not given by a user
197
0
            if(!hInternal.empty()){
198
0
                CV_CheckEQ(hInternal.cols, Wh.cols, "");
199
0
            }
200
0
            if(!cInternal.empty()){
201
0
                CV_CheckEQ(cInternal.cols, Wh.cols, "");
202
0
            }
203
0
            if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward
204
0
                CV_CheckEQ(hInternal.rows, cInternal.rows, "");
205
0
            }
206
0
            CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
207
208
            // Peephole weights.
209
0
            if (blobs.size() > 5)
210
0
            {
211
0
                CV_Assert(blobs.size() == 8);
212
0
                const int N = Wh.cols;
213
0
                for (int i = 5; i < 8; ++i)
214
0
                {
215
0
                    CV_Assert(blobs[i].rows == N && blobs[i].cols == N);
216
0
                    CV_Assert(blobs[i].type() == bias.type());
217
0
                }
218
0
            }
219
0
        }
220
0
        layout = (layout_t) params.get<int>("layout", SEQ_BATCH_HID);
221
0
        useTimestampDim = params.get<bool>("use_timestamp_dim", true);
222
0
        produceCellOutput = params.get<bool>("produce_cell_output", false);
223
0
        forgetBias = params.get<float>("forget_bias", 0.0f);
224
0
        cellClip = params.get<float>("cell_clip", 0.0f);
225
0
        useCellClip = params.get<bool>("use_cell_clip", false);
226
0
        usePeephole = params.get<bool>("use_peephole", false);
227
0
        reverse = params.get<bool>("reverse", false);
228
0
        numHidden = params.get<int>("hidden_size", 1);
229
0
        CV_Assert(!reverse || !bidirectional);
230
231
        // read activations
232
0
        DictValue activations = params.get<DictValue>("activations", DictValue(String()));
233
0
        if (activations.size() == 1) // if activations wasn't specified use default
234
0
        {
235
0
            f_activation = sigmoid;
236
0
            g_activation = tanh;
237
0
            h_activation = tanh;
238
0
            isDefaultActivations = true;
239
0
        } else {
240
0
            CV_Assert(activations.size() == 3);
241
0
            f_activation = get_activation_function(activations.getStringValue(0));
242
0
            g_activation = get_activation_function(activations.getStringValue(1));
243
0
            h_activation = get_activation_function(activations.getStringValue(2));
244
0
            isDefaultActivations = activations.getStringValue(0) == "Sigmoid"
245
0
                                   && activations.getStringValue(1) == "Tanh"
246
0
                                   && activations.getStringValue(2) == "Tanh";
247
0
        }
248
249
0
        allocated = false;
250
0
        outTailShape.clear();
251
0
    }
252
253
    void setUseTimstampsDim(bool use) CV_OVERRIDE
254
0
    {
255
0
        CV_Assert(!allocated);
256
0
        useTimestampDim = use;
257
0
    }
258
259
    void setProduceCellOutput(bool produce) CV_OVERRIDE
260
0
    {
261
0
        CV_Assert(!allocated);
262
0
        produceCellOutput = produce;
263
0
    }
264
265
    void setOutShape(const MatShape &outTailShape_) CV_OVERRIDE
266
0
    {
267
0
        CV_Assert(!allocated || total(outTailShape) == total(outTailShape_));
268
0
        outTailShape = outTailShape_;
269
0
    }
270
271
    void setWeights(const Mat &Wh, const Mat &Wx, const Mat &bias) CV_OVERRIDE
272
0
    {
273
0
        CV_Assert(Wh.dims == 2 && Wx.dims == 2);
274
0
        CV_Assert(Wh.rows == Wx.rows);
275
0
        CV_Assert(Wh.rows == 4*Wh.cols);
276
0
        CV_Assert(Wh.rows == (int)bias.total());
277
0
        CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
278
279
0
        blobs.resize(3);
280
0
        blobs[0] = Mat(Wh.clone());
281
0
        blobs[1] = Mat(Wx.clone());
282
0
        blobs[2] = Mat(bias.clone()).reshape(1, 1);
283
0
    }
284
285
    bool supportBackend(int backendId) CV_OVERRIDE
286
0
    {
287
0
        return backendId == DNN_BACKEND_OPENCV
288
0
               || (backendId == DNN_BACKEND_CUDA && isDefaultActivations && !reverse && !usePeephole);
289
0
    }
290
291
    bool getMemoryShapes(const std::vector<MatShape> &inputs,
292
                         const int requiredOutputs,
293
                         std::vector<MatShape> &outputs,
294
                         std::vector<MatShape> &internals) const CV_OVERRIDE
295
0
    {
296
0
        CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
297
0
        CV_Assert((inputs.size() == 1 || inputs.size() == 3));
298
0
        const MatShape& inp0 = inputs[0];
299
300
0
        const Mat &Wh = blobs[0], &Wx = blobs[1];
301
0
        int _numOut = Wh.size[1];
302
0
        int _numInp = Wx.size[1];
303
0
        MatShape outTailShape_(outTailShape), outResShape;
304
305
0
        if (!outTailShape_.empty())
306
0
            CV_Assert(total(outTailShape_) == _numOut);
307
0
        else
308
0
            outTailShape_.assign(1, _numOut);
309
310
0
        int _numSamples;
311
0
        if (useTimestampDim)
312
0
        {
313
0
            CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp);
314
0
            if (layout == SEQ_BATCH_HID) {
315
0
                _numSamples = inp0[1];
316
0
                outResShape.push_back(inp0[0]);
317
0
            } else {
318
0
                _numSamples = inp0[0];
319
0
                outResShape.push_back(inp0[1]);
320
0
            }
321
0
        }
322
0
        else
323
0
        {
324
0
            CV_Assert(inp0.size() >= 2 && total(inp0, 1) == _numInp);
325
0
            _numSamples = inp0[0];
326
0
        }
327
328
0
        outResShape.push_back(_numSamples);
329
0
        outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
330
0
        outResShape.back() *= (1 + static_cast<int>(bidirectional));
331
332
0
        outputs.assign(1, outResShape);
333
0
        if (produceCellOutput)
334
0
        {
335
            // the producer is ONNX, so CellState is different
336
0
            if (!originalBlobs.empty())
337
0
            {
338
0
                int shp[] = {(1 + static_cast<int>(bidirectional)), _numSamples, numHidden};
339
0
                MatShape newShape(shp, shp + sizeof(shp)/sizeof(shp[0]));
340
0
                outputs.push_back(newShape);
341
0
            }
342
0
            else
343
0
            {
344
0
                outputs.push_back(outResShape);
345
0
            }
346
0
        }
347
348
0
        internals.assign(1, shape(_numSamples, _numOut)); // hInternal
349
0
        internals.push_back(shape(_numSamples, _numOut)); // cInternal
350
0
        internals.push_back(shape(_numSamples, 1)); // dummyOnes
351
0
        internals.push_back(shape(_numSamples, 4*_numOut)); // gates
352
353
0
        return false;
354
0
    }
355
356
    void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
357
0
    {
358
0
        std::vector<Mat> input;
359
0
        inputs_arr.getMatVector(input);
360
361
0
        CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
362
0
        CV_Assert((input.size() == 1 || input.size() == 3));
363
0
        const Mat& inp0 = input[0];
364
365
0
        Mat &Wh = blobs[0], &Wx = blobs[1];
366
0
        int numOut = Wh.size[1];
367
0
        int numInp = Wx.size[1];
368
369
0
        if (!outTailShape.empty())
370
0
            CV_Assert(total(outTailShape) == numOut);
371
0
        else
372
0
            outTailShape.assign(1, numOut);
373
374
0
        if (useTimestampDim)
375
0
        {
376
0
            CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp);
377
0
            if (layout == SEQ_BATCH_HID){
378
0
                numTimeStamps = inp0.size[0];
379
0
                numSamples = inp0.size[1];
380
0
            }else{
381
0
                numTimeStamps = inp0.size[1];
382
0
                numSamples = inp0.size[0];
383
0
            }
384
0
        }
385
0
        else
386
0
        {
387
0
            CV_Assert(inp0.dims >= 2 && (int)inp0.total(1) == numInp);
388
0
            numTimeStamps = 1;
389
0
            numSamples = inp0.size[0];
390
0
        }
391
392
0
        outTsShape.clear();
393
0
        outTsShape.push_back(numSamples);
394
0
        outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
395
0
        outTsShape.back() *= (1 + static_cast<int>(bidirectional));
396
397
0
        allocated = true;
398
0
    }
399
400
    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
401
0
    {
402
0
        CV_TRACE_FUNCTION();
403
0
        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
404
405
0
        if (inputs_arr.depth() == CV_16F)
406
0
        {
407
0
            forward_fallback(inputs_arr, outputs_arr, internals_arr);
408
0
            return;
409
0
        }
410
411
0
        std::vector<Mat> input, output, internals;
412
0
        inputs_arr.getMatVector(input);
413
0
        outputs_arr.getMatVector(output);
414
0
        internals_arr.getMatVector(internals);
415
416
0
        if (layout == BATCH_SEQ_HID){
417
            //swap axis 0 and 1 input x
418
0
            cv::Mat tmp;
419
            // Since python input is 4 dimensional and C++ input 3 dimensional
420
            // we need to process each differently
421
0
            if (input[0].dims == 4){
422
                // here !!!
423
0
                CV_Assert(input[0].size[3] == 1);
424
0
                cv::transposeND(input[0], {1, 0, 2, 3}, tmp); //back to seq_len, batch_size, hidden_size format
425
0
            }else{
426
0
                cv::transposeND(input[0], {1, 0, 2}, tmp); //back to seq_len, batch_size, hidden_size format
427
0
            }
428
0
            input[0] = tmp;
429
0
        }
430
431
0
        Mat cOut = produceCellOutput ? output[0].clone() : Mat();
432
0
        const bool needYcTransform = !originalBlobs.empty(); // if the producer is onnx
433
0
        const int numDirs = 1 + static_cast<int>(bidirectional);
434
0
        for (int i = 0; i < numDirs; ++i)
435
0
        {
436
0
            Mat Wh = blobs[0];
437
0
            Mat Wx = blobs[1];
438
0
            Mat bias = blobs[2];
439
440
0
            Mat h_0, c_0;
441
            // Handle h_0 and c_0 based on input size
442
0
            h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3];
443
0
            c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4];
444
445
            // Perform checks if input size is 2 or 3
446
0
            if (input.size() >= 2) {
447
0
                CV_CheckEQ(h_0.cols, Wh.cols, "");
448
0
                CV_CheckEQ(h_0.cols, c_0.cols, "");
449
0
                CV_CheckEQ(h_0.rows, c_0.rows, "");
450
0
            }
451
452
453
0
            Mat pI, pF, pO;
454
455
0
            Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
456
0
            Wx = Wx.rowRange(i * Wx.rows / numDirs, (i + 1) * Wx.rows / numDirs);
457
0
            bias = bias.colRange(i * bias.cols / numDirs, (i + 1) * bias.cols / numDirs);
458
0
            h_0 = h_0.rowRange(i * h_0.rows / numDirs, (i + 1) * h_0.rows / numDirs);
459
0
            c_0 = c_0.rowRange(i * c_0.rows / numDirs, (i + 1) * c_0.rows / numDirs);
460
461
0
            if (usePeephole)
462
0
            {
463
0
                pI = blobs[5];
464
0
                pF = blobs[6];
465
0
                pO = blobs[7];
466
467
0
                pI = pI.rowRange(i * pI.rows / numDirs, (i + 1) * pI.rows / numDirs);
468
0
                pI = pI.colRange(i * pI.cols / numDirs, (i + 1) * pI.cols / numDirs);
469
470
0
                pF = pF.rowRange(i * pF.rows / numDirs, (i + 1) * pF.rows / numDirs);
471
0
                pF = pF.colRange(i * pF.cols / numDirs, (i + 1) * pF.cols / numDirs);
472
473
0
                pO = pO.rowRange(i * pO.rows / numDirs, (i + 1) * pO.rows / numDirs);
474
0
                pO = pO.colRange(i * pO.cols / numDirs, (i + 1) * pO.cols / numDirs);
475
0
            }
476
477
0
            int numOut = Wh.size[1];
478
0
            Mat hInternal = internals[0], cInternal = internals[1],
479
0
                    dummyOnes = internals[2], gates = internals[3];
480
0
            h_0.copyTo(hInternal);
481
0
            c_0.copyTo(cInternal);
482
0
            dummyOnes.setTo(1.);
483
484
0
            int numSamplesTotal = numTimeStamps*numSamples;
485
0
            Mat xTs = input[0].reshape(1, numSamplesTotal);
486
487
0
            Mat hOutTs = output[0].reshape(1, numSamplesTotal);
488
0
            hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
489
0
            Mat cOutTs;
490
0
            if (produceCellOutput)
491
0
            {
492
0
                cOutTs = cOut.reshape(1, numSamplesTotal);
493
0
                cOutTs = cOutTs.colRange(i * cOutTs.cols / numDirs, (i + 1) * cOutTs.cols / numDirs);
494
0
            }
495
496
0
#if CV_TRY_AVX2 || CV_TRY_AVX
497
0
            bool canUseAvx = gates.isContinuous() && bias.isContinuous()
498
0
                && Wx.depth() == CV_32F && gates.depth() == CV_32F
499
0
                && bias.depth() == CV_32F && Wx.cols >= 8;
500
0
            bool canUseAvx_hInternal = hInternal.isContinuous() && gates.isContinuous() && bias.isContinuous()
501
0
                && Wh.depth() == CV_32F && hInternal.depth() == CV_32F && gates.depth() == CV_32F
502
0
                && Wh.cols >= 8;
503
0
#endif
504
#if CV_TRY_SVE
505
            bool canUseSVE = gates.isContinuous() && bias.isContinuous()
506
                && Wx.depth() == CV_32F && gates.depth() == CV_32F
507
                && bias.depth() == CV_32F;
508
            bool canUseSVE_hInternal = hInternal.isContinuous() && gates.isContinuous() && bias.isContinuous()
509
                && Wh.depth() == CV_32F && hInternal.depth() == CV_32F && gates.depth() == CV_32F;
510
#endif
511
#if CV_TRY_NEON
512
            bool canUseNeon = gates.isContinuous() && bias.isContinuous()
513
                && Wx.depth() == CV_32F && gates.depth() == CV_32F
514
                && bias.depth() == CV_32F && Wx.cols >= 4;
515
            bool canUseNeon_hInternal = hInternal.isContinuous() && gates.isContinuous() && bias.isContinuous()
516
                && Wh.depth() == CV_32F && hInternal.depth() == CV_32F && gates.depth() == CV_32F
517
                && Wh.cols >= 4;
518
#endif
519
520
0
            int tsStart, tsEnd, tsInc;
521
0
            if (reverse || i == 1) {
522
0
                tsStart = numTimeStamps - 1;
523
0
                tsEnd = -1;
524
0
                tsInc = -1;
525
0
            }
526
0
            else {
527
0
                tsStart = 0;
528
0
                tsEnd = numTimeStamps;
529
0
                tsInc = 1;
530
0
            }
531
0
            for (int ts = tsStart; ts != tsEnd; ts += tsInc)
532
0
            {
533
0
                Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
534
0
                Mat xCurr = xTs.rowRange(curRowRange);
535
536
0
#if CV_TRY_AVX2
537
0
                if (useAVX2 && canUseAvx && xCurr.isContinuous())
538
0
                {
539
0
                    for (int n = 0; n < xCurr.rows; n++) {
540
0
                        opt_AVX2::fastGEMM1T(
541
0
                            xCurr.ptr<float>(n),
542
0
                            Wx.ptr<float>(),
543
0
                            Wx.step1(),
544
0
                            bias.ptr<float>(),
545
0
                            gates.ptr<float>(n),
546
0
                            Wx.rows,
547
0
                            Wx.cols
548
0
                        );
549
0
                    }
550
0
                }
551
0
                else
552
0
#endif
553
0
#if CV_TRY_AVX
554
0
                if (useAVX && canUseAvx && xCurr.isContinuous())
555
0
                {
556
0
                    for (int n = 0; n < xCurr.rows; n++) {
557
0
                        opt_AVX::fastGEMM1T(
558
0
                            xCurr.ptr<float>(n),
559
0
                            Wx.ptr<float>(),
560
0
                            Wx.step1(),
561
0
                            bias.ptr<float>(),
562
0
                            gates.ptr<float>(n),
563
0
                            Wx.rows,
564
0
                            Wx.cols
565
0
                        );
566
0
                    }
567
0
                }
568
0
                else
569
0
#endif
570
#if CV_TRY_SVE
571
                if (useSVE && canUseSVE && xCurr.isContinuous())
572
                {
573
                    for (int n = 0; n < xCurr.rows; n++) {
574
                        opt_SVE::fastGEMM1T(
575
                            xCurr.ptr<float>(n),
576
                            Wx.ptr<float>(),
577
                            Wx.step1(),
578
                            bias.ptr<float>(),
579
                            gates.ptr<float>(n),
580
                            Wx.rows,
581
                            Wx.cols
582
                        );
583
                    }
584
                }
585
                else
586
#endif
587
#if CV_TRY_NEON
588
                if (useNEON && canUseNeon && xCurr.isContinuous())
589
                {
590
                    for (int n = 0; n < xCurr.rows; n++) {
591
                        opt_NEON::fastGEMM1T(
592
                            xCurr.ptr<float>(n),
593
                            Wx.ptr<float>(),
594
                            Wx.step1(),
595
                            bias.ptr<float>(),
596
                            gates.ptr<float>(n),
597
                            Wx.rows,
598
                            Wx.cols
599
                        );
600
                    }
601
                }
602
                else
603
#endif
604
0
                {
605
0
                    gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T);      // Wx * x_t
606
0
                    gemm(dummyOnes, bias, 1, gates, 1, gates);          //+b
607
0
                }
608
609
0
#if CV_TRY_AVX2
610
0
                if (useAVX2 && canUseAvx_hInternal)
611
0
                {
612
0
                    for (int n = 0; n < hInternal.rows; n++) {
613
0
                        opt_AVX2::fastGEMM1T(
614
0
                            hInternal.ptr<float>(n),
615
0
                            Wh.ptr<float>(),
616
0
                            Wh.step1(),
617
0
                            gates.ptr<float>(n),
618
0
                            gates.ptr<float>(n),
619
0
                            Wh.rows,
620
0
                            Wh.cols
621
0
                        );
622
0
                    }
623
0
                }
624
0
                else
625
0
#endif
626
0
#if CV_TRY_AVX
627
0
                if (useAVX && canUseAvx_hInternal)
628
0
                {
629
0
                    for (int n = 0; n < hInternal.rows; n++) {
630
0
                        opt_AVX::fastGEMM1T(
631
0
                            hInternal.ptr<float>(n),
632
0
                            Wh.ptr<float>(),
633
0
                            Wh.step1(),
634
0
                            gates.ptr<float>(n),
635
0
                            gates.ptr<float>(n),
636
0
                            Wh.rows,
637
0
                            Wh.cols
638
0
                        );
639
0
                    }
640
0
                }
641
0
                else
642
0
#endif
643
#if CV_TRY_SVE
644
                if (useSVE && canUseSVE_hInternal)
645
                {
646
                    for (int n = 0; n < hInternal.rows; n++) {
647
                        opt_SVE::fastGEMM1T(
648
                            hInternal.ptr<float>(n),
649
                            Wh.ptr<float>(),
650
                            Wh.step1(),
651
                            gates.ptr<float>(n),
652
                            gates.ptr<float>(n),
653
                            Wh.rows,
654
                            Wh.cols
655
                        );
656
                    }
657
                }
658
                else
659
#endif
660
#if CV_TRY_NEON
661
                if (useNEON && canUseNeon_hInternal)
662
                {
663
                    for (int n = 0; n < hInternal.rows; n++) {
664
                        opt_NEON::fastGEMM1T(
665
                            hInternal.ptr<float>(n),
666
                            Wh.ptr<float>(),
667
                            Wh.step1(),
668
                            gates.ptr<float>(n),
669
                            gates.ptr<float>(n),
670
                            Wh.rows,
671
                            Wh.cols
672
                        );
673
                    }
674
                }
675
                else
676
#endif
677
0
                {
678
0
                    gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T);  //+Wh * h_{t-1}
679
0
                }
680
681
0
                Mat gateI = gates.colRange(0*numOut, 1*numOut);
682
0
                Mat gateF = gates.colRange(1*numOut, 2*numOut);
683
0
                Mat gateO = gates.colRange(2*numOut, 3*numOut);
684
0
                Mat gateG = gates.colRange(3*numOut, 4*numOut);
685
686
0
                if (forgetBias)
687
0
                    add(gateF, forgetBias, gateF);
688
689
0
                if (usePeephole)
690
0
                {
691
0
                    Mat gatesIF = gates.colRange(0, 2*numOut);
692
0
                    gemm(cInternal, pI, 1, gateI, 1, gateI);
693
0
                    gemm(cInternal, pF, 1, gateF, 1, gateF);
694
0
                    f_activation(gatesIF, gatesIF);
695
0
                }
696
0
                else
697
0
                {
698
0
                    Mat gatesIFO = gates.colRange(0, 3*numOut);
699
0
                    f_activation(gatesIFO, gatesIFO);
700
0
                }
701
702
0
                g_activation(gateG, gateG);
703
704
                //compute c_t
705
0
                multiply(gateF, cInternal, gateF);  // f_t (*) c_{t-1}
706
0
                multiply(gateI, gateG, gateI);      // i_t (*) g_t
707
0
                add(gateF, gateI, cInternal);       // c_t = f_t (*) c_{t-1} + i_t (*) g_t
708
709
0
                if (useCellClip)
710
0
                {
711
0
                    min(cInternal, cellClip, cInternal);
712
0
                    max(cInternal, -cellClip, cInternal);
713
0
                }
714
0
                if (usePeephole)
715
0
                {
716
0
                    gemm(cInternal, pO, 1, gateO, 1, gateO);
717
0
                    f_activation(gateO, gateO);
718
0
                }
719
720
                //compute h_t
721
0
                h_activation(cInternal, hInternal);
722
0
                multiply(gateO, hInternal, hInternal);
723
724
                //save results in output blobs
725
0
                hInternal.copyTo(hOutTs.rowRange(curRowRange));
726
0
                if (produceCellOutput)
727
0
                    cInternal.copyTo(cOutTs.rowRange(curRowRange));
728
0
            }
729
0
        }
730
        // transpose to match batch first output
731
0
        if (layout == BATCH_SEQ_HID){
732
0
            cv::Mat tmp;
733
0
            cv::transposeND(output[0], {1, 0, 2}, tmp);
734
0
            output[0] = tmp;
735
0
        }
736
0
        if (needYcTransform && produceCellOutput)
737
0
        {
738
0
            fixCellState(cOut, numDirs);
739
0
        }
740
0
        if (produceCellOutput)
741
0
        {
742
0
            cOut.copyTo(output[1]);
743
0
        }
744
0
    }
745
746
    void fixCellState(Mat& cOut, int numDirs)
747
0
    {
748
        // seq, batch, dirs, hidden
749
0
        int shp[] = {0, numSamples, numDirs, numHidden};
750
0
        cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
751
752
        // permute to {0, 2, 1, 3};
753
0
        cv::Mat newCellState;
754
        // transpose to match batch first output
755
0
        if (layout == BATCH_SEQ_HID){
756
0
            cv::transposeND(cOut, {2, 0, 1, 3}, newCellState);
757
0
        }
758
0
        else{
759
0
            cv::transposeND(cOut, {0, 2, 1, 3}, newCellState);
760
0
        }
761
0
        cOut = newCellState;
762
763
0
        if (numDirs == 1)
764
0
        {
765
            // Slice: Yh = Y[-1, :, :, :]
766
0
            Range ranges[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range::all(), cv::Range::all(), cv::Range::all()};
767
0
            cOut = cOut(ranges);
768
            // Reshape: 1x1xBxH -> 1xBxH
769
0
            int shp[] = {1, numSamples, numHidden};
770
0
            cOut = cOut.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
771
0
        }
772
0
        else
773
0
        {
774
            // Slice: SxDxBxH -> last sequence, first direction
775
0
            Range ranges1[] = {cv::Range(cOut.size[0] - 1, cOut.size[0]), cv::Range(0, 1), cv::Range::all(), cv::Range::all()};
776
0
            Mat part1 = cOut(ranges1);
777
778
            // Slice: SxDxBxH -> first sequence, last direction
779
0
            Range ranges2[] = {cv::Range(0, 1), cv::Range(cOut.size[1] - 1, cOut.size[1]), cv::Range::all(), cv::Range::all()};
780
0
            Mat part2 = cOut(ranges2);
781
782
0
            int shp[] = {1, part1.size[2] * part1.size[3]};
783
0
            part1 = part1.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
784
0
            part2 = part2.reshape(1, sizeof(shp)/sizeof(shp[0]), shp);
785
786
0
            vconcat(part1, part2, cOut);
787
788
            // Reshape: 1x2xBxH -> 2xBxH
789
0
            int finalShape[] = {2, numSamples, numHidden};
790
0
            cOut = cOut.reshape(1, sizeof(finalShape)/sizeof(finalShape[0]), finalShape);
791
0
        }
792
0
    }
793
794
#ifdef HAVE_CUDA
795
    Ptr<BackendNode> initCUDA(void *context_, const std::vector<Ptr<BackendWrapper>> &inputs,
796
                              const std::vector<Ptr<BackendWrapper>> &outputs) override
797
    {
798
        const int numDirs = 1 + static_cast<int>(bidirectional);
799
        auto toIFCO = [numDirs] (Mat& in) {
800
            int first = in.size[0];
801
            int rest = in.total() / first / 4;
802
            // every weight blob contains weights for Input, Output, Forget and Cell gates
803
            Mat m = in.reshape(1, {first, 4, rest});
804
            Mat outputGate = m.col(1);
805
            Mat forgetGate = m.col(2);
806
            Mat cellGate = m.col(3);
807
            // IOFC -> IFOC
808
            std::swap_ranges(outputGate.begin<float>(), outputGate.end<float>(), forgetGate.begin<float>());
809
            std::swap(outputGate, forgetGate);
810
            // IFOC -> IFCO
811
            std::swap_ranges(outputGate.begin<float>(), outputGate.end<float>(), cellGate.begin<float>());
812
            in = in.reshape(1, numDirs);
813
        };
814
815
        Mat& b = originalBlobs[2];
816
        // B is a concatenation of biases for Wh and Wx
817
        b = b.reshape(1, originalBlobs[2].size[0]*2);
818
819
        for (auto& m : originalBlobs)
820
        {
821
            toIFCO(m);
822
        }
823
824
        b = b.reshape(1, static_cast<int>(b.total()));
825
826
        Mat ordered_weights;
827
        // Wx_f, Wh_f, [Wx_b, Wh_b,] b
828
        for (int i = 0; i < numDirs; ++i)
829
        {
830
            for (size_t j = 0; j < 2; ++j) // Wx, Wh
831
            {
832
                Mat oneDirection = originalBlobs[j].row(i);
833
                ordered_weights.push_back(oneDirection.reshape(1, static_cast<int>(oneDirection.total())));
834
            }
835
        }
836
        ordered_weights.push_back(b);
837
838
        // Pass hidden states as is
839
        Mat h0 = blobs[3];
840
        Mat c0 = blobs[4];
841
842
        CV_Assert(!inputs.empty());
843
        auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
844
        auto input_shape = input_wrapper->getShape();
845
846
        RNNConfiguration config
847
        {
848
            input_shape[0],      // seqLength;
849
            1,                   // numLayers;
850
            numHidden,           // hiddenSize;
851
            input_shape[2],      // inputSize;
852
            input_shape[1],      // miniBatch;
853
            bidirectional
854
        };
855
856
857
        auto *context = reinterpret_cast<cuda4dnn::csl::CSLContext *>(context_);
858
        return make_cuda_node<cuda4dnn::LSTMOp>(preferableTarget, std::move(context->stream),
859
                                                std::move(context->cudnn_handle),
860
                                                ordered_weights, h0, c0,
861
                                                config);
862
    }
863
#endif
864
};
865
866
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
867
0
{
868
0
    return Ptr<LSTMLayer>(new LSTMLayerImpl(params));
869
0
}
870
871
int LSTMLayer::inputNameToIndex(String inputName)
872
0
{
873
0
    if (toLowerCase(inputName) == "x")
874
0
        return 0;
875
0
    return -1;
876
0
}
877
878
int LSTMLayer::outputNameToIndex(const String& outputName)
879
0
{
880
0
    if (toLowerCase(outputName) == "h")
881
0
        return 0;
882
0
    else if (toLowerCase(outputName) == "c")
883
0
        return 1;
884
0
    return -1;
885
0
}
886
887
888
class RNNLayerImpl : public RNNLayer
889
{
890
    int numX, numH, numO;
891
    int numSamples, numTimestamps, numSamplesTotal;
892
    int dtype;
893
    Mat Whh, Wxh, bh;
894
    Mat Who, bo;
895
    bool produceH;
896
897
public:
898
899
    RNNLayerImpl(const LayerParams& params)
900
0
        : numX(0), numH(0), numO(0), numSamples(0), numTimestamps(0), numSamplesTotal(0), dtype(0)
901
0
    {
902
0
        setParamsFrom(params);
903
0
        type = "RNN";
904
0
        produceH = false;
905
0
    }
906
907
    void setProduceHiddenOutput(bool produce = false) CV_OVERRIDE
908
0
    {
909
0
        produceH = produce;
910
0
    }
911
912
    void setWeights(const Mat &W_xh, const Mat &b_h, const Mat &W_hh, const Mat &W_ho, const Mat &b_o) CV_OVERRIDE
913
0
    {
914
0
        CV_Assert(W_hh.dims == 2 && W_xh.dims == 2);
915
0
        CV_Assert(W_hh.size[0] == W_xh.size[0] && W_hh.size[0] == W_hh.size[1] && (int)b_h.total() == W_xh.size[0]);
916
0
        CV_Assert(W_ho.size[0] == (int)b_o.total());
917
0
        CV_Assert(W_ho.size[1] == W_hh.size[1]);
918
919
0
        blobs.resize(5);
920
0
        blobs[0] = Mat(W_xh.clone());
921
0
        blobs[1] = Mat(b_h.clone());
922
0
        blobs[2] = Mat(W_hh.clone());
923
0
        blobs[3] = Mat(W_ho.clone());
924
0
        blobs[4] = Mat(b_o.clone());
925
0
    }
926
927
    bool getMemoryShapes(const std::vector<MatShape> &inputs,
928
                         const int requiredOutputs,
929
                         std::vector<MatShape> &outputs,
930
                         std::vector<MatShape> &internals) const CV_OVERRIDE
931
0
    {
932
0
        CV_Assert(inputs.size() >= 1 && inputs.size() <= 2);
933
934
0
        Mat Who_ = blobs[3];
935
0
        Mat Wxh_ = blobs[0];
936
937
0
        int numTimestamps_ = inputs[0][0];
938
0
        int numSamples_ = inputs[0][1];
939
940
0
        int numO_ = Who_.rows;
941
0
        int numH_ = Wxh_.rows;
942
943
0
        outputs.clear();
944
0
        int dims[] = {numTimestamps_, numSamples_, numO_};
945
0
        outputs.push_back(shape(dims, 3));
946
0
        dims[2] = numH_;
947
0
        if (produceH)
948
0
            outputs.push_back(shape(dims, 3));
949
950
0
        internals.assign(2, shape(numSamples_, numH_));
951
0
        internals.push_back(shape(numSamples_, 1));
952
953
0
        return false;
954
0
    }
955
956
    void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
957
0
    {
958
0
        std::vector<Mat> input, outputs;
959
0
        inputs_arr.getMatVector(input);
960
961
0
        CV_Assert(input.size() >= 1 && input.size() <= 2);
962
963
0
        Wxh = blobs[0];
964
0
        bh  = blobs[1];
965
0
        Whh = blobs[2];
966
0
        Who = blobs[3];
967
0
        bo  = blobs[4];
968
969
0
        numH = Wxh.rows;
970
0
        numX = Wxh.cols;
971
0
        numO = Who.rows;
972
973
0
        const Mat& inp0 = input[0];
974
975
0
        CV_Assert(inp0.dims >= 2);
976
0
        CV_Assert(inp0.total(2) == numX);
977
0
        dtype = CV_32F;
978
0
        CV_Assert(inp0.type() == dtype);
979
0
        numTimestamps = inp0.size[0];
980
0
        numSamples = inp0.size[1];
981
0
        numSamplesTotal = numTimestamps * numSamples;
982
983
0
        bh = bh.reshape(1, 1); //is 1 x numH Mat
984
0
        bo = bo.reshape(1, 1); //is 1 x numO Mat
985
0
    }
986
987
    void reshapeOutput(std::vector<Mat> &output)
988
0
    {
989
0
        output.resize(produceH ? 2 : 1);
990
0
        int sz0[] = { numTimestamps, numSamples, numO };
991
0
        output[0].create(3, sz0, dtype);
992
0
        if (produceH)
993
0
        {
994
0
            int sz1[] = { numTimestamps, numSamples, numH };
995
0
            output[1].create(3, sz1, dtype);
996
0
        }
997
0
    }
998
999
    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
1000
0
    {
1001
0
        CV_TRACE_FUNCTION();
1002
0
        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
1003
1004
0
        if (inputs_arr.depth() == CV_16F)
1005
0
        {
1006
0
            forward_fallback(inputs_arr, outputs_arr, internals_arr);
1007
0
            return;
1008
0
        }
1009
1010
0
        std::vector<Mat> input, output, internals;
1011
0
        inputs_arr.getMatVector(input);
1012
0
        outputs_arr.getMatVector(output);
1013
0
        internals_arr.getMatVector(internals);
1014
1015
0
        Mat xTs = input[0].reshape(1, numSamplesTotal);
1016
0
        Mat oTs = output[0].reshape(1, numSamplesTotal);
1017
0
        Mat hTs = produceH ? output[1].reshape(1, numSamplesTotal) : Mat();
1018
0
        Mat hCurr = internals[0];
1019
0
        Mat hPrev = internals[1];
1020
0
        Mat dummyBiasOnes = internals[2];
1021
1022
0
        hPrev.setTo(0.);
1023
0
        dummyBiasOnes.setTo(1.);
1024
1025
0
        for (int ts = 0; ts < numTimestamps; ts++)
1026
0
        {
1027
0
            Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples);
1028
0
            Mat xCurr = xTs.rowRange(curRowRange);
1029
1030
0
            gemm(hPrev, Whh, 1, hCurr, 0, hCurr, GEMM_2_T); // W_{hh} * h_{prev}
1031
0
            gemm(xCurr, Wxh, 1, hCurr, 1, hCurr, GEMM_2_T); //+W_{xh} * x_{curr}
1032
0
            gemm(dummyBiasOnes, bh, 1, hCurr, 1, hCurr);    //+bh
1033
0
            tanh(hCurr, hPrev);
1034
1035
0
            Mat oCurr = oTs.rowRange(curRowRange);
1036
0
            gemm(hPrev, Who, 1, oCurr, 0, oCurr, GEMM_2_T); // W_{ho} * h_{prev}
1037
0
            gemm(dummyBiasOnes, bo, 1, oCurr, 1, oCurr);    //+b_o
1038
0
            tanh(oCurr, oCurr);
1039
1040
0
            if (produceH)
1041
0
                hPrev.copyTo(hTs.rowRange(curRowRange));
1042
0
        }
1043
0
    }
1044
};
1045
1046
CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create(const LayerParams& params)
1047
0
{
1048
0
    return Ptr<RNNLayer>(new RNNLayerImpl(params));
1049
0
}
1050
1051
class GRULayerImpl CV_FINAL : public GRULayer
1052
{
1053
    int numTimeStamps, numSamples;
1054
    bool allocated;
1055
1056
    MatShape outTailShape;  //shape of single output sample
1057
    MatShape outTsShape;    //shape of N output samples
1058
    bool bidirectional;     // If true, produces both forward and reversed directions along time axis
1059
1060
public:
1061
1062
0
    GRULayerImpl(const LayerParams& params) : numTimeStamps(0), numSamples(0)
1063
0
    {
1064
0
        setParamsFrom(params);
1065
1066
0
        bidirectional = params.get<bool>("bidirectional", false);
1067
0
        if (!blobs.empty())
1068
0
        {
1069
0
            CV_Assert(blobs.size() >= 3);
1070
1071
0
            blobs[2] = blobs[2].reshape(1, 1);
1072
1073
0
            const Mat& Wh = blobs[0];
1074
0
            const Mat& Wx = blobs[1];
1075
0
            const Mat& bias = blobs[2];
1076
0
            const Mat& hInternal = blobs[3];
1077
0
            CV_CheckEQ(Wh.dims, 2, "");
1078
0
            CV_CheckEQ(Wx.dims, 2, "");
1079
0
            CV_CheckEQ(Wh.rows, Wx.rows, "");
1080
0
            CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional)) * 3 * Wh.cols, "");
1081
0
            CV_CheckEQ(Wh.rows * 2, (int)bias.total(), "");
1082
0
            CV_CheckEQ(hInternal.cols, Wh.cols, "");
1083
0
            CV_CheckTypeEQ(Wh.type(), Wx.type(), "");
1084
0
            CV_CheckTypeEQ(Wx.type(), bias.type(), "");
1085
0
        }
1086
1087
0
        allocated = false;
1088
0
        outTailShape.clear();
1089
0
    }
1090
1091
    bool getMemoryShapes(const std::vector<MatShape> &inputs,
1092
                         const int requiredOutputs,
1093
                         std::vector<MatShape> &outputs,
1094
                         std::vector<MatShape> &internals) const CV_OVERRIDE
1095
0
    {
1096
0
        CV_Assert(inputs.size() == 1);
1097
0
        const MatShape& inp0 = inputs[0];
1098
1099
0
        const Mat &Wh = blobs[0], &Wx = blobs[1];
1100
0
        int _numOut = Wh.size[1];
1101
0
        int _numInp = Wx.size[1];
1102
0
        MatShape outTailShape_(outTailShape), outResShape;
1103
1104
0
        if (!outTailShape_.empty())
1105
0
            CV_Assert(total(outTailShape_) == _numOut);
1106
0
        else
1107
0
            outTailShape_.assign(1, _numOut);
1108
1109
0
        int _numSamples;
1110
0
        CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp);
1111
0
        _numSamples = inp0[1];
1112
0
        outResShape.push_back(inp0[0]);
1113
1114
0
        outResShape.push_back(_numSamples);
1115
0
        outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
1116
0
        outResShape.back() *= (1 + static_cast<int>(bidirectional));
1117
1118
0
        outputs.assign(1, outResShape);
1119
1120
0
        internals.assign(1, shape(_numSamples, _numOut));     // hInternal
1121
0
        internals.push_back(shape(_numSamples, 1));           // dummyOnes
1122
0
        internals.push_back(shape(_numSamples, 2 * _numOut)); // gates
1123
0
        internals.push_back(shape(_numSamples, 2 * _numOut)); // gates_b
1124
0
        internals.push_back(shape(_numSamples, 1 * _numOut)); // h_linear
1125
0
        internals.push_back(shape(_numSamples, _numOut));     // ones
1126
1127
0
        return false;
1128
0
    }
1129
1130
    void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
1131
0
    {
1132
0
        std::vector<Mat> input;
1133
0
        inputs_arr.getMatVector(input);
1134
1135
0
        CV_Assert(input.size() == 1);
1136
0
        const Mat& inp0 = input[0];
1137
1138
0
        Mat &Wh = blobs[0], &Wx = blobs[1];
1139
0
        int numOut = Wh.size[1];
1140
0
        int numInp = Wx.size[1];
1141
1142
0
        if (!outTailShape.empty())
1143
0
            CV_Assert(total(outTailShape) == numOut);
1144
0
        else
1145
0
            outTailShape.assign(1, numOut);
1146
1147
0
        CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp);
1148
0
        numTimeStamps = inp0.size[0];
1149
0
        numSamples = inp0.size[1];
1150
1151
0
        outTsShape.clear();
1152
0
        outTsShape.push_back(numSamples);
1153
0
        outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
1154
0
        outTsShape.back() *= (1 + static_cast<int>(bidirectional));
1155
1156
0
        allocated = true;
1157
0
    }
1158
1159
    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
1160
0
    {
1161
0
        CV_TRACE_FUNCTION();
1162
0
        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
1163
1164
0
        if (inputs_arr.depth() == CV_16F)
1165
0
        {
1166
0
            forward_fallback(inputs_arr, outputs_arr, internals_arr);
1167
0
            return;
1168
0
        }
1169
1170
0
        std::vector<Mat> input, output, internals;
1171
0
        inputs_arr.getMatVector(input);
1172
0
        outputs_arr.getMatVector(output);
1173
0
        internals_arr.getMatVector(internals);
1174
1175
0
        const int numDirs = 1 + static_cast<int>(bidirectional);
1176
0
        for (int i = 0; i < numDirs; ++i)
1177
0
        {
1178
0
            const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
1179
0
            const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
1180
0
            const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
1181
0
            const Mat &h_0 = blobs[3].rowRange(i * blobs[3].rows / numDirs, (i + 1) * blobs[3].rows / numDirs);
1182
1183
0
            const Mat &bx = bias.colRange(0, bias.cols / 2);
1184
0
            const Mat &bh = bias.colRange(bias.cols / 2, bias.cols);
1185
1186
0
            Mat hInternal = internals[0], dummyOnes = internals[1], gates = internals[2],
1187
0
                b_rz = internals[3], n_t = internals[4], ones = internals[5];
1188
0
            h_0.copyTo(hInternal);
1189
0
            dummyOnes.setTo(1.);
1190
0
            ones.setTo(1.);
1191
1192
0
            int numOut = Wh.size[1];
1193
0
            const Mat& wx_rz = Wx.rowRange(0, 2 * numOut);
1194
0
            const Mat& wh_rz = Wh.rowRange(0, 2 * numOut);
1195
0
            b_rz = bx.colRange(0, 2 * numOut) + bh.colRange(0, 2 * numOut);
1196
0
            const Mat& wx_n = Wx.rowRange(2 * numOut, 3 * numOut);
1197
0
            const Mat& wh_n = Wh.rowRange(2 * numOut, 3 * numOut);
1198
0
            const Mat& b_in = bx.colRange(2 * numOut, 3 * numOut);
1199
0
            const Mat& b_hn = bh.colRange(2 * numOut, 3 * numOut);
1200
1201
0
            int numSamplesTotal = numTimeStamps * numSamples;
1202
0
            Mat xTs = input[0].reshape(1, numSamplesTotal);
1203
1204
0
            Mat hOutTs = output[0].reshape(1, numSamplesTotal);
1205
0
            hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
1206
0
            Mat cOutTs = Mat();
1207
1208
0
            int tsStart, tsEnd, tsInc;
1209
0
            if (i == 1) {
1210
0
                tsStart = numTimeStamps - 1;
1211
0
                tsEnd = -1;
1212
0
                tsInc = -1;
1213
0
            }
1214
0
            else {
1215
0
                tsStart = 0;
1216
0
                tsEnd = numTimeStamps;
1217
0
                tsInc = 1;
1218
0
            }
1219
0
            for (int ts = tsStart; ts != tsEnd; ts += tsInc)
1220
0
            {
1221
0
                Range curRowRange(ts * numSamples, (ts + 1) * numSamples);
1222
0
                Mat xCurr = xTs.rowRange(curRowRange);
1223
1224
                // calculate r_t = sigmoid(x * Wx_r + h_(t-1) * Wh_r + b_r)
1225
                // calculate z_t = sigmoid(x * Wx_z + h_(t-1) * Wh_z + b_z)
1226
0
                gemm(xCurr, wx_rz, 1, gates, 0, gates, GEMM_2_T);      // x * Wx_rz
1227
0
                gemm(hInternal, wh_rz, 1, gates, 1, gates, GEMM_2_T);  // + h_(t-1) * Wh_rz
1228
0
                gemm(dummyOnes, b_rz, 1, gates, 1, gates);             // + b_rz
1229
0
                sigmoid(gates, gates);                                 // sigmoid()
1230
1231
0
                Mat z = gates.colRange(0, gates.cols / 2);
1232
0
                Mat r = gates.colRange(gates.cols / 2, gates.cols);
1233
1234
                // calculate n_t = tanh(r (*) (h_(t-1) * Wh_n + b_hn) + x * Wx_n + b_in)
1235
0
                gemm(hInternal, wh_n, 1, n_t, 0, n_t, GEMM_2_T);       // h_(t-1) * Wh_n
1236
0
                gemm(dummyOnes, b_hn, 1, n_t, 1, n_t);                 // + b_hn
1237
0
                multiply(r, n_t, n_t);                                 // r (*) (h_(t-1) * Wh_n + b_hn)
1238
1239
0
                gemm(xCurr, wx_n, 1, n_t, 1, n_t, GEMM_2_T);          // + x * Wx_n
1240
0
                gemm(dummyOnes, b_in, 1, n_t, 1, n_t);                // + b_in
1241
0
                tanh(n_t, n_t);                                       // tanh()
1242
1243
                //compute next h_t = z (*) h_(t-1) + (1 - z) (*) n_t
1244
0
                multiply(z, hInternal, hInternal);                    // z (*) h_{t-1}
1245
0
                subtract(ones, z, z);                                 // 1 - z
1246
0
                multiply(z, n_t, z);                                  // (1 - z) * n
1247
0
                add(z, hInternal, hInternal);                         // z (*) h_(t-1) + (1 - z) (*) n_t
1248
1249
                //save results in output blobs
1250
0
                hInternal.copyTo(hOutTs.rowRange(curRowRange));
1251
0
            }
1252
0
        }
1253
0
    }
1254
};
1255
1256
0
Ptr<GRULayer> GRULayer::create(const LayerParams &params) {
1257
0
    return Ptr<GRULayer>(new GRULayerImpl(params));
1258
0
}
1259
1260
}
1261
}