Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "clang/Support/RISCVVIntrinsicUtils.h"
10
#include "llvm/ADT/ArrayRef.h"
11
#include "llvm/ADT/SmallSet.h"
12
#include "llvm/ADT/StringExtras.h"
13
#include "llvm/ADT/StringSet.h"
14
#include "llvm/ADT/Twine.h"
15
#include "llvm/Support/ErrorHandling.h"
16
#include "llvm/Support/raw_ostream.h"
17
#include <numeric>
18
#include <optional>
19
20
using namespace llvm;
21
22
namespace clang {
23
namespace RISCV {
24
25
const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
26
    BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
27
const PrototypeDescriptor PrototypeDescriptor::VL =
28
    PrototypeDescriptor(BaseTypeModifier::SizeT);
29
const PrototypeDescriptor PrototypeDescriptor::Vector =
30
    PrototypeDescriptor(BaseTypeModifier::Vector);
31
32
//===----------------------------------------------------------------------===//
33
// Type implementation
34
//===----------------------------------------------------------------------===//
35
36
0
LMULType::LMULType(int NewLog2LMUL) {
37
  // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38
0
  assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
39
0
  Log2LMUL = NewLog2LMUL;
40
0
}
41
42
0
std::string LMULType::str() const {
43
0
  if (Log2LMUL < 0)
44
0
    return "mf" + utostr(1ULL << (-Log2LMUL));
45
0
  return "m" + utostr(1ULL << Log2LMUL);
46
0
}
47
48
0
VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
49
0
  int Log2ScaleResult = 0;
50
0
  switch (ElementBitwidth) {
51
0
  default:
52
0
    break;
53
0
  case 8:
54
0
    Log2ScaleResult = Log2LMUL + 3;
55
0
    break;
56
0
  case 16:
57
0
    Log2ScaleResult = Log2LMUL + 2;
58
0
    break;
59
0
  case 32:
60
0
    Log2ScaleResult = Log2LMUL + 1;
61
0
    break;
62
0
  case 64:
63
0
    Log2ScaleResult = Log2LMUL;
64
0
    break;
65
0
  }
66
  // Illegal vscale result would be less than 1
67
0
  if (Log2ScaleResult < 0)
68
0
    return std::nullopt;
69
0
  return 1 << Log2ScaleResult;
70
0
}
71
72
0
void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
73
74
RVVType::RVVType(BasicType BT, int Log2LMUL,
75
                 const PrototypeDescriptor &prototype)
76
0
    : BT(BT), LMUL(LMULType(Log2LMUL)) {
77
0
  applyBasicType();
78
0
  applyModifier(prototype);
79
0
  Valid = verifyType();
80
0
  if (Valid) {
81
0
    initBuiltinStr();
82
0
    initTypeStr();
83
0
    if (isVector()) {
84
0
      initClangBuiltinStr();
85
0
    }
86
0
  }
87
0
}
88
89
// clang-format off
90
// boolean type are encoded the ratio of n (SEW/LMUL)
91
// SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
92
// c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
93
// IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
94
95
// type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
96
// --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
97
// i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
98
// i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
99
// i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
100
// i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
101
// double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
102
// float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
103
// half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
104
// bfloat16  | N/A    | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
105
// clang-format on
106
107
0
bool RVVType::verifyType() const {
108
0
  if (ScalarType == Invalid)
109
0
    return false;
110
0
  if (isScalar())
111
0
    return true;
112
0
  if (!Scale)
113
0
    return false;
114
0
  if (isFloat() && ElementBitwidth == 8)
115
0
    return false;
116
0
  if (isBFloat() && ElementBitwidth != 16)
117
0
    return false;
118
0
  if (IsTuple && (NF == 1 || NF > 8))
119
0
    return false;
120
0
  if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
121
0
    return false;
122
0
  unsigned V = *Scale;
123
0
  switch (ElementBitwidth) {
124
0
  case 1:
125
0
  case 8:
126
    // Check Scale is 1,2,4,8,16,32,64
127
0
    return (V <= 64 && isPowerOf2_32(V));
128
0
  case 16:
129
    // Check Scale is 1,2,4,8,16,32
130
0
    return (V <= 32 && isPowerOf2_32(V));
131
0
  case 32:
132
    // Check Scale is 1,2,4,8,16
133
0
    return (V <= 16 && isPowerOf2_32(V));
134
0
  case 64:
135
    // Check Scale is 1,2,4,8
136
0
    return (V <= 8 && isPowerOf2_32(V));
137
0
  }
138
0
  return false;
139
0
}
140
141
0
void RVVType::initBuiltinStr() {
142
0
  assert(isValid() && "RVVType is invalid");
143
0
  switch (ScalarType) {
144
0
  case ScalarTypeKind::Void:
145
0
    BuiltinStr = "v";
146
0
    return;
147
0
  case ScalarTypeKind::Size_t:
148
0
    BuiltinStr = "z";
149
0
    if (IsImmediate)
150
0
      BuiltinStr = "I" + BuiltinStr;
151
0
    if (IsPointer)
152
0
      BuiltinStr += "*";
153
0
    return;
154
0
  case ScalarTypeKind::Ptrdiff_t:
155
0
    BuiltinStr = "Y";
156
0
    return;
157
0
  case ScalarTypeKind::UnsignedLong:
158
0
    BuiltinStr = "ULi";
159
0
    return;
160
0
  case ScalarTypeKind::SignedLong:
161
0
    BuiltinStr = "Li";
162
0
    return;
163
0
  case ScalarTypeKind::Boolean:
164
0
    assert(ElementBitwidth == 1);
165
0
    BuiltinStr += "b";
166
0
    break;
167
0
  case ScalarTypeKind::SignedInteger:
168
0
  case ScalarTypeKind::UnsignedInteger:
169
0
    switch (ElementBitwidth) {
170
0
    case 8:
171
0
      BuiltinStr += "c";
172
0
      break;
173
0
    case 16:
174
0
      BuiltinStr += "s";
175
0
      break;
176
0
    case 32:
177
0
      BuiltinStr += "i";
178
0
      break;
179
0
    case 64:
180
0
      BuiltinStr += "Wi";
181
0
      break;
182
0
    default:
183
0
      llvm_unreachable("Unhandled ElementBitwidth!");
184
0
    }
185
0
    if (isSignedInteger())
186
0
      BuiltinStr = "S" + BuiltinStr;
187
0
    else
188
0
      BuiltinStr = "U" + BuiltinStr;
189
0
    break;
190
0
  case ScalarTypeKind::Float:
191
0
    switch (ElementBitwidth) {
192
0
    case 16:
193
0
      BuiltinStr += "x";
194
0
      break;
195
0
    case 32:
196
0
      BuiltinStr += "f";
197
0
      break;
198
0
    case 64:
199
0
      BuiltinStr += "d";
200
0
      break;
201
0
    default:
202
0
      llvm_unreachable("Unhandled ElementBitwidth!");
203
0
    }
204
0
    break;
205
0
  case ScalarTypeKind::BFloat:
206
0
    BuiltinStr += "y";
207
0
    break;
208
0
  default:
209
0
    llvm_unreachable("ScalarType is invalid!");
210
0
  }
211
0
  if (IsImmediate)
212
0
    BuiltinStr = "I" + BuiltinStr;
213
0
  if (isScalar()) {
214
0
    if (IsConstant)
215
0
      BuiltinStr += "C";
216
0
    if (IsPointer)
217
0
      BuiltinStr += "*";
218
0
    return;
219
0
  }
220
0
  BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
221
  // Pointer to vector types. Defined for segment load intrinsics.
222
  // segment load intrinsics have pointer type arguments to store the loaded
223
  // vector values.
224
0
  if (IsPointer)
225
0
    BuiltinStr += "*";
226
227
0
  if (IsTuple)
228
0
    BuiltinStr = "T" + utostr(NF) + BuiltinStr;
229
0
}
230
231
0
void RVVType::initClangBuiltinStr() {
232
0
  assert(isValid() && "RVVType is invalid");
233
0
  assert(isVector() && "Handle Vector type only");
234
235
0
  ClangBuiltinStr = "__rvv_";
236
0
  switch (ScalarType) {
237
0
  case ScalarTypeKind::Boolean:
238
0
    ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
239
0
    return;
240
0
  case ScalarTypeKind::Float:
241
0
    ClangBuiltinStr += "float";
242
0
    break;
243
0
  case ScalarTypeKind::BFloat:
244
0
    ClangBuiltinStr += "bfloat";
245
0
    break;
246
0
  case ScalarTypeKind::SignedInteger:
247
0
    ClangBuiltinStr += "int";
248
0
    break;
249
0
  case ScalarTypeKind::UnsignedInteger:
250
0
    ClangBuiltinStr += "uint";
251
0
    break;
252
0
  default:
253
0
    llvm_unreachable("ScalarTypeKind is invalid");
254
0
  }
255
0
  ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() +
256
0
                     (IsTuple ? "x" + utostr(NF) : "") + "_t";
257
0
}
258
259
0
void RVVType::initTypeStr() {
260
0
  assert(isValid() && "RVVType is invalid");
261
262
0
  if (IsConstant)
263
0
    Str += "const ";
264
265
0
  auto getTypeString = [&](StringRef TypeStr) {
266
0
    if (isScalar())
267
0
      return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
268
0
    return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() +
269
0
                 (IsTuple ? "x" + utostr(NF) : "") + "_t")
270
0
        .str();
271
0
  };
272
273
0
  switch (ScalarType) {
274
0
  case ScalarTypeKind::Void:
275
0
    Str = "void";
276
0
    return;
277
0
  case ScalarTypeKind::Size_t:
278
0
    Str = "size_t";
279
0
    if (IsPointer)
280
0
      Str += " *";
281
0
    return;
282
0
  case ScalarTypeKind::Ptrdiff_t:
283
0
    Str = "ptrdiff_t";
284
0
    return;
285
0
  case ScalarTypeKind::UnsignedLong:
286
0
    Str = "unsigned long";
287
0
    return;
288
0
  case ScalarTypeKind::SignedLong:
289
0
    Str = "long";
290
0
    return;
291
0
  case ScalarTypeKind::Boolean:
292
0
    if (isScalar())
293
0
      Str += "bool";
294
0
    else
295
      // Vector bool is special case, the formulate is
296
      // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
297
0
      Str += "vbool" + utostr(64 / *Scale) + "_t";
298
0
    break;
299
0
  case ScalarTypeKind::Float:
300
0
    if (isScalar()) {
301
0
      if (ElementBitwidth == 64)
302
0
        Str += "double";
303
0
      else if (ElementBitwidth == 32)
304
0
        Str += "float";
305
0
      else if (ElementBitwidth == 16)
306
0
        Str += "_Float16";
307
0
      else
308
0
        llvm_unreachable("Unhandled floating type.");
309
0
    } else
310
0
      Str += getTypeString("float");
311
0
    break;
312
0
  case ScalarTypeKind::BFloat:
313
0
    if (isScalar()) {
314
0
      if (ElementBitwidth == 16)
315
0
        Str += "__bf16";
316
0
      else
317
0
        llvm_unreachable("Unhandled floating type.");
318
0
    } else
319
0
      Str += getTypeString("bfloat");
320
0
    break;
321
0
  case ScalarTypeKind::SignedInteger:
322
0
    Str += getTypeString("int");
323
0
    break;
324
0
  case ScalarTypeKind::UnsignedInteger:
325
0
    Str += getTypeString("uint");
326
0
    break;
327
0
  default:
328
0
    llvm_unreachable("ScalarType is invalid!");
329
0
  }
330
0
  if (IsPointer)
331
0
    Str += " *";
332
0
}
333
334
0
void RVVType::initShortStr() {
335
0
  switch (ScalarType) {
336
0
  case ScalarTypeKind::Boolean:
337
0
    assert(isVector());
338
0
    ShortStr = "b" + utostr(64 / *Scale);
339
0
    return;
340
0
  case ScalarTypeKind::Float:
341
0
    ShortStr = "f" + utostr(ElementBitwidth);
342
0
    break;
343
0
  case ScalarTypeKind::BFloat:
344
0
    ShortStr = "bf" + utostr(ElementBitwidth);
345
0
    break;
346
0
  case ScalarTypeKind::SignedInteger:
347
0
    ShortStr = "i" + utostr(ElementBitwidth);
348
0
    break;
349
0
  case ScalarTypeKind::UnsignedInteger:
350
0
    ShortStr = "u" + utostr(ElementBitwidth);
351
0
    break;
352
0
  default:
353
0
    llvm_unreachable("Unhandled case!");
354
0
  }
355
0
  if (isVector())
356
0
    ShortStr += LMUL.str();
357
0
  if (isTuple())
358
0
    ShortStr += "x" + utostr(NF);
359
0
}
360
361
0
static VectorTypeModifier getTupleVTM(unsigned NF) {
362
0
  assert(2 <= NF && NF <= 8 && "2 <= NF <= 8");
363
0
  return static_cast<VectorTypeModifier>(
364
0
      static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2));
