Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp
Line
Count
Source (jump to first uncovered line)
1
//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This provides an abstract class for HLSL code generation.  Concrete
10
// subclasses of this implement code generation for specific HLSL
11
// runtime libraries.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#include "CGHLSLRuntime.h"
16
#include "CGDebugInfo.h"
17
#include "CodeGenModule.h"
18
#include "clang/AST/Decl.h"
19
#include "clang/Basic/TargetOptions.h"
20
#include "llvm/IR/IntrinsicsDirectX.h"
21
#include "llvm/IR/Metadata.h"
22
#include "llvm/IR/Module.h"
23
#include "llvm/Support/FormatVariadic.h"
24
25
using namespace clang;
26
using namespace CodeGen;
27
using namespace clang::hlsl;
28
using namespace llvm;
29
30
namespace {
31
32
0
void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
33
  // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34
  // Assume ValVersionStr is legal here.
35
0
  VersionTuple Version;
36
0
  if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
37
0
      Version.getSubminor() || !Version.getMinor()) {
38
0
    return;
39
0
  }
40
41
0
  uint64_t Major = Version.getMajor();
42
0
  uint64_t Minor = *Version.getMinor();
43
44
0
  auto &Ctx = M.getContext();
45
0
  IRBuilder<> B(M.getContext());
46
0
  MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
47
0
                                  ConstantAsMetadata::get(B.getInt32(Minor))});
48
0
  StringRef DXILValKey = "dx.valver";
49
0
  auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50
0
  DXILValMD->addOperand(Val);
51
0
}
52
0
void addDisableOptimizations(llvm::Module &M) {
53
0
  StringRef Key = "dx.disable_optimizations";
54
0
  M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55
0
}
56
// cbuffer will be translated into global variable in special address space.
57
// If translate into C,
58
// cbuffer A {
59
//   float a;
60
//   float b;
61
// }
62
// float foo() { return a + b; }
63
//
64
// will be translated into
65
//
66
// struct A {
67
//   float a;
68
//   float b;
69
// } cbuffer_A __attribute__((address_space(4)));
70
// float foo() { return cbuffer_A.a + cbuffer_A.b; }
71
//
72
// layoutBuffer will create the struct A type.
73
// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74
// and cbuffer_A.b.
75
//
76
0
void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77
0
  if (Buf.Constants.empty())
78
0
    return;
79
80
0
  std::vector<llvm::Type *> EltTys;
81
0
  for (auto &Const : Buf.Constants) {
82
0
    GlobalVariable *GV = Const.first;
83
0
    Const.second = EltTys.size();
84
0
    llvm::Type *Ty = GV->getValueType();
85
0
    EltTys.emplace_back(Ty);
86
0
  }
87
0
  Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88
0
}
89
90
0
GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91
  // Create global variable for CB.
92
0
  GlobalVariable *CBGV = new GlobalVariable(
93
0
      Buf.LayoutStruct, /*isConstant*/ true,
94
0
      GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95
0
      llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96
0
      GlobalValue::NotThreadLocal);
97
98
0
  IRBuilder<> B(CBGV->getContext());
99
0
  Value *ZeroIdx = B.getInt32(0);
100
  // Replace Const use with CB use.
101
0
  for (auto &[GV, Offset] : Buf.Constants) {
102
0
    Value *GEP =
103
0
        B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104
105
0
    assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106
0
           "constant type mismatch");
107
108
    // Replace.
109
0
    GV->replaceAllUsesWith(GEP);
110
    // Erase GV.
111
0
    GV->removeDeadConstantUsers();
112
0
    GV->eraseFromParent();
113
0
  }
114
0
  return CBGV;
115
0
}
116
117
} // namespace
118
119
0
void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120
0
  if (D->getStorageClass() == SC_Static) {
121
    // For static inside cbuffer, take as global static.
122
    // Don't add to cbuffer.
123
0
    CGM.EmitGlobal(D);
124
0
    return;
125
0
  }
126
127
0
  auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128
  // Add debug info for constVal.
129
0
  if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
130
0
    if (CGM.getCodeGenOpts().getDebugInfo() >=
131
0
        codegenoptions::DebugInfoKind::LimitedDebugInfo)
