/src/Simd/src/Test/TestTensor.h
Line | Count | Source |
1 | | /* |
2 | | * Tests for Simd Library (http://ermig1979.github.io/Simd). |
3 | | * |
4 | | * Copyright (c) 2011-2025 Yermalayeu Ihar. |
5 | | * |
6 | | * Permission is hereby granted, free of charge, to any person obtaining a copy |
7 | | * of this software and associated documentation files (the "Software"), to deal |
8 | | * in the Software without restriction, including without limitation the rights |
9 | | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
10 | | * copies of the Software, and to permit persons to whom the Software is |
11 | | * furnished to do so, subject to the following conditions: |
12 | | * |
13 | | * The above copyright notice and this permission notice shall be included in |
14 | | * all copies or substantial portions of the Software. |
15 | | * |
16 | | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
17 | | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
18 | | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
19 | | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
20 | | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
21 | | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
22 | | * SOFTWARE. |
23 | | */ |
24 | | #ifndef __TestTensor_h__ |
25 | | #define __TestTensor_h__ |
26 | | |
27 | | #include "Test/TestLog.h" |
28 | | |
29 | | namespace Test |
30 | | { |
31 | | typedef std::vector<size_t> Index; |
32 | | |
33 | | //------------------------------------------------------------------------------------------------- |
34 | | |
35 | | template<class T> SimdTensorDataType DataType(); |
36 | | |
37 | 0 | template<> SIMD_INLINE SimdTensorDataType DataType<float>() { return SimdTensorData32f; }; |
38 | 0 | template<> SIMD_INLINE SimdTensorDataType DataType<uint16_t>() { return SimdTensorData16f; }; |
39 | 0 | template<> SIMD_INLINE SimdTensorDataType DataType<uint8_t>() { return SimdTensorData8u; }; |
40 | | |
41 | | //------------------------------------------------------------------------------------------------- |
42 | | |
43 | | template<class T> class Tensor |
44 | | { |
45 | | public: |
46 | | typedef T Type; |
47 | | |
48 | | SIMD_INLINE Tensor() |
49 | | : _size(0) |
50 | | , _format(SimdTensorFormatUnknown) |
51 | | { |
52 | | } |
53 | | |
54 | | SIMD_INLINE Tensor(const Test::Shape & shape, SimdTensorFormatType format = SimdTensorFormatUnknown, const Type & value = Type()) |
55 | | : _shape(shape) |
56 | | , _format(format) |
57 | | { |
58 | | Resize(value); |
59 | | } |
60 | | |
61 | | SIMD_INLINE Tensor(std::initializer_list<size_t> shape, SimdTensorFormatType format = SimdTensorFormatUnknown, const Type & value = Type()) |
62 | | : _shape(shape.begin(), shape.end()) |
63 | | , _format(format) |
64 | | { |
65 | | Resize(value); |
66 | | } |
67 | | |
68 | | SIMD_INLINE ~Tensor() |
69 | | { |
70 | | } |
71 | | |
72 | | SIMD_INLINE void Reshape(const Test::Shape & shape, SimdTensorFormatType format = SimdTensorFormatUnknown, const Type & value = Type()) |
73 | | { |
74 | | _shape = shape; |
75 | | _format = format; |
76 | | Resize(value); |
77 | | } |
78 | | |
79 | | SIMD_INLINE void Reshape(std::initializer_list<size_t> shape, SimdTensorFormatType format = SimdTensorFormatUnknown, const Type & value = Type()) |
80 | | { |
81 | | _shape.assign(shape.begin(), shape.end()); |
82 | | _format = format; |
83 | | Resize(value); |
84 | | } |
85 | | |
86 | | SIMD_INLINE void Extend(const Test::Shape & shape, const Type& value = Type()) |
87 | | { |
88 | | _shape = shape; |
89 | | Extend(value); |
90 | | } |
91 | | |
92 | | SIMD_INLINE void Extend(std::initializer_list<size_t> shape, const Type& value = Type()) |
93 | | { |
94 | | _shape.assign(shape.begin(), shape.end()); |
95 | | Extend(value); |
96 | | } |
97 | | |
98 | | SIMD_INLINE void Clone(const Tensor& tensor) |
99 | | { |
100 | | _shape = tensor._shape; |
101 | | _format = tensor._format; |
102 | | _size = tensor._size; |
103 | | _data = tensor._data; |
104 | | } |
105 | | |
106 | | static SIMD_INLINE SimdTensorDataType DataType() |
107 | | { |
108 | | return Test::DataType<Type>(); |
109 | | } |
110 | | |
111 | | SIMD_INLINE SimdTensorFormatType Format() const |
112 | | { |
113 | | return _format; |
114 | | } |
115 | | |
116 | | SIMD_INLINE const Test::Shape & Shape() const |
117 | | { |
118 | | return _shape; |
119 | | } |
120 | | |
121 | | SIMD_INLINE size_t Count() const |
122 | 0 | { |
123 | 0 | return _shape.size(); |
124 | 0 | } |
125 | | |
126 | | SIMD_INLINE size_t Index(ptrdiff_t axis) const |
127 | 0 | { |
128 | 0 | if (axis < 0) |
129 | 0 | axis += _shape.size(); |
130 | 0 | return axis; |
131 | 0 | } |
132 | | |
133 | | SIMD_INLINE size_t Axis(ptrdiff_t axis) const |
134 | 0 | { |
135 | 0 | return _shape[Index(axis)]; |
136 | 0 | } |
137 | | |
138 | | SIMD_INLINE size_t Size(ptrdiff_t startAxis, ptrdiff_t endAxis) const |
139 | | { |
140 | | startAxis = Index(startAxis); |
141 | | endAxis = Index(endAxis); |
142 | | assert(startAxis <= endAxis && (size_t)endAxis <= _shape.size()); |
143 | | size_t size = 1; |
144 | | for (ptrdiff_t axis = startAxis; axis < endAxis; ++axis) |
145 | | size *= _shape[axis]; |
146 | | return size; |
147 | | } |
148 | | |
149 | | SIMD_INLINE size_t Size(ptrdiff_t startAxis) const |
150 | | { |
151 | | return Size(startAxis, _shape.size()); |
152 | | } |
153 | | |
154 | | SIMD_INLINE size_t Size() const |
155 | 0 | { |
156 | 0 | return _size; |
157 | 0 | } |
158 | | |
159 | | SIMD_INLINE size_t Offset(const Test::Index & index) const |
160 | 0 | { |
161 | 0 | assert(_shape.size() == index.size()); |
162 | 0 |
|
163 | 0 | size_t offset = 0; |
164 | 0 | for (size_t axis = 0; axis < _shape.size(); ++axis) |
165 | 0 | { |
166 | 0 | assert(_shape[axis] > 0); |
167 | 0 | assert(index[axis] < _shape[axis]); |
168 | 0 |
|
169 | 0 | offset *= _shape[axis]; |
170 | 0 | offset += index[axis]; |
171 | 0 | } |
172 | 0 | return offset; |
173 | 0 | } |
174 | | |
175 | | SIMD_INLINE size_t Offset(std::initializer_list<size_t> index) const |
176 | | { |
177 | | assert(_shape.size() == index.size()); |
178 | | |
179 | | size_t offset = 0; |
180 | | for (const size_t * s = _shape.data(), *i = index.begin(); i < index.end(); ++s, ++i) |
181 | | { |
182 | | assert(*s > 0); |
183 | | assert(*i < *s); |
184 | | |
185 | | offset *= *s; |
186 | | offset += *i; |
187 | | } |
188 | | return offset; |
189 | | } |
190 | | |
191 | | SIMD_INLINE Type * Data() |
192 | 0 | { |
193 | 0 | return _data.data(); |
194 | 0 | } |
195 | | |
196 | | SIMD_INLINE const Type * Data() const |
197 | 0 | { |
198 | 0 | return _data.data(); |
199 | 0 | } |
200 | | |
201 | | SIMD_INLINE Type * Data(const Test::Index & index) |
202 | 0 | { |
203 | 0 | return Data() + Offset(index); |
204 | 0 | } |
205 | | |
206 | | SIMD_INLINE const Type * Data(const Test::Index & index) const |
207 | 0 | { |
208 | 0 | return Data() + Offset(index); |
209 | 0 | } |
210 | | |
211 | | SIMD_INLINE Type * Data(std::initializer_list<size_t> index) |
212 | | { |
213 | | return Data() + Offset(index); |
214 | | } |
215 | | |
216 | | SIMD_INLINE const Type * Data(std::initializer_list<size_t> index) const |
217 | | { |
218 | | return Data() + Offset(index); |
219 | | } |
220 | | |
221 | | SIMD_INLINE size_t Batch() const |
222 | | { |
223 | | assert(_shape.size() == 4 && (_format == SimdTensorFormatNchw || _format == SimdTensorFormatNhwc)); |
224 | | return _shape[0]; |
225 | | } |
226 | | |
227 | | SIMD_INLINE size_t Channels() const |
228 | | { |
229 | | assert(_shape.size() == 4 && (_format == SimdTensorFormatNchw || _format == SimdTensorFormatNhwc)); |
230 | | return _format == SimdTensorFormatNchw ? _shape[1] : _shape[3]; |
231 | | } |
232 | | |
233 | | SIMD_INLINE size_t Height() const |
234 | | { |
235 | | assert(_shape.size() == 4 && (_format == SimdTensorFormatNchw || _format == SimdTensorFormatNhwc)); |
236 | | return _format == SimdTensorFormatNchw ? _shape[2] : _shape[1]; |
237 | | } |
238 | | |
239 | | SIMD_INLINE size_t Width() const |
240 | | { |
241 | | assert(_shape.size() == 4 && (_format == SimdTensorFormatNchw || _format == SimdTensorFormatNhwc)); |
242 | | return _format == SimdTensorFormatNchw ? _shape[3] : _shape[2]; |
243 | | } |
244 | | |
245 | | void DebugPrint(std::ostream & os, const String & name, size_t first = 5, size_t last = 2) const |
246 | | { |
247 | | os << name << " { "; |
248 | | for (size_t i = 0; i < _shape.size(); ++i) |
249 | | os << _shape[i] << " "; |
250 | | os << "} " << std::endl; |
251 | | |
252 | | if (_size == 0) |
253 | | return; |
254 | | |
255 | | size_t n = _shape.size(); |
256 | | Test::Shape firsts(n), lasts(n), index(n, 0); |
257 | | Strings separators(n); |
258 | | for (ptrdiff_t i = n - 1; i >= 0; --i) |
259 | | { |
260 | | if (i == n - 1) |
261 | | { |
262 | | firsts[i] = first; |
263 | | lasts[i] = last; |
264 | | separators[i] = "\t"; |
265 | | } |
266 | | else |
267 | | { |
268 | | firsts[i] = std::max<size_t>(firsts[i + 1] - 1, 1); |
269 | | lasts[i] = std::max<size_t>(lasts[i + 1] - 1, 1); |
270 | | separators[i] = separators[i + 1] + "\n"; |
271 | | } |
272 | | } |
273 | | DebugPrint(os, firsts, lasts, separators, index, 0); |
274 | | if (n == 1 || n == 0) |
275 | | os << "\n"; |
276 | | } |
277 | | |
278 | | private: |
279 | | |
280 | | void DebugPrint(std::ostream & os, const Test::Shape & firsts, const Test::Shape & lasts, const Strings & separators, Test::Shape index, size_t order) const |
281 | | { |
282 | | if (order == _shape.size()) |
283 | | { |
284 | | std::cout << std::fixed << std::setprecision(4); |
285 | | os << *Data(index); |
286 | | return; |
287 | | } |
288 | | if (firsts[order] + lasts[order] < _shape[order]) |
289 | | { |
290 | | size_t lo = firsts[order], hi = _shape[order] - lasts[order]; |
291 | | for (index[order] = 0; index[order] < lo; ++index[order]) |
292 | | { |
293 | | DebugPrint(os, firsts, lasts, separators, index, order + 1); |
294 | | os << separators[order]; |
295 | | } |
296 | | os << "..." << separators[order]; |
297 | | for (index[order] = hi; index[order] < _shape[order]; ++index[order]) |
298 | | { |
299 | | DebugPrint(os, firsts, lasts, separators, index, order + 1); |
300 | | os << separators[order]; |
301 | | } |
302 | | } |
303 | | else |
304 | | { |
305 | | for (index[order] = 0; index[order] < _shape[order]; ++index[order]) |
306 | | { |
307 | | DebugPrint(os, firsts, lasts, separators, index, order + 1); |
308 | | os << separators[order]; |
309 | | } |
310 | | } |
311 | | } |
312 | | |
313 | | SIMD_INLINE void Resize(const Type & value) |
314 | | { |
315 | | _size = Size(0, _shape.size()); |
316 | | _data.resize(_size, value); |
317 | | SetDebugPtr(); |
318 | | } |
319 | | |
320 | | SIMD_INLINE void Extend(const Type& value) |
321 | | { |
322 | | _size = Size(0, _shape.size()); |
323 | | if (_size > _data.size()) |
324 | | _data.resize(_size, value); |
325 | | SetDebugPtr(); |
326 | | } |
327 | | |
328 | | #if defined(_DEBUG) && defined(_MSC_VER) |
329 | | const Type * _ptr; |
330 | | |
331 | | SIMD_INLINE void SetDebugPtr() |
332 | | { |
333 | | _ptr = _data.data(); |
334 | | } |
335 | | #else |
336 | | SIMD_INLINE void SetDebugPtr() |
337 | | { |
338 | | } |
339 | | #endif |
340 | | |
341 | | typedef std::vector<Type, Simd::Allocator<Type>> Vector; |
342 | | |
343 | | SimdTensorFormatType _format; |
344 | | Test::Shape _shape; |
345 | | size_t _size; |
346 | | Vector _data; |
347 | | }; |
348 | | |
349 | | typedef Tensor<float> Tensor32f; |
350 | | typedef Tensor<uint8_t> Tensor8u; |
351 | | typedef Tensor<uint16_t> Tensor16u; |
352 | | |
353 | | //------------------------------------------------------------------------------------------------- |
354 | | |
355 | | inline Shape Shp() |
356 | 0 | { |
357 | 0 | return Shape(); |
358 | 0 | } |
359 | | |
360 | | inline Shape Shp(size_t axis0) |
361 | 0 | { |
362 | 0 | return Shape({ axis0 }); |
363 | 0 | } |
364 | | |
365 | | inline Shape Shp(size_t axis0, size_t axis1) |
366 | 0 | { |
367 | 0 | return Shape({ axis0, axis1 }); |
368 | 0 | } |
369 | | |
370 | | inline Shape Shp(size_t axis0, size_t axis1, size_t axis2) |
371 | 0 | { |
372 | 0 | return Shape({ axis0, axis1, axis2 }); |
373 | 0 | } |
374 | | |
375 | | inline Shape Shp(size_t axis0, size_t axis1, size_t axis2, size_t axis3) |
376 | 0 | { |
377 | 0 | return Shape({ axis0, axis1, axis2, axis3 }); |
378 | 0 | } |
379 | | |
380 | | inline Shape Shp(size_t axis0, size_t axis1, size_t axis2, size_t axis3, size_t axis4) |
381 | 0 | { |
382 | 0 | return Shape({ axis0, axis1, axis2, axis3, axis4 }); |
383 | 0 | } |
384 | | |
385 | | inline Shape Shp(size_t axis0, size_t axis1, size_t axis2, size_t axis3, size_t axis4, size_t axis5) |
386 | 0 | { |
387 | 0 | return Shape({ axis0, axis1, axis2, axis3, axis4, axis5 }); |
388 | 0 | } |
389 | | |
390 | | //------------------------------------------------------------------------------------------------- |
391 | | |
392 | | template<class T> inline void Copy(const Tensor<T> & src, Tensor<T> & dst) |
393 | | { |
394 | | assert(src.Size() == dst.Size()); |
395 | | memcpy(dst.Data(), src.Data(), src.Size() * sizeof(T)); |
396 | | } |
397 | | |
398 | | template<class T> inline void Fill(Tensor<T>& tensor, T value) |
399 | | { |
400 | | for (size_t i = 0; i < tensor.Size(); ++i) |
401 | | tensor.Data()[i] = value; |
402 | | } |
403 | | |
404 | | //------------------------------------------------------------------------------------------------- |
405 | | |
406 | | inline void Compare(const Tensor32f & a, const Tensor32f & b, float differenceMax, bool printError, int errorCountMax, DifferenceType differenceType, const String & description, |
407 | | Shape index, size_t order, int & errorCount, std::stringstream & message) |
408 | 0 | { |
409 | 0 | if (order == a.Count()) |
410 | 0 | { |
411 | 0 | float _a = *a.Data(index); |
412 | 0 | float _b = *b.Data(index); |
413 | 0 | float absolute = ::fabs(_a - _b); |
414 | 0 | float relative = ::fabs(_a - _b) / Simd::Max(::fabs(_a), ::fabs(_b)); |
415 | 0 | bool aNan = _a != _a; |
416 | 0 | bool bNan = _b != _b; |
417 | 0 | bool error = false; |
418 | 0 | switch (differenceType) |
419 | 0 | { |
420 | 0 | case DifferenceAbsolute: error = absolute > differenceMax || aNan || bNan; break; |
421 | 0 | case DifferenceRelative: error = relative > differenceMax; break; |
422 | 0 | case DifferenceBoth: error = (absolute > differenceMax && relative > differenceMax) || aNan || bNan; break; |
423 | 0 | case DifferenceAny: error = absolute > differenceMax || relative > differenceMax || aNan || bNan; break; |
424 | 0 | case DifferenceLogical: error = aNan != bNan || (aNan == false && _a != _b); break; |
425 | 0 | default: |
426 | 0 | assert(0); |
427 | 0 | } |
428 | 0 | if (error) |
429 | 0 | { |
430 | 0 | errorCount++; |
431 | 0 | if (printError) |
432 | 0 | { |
433 | 0 | if (errorCount == 1) |
434 | 0 | message << std::endl << "Fail comparison: " << description << std::endl; |
435 | 0 | message << "Error at ["; |
436 | 0 | for(size_t i = 0; i < index.size() - 1; ++i) |
437 | 0 | message << index[i] << ", "; |
438 | 0 | message << index[index.size() - 1] << "] : " << _a << " != " << _b << ";" |
439 | 0 | << " (absolute = " << absolute << ", relative = " << relative << ", threshold = " << differenceMax << ")!" << std::endl; |
440 | 0 | } |
441 | 0 | if (errorCount > errorCountMax) |
442 | 0 | { |
443 | 0 | if (printError) |
444 | 0 | message << "Stop comparison." << std::endl; |
445 | 0 | } |
446 | 0 | } |
447 | 0 | } |
448 | 0 | else |
449 | 0 | { |
450 | 0 | for (index[order] = 0; index[order] < a.Axis(order) && errorCount < errorCountMax; ++index[order]) |
451 | 0 | Compare(a, b, differenceMax, printError, errorCountMax, differenceType, description, index, order + 1, errorCount, message); |
452 | 0 | } |
453 | 0 | } |
454 | | |
455 | | inline bool Compare(const Tensor32f & a, const Tensor32f & b, float differenceMax, bool printError, int errorCountMax, DifferenceType differenceType, const String & description = "") |
456 | 0 | { |
457 | 0 | std::stringstream message; |
458 | 0 | message << std::fixed << std::setprecision(6); |
459 | 0 | int errorCount = 0; |
460 | 0 | if (memcmp(a.Data(), b.Data(), a.Size() * sizeof(float)) == 0) |
461 | 0 | return true; |
462 | 0 | Index index(a.Count(), 0); |
463 | 0 | Compare(a, b, differenceMax, printError, errorCountMax, differenceType, description, index, 0, errorCount, message); |
464 | 0 | if (printError && errorCount > 0) |
465 | 0 | TEST_LOG_SS(Error, message.str()); |
466 | 0 | return errorCount == 0; |
467 | 0 | } |
468 | | |
469 | | template<class T> inline void Compare(const Tensor<T>& a, const Tensor<T>& b, int differenceMax, bool printError, int errorCountMax, const String& description, |
470 | | Shape index, size_t order, int& errorCount, std::stringstream& message) |
471 | | { |
472 | | if (order == a.Count()) |
473 | | { |
474 | | int _a = *a.Data(index); |
475 | | int _b = *b.Data(index); |
476 | | int difference = Simd::Abs(_a - _b); |
477 | | bool error = difference > differenceMax; |
478 | | if (error) |
479 | | { |
480 | | errorCount++; |
481 | | if (printError) |
482 | | { |
483 | | if (errorCount == 1) |
484 | | message << std::endl << "Fail comparison: " << description << std::endl; |
485 | | message << "Error at ["; |
486 | | for (size_t i = 0; i < index.size() - 1; ++i) |
487 | | message << index[i] << ", "; |
488 | | message << index[index.size() - 1] << "] : " << _a << " != " << _b << ";" |
489 | | << " (difference = " << difference << ")!" << std::endl; |
490 | | } |
491 | | if (errorCount > errorCountMax) |
492 | | { |
493 | | if (printError) |
494 | | message << "Stop comparison." << std::endl; |
495 | | } |
496 | | } |
497 | | } |
498 | | else |
499 | | { |
500 | | for (index[order] = 0; index[order] < a.Axis(order) && errorCount < errorCountMax; ++index[order]) |
501 | | Compare(a, b, differenceMax, printError, errorCountMax, description, index, order + 1, errorCount, message); |
502 | | } |
503 | | } |
504 | | |
505 | | template<class T> inline bool Compare(const Tensor<T>& a, const Tensor<T>& b, int differenceMax, bool printError, int errorCountMax, const String& description = "") |
506 | | { |
507 | | std::stringstream message; |
508 | | int errorCount = 0; |
509 | | Index index(a.Count(), 0); |
510 | | Compare(a, b, differenceMax, printError, errorCountMax, description, index, 0, errorCount, message); |
511 | | if (printError && errorCount > 0) |
512 | | TEST_LOG_SS(Error, message.str()); |
513 | | return errorCount == 0; |
514 | | } |
515 | | |
516 | | //------------------------------------------------------------------------------------------------- |
517 | | |
518 | | inline void FillDebug(Tensor32f & dst, Shape index, size_t order) |
519 | 0 | { |
520 | 0 | if (order == dst.Count()) |
521 | 0 | { |
522 | 0 | float value = 0.0f; |
523 | 0 | for (int i = (int)index.size() - 1, n = 1; i >= 0; --i, n *= 10) |
524 | 0 | value += float(index[i]%10*n); |
525 | 0 | *dst.Data(index) = value; |
526 | 0 | } |
527 | 0 | else |
528 | 0 | { |
529 | 0 | for (index[order] = 0; index[order] < dst.Axis(order); ++index[order]) |
530 | 0 | FillDebug(dst, index, order + 1); |
531 | 0 | } |
532 | 0 | } |
533 | | |
534 | | inline void FillDebug(Tensor32f & dst) |
535 | 0 | { |
536 | 0 | Index index(dst.Count(), 0); |
537 | 0 | FillDebug(dst, index, 0); |
538 | 0 | } |
539 | | |
540 | | //------------------------------------------------------------------------------------------------- |
541 | | |
542 | | inline String ToString(SimdTensorFormatType format) |
543 | 0 | { |
544 | 0 | switch (format) |
545 | 0 | { |
546 | 0 | case SimdTensorFormatUnknown: return "Unknown"; |
547 | 0 | case SimdTensorFormatNchw: return "Nchw"; |
548 | 0 | case SimdTensorFormatNhwc: return "Nhwc"; |
549 | 0 | default: assert(0); return "Assert"; |
550 | 0 | } |
551 | 0 | } |
552 | | |
553 | | inline String ToString(SimdTensorDataType data) |
554 | 0 | { |
555 | 0 | switch (data) |
556 | 0 | { |
557 | 0 | case SimdTensorDataUnknown: return "Unknown"; |
558 | 0 | case SimdTensorData32f: return "32f"; |
559 | 0 | case SimdTensorData32i: return "32i"; |
560 | 0 | case SimdTensorData8i: return "8i"; |
561 | 0 | case SimdTensorData8u: return "8u"; |
562 | 0 | case SimdTensorData64i: return "64i"; |
563 | 0 | case SimdTensorData64u: return "64u"; |
564 | 0 | case SimdTensorDataBool: return "Bool"; |
565 | 0 | case SimdTensorData16b: return "16b"; |
566 | 0 | case SimdTensorData16f: return "16f"; |
567 | 0 | default: assert(0); return "Assert"; |
568 | 0 | } |
569 | 0 | } |
570 | | |
571 | | inline String ToChar(SimdTensorDataType data) |
572 | 0 | { |
573 | 0 | switch (data) |
574 | 0 | { |
575 | 0 | case SimdTensorDataUnknown: return "?"; |
576 | 0 | case SimdTensorData32f: return "f"; |
577 | 0 | case SimdTensorData32i: return "i"; |
578 | 0 | case SimdTensorData8i: return "u"; |
579 | 0 | case SimdTensorData8u: return "u"; |
580 | 0 | case SimdTensorData64i: return "l"; |
581 | 0 | case SimdTensorData64u: return "l"; |
582 | 0 | case SimdTensorDataBool: return "~"; |
583 | 0 | case SimdTensorData16b: return "b"; |
584 | 0 | case SimdTensorData16f: return "h"; |
585 | 0 | default: assert(0); return "Assert"; |
586 | 0 | } |
587 | 0 | } |
588 | | |
589 | | //------------------------------------------------------------------------------------------------- |
590 | | |
591 | | inline Shape ToShape(size_t batch, size_t channels, size_t height, size_t width, SimdTensorFormatType format) |
592 | 0 | { |
593 | 0 | switch (format) |
594 | 0 | { |
595 | 0 | case SimdTensorFormatNchw: return Shp(batch, channels, height, width); |
596 | 0 | case SimdTensorFormatNhwc: return Shp(batch, height, width, channels); |
597 | 0 | default: assert(0); return Shape(0); |
598 | 0 | } |
599 | 0 | } |
600 | | |
601 | | inline Shape ToShape(size_t channels, size_t height, size_t width, SimdTensorFormatType format) |
602 | 0 | { |
603 | 0 | switch (format) |
604 | 0 | { |
605 | 0 | case SimdTensorFormatNchw: return Shp(channels, height, width); |
606 | 0 | case SimdTensorFormatNhwc: return Shp(height, width, channels); |
607 | 0 | default: assert(0); return Shape(0); |
608 | 0 | } |
609 | 0 | } |
610 | | |
611 | | inline Shape ToShape(size_t channels, size_t spatial, SimdTensorFormatType format) |
612 | 0 | { |
613 | 0 | switch (format) |
614 | 0 | { |
615 | 0 | case SimdTensorFormatNchw: return Shp(channels, spatial); |
616 | 0 | case SimdTensorFormatNhwc: return Shp(spatial, channels); |
617 | 0 | default: assert(0); return Shape(0); |
618 | 0 | } |
619 | 0 | } |
620 | | |
621 | | inline Shape ToShape(size_t value) |
622 | 0 | { |
623 | 0 | return Shape(1, value); |
624 | 0 | } |
625 | | |
626 | | inline Shape ToShape(size_t channels, SimdTensorFormatType format) |
627 | 0 | { |
628 | 0 | return Shp(channels); |
629 | 0 | } |
630 | | |
631 | | //------------------------------------------------------------------------------------------------- |
632 | | |
633 | | inline bool IsCompatible(const Shape& a, const Shape& b) |
634 | 0 | { |
635 | 0 | for (size_t i = 0, n = std::max(a.size(), b.size()), a0 = n - a.size(), b0 = n - b.size(); i < n; ++i) |
636 | 0 | { |
637 | 0 | size_t ai = i < a0 ? 1 : a[i - a0]; |
638 | 0 | size_t bi = i < b0 ? 1 : b[i - b0]; |
639 | 0 | if (!(ai == bi || ai == 1 || bi == 1)) |
640 | 0 | return false; |
641 | 0 | } |
642 | 0 | return true; |
643 | 0 | } |
644 | | |
645 | | inline Shape OutputShape(const Shape& a, const Shape& b) |
646 | 0 | { |
647 | 0 | Shape d(std::max(a.size(), b.size()), 1); |
648 | 0 | for (size_t i = 0, n = d.size(), a0 = n - a.size(), b0 = n - b.size(); i < n; ++i) |
649 | 0 | { |
650 | 0 | size_t ai = i < a0 ? 1 : a[i - a0]; |
651 | 0 | size_t bi = i < b0 ? 1 : b[i - b0]; |
652 | 0 | d[i] = std::max(ai, bi); |
653 | 0 | } |
654 | 0 | return d; |
655 | 0 | } |
656 | | } |
657 | | |
658 | | #endif// __TestTensor_h__ |