365
0
}
366
367
0
void RVVType::applyBasicType() {
368
0
  switch (BT) {
369
0
  case BasicType::Int8:
370
0
    ElementBitwidth = 8;
371
0
    ScalarType = ScalarTypeKind::SignedInteger;
372
0
    break;
373
0
  case BasicType::Int16:
374
0
    ElementBitwidth = 16;
375
0
    ScalarType = ScalarTypeKind::SignedInteger;
376
0
    break;
377
0
  case BasicType::Int32:
378
0
    ElementBitwidth = 32;
379
0
    ScalarType = ScalarTypeKind::SignedInteger;
380
0
    break;
381
0
  case BasicType::Int64:
382
0
    ElementBitwidth = 64;
383
0
    ScalarType = ScalarTypeKind::SignedInteger;
384
0
    break;
385
0
  case BasicType::Float16:
386
0
    ElementBitwidth = 16;
387
0
    ScalarType = ScalarTypeKind::Float;
388
0
    break;
389
0
  case BasicType::Float32:
390
0
    ElementBitwidth = 32;
391
0
    ScalarType = ScalarTypeKind::Float;
392
0
    break;
393
0
  case BasicType::Float64:
394
0
    ElementBitwidth = 64;
395
0
    ScalarType = ScalarTypeKind::Float;
396
0
    break;
397
0
  case BasicType::BFloat16:
398
0
    ElementBitwidth = 16;
399
0
    ScalarType = ScalarTypeKind::BFloat;
400
0
    break;
401
0
  default:
402
0
    llvm_unreachable("Unhandled type code!");
403
0
  }
404
0
  assert(ElementBitwidth != 0 && "Bad element bitwidth!");
405
0
}
406
407
std::optional<PrototypeDescriptor>
408
PrototypeDescriptor::parsePrototypeDescriptor(
409
0
    llvm::StringRef PrototypeDescriptorStr) {
410
0
  PrototypeDescriptor PD;
411
0
  BaseTypeModifier PT = BaseTypeModifier::Invalid;
412
0
  VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
413
414
0
  if (PrototypeDescriptorStr.empty())
415
0
    return PD;
416
417
  // Handle base type modifier
418
0
  auto PType = PrototypeDescriptorStr.back();
419
0
  switch (PType) {
420
0
  case 'e':
421
0
    PT = BaseTypeModifier::Scalar;
422
0
    break;
423
0
  case 'v':
424
0
    PT = BaseTypeModifier::Vector;
425
0
    break;
426
0
  case 'w':
427
0
    PT = BaseTypeModifier::Vector;
428
0
    VTM = VectorTypeModifier::Widening2XVector;
429
0
    break;
430
0
  case 'q':
431
0
    PT = BaseTypeModifier::Vector;
432
0
    VTM = VectorTypeModifier::Widening4XVector;
433
0
    break;
434
0
  case 'o':
435
0
    PT = BaseTypeModifier::Vector;
436
0
    VTM = VectorTypeModifier::Widening8XVector;
437
0
    break;
438
0
  case 'm':
439
0
    PT = BaseTypeModifier::Vector;
440
0
    VTM = VectorTypeModifier::MaskVector;
441
0
    break;
442
0
  case '0':
443
0
    PT = BaseTypeModifier::Void;
444
0
    break;
445
0
  case 'z':
446
0
    PT = BaseTypeModifier::SizeT;
447
0
    break;
448
0
  case 't':
449
0
    PT = BaseTypeModifier::Ptrdiff;
450
0
    break;
451
0
  case 'u':
452
0
    PT = BaseTypeModifier::UnsignedLong;
453
0
    break;
454
0
  case 'l':
455
0
    PT = BaseTypeModifier::SignedLong;
456
0
    break;
457
0
  case 'f':
458
0
    PT = BaseTypeModifier::Float32;
459
0
    break;
460
0
  default:
461
0
    llvm_unreachable("Illegal primitive type transformers!");
462
0
  }
463
0
  PD.PT = static_cast<uint8_t>(PT);
464
0
  PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
465
466
  // Compute the vector type transformers, it can only appear one time.
467
0
  if (PrototypeDescriptorStr.starts_with("(")) {
468
0
    assert(VTM == VectorTypeModifier::NoModifier &&
469
0
           "VectorTypeModifier should only have one modifier");
470
0
    size_t Idx = PrototypeDescriptorStr.find(')');
471
0
    assert(Idx != StringRef::npos);
472
0
    StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
473
0
    PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
474
0
    assert(!PrototypeDescriptorStr.contains('(') &&
475
0
           "Only allow one vector type modifier");
476
477
0
    auto ComplexTT = ComplexType.split(":");
478
0
    if (ComplexTT.first == "Log2EEW") {
479
0
      uint32_t Log2EEW;
480
0
      if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
481
0
        llvm_unreachable("Invalid Log2EEW value!");
482
0
        return std::nullopt;
483
0
      }
484
0
      switch (Log2EEW) {
485
0
      case 3:
486
0
        VTM = VectorTypeModifier::Log2EEW3;
487
0
        break;
488
0
      case 4:
489
0
        VTM = VectorTypeModifier::Log2EEW4;
490
0
        break;
491
0
      case 5:
492
0
        VTM = VectorTypeModifier::Log2EEW5;
493
0
        break;
494
0
      case 6:
495
0
        VTM = VectorTypeModifier::Log2EEW6;
496
0
        break;
497
0
      default:
498
0
        llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
499
0
        return std::nullopt;
500
0
      }
501
0
    } else if (ComplexTT.first == "FixedSEW") {
502
0
      uint32_t NewSEW;
503
0
      if (ComplexTT.second.getAsInteger(10, NewSEW)) {
504
0
        llvm_unreachable("Invalid FixedSEW value!");
505
0
        return std::nullopt;
506
0
      }
507
0
      switch (NewSEW) {
508
0
      case 8:
509
0
        VTM = VectorTypeModifier::FixedSEW8;
510
0
        break;
511
0
      case 16:
512
0
        VTM = VectorTypeModifier::FixedSEW16;
513
0
        break;
514
0
      case 32:
515
0
        VTM = VectorTypeModifier::FixedSEW32;
516
0
        break;
517
0
      case 64:
518
0
        VTM = VectorTypeModifier::FixedSEW64;
519
0
        break;
520
0
      default:
521
0
        llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
522
0
        return std::nullopt;
523
0
      }
524
0
    } else if (ComplexTT.first == "LFixedLog2LMUL") {
525
0
      int32_t Log2LMUL;
526
0
      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
527
0
        llvm_unreachable("Invalid LFixedLog2LMUL value!");
528
0
        return std::nullopt;
529
0
      }
530
0
      switch (Log2LMUL) {
531
0
      case -3:
532
0
        VTM = VectorTypeModifier::LFixedLog2LMULN3;
533
0
        break;
534
0
      case -2:
535
0
        VTM = VectorTypeModifier::LFixedLog2LMULN2;
536
0
        break;
537
0
      case -1:
538
0
        VTM = VectorTypeModifier::LFixedLog2LMULN1;
539
0
        break;
540
0
      case 0:
541
0
        VTM = VectorTypeModifier::LFixedLog2LMUL0;
542
0
        break;
543
0
      case 1:
544
0
        VTM = VectorTypeModifier::LFixedLog2LMUL1;
545
0
        break;
546
0
      case 2:
547
0
        VTM = VectorTypeModifier::LFixedLog2LMUL2;
548
0
        break;
549
0
      case 3:
550
0
        VTM = VectorTypeModifier::LFixedLog2LMUL3;
551
0
        break;
552
0
      default:
553
0
        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
554
0
        return std::nullopt;
555
0
      }
556
0
    } else if (ComplexTT.first == "SFixedLog2LMUL") {
557
0
      int32_t Log2LMUL;
558
0
      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
559
0
        llvm_unreachable("Invalid SFixedLog2LMUL value!");
560
0
        return std::nullopt;
561
0
      }