132
0
      DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133
134
  // FIXME: support packoffset.
135
  // See https://github.com/llvm/llvm-project/issues/57914.
136
0
  uint32_t Offset = 0;
137
0
  bool HasUserOffset = false;
138
139
0
  unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140
0
  CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141
0
}
142
143
0
void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144
0
  for (Decl *it : DC->decls()) {
145
0
    if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146
0
      addConstant(ConstDecl, CB);
147
0
    } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148
      // Nothing to do for this declaration.
149
0
    } else if (isa<FunctionDecl>(it)) {
150
      // A function within an cbuffer is effectively a top-level function,
151
      // as it only refers to globally scoped declarations.
152
0
      CGM.EmitTopLevelDecl(it);
153
0
    }
154
0
  }
155
0
}
156
157
0
void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
158
0
  Buffers.emplace_back(Buffer(D));
159
0
  addBufferDecls(D, Buffers.back());
160
0
}
161
162
0
void CGHLSLRuntime::finishCodeGen() {
163
0
  auto &TargetOpts = CGM.getTarget().getTargetOpts();
164
0
  llvm::Module &M = CGM.getModule();
165
0
  Triple T(M.getTargetTriple());
166
0
  if (T.getArch() == Triple::ArchType::dxil)
167
0
    addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168
169
0
  generateGlobalCtorDtorCalls();
170
0
  if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171
0
    addDisableOptimizations(M);
172
173
0
  const DataLayout &DL = M.getDataLayout();
174
175
0
  for (auto &Buf : Buffers) {
176
0
    layoutBuffer(Buf, DL);
177
0
    GlobalVariable *GV = replaceBuffer(Buf);
178
0
    M.insertGlobalVariable(GV);
179
0
    llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180
0
                                       ? llvm::hlsl::ResourceClass::CBuffer
181
0
                                       : llvm::hlsl::ResourceClass::SRV;
182
0
    llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183
0
                                      ? llvm::hlsl::ResourceKind::CBuffer
184
0
                                      : llvm::hlsl::ResourceKind::TBuffer;
185
0
    addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
186
0
                                llvm::hlsl::ElementType::Invalid, Buf.Binding);
187
0
  }
188
0
}
189
190
CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
191
    : Name(D->getName()), IsCBuffer(D->isCBuffer()),
192
0
      Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
193
194
void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
195
                                                llvm::hlsl::ResourceClass RC,
196
                                                llvm::hlsl::ResourceKind RK,
197
                                                bool IsROV,
198
                                                llvm::hlsl::ElementType ET,
199
0
                                                BufferResBinding &Binding) {
200
0
  llvm::Module &M = CGM.getModule();
201
202
0
  NamedMDNode *ResourceMD = nullptr;
203
0
  switch (RC) {
204
0
  case llvm::hlsl::ResourceClass::UAV:
205
0
    ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
206
0
    break;
207
0
  case llvm::hlsl::ResourceClass::SRV:
208
0
    ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
209
0
    break;
210
0
  case llvm::hlsl::ResourceClass::CBuffer:
211
0
    ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
212
0
    break;
213
0
  default:
214
0
    assert(false && "Unsupported buffer type!");
215
0
    return;
216
0
  }
217
0
  assert(ResourceMD != nullptr &&
218
0
         "ResourceMD must have been set by the switch above.");
219
220
0
  llvm::hlsl::FrontendResource Res(
221
0
      GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
222
0
  ResourceMD->addOperand(Res.getMetadata());
223
0
}
224
225
static llvm::hlsl::ElementType
226
0
calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
227
0
  using llvm::hlsl::ElementType;
228
229
  // TODO: We may need to update this when we add things like ByteAddressBuffer
230
  // that don't have a template parameter (or, indeed, an element type).
231
0
  const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
232
0
  assert(TST && "Resource types must be template specializations");
233
0
  ArrayRef<TemplateArgument> Args = TST->template_arguments();
234
0
  assert(!Args.empty() && "Resource has no element type");
235
236
  // At this point we have a resource with an element type, so we can assume
237
  // that it's valid or we would have diagnosed the error earlier.
238
0
  QualType ElTy = Args[0].getAsType();
239
240
  // We should either have a basic type or a vector of a basic type.
241
0
  if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
242
0
    ElTy = VecTy->getElementType();
243
244
0
  if (ElTy->isSignedIntegerType()) {
245
0
    switch (Context.getTypeSize(ElTy)) {
246
0
    case 16:
247
0
      return ElementType::I16;
248
0
    case 32:
249
0
      return ElementType::I32;
250
0
    case 64:
251
0
      return ElementType::I64;
252
0
    }
253
0
  } else if (ElTy->isUnsignedIntegerType()) {
254
0
    switch (Context.getTypeSize(ElTy)) {
255
0
    case 16:
256
0
      return ElementType::U16;
257
0
    case 32:
258
0
      return ElementType::U32;
259
0
    case 64:
260
0
      return ElementType::U64;
261
0
    }
262
0
  } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
263
0
    return ElementType::F16;
264
0
  else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
265
0
    return ElementType::F32;
266
0
  else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
267
0
    return ElementType::F64;
268
269
  // TODO: We need to handle unorm/snorm float types here once we support them
270
0
  llvm_unreachable("Invalid element type for resource");
271
0
}
272
273
0
void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
274
0
  const Type *Ty = D->getType()->getPointeeOrArrayElementType();
