/src/llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "clang/Support/RISCVVIntrinsicUtils.h" |
10 | | #include "llvm/ADT/ArrayRef.h" |
11 | | #include "llvm/ADT/SmallSet.h" |
12 | | #include "llvm/ADT/StringExtras.h" |
13 | | #include "llvm/ADT/StringSet.h" |
14 | | #include "llvm/ADT/Twine.h" |
15 | | #include "llvm/Support/ErrorHandling.h" |
16 | | #include "llvm/Support/raw_ostream.h" |
17 | | #include <numeric> |
18 | | #include <optional> |
19 | | |
20 | | using namespace llvm; |
21 | | |
22 | | namespace clang { |
23 | | namespace RISCV { |
24 | | |
25 | | const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( |
26 | | BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); |
27 | | const PrototypeDescriptor PrototypeDescriptor::VL = |
28 | | PrototypeDescriptor(BaseTypeModifier::SizeT); |
29 | | const PrototypeDescriptor PrototypeDescriptor::Vector = |
30 | | PrototypeDescriptor(BaseTypeModifier::Vector); |
31 | | |
32 | | //===----------------------------------------------------------------------===// |
33 | | // Type implementation |
34 | | //===----------------------------------------------------------------------===// |
35 | | |
36 | 0 | LMULType::LMULType(int NewLog2LMUL) { |
37 | | // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 |
38 | 0 | assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); |
39 | 0 | Log2LMUL = NewLog2LMUL; |
40 | 0 | } |
41 | | |
42 | 0 | std::string LMULType::str() const { |
43 | 0 | if (Log2LMUL < 0) |
44 | 0 | return "mf" + utostr(1ULL << (-Log2LMUL)); |
45 | 0 | return "m" + utostr(1ULL << Log2LMUL); |
46 | 0 | } |
47 | | |
48 | 0 | VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { |
49 | 0 | int Log2ScaleResult = 0; |
50 | 0 | switch (ElementBitwidth) { |
51 | 0 | default: |
52 | 0 | break; |
53 | 0 | case 8: |
54 | 0 | Log2ScaleResult = Log2LMUL + 3; |
55 | 0 | break; |
56 | 0 | case 16: |
57 | 0 | Log2ScaleResult = Log2LMUL + 2; |
58 | 0 | break; |
59 | 0 | case 32: |
60 | 0 | Log2ScaleResult = Log2LMUL + 1; |
61 | 0 | break; |
62 | 0 | case 64: |
63 | 0 | Log2ScaleResult = Log2LMUL; |
64 | 0 | break; |
65 | 0 | } |
66 | | // Illegal vscale result would be less than 1 |
67 | 0 | if (Log2ScaleResult < 0) |
68 | 0 | return std::nullopt; |
69 | 0 | return 1 << Log2ScaleResult; |
70 | 0 | } |
71 | | |
72 | 0 | void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } |
73 | | |
74 | | RVVType::RVVType(BasicType BT, int Log2LMUL, |
75 | | const PrototypeDescriptor &prototype) |
76 | 0 | : BT(BT), LMUL(LMULType(Log2LMUL)) { |
77 | 0 | applyBasicType(); |
78 | 0 | applyModifier(prototype); |
79 | 0 | Valid = verifyType(); |
80 | 0 | if (Valid) { |
81 | 0 | initBuiltinStr(); |
82 | 0 | initTypeStr(); |
83 | 0 | if (isVector()) { |
84 | 0 | initClangBuiltinStr(); |
85 | 0 | } |
86 | 0 | } |
87 | 0 | } |
88 | | |
89 | | // clang-format off |
90 | | // boolean type are encoded the ratio of n (SEW/LMUL) |
91 | | // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 |
92 | | // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t |
93 | | // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 |
94 | | |
95 | | // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 |
96 | | // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- |
97 | | // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 |
98 | | // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 |
99 | | // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 |
100 | | // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 |
101 | | // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 |
102 | | // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 |
103 | | // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 |
104 | | // bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16 |
105 | | // clang-format on |
106 | | |
107 | 0 | bool RVVType::verifyType() const { |
108 | 0 | if (ScalarType == Invalid) |
109 | 0 | return false; |
110 | 0 | if (isScalar()) |
111 | 0 | return true; |
112 | 0 | if (!Scale) |
113 | 0 | return false; |
114 | 0 | if (isFloat() && ElementBitwidth == 8) |
115 | 0 | return false; |
116 | 0 | if (isBFloat() && ElementBitwidth != 16) |
117 | 0 | return false; |
118 | 0 | if (IsTuple && (NF == 1 || NF > 8)) |
119 | 0 | return false; |
120 | 0 | if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8) |
121 | 0 | return false; |
122 | 0 | unsigned V = *Scale; |
123 | 0 | switch (ElementBitwidth) { |
124 | 0 | case 1: |
125 | 0 | case 8: |
126 | | // Check Scale is 1,2,4,8,16,32,64 |
127 | 0 | return (V <= 64 && isPowerOf2_32(V)); |
128 | 0 | case 16: |
129 | | // Check Scale is 1,2,4,8,16,32 |
130 | 0 | return (V <= 32 && isPowerOf2_32(V)); |
131 | 0 | case 32: |
132 | | // Check Scale is 1,2,4,8,16 |
133 | 0 | return (V <= 16 && isPowerOf2_32(V)); |
134 | 0 | case 64: |
135 | | // Check Scale is 1,2,4,8 |
136 | 0 | return (V <= 8 && isPowerOf2_32(V)); |
137 | 0 | } |
138 | 0 | return false; |
139 | 0 | } |
140 | | |
141 | 0 | void RVVType::initBuiltinStr() { |
142 | 0 | assert(isValid() && "RVVType is invalid"); |
143 | 0 | switch (ScalarType) { |
144 | 0 | case ScalarTypeKind::Void: |
145 | 0 | BuiltinStr = "v"; |
146 | 0 | return; |
147 | 0 | case ScalarTypeKind::Size_t: |
148 | 0 | BuiltinStr = "z"; |
149 | 0 | if (IsImmediate) |
150 | 0 | BuiltinStr = "I" + BuiltinStr; |
151 | 0 | if (IsPointer) |
152 | 0 | BuiltinStr += "*"; |
153 | 0 | return; |
154 | 0 | case ScalarTypeKind::Ptrdiff_t: |
155 | 0 | BuiltinStr = "Y"; |
156 | 0 | return; |
157 | 0 | case ScalarTypeKind::UnsignedLong: |
158 | 0 | BuiltinStr = "ULi"; |
159 | 0 | return; |
160 | 0 | case ScalarTypeKind::SignedLong: |
161 | 0 | BuiltinStr = "Li"; |
162 | 0 | return; |
163 | 0 | case ScalarTypeKind::Boolean: |
164 | 0 | assert(ElementBitwidth == 1); |
165 | 0 | BuiltinStr += "b"; |
166 | 0 | break; |
167 | 0 | case ScalarTypeKind::SignedInteger: |
168 | 0 | case ScalarTypeKind::UnsignedInteger: |
169 | 0 | switch (ElementBitwidth) { |
170 | 0 | case 8: |
171 | 0 | BuiltinStr += "c"; |
172 | 0 | break; |
173 | 0 | case 16: |
174 | 0 | BuiltinStr += "s"; |
175 | 0 | break; |
176 | 0 | case 32: |
177 | 0 | BuiltinStr += "i"; |
178 | 0 | break; |
179 | 0 | case 64: |
180 | 0 | BuiltinStr += "Wi"; |
181 | 0 | break; |
182 | 0 | default: |
183 | 0 | llvm_unreachable("Unhandled ElementBitwidth!"); |
184 | 0 | } |
185 | 0 | if (isSignedInteger()) |
186 | 0 | BuiltinStr = "S" + BuiltinStr; |
187 | 0 | else |
188 | 0 | BuiltinStr = "U" + BuiltinStr; |
189 | 0 | break; |
190 | 0 | case ScalarTypeKind::Float: |
191 | 0 | switch (ElementBitwidth) { |
192 | 0 | case 16: |
193 | 0 | BuiltinStr += "x"; |
194 | 0 | break; |
195 | 0 | case 32: |
196 | 0 | BuiltinStr += "f"; |
197 | 0 | break; |
198 | 0 | case 64: |
199 | 0 | BuiltinStr += "d"; |
200 | 0 | break; |
201 | 0 | default: |
202 | 0 | llvm_unreachable("Unhandled ElementBitwidth!"); |
203 | 0 | } |
204 | 0 | break; |
205 | 0 | case ScalarTypeKind::BFloat: |
206 | 0 | BuiltinStr += "y"; |
207 | 0 | break; |
208 | 0 | default: |
209 | 0 | llvm_unreachable("ScalarType is invalid!"); |
210 | 0 | } |
211 | 0 | if (IsImmediate) |
212 | 0 | BuiltinStr = "I" + BuiltinStr; |
213 | 0 | if (isScalar()) { |
214 | 0 | if (IsConstant) |
215 | 0 | BuiltinStr += "C"; |
216 | 0 | if (IsPointer) |
217 | 0 | BuiltinStr += "*"; |
218 | 0 | return; |
219 | 0 | } |
220 | 0 | BuiltinStr = "q" + utostr(*Scale) + BuiltinStr; |
221 | | // Pointer to vector types. Defined for segment load intrinsics. |
222 | | // segment load intrinsics have pointer type arguments to store the loaded |
223 | | // vector values. |
224 | 0 | if (IsPointer) |
225 | 0 | BuiltinStr += "*"; |
226 | |
|
227 | 0 | if (IsTuple) |
228 | 0 | BuiltinStr = "T" + utostr(NF) + BuiltinStr; |
229 | 0 | } |
230 | | |
231 | 0 | void RVVType::initClangBuiltinStr() { |
232 | 0 | assert(isValid() && "RVVType is invalid"); |
233 | 0 | assert(isVector() && "Handle Vector type only"); |
234 | | |
235 | 0 | ClangBuiltinStr = "__rvv_"; |
236 | 0 | switch (ScalarType) { |
237 | 0 | case ScalarTypeKind::Boolean: |
238 | 0 | ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t"; |
239 | 0 | return; |
240 | 0 | case ScalarTypeKind::Float: |
241 | 0 | ClangBuiltinStr += "float"; |
242 | 0 | break; |
243 | 0 | case ScalarTypeKind::BFloat: |
244 | 0 | ClangBuiltinStr += "bfloat"; |
245 | 0 | break; |
246 | 0 | case ScalarTypeKind::SignedInteger: |
247 | 0 | ClangBuiltinStr += "int"; |
248 | 0 | break; |
249 | 0 | case ScalarTypeKind::UnsignedInteger: |
250 | 0 | ClangBuiltinStr += "uint"; |
251 | 0 | break; |
252 | 0 | default: |
253 | 0 | llvm_unreachable("ScalarTypeKind is invalid"); |
254 | 0 | } |
255 | 0 | ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + |
256 | 0 | (IsTuple ? "x" + utostr(NF) : "") + "_t"; |
257 | 0 | } |
258 | | |
259 | 0 | void RVVType::initTypeStr() { |
260 | 0 | assert(isValid() && "RVVType is invalid"); |
261 | | |
262 | 0 | if (IsConstant) |
263 | 0 | Str += "const "; |
264 | |
|
265 | 0 | auto getTypeString = [&](StringRef TypeStr) { |
266 | 0 | if (isScalar()) |
267 | 0 | return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); |
268 | 0 | return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + |
269 | 0 | (IsTuple ? "x" + utostr(NF) : "") + "_t") |
270 | 0 | .str(); |
271 | 0 | }; |
272 | |
|
273 | 0 | switch (ScalarType) { |
274 | 0 | case ScalarTypeKind::Void: |
275 | 0 | Str = "void"; |
276 | 0 | return; |
277 | 0 | case ScalarTypeKind::Size_t: |
278 | 0 | Str = "size_t"; |
279 | 0 | if (IsPointer) |
280 | 0 | Str += " *"; |
281 | 0 | return; |
282 | 0 | case ScalarTypeKind::Ptrdiff_t: |
283 | 0 | Str = "ptrdiff_t"; |
284 | 0 | return; |
285 | 0 | case ScalarTypeKind::UnsignedLong: |
286 | 0 | Str = "unsigned long"; |
287 | 0 | return; |
288 | 0 | case ScalarTypeKind::SignedLong: |
289 | 0 | Str = "long"; |
290 | 0 | return; |
291 | 0 | case ScalarTypeKind::Boolean: |
292 | 0 | if (isScalar()) |
293 | 0 | Str += "bool"; |
294 | 0 | else |
295 | | // Vector bool is special case, the formulate is |
296 | | // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 |
297 | 0 | Str += "vbool" + utostr(64 / *Scale) + "_t"; |
298 | 0 | break; |
299 | 0 | case ScalarTypeKind::Float: |
300 | 0 | if (isScalar()) { |
301 | 0 | if (ElementBitwidth == 64) |
302 | 0 | Str += "double"; |
303 | 0 | else if (ElementBitwidth == 32) |
304 | 0 | Str += "float"; |
305 | 0 | else if (ElementBitwidth == 16) |
306 | 0 | Str += "_Float16"; |
307 | 0 | else |
308 | 0 | llvm_unreachable("Unhandled floating type."); |
309 | 0 | } else |
310 | 0 | Str += getTypeString("float"); |
311 | 0 | break; |
312 | 0 | case ScalarTypeKind::BFloat: |
313 | 0 | if (isScalar()) { |
314 | 0 | if (ElementBitwidth == 16) |
315 | 0 | Str += "__bf16"; |
316 | 0 | else |
317 | 0 | llvm_unreachable("Unhandled floating type."); |
318 | 0 | } else |
319 | 0 | Str += getTypeString("bfloat"); |
320 | 0 | break; |
321 | 0 | case ScalarTypeKind::SignedInteger: |
322 | 0 | Str += getTypeString("int"); |
323 | 0 | break; |
324 | 0 | case ScalarTypeKind::UnsignedInteger: |
325 | 0 | Str += getTypeString("uint"); |
326 | 0 | break; |
327 | 0 | default: |
328 | 0 | llvm_unreachable("ScalarType is invalid!"); |
329 | 0 | } |
330 | 0 | if (IsPointer) |
331 | 0 | Str += " *"; |
332 | 0 | } |
333 | | |
334 | 0 | void RVVType::initShortStr() { |
335 | 0 | switch (ScalarType) { |
336 | 0 | case ScalarTypeKind::Boolean: |
337 | 0 | assert(isVector()); |
338 | 0 | ShortStr = "b" + utostr(64 / *Scale); |
339 | 0 | return; |
340 | 0 | case ScalarTypeKind::Float: |
341 | 0 | ShortStr = "f" + utostr(ElementBitwidth); |
342 | 0 | break; |
343 | 0 | case ScalarTypeKind::BFloat: |
344 | 0 | ShortStr = "bf" + utostr(ElementBitwidth); |
345 | 0 | break; |
346 | 0 | case ScalarTypeKind::SignedInteger: |
347 | 0 | ShortStr = "i" + utostr(ElementBitwidth); |
348 | 0 | break; |
349 | 0 | case ScalarTypeKind::UnsignedInteger: |
350 | 0 | ShortStr = "u" + utostr(ElementBitwidth); |
351 | 0 | break; |
352 | 0 | default: |
353 | 0 | llvm_unreachable("Unhandled case!"); |
354 | 0 | } |
355 | 0 | if (isVector()) |
356 | 0 | ShortStr += LMUL.str(); |
357 | 0 | if (isTuple()) |
358 | 0 | ShortStr += "x" + utostr(NF); |
359 | 0 | } |
360 | | |
361 | 0 | static VectorTypeModifier getTupleVTM(unsigned NF) { |
362 | 0 | assert(2 <= NF && NF <= 8 && "2 <= NF <= 8"); |
363 | 0 | return static_cast<VectorTypeModifier>( |
364 | 0 | static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2)); |
365 | 0 | } |
366 | | |
367 | 0 | void RVVType::applyBasicType() { |
368 | 0 | switch (BT) { |
369 | 0 | case BasicType::Int8: |
370 | 0 | ElementBitwidth = 8; |
371 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
372 | 0 | break; |
373 | 0 | case BasicType::Int16: |
374 | 0 | ElementBitwidth = 16; |
375 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
376 | 0 | break; |
377 | 0 | case BasicType::Int32: |
378 | 0 | ElementBitwidth = 32; |
379 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
380 | 0 | break; |
381 | 0 | case BasicType::Int64: |
382 | 0 | ElementBitwidth = 64; |
383 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
384 | 0 | break; |
385 | 0 | case BasicType::Float16: |
386 | 0 | ElementBitwidth = 16; |
387 | 0 | ScalarType = ScalarTypeKind::Float; |
388 | 0 | break; |
389 | 0 | case BasicType::Float32: |
390 | 0 | ElementBitwidth = 32; |
391 | 0 | ScalarType = ScalarTypeKind::Float; |
392 | 0 | break; |
393 | 0 | case BasicType::Float64: |
394 | 0 | ElementBitwidth = 64; |
395 | 0 | ScalarType = ScalarTypeKind::Float; |
396 | 0 | break; |
397 | 0 | case BasicType::BFloat16: |
398 | 0 | ElementBitwidth = 16; |
399 | 0 | ScalarType = ScalarTypeKind::BFloat; |
400 | 0 | break; |
401 | 0 | default: |
402 | 0 | llvm_unreachable("Unhandled type code!"); |
403 | 0 | } |
404 | 0 | assert(ElementBitwidth != 0 && "Bad element bitwidth!"); |
405 | 0 | } |
406 | | |
407 | | std::optional<PrototypeDescriptor> |
408 | | PrototypeDescriptor::parsePrototypeDescriptor( |
409 | 0 | llvm::StringRef PrototypeDescriptorStr) { |
410 | 0 | PrototypeDescriptor PD; |
411 | 0 | BaseTypeModifier PT = BaseTypeModifier::Invalid; |
412 | 0 | VectorTypeModifier VTM = VectorTypeModifier::NoModifier; |
413 | |
|
414 | 0 | if (PrototypeDescriptorStr.empty()) |
415 | 0 | return PD; |
416 | | |
417 | | // Handle base type modifier |
418 | 0 | auto PType = PrototypeDescriptorStr.back(); |
419 | 0 | switch (PType) { |
420 | 0 | case 'e': |
421 | 0 | PT = BaseTypeModifier::Scalar; |
422 | 0 | break; |
423 | 0 | case 'v': |
424 | 0 | PT = BaseTypeModifier::Vector; |
425 | 0 | break; |
426 | 0 | case 'w': |
427 | 0 | PT = BaseTypeModifier::Vector; |
428 | 0 | VTM = VectorTypeModifier::Widening2XVector; |
429 | 0 | break; |
430 | 0 | case 'q': |
431 | 0 | PT = BaseTypeModifier::Vector; |
432 | 0 | VTM = VectorTypeModifier::Widening4XVector; |
433 | 0 | break; |
434 | 0 | case 'o': |
435 | 0 | PT = BaseTypeModifier::Vector; |
436 | 0 | VTM = VectorTypeModifier::Widening8XVector; |
437 | 0 | break; |
438 | 0 | case 'm': |
439 | 0 | PT = BaseTypeModifier::Vector; |
440 | 0 | VTM = VectorTypeModifier::MaskVector; |
441 | 0 | break; |
442 | 0 | case '0': |
443 | 0 | PT = BaseTypeModifier::Void; |
444 | 0 | break; |
445 | 0 | case 'z': |
446 | 0 | PT = BaseTypeModifier::SizeT; |
447 | 0 | break; |
448 | 0 | case 't': |
449 | 0 | PT = BaseTypeModifier::Ptrdiff; |
450 | 0 | break; |
451 | 0 | case 'u': |
452 | 0 | PT = BaseTypeModifier::UnsignedLong; |
453 | 0 | break; |
454 | 0 | case 'l': |
455 | 0 | PT = BaseTypeModifier::SignedLong; |
456 | 0 | break; |
457 | 0 | case 'f': |
458 | 0 | PT = BaseTypeModifier::Float32; |
459 | 0 | break; |
460 | 0 | default: |
461 | 0 | llvm_unreachable("Illegal primitive type transformers!"); |
462 | 0 | } |
463 | 0 | PD.PT = static_cast<uint8_t>(PT); |
464 | 0 | PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); |
465 | | |
466 | | // Compute the vector type transformers, it can only appear one time. |
467 | 0 | if (PrototypeDescriptorStr.starts_with("(")) { |
468 | 0 | assert(VTM == VectorTypeModifier::NoModifier && |
469 | 0 | "VectorTypeModifier should only have one modifier"); |
470 | 0 | size_t Idx = PrototypeDescriptorStr.find(')'); |
471 | 0 | assert(Idx != StringRef::npos); |
472 | 0 | StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); |
473 | 0 | PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); |
474 | 0 | assert(!PrototypeDescriptorStr.contains('(') && |
475 | 0 | "Only allow one vector type modifier"); |
476 | | |
477 | 0 | auto ComplexTT = ComplexType.split(":"); |
478 | 0 | if (ComplexTT.first == "Log2EEW") { |
479 | 0 | uint32_t Log2EEW; |
480 | 0 | if (ComplexTT.second.getAsInteger(10, Log2EEW)) { |
481 | 0 | llvm_unreachable("Invalid Log2EEW value!"); |
482 | 0 | return std::nullopt; |
483 | 0 | } |
484 | 0 | switch (Log2EEW) { |
485 | 0 | case 3: |
486 | 0 | VTM = VectorTypeModifier::Log2EEW3; |
487 | 0 | break; |
488 | 0 | case 4: |
489 | 0 | VTM = VectorTypeModifier::Log2EEW4; |
490 | 0 | break; |
491 | 0 | case 5: |
492 | 0 | VTM = VectorTypeModifier::Log2EEW5; |
493 | 0 | break; |
494 | 0 | case 6: |
495 | 0 | VTM = VectorTypeModifier::Log2EEW6; |
496 | 0 | break; |
497 | 0 | default: |
498 | 0 | llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); |
499 | 0 | return std::nullopt; |
500 | 0 | } |
501 | 0 | } else if (ComplexTT.first == "FixedSEW") { |
502 | 0 | uint32_t NewSEW; |
503 | 0 | if (ComplexTT.second.getAsInteger(10, NewSEW)) { |
504 | 0 | llvm_unreachable("Invalid FixedSEW value!"); |
505 | 0 | return std::nullopt; |
506 | 0 | } |
507 | 0 | switch (NewSEW) { |
508 | 0 | case 8: |
509 | 0 | VTM = VectorTypeModifier::FixedSEW8; |
510 | 0 | break; |
511 | 0 | case 16: |
512 | 0 | VTM = VectorTypeModifier::FixedSEW16; |
513 | 0 | break; |
514 | 0 | case 32: |
515 | 0 | VTM = VectorTypeModifier::FixedSEW32; |
516 | 0 | break; |
517 | 0 | case 64: |
518 | 0 | VTM = VectorTypeModifier::FixedSEW64; |
519 | 0 | break; |
520 | 0 | default: |
521 | 0 | llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); |
522 | 0 | return std::nullopt; |
523 | 0 | } |
524 | 0 | } else if (ComplexTT.first == "LFixedLog2LMUL") { |
525 | 0 | int32_t Log2LMUL; |
526 | 0 | if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { |
527 | 0 | llvm_unreachable("Invalid LFixedLog2LMUL value!"); |
528 | 0 | return std::nullopt; |
529 | 0 | } |
530 | 0 | switch (Log2LMUL) { |
531 | 0 | case -3: |
532 | 0 | VTM = VectorTypeModifier::LFixedLog2LMULN3; |
533 | 0 | break; |
534 | 0 | case -2: |
535 | 0 | VTM = VectorTypeModifier::LFixedLog2LMULN2; |
536 | 0 | break; |
537 | 0 | case -1: |
538 | 0 | VTM = VectorTypeModifier::LFixedLog2LMULN1; |
539 | 0 | break; |
540 | 0 | case 0: |
541 | 0 | VTM = VectorTypeModifier::LFixedLog2LMUL0; |
542 | 0 | break; |
543 | 0 | case 1: |
544 | 0 | VTM = VectorTypeModifier::LFixedLog2LMUL1; |
545 | 0 | break; |
546 | 0 | case 2: |
547 | 0 | VTM = VectorTypeModifier::LFixedLog2LMUL2; |
548 | 0 | break; |
549 | 0 | case 3: |
550 | 0 | VTM = VectorTypeModifier::LFixedLog2LMUL3; |
551 | 0 | break; |
552 | 0 | default: |
553 | 0 | llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); |
554 | 0 | return std::nullopt; |
555 | 0 | } |
556 | 0 | } else if (ComplexTT.first == "SFixedLog2LMUL") { |
557 | 0 | int32_t Log2LMUL; |
558 | 0 | if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { |
559 | 0 | llvm_unreachable("Invalid SFixedLog2LMUL value!"); |
560 | 0 | return std::nullopt; |
561 | 0 | } |
562 | 0 | switch (Log2LMUL) { |
563 | 0 | case -3: |
564 | 0 | VTM = VectorTypeModifier::SFixedLog2LMULN3; |
565 | 0 | break; |
566 | 0 | case -2: |
567 | 0 | VTM = VectorTypeModifier::SFixedLog2LMULN2; |
568 | 0 | break; |
569 | 0 | case -1: |
570 | 0 | VTM = VectorTypeModifier::SFixedLog2LMULN1; |
571 | 0 | break; |
572 | 0 | case 0: |
573 | 0 | VTM = VectorTypeModifier::SFixedLog2LMUL0; |
574 | 0 | break; |
575 | 0 | case 1: |
576 | 0 | VTM = VectorTypeModifier::SFixedLog2LMUL1; |
577 | 0 | break; |
578 | 0 | case 2: |
579 | 0 | VTM = VectorTypeModifier::SFixedLog2LMUL2; |
580 | 0 | break; |
581 | 0 | case 3: |
582 | 0 | VTM = VectorTypeModifier::SFixedLog2LMUL3; |
583 | 0 | break; |
584 | 0 | default: |
585 | 0 | llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); |
586 | 0 | return std::nullopt; |
587 | 0 | } |
588 | |
|
589 | 0 | } else if (ComplexTT.first == "SEFixedLog2LMUL") { |
590 | 0 | int32_t Log2LMUL; |
591 | 0 | if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { |
592 | 0 | llvm_unreachable("Invalid SEFixedLog2LMUL value!"); |
593 | 0 | return std::nullopt; |
594 | 0 | } |
595 | 0 | switch (Log2LMUL) { |
596 | 0 | case -3: |
597 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMULN3; |
598 | 0 | break; |
599 | 0 | case -2: |
600 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMULN2; |
601 | 0 | break; |
602 | 0 | case -1: |
603 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMULN1; |
604 | 0 | break; |
605 | 0 | case 0: |
606 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMUL0; |
607 | 0 | break; |
608 | 0 | case 1: |
609 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMUL1; |
610 | 0 | break; |
611 | 0 | case 2: |
612 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMUL2; |
613 | 0 | break; |
614 | 0 | case 3: |
615 | 0 | VTM = VectorTypeModifier::SEFixedLog2LMUL3; |
616 | 0 | break; |
617 | 0 | default: |
618 | 0 | llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); |
619 | 0 | return std::nullopt; |
620 | 0 | } |
621 | 0 | } else if (ComplexTT.first == "Tuple") { |
622 | 0 | unsigned NF = 0; |
623 | 0 | if (ComplexTT.second.getAsInteger(10, NF)) { |
624 | 0 | llvm_unreachable("Invalid NF value!"); |
625 | 0 | return std::nullopt; |
626 | 0 | } |
627 | 0 | VTM = getTupleVTM(NF); |
628 | 0 | } else { |
629 | 0 | llvm_unreachable("Illegal complex type transformers!"); |
630 | 0 | } |
631 | 0 | } |
632 | 0 | PD.VTM = static_cast<uint8_t>(VTM); |
633 | | |
634 | | // Compute the remain type transformers |
635 | 0 | TypeModifier TM = TypeModifier::NoModifier; |
636 | 0 | for (char I : PrototypeDescriptorStr) { |
637 | 0 | switch (I) { |
638 | 0 | case 'P': |
639 | 0 | if ((TM & TypeModifier::Const) == TypeModifier::Const) |
640 | 0 | llvm_unreachable("'P' transformer cannot be used after 'C'"); |
641 | 0 | if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) |
642 | 0 | llvm_unreachable("'P' transformer cannot be used twice"); |
643 | 0 | TM |= TypeModifier::Pointer; |
644 | 0 | break; |
645 | 0 | case 'C': |
646 | 0 | TM |= TypeModifier::Const; |
647 | 0 | break; |
648 | 0 | case 'K': |
649 | 0 | TM |= TypeModifier::Immediate; |
650 | 0 | break; |
651 | 0 | case 'U': |
652 | 0 | TM |= TypeModifier::UnsignedInteger; |
653 | 0 | break; |
654 | 0 | case 'I': |
655 | 0 | TM |= TypeModifier::SignedInteger; |
656 | 0 | break; |
657 | 0 | case 'F': |
658 | 0 | TM |= TypeModifier::Float; |
659 | 0 | break; |
660 | 0 | case 'S': |
661 | 0 | TM |= TypeModifier::LMUL1; |
662 | 0 | break; |
663 | 0 | default: |
664 | 0 | llvm_unreachable("Illegal non-primitive type transformer!"); |
665 | 0 | } |
666 | 0 | } |
667 | 0 | PD.TM = static_cast<uint8_t>(TM); |
668 | |
|
669 | 0 | return PD; |
670 | 0 | } |
671 | | |
672 | 0 | void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { |
673 | | // Handle primitive type transformer |
674 | 0 | switch (static_cast<BaseTypeModifier>(Transformer.PT)) { |
675 | 0 | case BaseTypeModifier::Scalar: |
676 | 0 | Scale = 0; |
677 | 0 | break; |
678 | 0 | case BaseTypeModifier::Vector: |
679 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
680 | 0 | break; |
681 | 0 | case BaseTypeModifier::Void: |
682 | 0 | ScalarType = ScalarTypeKind::Void; |
683 | 0 | break; |
684 | 0 | case BaseTypeModifier::SizeT: |
685 | 0 | ScalarType = ScalarTypeKind::Size_t; |
686 | 0 | break; |
687 | 0 | case BaseTypeModifier::Ptrdiff: |
688 | 0 | ScalarType = ScalarTypeKind::Ptrdiff_t; |
689 | 0 | break; |
690 | 0 | case BaseTypeModifier::UnsignedLong: |
691 | 0 | ScalarType = ScalarTypeKind::UnsignedLong; |
692 | 0 | break; |
693 | 0 | case BaseTypeModifier::SignedLong: |
694 | 0 | ScalarType = ScalarTypeKind::SignedLong; |
695 | 0 | break; |
696 | 0 | case BaseTypeModifier::Float32: |
697 | 0 | ElementBitwidth = 32; |
698 | 0 | ScalarType = ScalarTypeKind::Float; |
699 | 0 | break; |
700 | 0 | case BaseTypeModifier::Invalid: |
701 | 0 | ScalarType = ScalarTypeKind::Invalid; |
702 | 0 | return; |
703 | 0 | } |
704 | | |
705 | 0 | switch (static_cast<VectorTypeModifier>(Transformer.VTM)) { |
706 | 0 | case VectorTypeModifier::Widening2XVector: |
707 | 0 | ElementBitwidth *= 2; |
708 | 0 | LMUL.MulLog2LMUL(1); |
709 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
710 | 0 | break; |
711 | 0 | case VectorTypeModifier::Widening4XVector: |
712 | 0 | ElementBitwidth *= 4; |
713 | 0 | LMUL.MulLog2LMUL(2); |
714 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
715 | 0 | break; |
716 | 0 | case VectorTypeModifier::Widening8XVector: |
717 | 0 | ElementBitwidth *= 8; |
718 | 0 | LMUL.MulLog2LMUL(3); |
719 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
720 | 0 | break; |
721 | 0 | case VectorTypeModifier::MaskVector: |
722 | 0 | ScalarType = ScalarTypeKind::Boolean; |
723 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
724 | 0 | ElementBitwidth = 1; |
725 | 0 | break; |
726 | 0 | case VectorTypeModifier::Log2EEW3: |
727 | 0 | applyLog2EEW(3); |
728 | 0 | break; |
729 | 0 | case VectorTypeModifier::Log2EEW4: |
730 | 0 | applyLog2EEW(4); |
731 | 0 | break; |
732 | 0 | case VectorTypeModifier::Log2EEW5: |
733 | 0 | applyLog2EEW(5); |
734 | 0 | break; |
735 | 0 | case VectorTypeModifier::Log2EEW6: |
736 | 0 | applyLog2EEW(6); |
737 | 0 | break; |
738 | 0 | case VectorTypeModifier::FixedSEW8: |
739 | 0 | applyFixedSEW(8); |
740 | 0 | break; |
741 | 0 | case VectorTypeModifier::FixedSEW16: |
742 | 0 | applyFixedSEW(16); |
743 | 0 | break; |
744 | 0 | case VectorTypeModifier::FixedSEW32: |
745 | 0 | applyFixedSEW(32); |
746 | 0 | break; |
747 | 0 | case VectorTypeModifier::FixedSEW64: |
748 | 0 | applyFixedSEW(64); |
749 | 0 | break; |
750 | 0 | case VectorTypeModifier::LFixedLog2LMULN3: |
751 | 0 | applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); |
752 | 0 | break; |
753 | 0 | case VectorTypeModifier::LFixedLog2LMULN2: |
754 | 0 | applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); |
755 | 0 | break; |
756 | 0 | case VectorTypeModifier::LFixedLog2LMULN1: |
757 | 0 | applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); |
758 | 0 | break; |
759 | 0 | case VectorTypeModifier::LFixedLog2LMUL0: |
760 | 0 | applyFixedLog2LMUL(0, FixedLMULType::LargerThan); |
761 | 0 | break; |
762 | 0 | case VectorTypeModifier::LFixedLog2LMUL1: |
763 | 0 | applyFixedLog2LMUL(1, FixedLMULType::LargerThan); |
764 | 0 | break; |
765 | 0 | case VectorTypeModifier::LFixedLog2LMUL2: |
766 | 0 | applyFixedLog2LMUL(2, FixedLMULType::LargerThan); |
767 | 0 | break; |
768 | 0 | case VectorTypeModifier::LFixedLog2LMUL3: |
769 | 0 | applyFixedLog2LMUL(3, FixedLMULType::LargerThan); |
770 | 0 | break; |
771 | 0 | case VectorTypeModifier::SFixedLog2LMULN3: |
772 | 0 | applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); |
773 | 0 | break; |
774 | 0 | case VectorTypeModifier::SFixedLog2LMULN2: |
775 | 0 | applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); |
776 | 0 | break; |
777 | 0 | case VectorTypeModifier::SFixedLog2LMULN1: |
778 | 0 | applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); |
779 | 0 | break; |
780 | 0 | case VectorTypeModifier::SFixedLog2LMUL0: |
781 | 0 | applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); |
782 | 0 | break; |
783 | 0 | case VectorTypeModifier::SFixedLog2LMUL1: |
784 | 0 | applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); |
785 | 0 | break; |
786 | 0 | case VectorTypeModifier::SFixedLog2LMUL2: |
787 | 0 | applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); |
788 | 0 | break; |
789 | 0 | case VectorTypeModifier::SFixedLog2LMUL3: |
790 | 0 | applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); |
791 | 0 | break; |
792 | 0 | case VectorTypeModifier::SEFixedLog2LMULN3: |
793 | 0 | applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual); |
794 | 0 | break; |
795 | 0 | case VectorTypeModifier::SEFixedLog2LMULN2: |
796 | 0 | applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual); |
797 | 0 | break; |
798 | 0 | case VectorTypeModifier::SEFixedLog2LMULN1: |
799 | 0 | applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual); |
800 | 0 | break; |
801 | 0 | case VectorTypeModifier::SEFixedLog2LMUL0: |
802 | 0 | applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual); |
803 | 0 | break; |
804 | 0 | case VectorTypeModifier::SEFixedLog2LMUL1: |
805 | 0 | applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual); |
806 | 0 | break; |
807 | 0 | case VectorTypeModifier::SEFixedLog2LMUL2: |
808 | 0 | applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual); |
809 | 0 | break; |
810 | 0 | case VectorTypeModifier::SEFixedLog2LMUL3: |
811 | 0 | applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual); |
812 | 0 | break; |
813 | 0 | case VectorTypeModifier::Tuple2: |
814 | 0 | case VectorTypeModifier::Tuple3: |
815 | 0 | case VectorTypeModifier::Tuple4: |
816 | 0 | case VectorTypeModifier::Tuple5: |
817 | 0 | case VectorTypeModifier::Tuple6: |
818 | 0 | case VectorTypeModifier::Tuple7: |
819 | 0 | case VectorTypeModifier::Tuple8: { |
820 | 0 | IsTuple = true; |
821 | 0 | NF = 2 + static_cast<uint8_t>(Transformer.VTM) - |
822 | 0 | static_cast<uint8_t>(VectorTypeModifier::Tuple2); |
823 | 0 | break; |
824 | 0 | } |
825 | 0 | case VectorTypeModifier::NoModifier: |
826 | 0 | break; |
827 | 0 | } |
828 | | |
829 | | // Early return if the current type modifier is already invalid. |
830 | 0 | if (ScalarType == Invalid) |
831 | 0 | return; |
832 | | |
833 | 0 | for (unsigned TypeModifierMaskShift = 0; |
834 | 0 | TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset); |
835 | 0 | ++TypeModifierMaskShift) { |
836 | 0 | unsigned TypeModifierMask = 1 << TypeModifierMaskShift; |
837 | 0 | if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) != |
838 | 0 | TypeModifierMask) |
839 | 0 | continue; |
840 | 0 | switch (static_cast<TypeModifier>(TypeModifierMask)) { |
841 | 0 | case TypeModifier::Pointer: |
842 | 0 | IsPointer = true; |
843 | 0 | break; |
844 | 0 | case TypeModifier::Const: |
845 | 0 | IsConstant = true; |
846 | 0 | break; |
847 | 0 | case TypeModifier::Immediate: |
848 | 0 | IsImmediate = true; |
849 | 0 | IsConstant = true; |
850 | 0 | break; |
851 | 0 | case TypeModifier::UnsignedInteger: |
852 | 0 | ScalarType = ScalarTypeKind::UnsignedInteger; |
853 | 0 | break; |
854 | 0 | case TypeModifier::SignedInteger: |
855 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
856 | 0 | break; |
857 | 0 | case TypeModifier::Float: |
858 | 0 | ScalarType = ScalarTypeKind::Float; |
859 | 0 | break; |
860 | 0 | case TypeModifier::BFloat: |
861 | 0 | ScalarType = ScalarTypeKind::BFloat; |
862 | 0 | break; |
863 | 0 | case TypeModifier::LMUL1: |
864 | 0 | LMUL = LMULType(0); |
865 | | // Update ElementBitwidth need to update Scale too. |
866 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
867 | 0 | break; |
868 | 0 | default: |
869 | 0 | llvm_unreachable("Unknown type modifier mask!"); |
870 | 0 | } |
871 | 0 | } |
872 | 0 | } |
873 | | |
874 | 0 | void RVVType::applyLog2EEW(unsigned Log2EEW) { |
875 | | // update new elmul = (eew/sew) * lmul |
876 | 0 | LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); |
877 | | // update new eew |
878 | 0 | ElementBitwidth = 1 << Log2EEW; |
879 | 0 | ScalarType = ScalarTypeKind::SignedInteger; |
880 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
881 | 0 | } |
882 | | |
883 | 0 | void RVVType::applyFixedSEW(unsigned NewSEW) { |
884 | | // Set invalid type if src and dst SEW are same. |
885 | 0 | if (ElementBitwidth == NewSEW) { |
886 | 0 | ScalarType = ScalarTypeKind::Invalid; |
887 | 0 | return; |
888 | 0 | } |
889 | | // Update new SEW |
890 | 0 | ElementBitwidth = NewSEW; |
891 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
892 | 0 | } |
893 | | |
894 | 0 | void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { |
895 | 0 | switch (Type) { |
896 | 0 | case FixedLMULType::LargerThan: |
897 | 0 | if (Log2LMUL <= LMUL.Log2LMUL) { |
898 | 0 | ScalarType = ScalarTypeKind::Invalid; |
899 | 0 | return; |
900 | 0 | } |
901 | 0 | break; |
902 | 0 | case FixedLMULType::SmallerThan: |
903 | 0 | if (Log2LMUL >= LMUL.Log2LMUL) { |
904 | 0 | ScalarType = ScalarTypeKind::Invalid; |
905 | 0 | return; |
906 | 0 | } |
907 | 0 | break; |
908 | 0 | case FixedLMULType::SmallerOrEqual: |
909 | 0 | if (Log2LMUL > LMUL.Log2LMUL) { |
910 | 0 | ScalarType = ScalarTypeKind::Invalid; |
911 | 0 | return; |
912 | 0 | } |
913 | 0 | break; |
914 | 0 | } |
915 | | |
916 | | // Update new LMUL |
917 | 0 | LMUL = LMULType(Log2LMUL); |
918 | 0 | Scale = LMUL.getScale(ElementBitwidth); |
919 | 0 | } |
920 | | |
921 | | std::optional<RVVTypes> |
922 | | RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, |
923 | 0 | ArrayRef<PrototypeDescriptor> Prototype) { |
924 | 0 | RVVTypes Types; |
925 | 0 | for (const PrototypeDescriptor &Proto : Prototype) { |
926 | 0 | auto T = computeType(BT, Log2LMUL, Proto); |
927 | 0 | if (!T) |
928 | 0 | return std::nullopt; |
929 | | // Record legal type index |
930 | 0 | Types.push_back(*T); |
931 | 0 | } |
932 | 0 | return Types; |
933 | 0 | } |
934 | | |
935 | | // Compute the hash value of RVVType, used for cache the result of computeType. |
936 | | static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, |
937 | 0 | PrototypeDescriptor Proto) { |
938 | | // Layout of hash value: |
939 | | // 0 8 16 24 32 40 |
940 | | // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | |
941 | 0 | assert(Log2LMUL >= -3 && Log2LMUL <= 3); |
942 | 0 | return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 | |
943 | 0 | ((uint64_t)(Proto.PT & 0xff) << 16) | |
944 | 0 | ((uint64_t)(Proto.TM & 0xff) << 24) | |
945 | 0 | ((uint64_t)(Proto.VTM & 0xff) << 32); |
946 | 0 | } |
947 | | |
948 | | std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL, |
949 | 0 | PrototypeDescriptor Proto) { |
950 | 0 | uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); |
951 | | // Search first |
952 | 0 | auto It = LegalTypes.find(Idx); |
953 | 0 | if (It != LegalTypes.end()) |
954 | 0 | return &(It->second); |
955 | | |
956 | 0 | if (IllegalTypes.count(Idx)) |
957 | 0 | return std::nullopt; |
958 | | |
959 | | // Compute type and record the result. |
960 | 0 | RVVType T(BT, Log2LMUL, Proto); |
961 | 0 | if (T.isValid()) { |
962 | | // Record legal type index and value. |
963 | 0 | std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool> |
964 | 0 | InsertResult = LegalTypes.insert({Idx, T}); |
965 | 0 | return &(InsertResult.first->second); |
966 | 0 | } |
967 | | // Record illegal type index. |
968 | 0 | IllegalTypes.insert(Idx); |
969 | 0 | return std::nullopt; |
970 | 0 | } |
971 | | |
972 | | //===----------------------------------------------------------------------===// |
973 | | // RVVIntrinsic implementation |
974 | | //===----------------------------------------------------------------------===// |
975 | | RVVIntrinsic::RVVIntrinsic( |
976 | | StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, |
977 | | StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, |
978 | | bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, |
979 | | bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen, |
980 | | const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, |
981 | | const std::vector<StringRef> &RequiredFeatures, unsigned NF, |
982 | | Policy NewPolicyAttrs, bool HasFRMRoundModeOp) |
983 | | : IRName(IRName), IsMasked(IsMasked), |
984 | | HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), |
985 | | SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), |
986 | 0 | ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { |
987 | | |
988 | | // Init BuiltinName, Name and OverloadedName |
989 | 0 | BuiltinName = NewName.str(); |
990 | 0 | Name = BuiltinName; |
991 | 0 | if (NewOverloadedName.empty()) |
992 | 0 | OverloadedName = NewName.split("_").first.str(); |
993 | 0 | else |
994 | 0 | OverloadedName = NewOverloadedName.str(); |
995 | 0 | if (!Suffix.empty()) |
996 | 0 | Name += "_" + Suffix.str(); |
997 | 0 | if (!OverloadedSuffix.empty()) |
998 | 0 | OverloadedName += "_" + OverloadedSuffix.str(); |
999 | |
|
1000 | 0 | updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, |
1001 | 0 | PolicyAttrs, HasFRMRoundModeOp); |
1002 | | |
1003 | | // Init OutputType and InputTypes |
1004 | 0 | OutputType = OutInTypes[0]; |
1005 | 0 | InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); |
1006 | | |
1007 | | // IntrinsicTypes is unmasked TA version index. Need to update it |
1008 | | // if there is merge operand (It is always in first operand). |
1009 | 0 | IntrinsicTypes = NewIntrinsicTypes; |
1010 | 0 | if ((IsMasked && hasMaskedOffOperand()) || |
1011 | 0 | (!IsMasked && hasPassthruOperand())) { |
1012 | 0 | for (auto &I : IntrinsicTypes) { |
1013 | 0 | if (I >= 0) |
1014 | 0 | I += NF; |
1015 | 0 | } |
1016 | 0 | } |
1017 | 0 | } |
1018 | | |
1019 | 0 | std::string RVVIntrinsic::getBuiltinTypeStr() const { |
1020 | 0 | std::string S; |
1021 | 0 | S += OutputType->getBuiltinStr(); |
1022 | 0 | for (const auto &T : InputTypes) { |
1023 | 0 | S += T->getBuiltinStr(); |
1024 | 0 | } |
1025 | 0 | return S; |
1026 | 0 | } |
1027 | | |
1028 | | std::string RVVIntrinsic::getSuffixStr( |
1029 | | RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL, |
1030 | 0 | llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { |
1031 | 0 | SmallVector<std::string> SuffixStrs; |
1032 | 0 | for (auto PD : PrototypeDescriptors) { |
1033 | 0 | auto T = TypeCache.computeType(Type, Log2LMUL, PD); |
1034 | 0 | SuffixStrs.push_back((*T)->getShortStr()); |
1035 | 0 | } |
1036 | 0 | return join(SuffixStrs, "_"); |
1037 | 0 | } |
1038 | | |
1039 | | llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes( |
1040 | | llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked, |
1041 | | bool HasMaskedOffOperand, bool HasVL, unsigned NF, |
1042 | 0 | PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) { |
1043 | 0 | SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(), |
1044 | 0 | Prototype.end()); |
1045 | 0 | bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand; |
1046 | 0 | if (IsMasked) { |
1047 | | // If HasMaskedOffOperand, insert result type as first input operand if |
1048 | | // need. |
1049 | 0 | if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) { |
1050 | 0 | if (NF == 1) { |
1051 | 0 | NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]); |
1052 | 0 | } else if (NF > 1) { |
1053 | 0 | if (IsTuple) { |
1054 | 0 | PrototypeDescriptor BasePtrOperand = Prototype[1]; |
1055 | 0 | PrototypeDescriptor MaskoffType = PrototypeDescriptor( |
1056 | 0 | static_cast<uint8_t>(BaseTypeModifier::Vector), |
1057 | 0 | static_cast<uint8_t>(getTupleVTM(NF)), |
1058 | 0 | BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer)); |
1059 | 0 | NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType); |
1060 | 0 | } else { |
1061 | | // Convert |
1062 | | // (void, op0 address, op1 address, ...) |
1063 | | // to |
1064 | | // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) |
1065 | 0 | PrototypeDescriptor MaskoffType = NewPrototype[1]; |
1066 | 0 | MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); |
1067 | 0 | NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); |
1068 | 0 | } |
1069 | 0 | } |
1070 | 0 | } |
1071 | 0 | if (HasMaskedOffOperand && NF > 1) { |
1072 | | // Convert |
1073 | | // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) |
1074 | | // to |
1075 | | // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, |
1076 | | // ...) |
1077 | 0 | if (IsTuple) |
1078 | 0 | NewPrototype.insert(NewPrototype.begin() + 1, |
1079 | 0 | PrototypeDescriptor::Mask); |
1080 | 0 | else |
1081 | 0 | NewPrototype.insert(NewPrototype.begin() + NF + 1, |
1082 | 0 | PrototypeDescriptor::Mask); |
1083 | 0 | } else { |
1084 | | // If IsMasked, insert PrototypeDescriptor:Mask as first input operand. |
1085 | 0 | NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask); |
1086 | 0 | } |
1087 | 0 | } else { |
1088 | 0 | if (NF == 1) { |
1089 | 0 | if (PolicyAttrs.isTUPolicy() && HasPassthruOp) |
1090 | 0 | NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]); |
1091 | 0 | } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) { |
1092 | 0 | if (IsTuple) { |
1093 | 0 | PrototypeDescriptor BasePtrOperand = Prototype[0]; |
1094 | 0 | PrototypeDescriptor MaskoffType = PrototypeDescriptor( |
1095 | 0 | static_cast<uint8_t>(BaseTypeModifier::Vector), |
1096 | 0 | static_cast<uint8_t>(getTupleVTM(NF)), |
1097 | 0 | BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer)); |
1098 | 0 | NewPrototype.insert(NewPrototype.begin(), MaskoffType); |
1099 | 0 | } else { |
1100 | | // NF > 1 cases for segment load operations. |
1101 | | // Convert |
1102 | | // (void, op0 address, op1 address, ...) |
1103 | | // to |
1104 | | // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...) |
1105 | 0 | PrototypeDescriptor MaskoffType = Prototype[1]; |
1106 | 0 | MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); |
1107 | 0 | NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); |
1108 | 0 | } |
1109 | 0 | } |
1110 | 0 | } |
1111 | | |
1112 | | // If HasVL, append PrototypeDescriptor:VL to last operand |
1113 | 0 | if (HasVL) |
1114 | 0 | NewPrototype.push_back(PrototypeDescriptor::VL); |
1115 | |
|
1116 | 0 | return NewPrototype; |
1117 | 0 | } |
1118 | | |
1119 | 0 | llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() { |
1120 | 0 | return {Policy(Policy::PolicyType::Undisturbed)}; // TU |
1121 | 0 | } |
1122 | | |
1123 | | llvm::SmallVector<Policy> |
1124 | | RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy, |
1125 | 0 | bool HasMaskPolicy) { |
1126 | 0 | if (HasTailPolicy && HasMaskPolicy) |
1127 | 0 | return {Policy(Policy::PolicyType::Undisturbed, |
1128 | 0 | Policy::PolicyType::Agnostic), // TUM |
1129 | 0 | Policy(Policy::PolicyType::Undisturbed, |
1130 | 0 | Policy::PolicyType::Undisturbed), // TUMU |
1131 | 0 | Policy(Policy::PolicyType::Agnostic, |
1132 | 0 | Policy::PolicyType::Undisturbed)}; // MU |
1133 | 0 | if (HasTailPolicy && !HasMaskPolicy) |
1134 | 0 | return {Policy(Policy::PolicyType::Undisturbed, |
1135 | 0 | Policy::PolicyType::Agnostic)}; // TU |
1136 | 0 | if (!HasTailPolicy && HasMaskPolicy) |
1137 | 0 | return {Policy(Policy::PolicyType::Agnostic, |
1138 | 0 | Policy::PolicyType::Undisturbed)}; // MU |
1139 | 0 | llvm_unreachable("An RVV instruction should not be without both tail policy " |
1140 | 0 | "and mask policy"); |
1141 | 0 | } |
1142 | | |
1143 | | void RVVIntrinsic::updateNamesAndPolicy( |
1144 | | bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, |
1145 | 0 | std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) { |
1146 | |
|
1147 | 0 | auto appendPolicySuffix = [&](const std::string &suffix) { |
1148 | 0 | Name += suffix; |
1149 | 0 | BuiltinName += suffix; |
1150 | 0 | OverloadedName += suffix; |
1151 | 0 | }; |
1152 | | |
1153 | | // This follows the naming guideline under riscv-c-api-doc to add the |
1154 | | // `__riscv_` suffix for all RVV intrinsics. |
1155 | 0 | Name = "__riscv_" + Name; |
1156 | 0 | OverloadedName = "__riscv_" + OverloadedName; |
1157 | |
|
1158 | 0 | if (HasFRMRoundModeOp) { |
1159 | 0 | Name += "_rm"; |
1160 | 0 | BuiltinName += "_rm"; |
1161 | 0 | } |
1162 | |
|
1163 | 0 | if (IsMasked) { |
1164 | 0 | if (PolicyAttrs.isTUMUPolicy()) |
1165 | 0 | appendPolicySuffix("_tumu"); |
1166 | 0 | else if (PolicyAttrs.isTUMAPolicy()) |
1167 | 0 | appendPolicySuffix("_tum"); |
1168 | 0 | else if (PolicyAttrs.isTAMUPolicy()) |
1169 | 0 | appendPolicySuffix("_mu"); |
1170 | 0 | else if (PolicyAttrs.isTAMAPolicy()) { |
1171 | 0 | Name += "_m"; |
1172 | 0 | BuiltinName += "_m"; |
1173 | 0 | } else |
1174 | 0 | llvm_unreachable("Unhandled policy condition"); |
1175 | 0 | } else { |
1176 | 0 | if (PolicyAttrs.isTUPolicy()) |
1177 | 0 | appendPolicySuffix("_tu"); |
1178 | 0 | else if (PolicyAttrs.isTAPolicy()) // no suffix needed |
1179 | 0 | return; |
1180 | 0 | else |
1181 | 0 | llvm_unreachable("Unhandled policy condition"); |
1182 | 0 | } |
1183 | 0 | } |
1184 | | |
1185 | 0 | SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { |
1186 | 0 | SmallVector<PrototypeDescriptor> PrototypeDescriptors; |
1187 | 0 | const StringRef Primaries("evwqom0ztulf"); |
1188 | 0 | while (!Prototypes.empty()) { |
1189 | 0 | size_t Idx = 0; |
1190 | | // Skip over complex prototype because it could contain primitive type |
1191 | | // character. |
1192 | 0 | if (Prototypes[0] == '(') |
1193 | 0 | Idx = Prototypes.find_first_of(')'); |
1194 | 0 | Idx = Prototypes.find_first_of(Primaries, Idx); |
1195 | 0 | assert(Idx != StringRef::npos); |
1196 | 0 | auto PD = PrototypeDescriptor::parsePrototypeDescriptor( |
1197 | 0 | Prototypes.slice(0, Idx + 1)); |
1198 | 0 | if (!PD) |
1199 | 0 | llvm_unreachable("Error during parsing prototype."); |
1200 | 0 | PrototypeDescriptors.push_back(*PD); |
1201 | 0 | Prototypes = Prototypes.drop_front(Idx + 1); |
1202 | 0 | } |
1203 | 0 | return PrototypeDescriptors; |
1204 | 0 | } |
1205 | | |
1206 | 0 | raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) { |
1207 | 0 | OS << "{"; |
1208 | 0 | OS << "\"" << Record.Name << "\","; |
1209 | 0 | if (Record.OverloadedName == nullptr || |
1210 | 0 | StringRef(Record.OverloadedName).empty()) |
1211 | 0 | OS << "nullptr,"; |
1212 | 0 | else |
1213 | 0 | OS << "\"" << Record.OverloadedName << "\","; |
1214 | 0 | OS << Record.PrototypeIndex << ","; |
1215 | 0 | OS << Record.SuffixIndex << ","; |
1216 | 0 | OS << Record.OverloadedSuffixIndex << ","; |
1217 | 0 | OS << (int)Record.PrototypeLength << ","; |
1218 | 0 | OS << (int)Record.SuffixLength << ","; |
1219 | 0 | OS << (int)Record.OverloadedSuffixSize << ","; |
1220 | 0 | OS << Record.RequiredExtensions << ","; |
1221 | 0 | OS << (int)Record.TypeRangeMask << ","; |
1222 | 0 | OS << (int)Record.Log2LMULMask << ","; |
1223 | 0 | OS << (int)Record.NF << ","; |
1224 | 0 | OS << (int)Record.HasMasked << ","; |
1225 | 0 | OS << (int)Record.HasVL << ","; |
1226 | 0 | OS << (int)Record.HasMaskedOffOperand << ","; |
1227 | 0 | OS << (int)Record.HasTailPolicy << ","; |
1228 | 0 | OS << (int)Record.HasMaskPolicy << ","; |
1229 | 0 | OS << (int)Record.HasFRMRoundModeOp << ","; |
1230 | 0 | OS << (int)Record.IsTuple << ","; |
1231 | 0 | OS << (int)Record.UnMaskedPolicyScheme << ","; |
1232 | 0 | OS << (int)Record.MaskedPolicyScheme << ","; |
1233 | 0 | OS << "},\n"; |
1234 | 0 | return OS; |
1235 | 0 | } |
1236 | | |
1237 | | } // end namespace RISCV |
1238 | | } // end namespace clang |