562
0
      switch (Log2LMUL) {
563
0
      case -3:
564
0
        VTM = VectorTypeModifier::SFixedLog2LMULN3;
565
0
        break;
566
0
      case -2:
567
0
        VTM = VectorTypeModifier::SFixedLog2LMULN2;
568
0
        break;
569
0
      case -1:
570
0
        VTM = VectorTypeModifier::SFixedLog2LMULN1;
571
0
        break;
572
0
      case 0:
573
0
        VTM = VectorTypeModifier::SFixedLog2LMUL0;
574
0
        break;
575
0
      case 1:
576
0
        VTM = VectorTypeModifier::SFixedLog2LMUL1;
577
0
        break;
578
0
      case 2:
579
0
        VTM = VectorTypeModifier::SFixedLog2LMUL2;
580
0
        break;
581
0
      case 3:
582
0
        VTM = VectorTypeModifier::SFixedLog2LMUL3;
583
0
        break;
584
0
      default:
585
0
        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
586
0
        return std::nullopt;
587
0
      }
588
589
0
    } else if (ComplexTT.first == "SEFixedLog2LMUL") {
590
0
      int32_t Log2LMUL;
591
0
      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
592
0
        llvm_unreachable("Invalid SEFixedLog2LMUL value!");
593
0
        return std::nullopt;
594
0
      }