275
0
  if (!Ty)
276
0
    return;
277
0
  const auto *RD = Ty->getAsCXXRecordDecl();
278
0
  if (!RD)
279
0
    return;
280
0
  const auto *Attr = RD->getAttr<HLSLResourceAttr>();
281
0
  if (!Attr)
282
0
    return;
283
284
0
  llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
285
0
  llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
286
0
  bool IsROV = Attr->getIsROV();
287
0
  llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
288
289
0
  BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
290
0
  addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
291
0
}
292
293
CGHLSLRuntime::BufferResBinding::BufferResBinding(
294
0
    HLSLResourceBindingAttr *Binding) {
295
0
  if (Binding) {
296
0
    llvm::APInt RegInt(64, 0);
297
0
    Binding->getSlot().substr(1).getAsInteger(10, RegInt);
298
0
    Reg = RegInt.getLimitedValue();
299
0
    llvm::APInt SpaceInt(64, 0);
300
0
    Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
301
0
    Space = SpaceInt.getLimitedValue();
302
0
  } else {
303
0
    Space = 0;
304
0
  }
305
0
}
306
307
void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
308
0
    const FunctionDecl *FD, llvm::Function *Fn) {
309
0
  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
310
0
  assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
311
0
  const StringRef ShaderAttrKindStr = "hlsl.shader";
312
0
  Fn->addFnAttr(ShaderAttrKindStr,
313
0
                ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
314
0
  if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
315
0
    const StringRef NumThreadsKindStr = "hlsl.numthreads";
316
0
    std::string NumThreadsStr =
317
0
        formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
318
0
                NumThreadsAttr->getZ());
319
0
    Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
320
0
  }
321
0
}
322
323
0
static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
324
0
  if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
325
0
    Value *Result = PoisonValue::get(Ty);
326
0
    for (unsigned I = 0; I < VT->getNumElements(); ++I) {
327
0
      Value *Elt = B.CreateCall(F, {B.getInt32(I)});
328
0
      Result = B.CreateInsertElement(Result, Elt, I);
329
0
    }
330
0
    return Result;
331
0
  }
332
0
  return B.CreateCall(F, {B.getInt32(0)});
333
0
}
334
335
llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
336
                                              const ParmVarDecl &D,
337
0
                                              llvm::Type *Ty) {
338
0
  assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
339
0
  if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
340
0
    llvm::Function *DxGroupIndex =
341
0
        CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
342
0
    return B.CreateCall(FunctionCallee(DxGroupIndex));
343
0
  }
344
0
  if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
345
0
    llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
346
0
    return buildVectorInput(B, DxThreadID, Ty);
347
0
  }
348
0
  assert(false && "Unhandled parameter attribute");
349
0
  return nullptr;