595
0
      switch (Log2LMUL) {
596
0
      case -3:
597
0
        VTM = VectorTypeModifier::SEFixedLog2LMULN3;
598
0
        break;
599
0
      case -2:
600
0
        VTM = VectorTypeModifier::SEFixedLog2LMULN2;
601
0
        break;
602
0
      case -1:
603
0
        VTM = VectorTypeModifier::SEFixedLog2LMULN1;
604
0
        break;
605
0
      case 0:
606
0
        VTM = VectorTypeModifier::SEFixedLog2LMUL0;
607
0
        break;
608
0
      case 1:
609
0
        VTM = VectorTypeModifier::SEFixedLog2LMUL1;
610
0
        break;
611
0
      case 2:
612
0
        VTM = VectorTypeModifier::SEFixedLog2LMUL2;
613
0
        break;
614
0
      case 3:
615
0
        VTM = VectorTypeModifier::SEFixedLog2LMUL3;
616
0
        break;
617
0
      default:
618
0
        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
619
0
        return std::nullopt;
620
0
      }
621
0
    } else if (ComplexTT.first == "Tuple") {
622
0
      unsigned NF = 0;
623
0
      if (ComplexTT.second.getAsInteger(10, NF)) {
624
0
        llvm_unreachable("Invalid NF value!");
625
0
        return std::nullopt;
626
0
      }
627
0
      VTM = getTupleVTM(NF);
628
0
    } else {
629
0
      llvm_unreachable("Illegal complex type transformers!");
630
0
    }
631
0
  }
632
0
  PD.VTM = static_cast<uint8_t>(VTM);
633
634
  // Compute the remain type transformers
635
0
  TypeModifier TM = TypeModifier::NoModifier;
636
0
  for (char I : PrototypeDescriptorStr) {
637
0
    switch (I) {
638
0
    case 'P':
639
0
      if ((TM & TypeModifier::Const) == TypeModifier::Const)
640
0
        llvm_unreachable("'P' transformer cannot be used after 'C'");
641
0
      if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
642
0
        llvm_unreachable("'P' transformer cannot be used twice");
643
0
      TM |= TypeModifier::Pointer;
644
0
      break;
645
0
    case 'C':
646
0
      TM |= TypeModifier::Const;
647
0
      break;
648
0
    case 'K':
649
0
      TM |= TypeModifier::Immediate;
650
0
      break;
651
0
    case 'U':
652
0
      TM |= TypeModifier::UnsignedInteger;
653
0
      break;
654
0
    case 'I':
655
0
      TM |= TypeModifier::SignedInteger;
656
0
      break;
657
0
    case 'F':
658
0
      TM |= TypeModifier::Float;
659
0
      break;
660
0
    case 'S':
661
0
      TM |= TypeModifier::LMUL1;
662
0
      break;
663
0
    default:
664
0
      llvm_unreachable("Illegal non-primitive type transformer!");
665
0
    }
666
0
  }
667
0
  PD.TM = static_cast<uint8_t>(TM);
668
669
0
  return PD;
670
0
}
671
672
0
void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
673
  // Handle primitive type transformer
674
0
  switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
675
0
  case BaseTypeModifier::Scalar:
676
0
    Scale = 0;
677
0
    break;
678
0
  case BaseTypeModifier::Vector:
679
0
    Scale = LMUL.getScale(ElementBitwidth);
680
0
    break;
681
0
  case BaseTypeModifier::Void:
682
0
    ScalarType = ScalarTypeKind::Void;
683
0
    break;
684
0
  case BaseTypeModifier::SizeT:
685
0
    ScalarType = ScalarTypeKind::Size_t;
686
0
    break;
687
0
  case BaseTypeModifier::Ptrdiff:
688
0
    ScalarType = ScalarTypeKind::Ptrdiff_t;
689
0
    break;
690
0
  case BaseTypeModifier::UnsignedLong:
691
0
    ScalarType = ScalarTypeKind::UnsignedLong;
692
0
    break;
693
0
  case BaseTypeModifier::SignedLong:
694
0
    ScalarType = ScalarTypeKind::SignedLong;
695
0
    break;
696
0
  case BaseTypeModifier::Float32:
697
0
    ElementBitwidth = 32;
698
0
    ScalarType = ScalarTypeKind::Float;
699
0
    break;
700
0
  case BaseTypeModifier::Invalid:
701
0
    ScalarType = ScalarTypeKind::Invalid;
702
0
    return;
703
0
  }
704
705
0
  switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
706
0
  case VectorTypeModifier::Widening2XVector:
707
0
    ElementBitwidth *= 2;
708
0
    LMUL.MulLog2LMUL(1);
709
0
    Scale = LMUL.getScale(ElementBitwidth);
710
0
    break;
711
0
  case VectorTypeModifier::Widening4XVector:
712
0
    ElementBitwidth *= 4;
713
0
    LMUL.MulLog2LMUL(2);
714
0
    Scale = LMUL.getScale(ElementBitwidth);
715
0
    break;
716
0
  case VectorTypeModifier::Widening8XVector:
717
0
    ElementBitwidth *= 8;
718
0
    LMUL.MulLog2LMUL(3);
719
0
    Scale = LMUL.getScale(ElementBitwidth);
720
0
    break;
721
0
  case VectorTypeModifier::MaskVector:
722
0
    ScalarType = ScalarTypeKind::Boolean;
723
0
    Scale = LMUL.getScale(ElementBitwidth);
724
0
    ElementBitwidth = 1;
725
0
    break;
726
0
  case VectorTypeModifier::Log2EEW3:
727
0
    applyLog2EEW(3);
728
0
    break;
729
0
  case VectorTypeModifier::Log2EEW4:
730
0
    applyLog2EEW(4);
731
0
    break;
732
0
  case VectorTypeModifier::Log2EEW5:
733
0
    applyLog2EEW(5);
734
0
    break;
735
0
  case VectorTypeModifier::Log2EEW6:
736
0
    applyLog2EEW(6);
737
0
    break;
738
0
  case VectorTypeModifier::FixedSEW8:
739
0
    applyFixedSEW(8);
740
0
    break;
741
0
  case VectorTypeModifier::FixedSEW16:
742
0
    applyFixedSEW(16);
743
0
    break;
744
0
  case VectorTypeModifier::FixedSEW32:
745
0
    applyFixedSEW(32);
746
0
    break;
747
0
  case VectorTypeModifier::FixedSEW64:
748
0
    applyFixedSEW(64);
749
0
    break;
750
0
  case VectorTypeModifier::LFixedLog2LMULN3:
751
0
    applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
752
0
    break;
753
0
  case VectorTypeModifier::LFixedLog2LMULN2:
754
0
    applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
755
0
    break;
756
0
  case VectorTypeModifier::LFixedLog2LMULN1:
757
0
    applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
758
0
    break;
759
0
  case VectorTypeModifier::LFixedLog2LMUL0:
760
0
    applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
761
0
    break;
762
0
  case VectorTypeModifier::LFixedLog2LMUL1:
763
0
    applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
764
0
    break;
765
0
  case VectorTypeModifier::LFixedLog2LMUL2:
766
0
    applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
767
0
    break;
768
0
  case VectorTypeModifier::LFixedLog2LMUL3:
769
0
    applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
770
0
    break;
771
0
  case VectorTypeModifier::SFixedLog2LMULN3:
772
0
    applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
773
0
    break;
774
0
  case VectorTypeModifier::SFixedLog2LMULN2:
775
0
    applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
776
0
    break;
777
0
  case VectorTypeModifier::SFixedLog2LMULN1:
778
0
    applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
779
0
    break;
780
0
  case VectorTypeModifier::SFixedLog2LMUL0:
781
0
    applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
782
0
    break;
783
0
  case VectorTypeModifier::SFixedLog2LMUL1:
784
0
    applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
785
0
    break;
786
0
  case VectorTypeModifier::SFixedLog2LMUL2:
787
0
    applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
788
0
    break;
789
0
  case VectorTypeModifier::SFixedLog2LMUL3:
790
0
    applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
791
0
    break;
792
0
  case VectorTypeModifier::SEFixedLog2LMULN3:
793
0
    applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual);
794
0
    break;
795
0
  case VectorTypeModifier::SEFixedLog2LMULN2:
796
0
    applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual);
797
0
    break;