350
0
}
351
352
void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
353
0
                                      llvm::Function *Fn) {
354
0
  llvm::Module &M = CGM.getModule();
355
0
  llvm::LLVMContext &Ctx = M.getContext();
356
0
  auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
357
0
  Function *EntryFn =
358
0
      Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
359
360
  // Copy function attributes over, we have no argument or return attributes
361
  // that can be valid on the real entry.
362
0
  AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
363
0
                                              Fn->getAttributes().getFnAttrs());
364
0
  EntryFn->setAttributes(NewAttrs);
365
0
  setHLSLEntryAttributes(FD, EntryFn);
366
367
  // Set the called function as internal linkage.
368
0
  Fn->setLinkage(GlobalValue::InternalLinkage);
369
370
0
  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
371
0
  IRBuilder<> B(BB);
372
0
  llvm::SmallVector<Value *> Args;
373
  // FIXME: support struct parameters where semantics are on members.
374
  // See: https://github.com/llvm/llvm-project/issues/57874
375
0
  unsigned SRetOffset = 0;
376
0
  for (const auto &Param : Fn->args()) {
377
0
    if (Param.hasStructRetAttr()) {
378
      // FIXME: support output.
379
      // See: https://github.com/llvm/llvm-project/issues/57874
380
0
      SRetOffset = 1;
381
0
      Args.emplace_back(PoisonValue::get(Param.getType()));
382
0
      continue;
383
0
    }
384
0
    const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
385
0
    Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
386
0
  }
387
388
0
  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
389
0
  (void)CI;
390
  // FIXME: Handle codegen for return type semantics.
391
  // See: https://github.com/llvm/llvm-project/issues/57875
392
0
  B.CreateRetVoid();
393
0
}
394
395
static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
396
0
                            bool CtorOrDtor) {
397
0
  const auto *GV =
398
0
      M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
399
0
  if (!GV)
400
0
    return;
401
0
  const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
402
0
  if (!CA)
403
0
    return;
404
  // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
405
  // HLSL neither supports priorities or COMDat values, so we will check those
406
  // in an assert but not handle them.
407
408
0
  llvm::SmallVector<Function *> CtorFns;
409
0
  for (const auto &Ctor : CA->operands()) {
410
0
    if (isa<ConstantAggregateZero>(Ctor))
411
0
      continue;
412
0
    ConstantStruct *CS = cast<ConstantStruct>(Ctor);
413
414
0
    assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
415
0
           "HLSL doesn't support setting priority for global ctors.");
416
0
    assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
417
0
           "HLSL doesn't support COMDat for global ctors.");
418
0
    Fns.push_back(cast<Function>(CS->getOperand(1)));
419
0
  }
420
0
}
421
422
0
void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
423
0
  llvm::Module &M = CGM.getModule();
424
0
  SmallVector<Function *> CtorFns;
425
0
  SmallVector<Function *> DtorFns;
426
0
  gatherFunctions(CtorFns, M, true);
427
0
  gatherFunctions(DtorFns, M, false);
428
429
  // Insert a call to the global constructor at the beginning of the entry block
430
  // to externally exported functions. This is a bit of a hack, but HLSL allows
431
  // global constructors, but doesn't support driver initialization of globals.
432
0
  for (auto &F : M.functions()) {
433
0
    if (!F.hasFnAttribute("hlsl.shader"))
434
0
      continue;
435
0
    IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
436
0
    for (auto *Fn : CtorFns)
437
0
      B.CreateCall(FunctionCallee(Fn));
438
439
    // Insert global dtors before the terminator of the last instruction
440
0
    B.SetInsertPoint(F.back().getTerminator());
441
0
    for (auto *Fn : DtorFns)
442
0
      B.CreateCall(FunctionCallee(Fn));
443
0
  }
444
445
  // No need to keep global ctors/dtors for non-lib profile after call to
446
  // ctors/dtors added for entry.
447
0
  Triple T(M.getTargetTriple());
448
0
  if (T.getEnvironment() != Triple::EnvironmentType::Library) {
449
0
    if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
450
0
      GV->eraseFromParent();
451
0
    if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
452
0
      GV->eraseFromParent();
453
0
  }
454
0
}