798
0
  case VectorTypeModifier::SEFixedLog2LMULN1:
799
0
    applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual);
800
0
    break;
801
0
  case VectorTypeModifier::SEFixedLog2LMUL0:
802
0
    applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual);
803
0
    break;
804
0
  case VectorTypeModifier::SEFixedLog2LMUL1:
805
0
    applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual);
806
0
    break;
807
0
  case VectorTypeModifier::SEFixedLog2LMUL2:
808
0
    applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual);
809
0
    break;
810
0
  case VectorTypeModifier::SEFixedLog2LMUL3:
811
0
    applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual);
812
0
    break;
813
0
  case VectorTypeModifier::Tuple2:
814
0
  case VectorTypeModifier::Tuple3:
815
0
  case VectorTypeModifier::Tuple4:
816
0
  case VectorTypeModifier::Tuple5:
817
0
  case VectorTypeModifier::Tuple6:
818
0
  case VectorTypeModifier::Tuple7:
819
0
  case VectorTypeModifier::Tuple8: {
820
0
    IsTuple = true;
821
0
    NF = 2 + static_cast<uint8_t>(Transformer.VTM) -
822
0
         static_cast<uint8_t>(VectorTypeModifier::Tuple2);
823
0
    break;
824
0
  }
825
0
  case VectorTypeModifier::NoModifier:
826
0
    break;
827
0
  }
828
829
  // Early return if the current type modifier is already invalid.
830
0
  if (ScalarType == Invalid)
831
0
    return;
832
833
0
  for (unsigned TypeModifierMaskShift = 0;
834
0
       TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
835
0
       ++TypeModifierMaskShift) {
836
0
    unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
837
0
    if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
838
0
        TypeModifierMask)
839
0
      continue;
840
0
    switch (static_cast<TypeModifier>(TypeModifierMask)) {
841
0
    case TypeModifier::Pointer:
842
0
      IsPointer = true;
843
0
      break;
844
0
    case TypeModifier::Const:
845
0
      IsConstant = true;
846
0
      break;
847
0
    case TypeModifier::Immediate:
848
0
      IsImmediate = true;
849
0
      IsConstant = true;
850
0
      break;
851
0
    case TypeModifier::UnsignedInteger:
852
0
      ScalarType = ScalarTypeKind::UnsignedInteger;
853
0
      break;
854
0
    case TypeModifier::SignedInteger:
855
0
      ScalarType = ScalarTypeKind::SignedInteger;
856
0
      break;
857
0
    case TypeModifier::Float:
858
0
      ScalarType = ScalarTypeKind::Float;
859
0
      break;
860
0
    case TypeModifier::BFloat:
861
0
      ScalarType = ScalarTypeKind::BFloat;
862
0
      break;
863
0
    case TypeModifier::LMUL1:
864
0
      LMUL = LMULType(0);
865
      // Update ElementBitwidth need to update Scale too.
866
0
      Scale = LMUL.getScale(ElementBitwidth);
867
0
      break;
868
0
    default:
869
0
      llvm_unreachable("Unknown type modifier mask!");
870
0
    }
871
0
  }
872
0
}
873
874
0
void RVVType::applyLog2EEW(unsigned Log2EEW) {
875
  // update new elmul = (eew/sew) * lmul
876
0
  LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
877
  // update new eew
878
0
  ElementBitwidth = 1 << Log2EEW;
879
0
  ScalarType = ScalarTypeKind::SignedInteger;
880
0
  Scale = LMUL.getScale(ElementBitwidth);
881
0
}
882
883
0
void RVVType::applyFixedSEW(unsigned NewSEW) {
884
  // Set invalid type if src and dst SEW are same.
885
0
  if (ElementBitwidth == NewSEW) {
886
0
    ScalarType = ScalarTypeKind::Invalid;
887
0
    return;
888
0
  }
889
  // Update new SEW
890
0
  ElementBitwidth = NewSEW;
891
0
  Scale = LMUL.getScale(ElementBitwidth);
892
0
}
893
894
0
void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
895
0
  switch (Type) {
896
0
  case FixedLMULType::LargerThan:
897
0
    if (Log2LMUL <= LMUL.Log2LMUL) {
898
0
      ScalarType = ScalarTypeKind::Invalid;
899
0
      return;
900
0
    }
901
0
    break;
902
0
  case FixedLMULType::SmallerThan:
903
0
    if (Log2LMUL >= LMUL.Log2LMUL) {
904
0
      ScalarType = ScalarTypeKind::Invalid;
905
0
      return;
906
0
    }
907
0
    break;
908
0
  case FixedLMULType::SmallerOrEqual:
909
0
    if (Log2LMUL > LMUL.Log2LMUL) {
910
0
      ScalarType = ScalarTypeKind::Invalid;
911
0
      return;
912
0
    }
913
0
    break;
914
0
  }
915
916
  // Update new LMUL
917
0
  LMUL = LMULType(Log2LMUL);
918
0
  Scale = LMUL.getScale(ElementBitwidth);
919
0
}
920
921
std::optional<RVVTypes>
922
RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
923
0
                           ArrayRef<PrototypeDescriptor> Prototype) {
924
0
  RVVTypes Types;
925
0
  for (const PrototypeDescriptor &Proto : Prototype) {
926
0
    auto T = computeType(BT, Log2LMUL, Proto);
927
0
    if (!T)
928
0
      return std::nullopt;
929
    // Record legal type index
930
0
    Types.push_back(*T);
931
0
  }
932
0
  return Types;
933
0
}
934
935
// Compute the hash value of RVVType, used for cache the result of computeType.
936
static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
937
0
                                        PrototypeDescriptor Proto) {
938
  // Layout of hash value:
939
  // 0               8    16          24        32          40
940
  // | Log2LMUL + 3  | BT  | Proto.PT | Proto.TM | Proto.VTM |
941
0
  assert(Log2LMUL >= -3 && Log2LMUL <= 3);
942
0
  return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
943
0
         ((uint64_t)(Proto.PT & 0xff) << 16) |
944
0
         ((uint64_t)(Proto.TM & 0xff) << 24) |
945
0
         ((uint64_t)(Proto.VTM & 0xff) << 32);
946
0
}
947
948
std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
949
0
                                                    PrototypeDescriptor Proto) {
950
0
  uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
951
  // Search first
952
0
  auto It = LegalTypes.find(Idx);
953
0
  if (It != LegalTypes.end())
954
0
    return &(It->second);
955
956
0
  if (IllegalTypes.count(Idx))
957
0
    return std::nullopt;
958
959
  // Compute type and record the result.
960
0
  RVVType T(BT, Log2LMUL, Proto);
961
0
  if (T.isValid()) {
962
    // Record legal type index and value.
963
0
    std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
964
0
        InsertResult = LegalTypes.insert({Idx, T});
965
0
    return &(InsertResult.first->second);
966
0
  }
967
  // Record illegal type index.
968
0
  IllegalTypes.insert(Idx);
969
0
  return std::nullopt;
970
0
}
971
972
//===----------------------------------------------------------------------===//
973
// RVVIntrinsic implementation
974
//===----------------------------------------------------------------------===//
975
RVVIntrinsic::RVVIntrinsic(
976
    StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
977
    StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
978
    bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
979
    bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
980
    const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
981
    const std::vector<StringRef> &RequiredFeatures, unsigned NF,
982
    Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
983
    : IRName(IRName), IsMasked(IsMasked),
984
      HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
985
      SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
986
0
      ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
987
988
  // Init BuiltinName, Name and OverloadedName
989
0
  BuiltinName = NewName.str();
990
0
  Name = BuiltinName;
991
0
  if (NewOverloadedName.empty())
992
0
    OverloadedName = NewName.split("_").first.str();
993
0
  else
994
0
    OverloadedName = NewOverloadedName.str();
995
0
  if (!Suffix.empty())
996
0
    Name += "_" + Suffix.str();
997
0
  if (!OverloadedSuffix.empty())
998
0
    OverloadedName += "_" + OverloadedSuffix.str();
999
1000
0
  updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
1001
0
                       PolicyAttrs, HasFRMRoundModeOp);
1002
1003
  // Init OutputType and InputTypes
1004
0
  OutputType = OutInTypes[0];
1005
0
  InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
1006
1007
  // IntrinsicTypes is unmasked TA version index. Need to update it
1008
  // if there is merge operand (It is always in first operand).
1009
0
  IntrinsicTypes = NewIntrinsicTypes;
1010
0
  if ((IsMasked && hasMaskedOffOperand()) ||
1011
0
      (!IsMasked && hasPassthruOperand())) {
1012
0
    for (auto &I : IntrinsicTypes) {
1013
0
      if (I >= 0)
1014
0
        I += NF;
1015
0
    }
1016
0
  }
1017
0
}
1018
1019
0
std::string RVVIntrinsic::getBuiltinTypeStr() const {
1020
0
  std::string S;
1021
0
  S += OutputType->getBuiltinStr();
1022
0
  for (const auto &T : InputTypes) {
1023
0
    S += T->getBuiltinStr();
1024
0
  }
1025
0
  return S;
1026
0
}
1027
1028
std::string RVVIntrinsic::getSuffixStr(
1029
    RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
1030
0
    llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
1031
0
  SmallVector<std::string> SuffixStrs;
1032
0
  for (auto PD : PrototypeDescriptors) {
1033
0
    auto T = TypeCache.computeType(Type, Log2LMUL, PD);
1034
0
    SuffixStrs.push_back((*T)->getShortStr());
1035
0
  }
1036
0
  return join(SuffixStrs, "_");
1037
0
}
1038
1039
llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
1040
    llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
1041
    bool HasMaskedOffOperand, bool HasVL, unsigned NF,
1042
0
    PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) {
1043
0
  SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
1044
0
                                                Prototype.end());
1045
0
  bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
1046
0
  if (IsMasked) {
1047
    // If HasMaskedOffOperand, insert result type as first input operand if
1048
    // need.
1049
0
    if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
1050
0
      if (NF == 1) {
1051
0
        NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
1052
0
      } else if (NF > 1) {
1053
0
        if (IsTuple) {
1054
0
          PrototypeDescriptor BasePtrOperand = Prototype[1];
1055
0
          PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1056
0
              static_cast<uint8_t>(BaseTypeModifier::Vector),
1057
0
              static_cast<uint8_t>(getTupleVTM(NF)),
1058
0
              BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1059
0
          NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType);
1060
0
        } else {
1061
          // Convert
1062
          // (void, op0 address, op1 address, ...)
1063
          // to
1064
          // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1065
0
          PrototypeDescriptor MaskoffType = NewPrototype[1];
1066
0
          MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1067
0
          NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1068
0
        }
1069
0
      }
1070
0
    }
1071
0
    if (HasMaskedOffOperand && NF > 1) {
1072
      // Convert
1073
      // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1074
      // to
1075
      // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1076
      // ...)
1077
0
      if (IsTuple)
1078
0
        NewPrototype.insert(NewPrototype.begin() + 1,
1079
0
                            PrototypeDescriptor::Mask);
1080
0
      else
1081
0
        NewPrototype.insert(NewPrototype.begin() + NF + 1,
1082
0
                            PrototypeDescriptor::Mask);
1083
0
    } else {
1084
      // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1085
0
      NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
1086
0
    }
1087
0
  } else {
1088
0
    if (NF == 1) {
1089
0
      if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
1090
0
        NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
1091
0
    } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
1092
0
      if (IsTuple) {
1093
0
        PrototypeDescriptor BasePtrOperand = Prototype[0];
1094
0
        PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1095
0
            static_cast<uint8_t>(BaseTypeModifier::Vector),
1096
0
            static_cast<uint8_t>(getTupleVTM(NF)),
1097
0
            BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1098
0
        NewPrototype.insert(NewPrototype.begin(), MaskoffType);
1099
0
      } else {
1100
        // NF > 1 cases for segment load operations.
1101
        // Convert
1102
        // (void, op0 address, op1 address, ...)
1103
        // to
1104
        // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1105
0
        PrototypeDescriptor MaskoffType = Prototype[1];
1106
0
        MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1107
0
        NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1108
0
      }
1109
0
    }
1110
0
 }
1111
1112
  // If HasVL, append PrototypeDescriptor:VL to last operand
1113
0
  if (HasVL)
1114
0
    NewPrototype.push_back(PrototypeDescriptor::VL);
1115
1116
0
  return NewPrototype;
1117
0
}
1118
1119
0
llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1120
0
  return {Policy(Policy::PolicyType::Undisturbed)}; // TU
1121
0
}
1122
1123
llvm::SmallVector<Policy>
1124
RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
1125
0
                                         bool HasMaskPolicy) {
1126
0
  if (HasTailPolicy && HasMaskPolicy)
1127
0
    return {Policy(Policy::PolicyType::Undisturbed,
1128
0
                   Policy::PolicyType::Agnostic), // TUM
1129
0
            Policy(Policy::PolicyType::Undisturbed,
1130
0
                   Policy::PolicyType::Undisturbed), // TUMU
1131
0
            Policy(Policy::PolicyType::Agnostic,
1132
0
                   Policy::PolicyType::Undisturbed)}; // MU
1133
0
  if (HasTailPolicy && !HasMaskPolicy)
1134
0
    return {Policy(Policy::PolicyType::Undisturbed,
1135
0
                   Policy::PolicyType::Agnostic)}; // TU
1136
0
  if (!HasTailPolicy && HasMaskPolicy)
1137
0
    return {Policy(Policy::PolicyType::Agnostic,
1138
0
                   Policy::PolicyType::Undisturbed)}; // MU
1139
0
  llvm_unreachable("An RVV instruction should not be without both tail policy "
1140
0
                   "and mask policy");
1141
0
}
1142
1143
void RVVIntrinsic::updateNamesAndPolicy(
1144
    bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
1145
0
    std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
1146
1147
0
  auto appendPolicySuffix = [&](const std::string &suffix) {
1148
0
    Name += suffix;
1149
0
    BuiltinName += suffix;
1150
0
    OverloadedName += suffix;
1151
0
  };
1152
1153
  // This follows the naming guideline under riscv-c-api-doc to add the
1154
  // `__riscv_` suffix for all RVV intrinsics.
1155
0
  Name = "__riscv_" + Name;
1156
0
  OverloadedName = "__riscv_" + OverloadedName;
1157
1158
0
  if (HasFRMRoundModeOp) {
1159
0
    Name += "_rm";
1160
0
    BuiltinName += "_rm";
1161
0
  }
1162
1163
0
  if (IsMasked) {
1164
0
    if (PolicyAttrs.isTUMUPolicy())
1165
0
      appendPolicySuffix("_tumu");
1166
0
    else if (PolicyAttrs.isTUMAPolicy())
1167
0
      appendPolicySuffix("_tum");
1168
0
    else if (PolicyAttrs.isTAMUPolicy())
1169
0
      appendPolicySuffix("_mu");
1170
0
    else if (PolicyAttrs.isTAMAPolicy()) {
1171
0
      Name += "_m";
1172
0
      BuiltinName += "_m";
1173
0
    } else
1174
0
      llvm_unreachable("Unhandled policy condition");
1175
0
  } else {
1176
0
    if (PolicyAttrs.isTUPolicy())
1177
0
      appendPolicySuffix("_tu");
1178
0
    else if (PolicyAttrs.isTAPolicy()) // no suffix needed
1179
0
      return;
1180
0
    else
1181
0
      llvm_unreachable("Unhandled policy condition");
1182
0
  }
1183
0
}
1184
1185
0
SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1186
0
  SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1187
0
  const StringRef Primaries("evwqom0ztulf");
1188
0
  while (!Prototypes.empty()) {
1189
0
    size_t Idx = 0;
1190
    // Skip over complex prototype because it could contain primitive type
1191
    // character.
1192
0
    if (Prototypes[0] == '(')
1193
0
      Idx = Prototypes.find_first_of(')');
1194
0
    Idx = Prototypes.find_first_of(Primaries, Idx);
1195
0
    assert(Idx != StringRef::npos);
1196
0
    auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1197
0
        Prototypes.slice(0, Idx + 1));
1198
0
    if (!PD)
1199
0
      llvm_unreachable("Error during parsing prototype.");
1200
0
    PrototypeDescriptors.push_back(*PD);
1201
0
    Prototypes = Prototypes.drop_front(Idx + 1);
1202
0
  }
1203
0
  return PrototypeDescriptors;
1204
0
}
1205
1206
0
raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1207
0
  OS << "{";
1208
0
  OS << "\"" << Record.Name << "\",";
1209
0
  if (Record.OverloadedName == nullptr ||
1210
0
      StringRef(Record.OverloadedName).empty())
1211
0
    OS << "nullptr,";
1212
0
  else
1213
0
    OS << "\"" << Record.OverloadedName << "\",";
1214
0
  OS << Record.PrototypeIndex << ",";
1215
0
  OS << Record.SuffixIndex << ",";
1216
0
  OS << Record.OverloadedSuffixIndex << ",";
1217
0
  OS << (int)Record.PrototypeLength << ",";
1218
0
  OS << (int)Record.SuffixLength << ",";
1219
0
  OS << (int)Record.OverloadedSuffixSize << ",";
1220
0
  OS << Record.RequiredExtensions << ",";
1221
0
  OS << (int)Record.TypeRangeMask << ",";
1222
0
  OS << (int)Record.Log2LMULMask << ",";
1223
0
  OS << (int)Record.NF << ",";
1224
0
  OS << (int)Record.HasMasked << ",";
1225
0
  OS << (int)Record.HasVL << ",";
1226
0
  OS << (int)Record.HasMaskedOffOperand << ",";
1227
0
  OS << (int)Record.HasTailPolicy << ",";
1228
0
  OS << (int)Record.HasMaskPolicy << ",";
1229
0
  OS << (int)Record.HasFRMRoundModeOp << ",";
1230
0
  OS << (int)Record.IsTuple << ",";
1231
0
  OS << (int)Record.UnMaskedPolicyScheme << ",";
1232
0
  OS << (int)Record.MaskedPolicyScheme << ",";
1233
0
  OS << "},\n";
1234
0
  return OS;
1235
0
}
1236
1237
} // end namespace RISCV
1238
} // end namespace clang