Coverage Report

Created: 2025-07-23 06:18

/src/spirv-tools/source/opt/scalar_analysis.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2018 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/scalar_analysis.h"
16
17
#include <functional>
18
#include <string>
19
#include <utility>
20
21
#include "source/opt/ir_context.h"
22
23
// Transforms a given scalar operation instruction into a DAG representation.
24
//
25
// 1. Take an instruction and traverse its operands until we reach a
26
// constant node or an instruction which we do not know how to compute the
27
// value, such as a load.
28
//
29
// 2. Create a new node for each instruction traversed and build the nodes for
30
// the in operands of that instruction as well.
31
//
32
// 3. Add the operand nodes as children of the first and hash the node. Use the
33
// hash to see if the node is already in the cache. We ensure the children are
34
// always in sorted order so that two nodes with the same children but inserted
35
// in a different order have the same hash and so that the overloaded operator==
36
// will return true. If the node is already in the cache return the cached
37
// version instead.
38
//
39
// 4. The created DAG can then be simplified by
40
// ScalarAnalysis::SimplifyExpression, implemented in
41
// scalar_analysis_simplification.cpp. See that file for further information on
42
// the simplification process.
43
//
44
45
namespace spvtools {
46
namespace opt {
47
48
uint32_t SENode::NumberOfNodes = 0;
49
50
ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
51
0
    : context_(context), pretend_equal_{} {
52
  // Create and cached the CantComputeNode.
53
0
  cached_cant_compute_ =
54
0
      GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
55
0
}
56
57
0
SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
58
  // If operand is can't compute then the whole graph is can't compute.
59
0
  if (operand->IsCantCompute()) return CreateCantComputeNode();
60
61
0
  if (operand->GetType() == SENode::Constant) {
62
0
    return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
63
0
  }
64
0
  std::unique_ptr<SENode> negation_node{new SENegative(this)};
65
0
  negation_node->AddChild(operand);
66
0
  return GetCachedOrAdd(std::move(negation_node));
67
0
}
68
69
0
SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
70
0
  return GetCachedOrAdd(
71
0
      std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
72
0
}
73
74
SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
75
0
    const Loop* loop, SENode* offset, SENode* coefficient) {
76
0
  assert(loop && "Recurrent add expressions must have a valid loop.");
77
78
  // If operands are can't compute then the whole graph is can't compute.
79
0
  if (offset->IsCantCompute() || coefficient->IsCantCompute())
80
0
    return CreateCantComputeNode();
81
82
0
  const Loop* loop_to_use = nullptr;
83
0
  if (pretend_equal_[loop]) {
84
0
    loop_to_use = pretend_equal_[loop];
85
0
  } else {
86
0
    loop_to_use = loop;
87
0
  }
88
89
0
  std::unique_ptr<SERecurrentNode> phi_node{
90
0
      new SERecurrentNode(this, loop_to_use)};
91
0
  phi_node->AddOffset(offset);
92
0
  phi_node->AddCoefficient(coefficient);
93
94
0
  return GetCachedOrAdd(std::move(phi_node));
95
0
}
96
97
SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
98
0
    const Instruction* multiply) {
99
0
  assert(multiply->opcode() == spv::Op::OpIMul &&
100
0
         "Multiply node did not come from a multiply instruction");
101
0
  analysis::DefUseManager* def_use = context_->get_def_use_mgr();
102
103
0
  SENode* op1 =
104
0
      AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
105
0
  SENode* op2 =
106
0
      AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
107
108
0
  return CreateMultiplyNode(op1, op2);
109
0
}
110
111
SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
112
0
                                                    SENode* operand_2) {
113
  // If operands are can't compute then the whole graph is can't compute.
114
0
  if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
115
0
    return CreateCantComputeNode();
116
117
0
  if (operand_1->GetType() == SENode::Constant &&
118
0
      operand_2->GetType() == SENode::Constant) {
119
0
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
120
0
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
121
0
  }
122
123
0
  std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
124
125
0
  multiply_node->AddChild(operand_1);
126
0
  multiply_node->AddChild(operand_2);
127
128
0
  return GetCachedOrAdd(std::move(multiply_node));
129
0
}
130
131
SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
132
0
                                                   SENode* operand_2) {
133
  // Fold if both operands are constant.
134
0
  if (operand_1->GetType() == SENode::Constant &&
135
0
      operand_2->GetType() == SENode::Constant) {
136
0
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
137
0
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
138
0
  }
139
140
0
  return CreateAddNode(operand_1, CreateNegation(operand_2));
141
0
}
142
143
SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
144
0
                                               SENode* operand_2) {
145
  // Fold if both operands are constant and the |simplify| flag is true.
146
0
  if (operand_1->GetType() == SENode::Constant &&
147
0
      operand_2->GetType() == SENode::Constant) {
148
0
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
149
0
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
150
0
  }
151
152
  // If operands are can't compute then the whole graph is can't compute.
153
0
  if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
154
0
    return CreateCantComputeNode();
155
156
0
  std::unique_ptr<SENode> add_node{new SEAddNode(this)};
157
158
0
  add_node->AddChild(operand_1);
159
0
  add_node->AddChild(operand_2);
160
161
0
  return GetCachedOrAdd(std::move(add_node));
162
0
}
163
164
0
SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
165
0
  auto itr = recurrent_node_map_.find(inst);
166
0
  if (itr != recurrent_node_map_.end()) return itr->second;
167
168
0
  SENode* output = nullptr;
169
0
  switch (inst->opcode()) {
170
0
    case spv::Op::OpPhi: {
171
0
      output = AnalyzePhiInstruction(inst);
172
0
      break;
173
0
    }
174
0
    case spv::Op::OpConstant:
175
0
    case spv::Op::OpConstantNull: {
176
0
      output = AnalyzeConstant(inst);
177
0
      break;
178
0
    }
179
0
    case spv::Op::OpISub:
180
0
    case spv::Op::OpIAdd: {
181
0
      output = AnalyzeAddOp(inst);
182
0
      break;
183
0
    }
184
0
    case spv::Op::OpIMul: {
185
0
      output = AnalyzeMultiplyOp(inst);
186
0
      break;
187
0
    }
188
0
    default: {
189
0
      output = CreateValueUnknownNode(inst);
190
0
      break;
191
0
    }
192
0
  }
193
194
0
  return output;
195
0
}
196
197
0
SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
198
0
  if (inst->opcode() == spv::Op::OpConstantNull) return CreateConstant(0);
199
200
0
  assert(inst->opcode() == spv::Op::OpConstant);
201
0
  assert(inst->NumInOperands() == 1);
202
0
  int64_t value = 0;
203
204
  // Look up the instruction in the constant manager.
205
0
  const analysis::Constant* constant =
206
0
      context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
207
208
0
  if (!constant) return CreateCantComputeNode();
209
210
0
  const analysis::IntConstant* int_constant = constant->AsIntConstant();
211
212
  // Exit out if it is a 64 bit integer.
213
0
  if (!int_constant || int_constant->words().size() != 1)
214
0
    return CreateCantComputeNode();
215
216
0
  if (int_constant->type()->AsInteger()->IsSigned()) {
217
0
    value = int_constant->GetS32BitValue();
218
0
  } else {
219
0
    value = int_constant->GetU32BitValue();
220
0
  }
221
222
0
  return CreateConstant(value);
223
0
}
224
225
// Handles both addition and subtraction. If the |sub| flag is set then the
226
// addition will be op1+(-op2) otherwise op1+op2.
227
0
SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
228
0
  assert((inst->opcode() == spv::Op::OpIAdd ||
229
0
          inst->opcode() == spv::Op::OpISub) &&
230
0
         "Add node must be created from a OpIAdd or OpISub instruction");
231
232
0
  analysis::DefUseManager* def_use = context_->get_def_use_mgr();
233
234
0
  SENode* op1 =
235
0
      AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
236
237
0
  SENode* op2 =
238
0
      AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
239
240
  // To handle subtraction we wrap the second operand in a unary negation node.
241
0
  if (inst->opcode() == spv::Op::OpISub) {
242
0
    op2 = CreateNegation(op2);
243
0
  }
244
245
0
  return CreateAddNode(op1, op2);
246
0
}
247
248
0
SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
249
  // The phi should only have two incoming value pairs.
250
0
  if (phi->NumInOperands() != 4) {
251
0
    return CreateCantComputeNode();
252
0
  }
253
254
0
  analysis::DefUseManager* def_use = context_->get_def_use_mgr();
255
256
  // Get the basic block this instruction belongs to.
257
0
  BasicBlock* basic_block =
258
0
      context_->get_instr_block(const_cast<Instruction*>(phi));
259
260
  // And then the function that the basic blocks belongs to.
261
0
  Function* function = basic_block->GetParent();
262
263
  // Use the function to get the loop descriptor.
264
0
  LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
265
266
  // We only handle phis in loops at the moment.
267
0
  if (!loop_descriptor) return CreateCantComputeNode();
268
269
  // Get the innermost loop which this block belongs to.
270
0
  Loop* loop = (*loop_descriptor)[basic_block->id()];
271
272
  // If the loop doesn't exist or doesn't have a preheader or latch block, exit
273
  // out.
274
0
  if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
275
0
      loop->GetHeaderBlock() != basic_block)
276
0
    return recurrent_node_map_[phi] = CreateCantComputeNode();
277
278
0
  const Loop* loop_to_use = nullptr;
279
0
  if (pretend_equal_[loop]) {
280
0
    loop_to_use = pretend_equal_[loop];
281
0
  } else {
282
0
    loop_to_use = loop;
283
0
  }
284
0
  std::unique_ptr<SERecurrentNode> phi_node{
285
0
      new SERecurrentNode(this, loop_to_use)};
286
287
  // We add the node to this map to allow it to be returned before the node is
288
  // fully built. This is needed as the subsequent call to AnalyzeInstruction
289
  // could lead back to this |phi| instruction so we return the pointer
290
  // immediately in AnalyzeInstruction to break the recursion.
291
0
  recurrent_node_map_[phi] = phi_node.get();
292
293
  // Traverse the operands of the instruction an create new nodes for each one.
294
0
  for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
295
0
    uint32_t value_id = phi->GetSingleWordInOperand(i);
296
0
    uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
297
298
0
    Instruction* value_inst = def_use->GetDef(value_id);
299
0
    SENode* value_node = AnalyzeInstruction(value_inst);
300
301
    // If any operand is CantCompute then the whole graph is CantCompute.
302
0
    if (value_node->IsCantCompute())
303
0
      return recurrent_node_map_[phi] = CreateCantComputeNode();
304
305
    // If the value is coming from the preheader block then the value is the
306
    // initial value of the phi.
307
0
    if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
308
0
      phi_node->AddOffset(value_node);
309
0
    } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
310
      // Assumed to be in the form of step + phi.
311
0
      if (value_node->GetType() != SENode::Add)
312
0
        return recurrent_node_map_[phi] = CreateCantComputeNode();
313
314
0
      SENode* step_node = nullptr;
315
0
      SENode* phi_operand = nullptr;
316
0
      SENode* operand_1 = value_node->GetChild(0);
317
0
      SENode* operand_2 = value_node->GetChild(1);
318
319
      // Find which node is the step term.
320
0
      if (!operand_1->AsSERecurrentNode())
321
0
        step_node = operand_1;
322
0
      else if (!operand_2->AsSERecurrentNode())
323
0
        step_node = operand_2;
324
325
      // Find which node is the recurrent expression.
326
0
      if (operand_1->AsSERecurrentNode())
327
0
        phi_operand = operand_1;
328
0
      else if (operand_2->AsSERecurrentNode())
329
0
        phi_operand = operand_2;
330
331
      // If it is not in the form step + phi exit out.
332
0
      if (!(step_node && phi_operand))
333
0
        return recurrent_node_map_[phi] = CreateCantComputeNode();
334
335
      // If the phi operand is not the same phi node exit out.
336
0
      if (phi_operand != phi_node.get())
337
0
        return recurrent_node_map_[phi] = CreateCantComputeNode();
338
339
0
      if (!IsLoopInvariant(loop, step_node))
340
0
        return recurrent_node_map_[phi] = CreateCantComputeNode();
341
342
0
      phi_node->AddCoefficient(step_node);
343
0
    }
344
0
  }
345
346
  // Once the node is fully built we update the map with the version from the
347
  // cache (if it has already been added to the cache).
348
0
  return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
349
0
}
350
351
SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
352
0
    const Instruction* inst) {
353
0
  std::unique_ptr<SEValueUnknown> load_node{
354
0
      new SEValueUnknown(this, inst->result_id())};
355
0
  return GetCachedOrAdd(std::move(load_node));
356
0
}
357
358
0
SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
359
0
  return cached_cant_compute_;
360
0
}
361
362
// Add the created node into the cache of nodes. If it already exists return it.
363
SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
364
0
    std::unique_ptr<SENode> prospective_node) {
365
0
  auto itr = node_cache_.find(prospective_node);
366
0
  if (itr != node_cache_.end()) {
367
0
    return (*itr).get();
368
0
  }
369
370
0
  SENode* raw_ptr_to_node = prospective_node.get();
371
0
  node_cache_.insert(std::move(prospective_node));
372
0
  return raw_ptr_to_node;
373
0
}
374
375
bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
376
0
                                              const SENode* node) const {
377
0
  for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
378
0
    if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
379
0
      const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
380
381
      // If the loop which the recurrent expression belongs to is either |loop
382
      // or a nested loop inside |loop| then we assume it is variant.
383
0
      if (loop->IsInsideLoop(header)) {
384
0
        return false;
385
0
      }
386
0
    } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
387
      // If the instruction is inside the loop we conservatively assume it is
388
      // loop variant.
389
0
      if (loop->IsInsideLoop(unknown->ResultId())) return false;
390
0
    }
391
0
  }
392
393
0
  return true;
394
0
}
395
396
SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
397
0
    SENode* node, const Loop* loop) {
398
  // Traverse the DAG to find the recurrent expression belonging to |loop|.
399
0
  for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
400
0
    SERecurrentNode* rec = itr->AsSERecurrentNode();
401
0
    if (rec && rec->GetLoop() == loop) {
402
0
      return rec->GetCoefficient();
403
0
    }
404
0
  }
405
0
  return CreateConstant(0);
406
0
}
407
408
SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
409
                                                 SENode* old_child,
410
0
                                                 SENode* new_child) {
411
  // Only handles add.
412
0
  if (parent->GetType() != SENode::Add) return parent;
413
414
0
  std::vector<SENode*> new_children;
415
0
  for (SENode* child : *parent) {
416
0
    if (child == old_child) {
417
0
      new_children.push_back(new_child);
418
0
    } else {
419
0
      new_children.push_back(child);
420
0
    }
421
0
  }
422
423
0
  std::unique_ptr<SENode> add_node{new SEAddNode(this)};
424
0
  for (SENode* child : new_children) {
425
0
    add_node->AddChild(child);
426
0
  }
427
428
0
  return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
429
0
}
430
431
// Rebuild the |node| eliminating, if it exists, the recurrent term which
432
// belongs to the |loop|.
433
SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
434
0
    SENode* node, const Loop* loop) {
435
  // If the node is already a recurrent expression belonging to loop then just
436
  // return the offset.
437
0
  SERecurrentNode* recurrent = node->AsSERecurrentNode();
438
0
  if (recurrent) {
439
0
    if (recurrent->GetLoop() == loop) {
440
0
      return recurrent->GetOffset();
441
0
    } else {
442
0
      return node;
443
0
    }
444
0
  }
445
446
0
  std::vector<SENode*> new_children;
447
  // Otherwise find the recurrent node in the children of this node.
448
0
  for (auto itr : *node) {
449
0
    recurrent = itr->AsSERecurrentNode();
450
0
    if (recurrent && recurrent->GetLoop() == loop) {
451
0
      new_children.push_back(recurrent->GetOffset());
452
0
    } else {
453
0
      new_children.push_back(itr);
454
0
    }
455
0
  }
456
457
0
  std::unique_ptr<SENode> add_node{new SEAddNode(this)};
458
0
  for (SENode* child : new_children) {
459
0
    add_node->AddChild(child);
460
0
  }
461
462
0
  return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
463
0
}
464
465
// Return the recurrent term belonging to |loop| if it appears in the graph
466
// starting at |node| or null if it doesn't.
467
SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
468
0
                                                           const Loop* loop) {
469
0
  for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
470
0
    SERecurrentNode* rec = itr->AsSERecurrentNode();
471
0
    if (rec && rec->GetLoop() == loop) {
472
0
      return rec;
473
0
    }
474
0
  }
475
0
  return nullptr;
476
0
}
477
0
std::string SENode::AsString() const {
478
0
  switch (GetType()) {
479
0
    case Constant:
480
0
      return "Constant";
481
0
    case RecurrentAddExpr:
482
0
      return "RecurrentAddExpr";
483
0
    case Add:
484
0
      return "Add";
485
0
    case Negative:
486
0
      return "Negative";
487
0
    case Multiply:
488
0
      return "Multiply";
489
0
    case ValueUnknown:
490
0
      return "Value Unknown";
491
0
    case CanNotCompute:
492
0
      return "Can not compute";
493
0
  }
494
0
  return "NULL";
495
0
}
496
497
0
bool SENode::operator==(const SENode& other) const {
498
0
  if (GetType() != other.GetType()) return false;
499
500
0
  if (other.GetChildren().size() != children_.size()) return false;
501
502
0
  const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
503
504
  // Check the children are the same, for SERecurrentNodes we need to check the
505
  // offset and coefficient manually as the child vector is sorted by ids so the
506
  // offset/coefficient information is lost.
507
0
  if (!this_as_recurrent) {
508
0
    for (size_t index = 0; index < children_.size(); ++index) {
509
0
      if (other.GetChildren()[index] != children_[index]) return false;
510
0
    }
511
0
  } else {
512
0
    const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
513
514
    // We've already checked the types are the same, this should not fail if
515
    // this->AsSERecurrentNode() succeeded.
516
0
    assert(other_as_recurrent);
517
518
0
    if (this_as_recurrent->GetCoefficient() !=
519
0
        other_as_recurrent->GetCoefficient())
520
0
      return false;
521
522
0
    if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
523
0
      return false;
524
525
0
    if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
526
0
      return false;
527
0
  }
528
529
  // If we're dealing with a value unknown node check both nodes were created by
530
  // the same instruction.
531
0
  if (GetType() == SENode::ValueUnknown) {
532
0
    if (AsSEValueUnknown()->ResultId() !=
533
0
        other.AsSEValueUnknown()->ResultId()) {
534
0
      return false;
535
0
    }
536
0
  }
537
538
0
  if (AsSEConstantNode()) {
539
0
    if (AsSEConstantNode()->FoldToSingleValue() !=
540
0
        other.AsSEConstantNode()->FoldToSingleValue())
541
0
      return false;
542
0
  }
543
544
0
  return true;
545
0
}
546
547
0
bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
548
549
namespace {
550
// Helper functions to insert 32/64 bit values into the 32 bit hash string. This
551
// allows us to add pointers to the string by reinterpreting the pointers as
552
// uintptr_t. PushToString will deduce the type, call sizeof on it and use
553
// that size to call into the correct PushToStringImpl functor depending on
554
// whether it is 32 or 64 bit.
555
556
template <typename T, size_t size_of_t>
557
struct PushToStringImpl;
558
559
template <typename T>
560
struct PushToStringImpl<T, 8> {
561
0
  void operator()(T id, std::u32string* str) {
562
0
    str->push_back(static_cast<uint32_t>(id >> 32));
563
0
    str->push_back(static_cast<uint32_t>(id));
564
0
  }
Unexecuted instantiation: scalar_analysis.cpp:spvtools::opt::(anonymous namespace)::PushToStringImpl<long, 8ul>::operator()(long, std::__1::basic_string<char32_t, std::__1::char_traits<char32_t>, std::__1::allocator<char32_t> >*)
Unexecuted instantiation: scalar_analysis.cpp:spvtools::opt::(anonymous namespace)::PushToStringImpl<unsigned long, 8ul>::operator()(unsigned long, std::__1::basic_string<char32_t, std::__1::char_traits<char32_t>, std::__1::allocator<char32_t> >*)
565
};
566
567
template <typename T>
568
struct PushToStringImpl<T, 4> {
569
0
  void operator()(T id, std::u32string* str) {
570
0
    str->push_back(static_cast<uint32_t>(id));
571
0
  }
572
};
573
574
template <typename T>
575
0
void PushToString(T id, std::u32string* str) {
576
0
  PushToStringImpl<T, sizeof(T)>{}(id, str);
577
0
}
Unexecuted instantiation: scalar_analysis.cpp:void spvtools::opt::(anonymous namespace)::PushToString<long>(long, std::__1::basic_string<char32_t, std::__1::char_traits<char32_t>, std::__1::allocator<char32_t> >*)
Unexecuted instantiation: scalar_analysis.cpp:void spvtools::opt::(anonymous namespace)::PushToString<unsigned long>(unsigned long, std::__1::basic_string<char32_t, std::__1::char_traits<char32_t>, std::__1::allocator<char32_t> >*)
Unexecuted instantiation: scalar_analysis.cpp:void spvtools::opt::(anonymous namespace)::PushToString<unsigned int>(unsigned int, std::__1::basic_string<char32_t, std::__1::char_traits<char32_t>, std::__1::allocator<char32_t> >*)
578
579
}  // namespace
580
581
// Implements the hashing of SENodes.
582
0
size_t SENodeHash::operator()(const SENode* node) const {
583
  // Concatenate the terms into a string which we can hash.
584
0
  std::u32string hash_string{};
585
586
  // Hashing the type as a string is safer than hashing the enum as the enum is
587
  // very likely to collide with constants.
588
0
  for (char ch : node->AsString()) {
589
0
    hash_string.push_back(static_cast<char32_t>(ch));
590
0
  }
591
592
  // We just ignore the literal value unless it is a constant.
593
0
  if (node->GetType() == SENode::Constant)
594
0
    PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
595
596
0
  const SERecurrentNode* recurrent = node->AsSERecurrentNode();
597
598
  // If we're dealing with a recurrent expression hash the loop as well so that
599
  // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
600
0
  if (recurrent) {
601
0
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
602
0
                 &hash_string);
603
604
    // Recurrent expressions can't be hashed using the normal method as the
605
    // order of coefficient and offset matters to the hash.
606
0
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
607
0
                 &hash_string);
608
0
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
609
0
                 &hash_string);
610
611
0
    return std::hash<std::u32string>{}(hash_string);
612
0
  }
613
614
  // Hash the result id of the original instruction which created this node if
615
  // it is a value unknown node.
616
0
  if (node->GetType() == SENode::ValueUnknown) {
617
0
    PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
618
0
  }
619
620
  // Hash the pointers of the child nodes, each SENode has a unique pointer
621
  // associated with it.
622
0
  const std::vector<SENode*>& children = node->GetChildren();
623
0
  for (const SENode* child : children) {
624
0
    PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
625
0
  }
626
627
0
  return std::hash<std::u32string>{}(hash_string);
628
0
}
629
630
// This overload is the actual overload used by the node_cache_ set.
631
0
size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
632
0
  return this->operator()(node.get());
633
0
}
634
635
0
void SENode::DumpDot(std::ostream& out, bool recurse) const {
636
0
  size_t unique_id = std::hash<const SENode*>{}(this);
637
0
  out << unique_id << " [label=\"" << AsString() << " ";
638
0
  if (GetType() == SENode::Constant) {
639
0
    out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
640
0
  }
641
0
  out << "\"]\n";
642
0
  for (const SENode* child : children_) {
643
0
    size_t child_unique_id = std::hash<const SENode*>{}(child);
644
0
    out << unique_id << " -> " << child_unique_id << " \n";
645
0
    if (recurse) child->DumpDot(out, true);
646
0
  }
647
0
}
648
649
namespace {
650
class IsGreaterThanZero {
651
 public:
652
0
  explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
653
654
  // Determine if the value of |node| is always strictly greater than zero if
655
  // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
656
  // true. It returns true is the evaluation was able to conclude something, in
657
  // which case the result is stored in |result|.
658
  // The algorithm work by going through all the nodes and determine the
659
  // sign of each of them.
660
0
  bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
661
0
    *result = false;
662
0
    switch (Visit(node)) {
663
0
      case Signedness::kPositiveOrNegative: {
664
0
        return false;
665
0
      }
666
0
      case Signedness::kStrictlyNegative: {
667
0
        *result = false;
668
0
        break;
669
0
      }
670
0
      case Signedness::kNegative: {
671
0
        if (!or_equal_zero) {
672
0
          return false;
673
0
        }
674
0
        *result = false;
675
0
        break;
676
0
      }
677
0
      case Signedness::kStrictlyPositive: {
678
0
        *result = true;
679
0
        break;
680
0
      }
681
0
      case Signedness::kPositive: {
682
0
        if (!or_equal_zero) {
683
0
          return false;
684
0
        }
685
0
        *result = true;
686
0
        break;
687
0
      }
688
0
    }
689
0
    return true;
690
0
  }
691
692
 private:
693
  enum class Signedness {
694
    kPositiveOrNegative,  // Yield a value positive or negative.
695
    kStrictlyNegative,    // Yield a value strictly less than 0.
696
    kNegative,            // Yield a value less or equal to 0.
697
    kStrictlyPositive,    // Yield a value strictly greater than 0.
698
    kPositive             // Yield a value greater or equal to 0.
699
  };
700
701
  // Combine the signedness according to arithmetic rules of a given operator.
702
  using Combiner = std::function<Signedness(Signedness, Signedness)>;
703
704
  // Returns a functor to interpret the signedness of 2 expressions as if they
705
  // were added.
706
0
  Combiner GetAddCombiner() const {
707
0
    return [](Signedness lhs, Signedness rhs) {
708
0
      switch (lhs) {
709
0
        case Signedness::kPositiveOrNegative:
710
0
          break;
711
0
        case Signedness::kStrictlyNegative:
712
0
          if (rhs == Signedness::kStrictlyNegative ||
713
0
              rhs == Signedness::kNegative)
714
0
            return lhs;
715
0
          break;
716
0
        case Signedness::kNegative: {
717
0
          if (rhs == Signedness::kStrictlyNegative)
718
0
            return Signedness::kStrictlyNegative;
719
0
          if (rhs == Signedness::kNegative) return Signedness::kNegative;
720
0
          break;
721
0
        }
722
0
        case Signedness::kStrictlyPositive: {
723
0
          if (rhs == Signedness::kStrictlyPositive ||
724
0
              rhs == Signedness::kPositive) {
725
0
            return Signedness::kStrictlyPositive;
726
0
          }
727
0
          break;
728
0
        }
729
0
        case Signedness::kPositive: {
730
0
          if (rhs == Signedness::kStrictlyPositive)
731
0
            return Signedness::kStrictlyPositive;
732
0
          if (rhs == Signedness::kPositive) return Signedness::kPositive;
733
0
          break;
734
0
        }
735
0
      }
736
0
      return Signedness::kPositiveOrNegative;
737
0
    };
738
0
  }
739
740
  // Returns a functor to interpret the signedness of 2 expressions as if they
741
  // were multiplied.
742
0
  Combiner GetMulCombiner() const {
743
0
    return [](Signedness lhs, Signedness rhs) {
744
0
      switch (lhs) {
745
0
        case Signedness::kPositiveOrNegative:
746
0
          break;
747
0
        case Signedness::kStrictlyNegative: {
748
0
          switch (rhs) {
749
0
            case Signedness::kPositiveOrNegative: {
750
0
              break;
751
0
            }
752
0
            case Signedness::kStrictlyNegative: {
753
0
              return Signedness::kStrictlyPositive;
754
0
            }
755
0
            case Signedness::kNegative: {
756
0
              return Signedness::kPositive;
757
0
            }
758
0
            case Signedness::kStrictlyPositive: {
759
0
              return Signedness::kStrictlyNegative;
760
0
            }
761
0
            case Signedness::kPositive: {
762
0
              return Signedness::kNegative;
763
0
            }
764
0
          }
765
0
          break;
766
0
        }
767
0
        case Signedness::kNegative: {
768
0
          switch (rhs) {
769
0
            case Signedness::kPositiveOrNegative: {
770
0
              break;
771
0
            }
772
0
            case Signedness::kStrictlyNegative:
773
0
            case Signedness::kNegative: {
774
0
              return Signedness::kPositive;
775
0
            }
776
0
            case Signedness::kStrictlyPositive:
777
0
            case Signedness::kPositive: {
778
0
              return Signedness::kNegative;
779
0
            }
780
0
          }
781
0
          break;
782
0
        }
783
0
        case Signedness::kStrictlyPositive: {
784
0
          return rhs;
785
0
        }
786
0
        case Signedness::kPositive: {
787
0
          switch (rhs) {
788
0
            case Signedness::kPositiveOrNegative: {
789
0
              break;
790
0
            }
791
0
            case Signedness::kStrictlyNegative:
792
0
            case Signedness::kNegative: {
793
0
              return Signedness::kNegative;
794
0
            }
795
0
            case Signedness::kStrictlyPositive:
796
0
            case Signedness::kPositive: {
797
0
              return Signedness::kPositive;
798
0
            }
799
0
          }
800
0
          break;
801
0
        }
802
0
      }
803
0
      return Signedness::kPositiveOrNegative;
804
0
    };
805
0
  }
806
807
0
  Signedness Visit(const SENode* node) {
808
0
    switch (node->GetType()) {
809
0
      case SENode::Constant:
810
0
        return Visit(node->AsSEConstantNode());
811
0
        break;
812
0
      case SENode::RecurrentAddExpr:
813
0
        return Visit(node->AsSERecurrentNode());
814
0
        break;
815
0
      case SENode::Negative:
816
0
        return Visit(node->AsSENegative());
817
0
        break;
818
0
      case SENode::CanNotCompute:
819
0
        return Visit(node->AsSECantCompute());
820
0
        break;
821
0
      case SENode::ValueUnknown:
822
0
        return Visit(node->AsSEValueUnknown());
823
0
        break;
824
0
      case SENode::Add:
825
0
        return VisitExpr(node, GetAddCombiner());
826
0
        break;
827
0
      case SENode::Multiply:
828
0
        return VisitExpr(node, GetMulCombiner());
829
0
        break;
830
0
    }
831
0
    return Signedness::kPositiveOrNegative;
832
0
  }
833
834
  // Returns the signedness of a constant |node|.
835
0
  Signedness Visit(const SEConstantNode* node) {
836
0
    if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
837
0
    if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
838
0
    if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
839
0
    return Signedness::kPositiveOrNegative;
840
0
  }
841
842
  // Returns the signedness of an unknown |node| based on its type.
843
0
  Signedness Visit(const SEValueUnknown* node) {
844
0
    Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
845
0
    analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
846
0
    assert(type && "Can't retrieve a type for the instruction");
847
0
    analysis::Integer* int_type = type->AsInteger();
848
0
    assert(type && "Can't retrieve an integer type for the instruction");
849
0
    return int_type->IsSigned() ? Signedness::kPositiveOrNegative
850
0
                                : Signedness::kPositive;
851
0
  }
852
853
  // Returns the signedness of a recurring expression.
854
0
  Signedness Visit(const SERecurrentNode* node) {
855
0
    Signedness coeff_sign = Visit(node->GetCoefficient());
856
    // SERecurrentNode represent an affine expression in the range [0,
857
    // loop_bound], so the result cannot be strictly positive or negative.
858
0
    switch (coeff_sign) {
859
0
      default:
860
0
        break;
861
0
      case Signedness::kStrictlyNegative:
862
0
        coeff_sign = Signedness::kNegative;
863
0
        break;
864
0
      case Signedness::kStrictlyPositive:
865
0
        coeff_sign = Signedness::kPositive;
866
0
        break;
867
0
    }
868
0
    return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
869
0
  }
870
871
  // Returns the signedness of a negation |node|.
872
0
  Signedness Visit(const SENegative* node) {
873
0
    switch (Visit(*node->begin())) {
874
0
      case Signedness::kPositiveOrNegative: {
875
0
        return Signedness::kPositiveOrNegative;
876
0
      }
877
0
      case Signedness::kStrictlyNegative: {
878
0
        return Signedness::kStrictlyPositive;
879
0
      }
880
0
      case Signedness::kNegative: {
881
0
        return Signedness::kPositive;
882
0
      }
883
0
      case Signedness::kStrictlyPositive: {
884
0
        return Signedness::kStrictlyNegative;
885
0
      }
886
0
      case Signedness::kPositive: {
887
0
        return Signedness::kNegative;
888
0
      }
889
0
    }
890
0
    return Signedness::kPositiveOrNegative;
891
0
  }
892
893
0
  Signedness Visit(const SECantCompute*) {
894
0
    return Signedness::kPositiveOrNegative;
895
0
  }
896
897
  // Returns the signedness of a binary expression by using the combiner
898
  // |reduce|.
899
  Signedness VisitExpr(
900
      const SENode* node,
901
0
      std::function<Signedness(Signedness, Signedness)> reduce) {
902
0
    Signedness result = Visit(*node->begin());
903
0
    for (const SENode* operand : make_range(++node->begin(), node->end())) {
904
0
      if (result == Signedness::kPositiveOrNegative) {
905
0
        return Signedness::kPositiveOrNegative;
906
0
      }
907
0
      result = reduce(result, Visit(operand));
908
0
    }
909
0
    return result;
910
0
  }
911
912
  IRContext* context_;
913
};
914
}  // namespace
915
916
bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
917
0
                                                      bool* is_gt_zero) const {
918
0
  return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
919
0
}
920
921
bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
922
0
    SENode* node, bool* is_ge_zero) const {
923
0
  return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
924
0
}
925
926
namespace {
927
928
// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
929
// if |node| is not in the chain, returns the original chain.
930
SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
931
0
                                       const SENode* node) {
932
0
  SENode* lhs = mul->GetChildren()[0];
933
0
  SENode* rhs = mul->GetChildren()[1];
934
0
  if (lhs == node) {
935
0
    return rhs;
936
0
  }
937
0
  if (rhs == node) {
938
0
    return lhs;
939
0
  }
940
0
  if (lhs->AsSEMultiplyNode()) {
941
0
    SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
942
0
    if (res != lhs)
943
0
      return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
944
0
  }
945
0
  if (rhs->AsSEMultiplyNode()) {
946
0
    SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
947
0
    if (res != rhs)
948
0
      return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
949
0
  }
950
951
0
  return mul;
952
0
}
953
}  // namespace
954
955
std::pair<SExpression, int64_t> SExpression::operator/(
956
0
    SExpression rhs_wrapper) const {
957
0
  SENode* lhs = node_;
958
0
  SENode* rhs = rhs_wrapper.node_;
959
  // Check for division by 0.
960
0
  if (rhs->AsSEConstantNode() &&
961
0
      !rhs->AsSEConstantNode()->FoldToSingleValue()) {
962
0
    return {scev_->CreateCantComputeNode(), 0};
963
0
  }
964
965
  // Trivial case.
966
0
  if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
967
0
    int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
968
0
    int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
969
0
    return {scev_->CreateConstant(lhs_value / rhs_value),
970
0
            lhs_value % rhs_value};
971
0
  }
972
973
  // look for a "c U / U" pattern.
974
0
  if (lhs->AsSEMultiplyNode()) {
975
0
    assert(lhs->GetChildren().size() == 2 &&
976
0
           "More than 2 operand for a multiply node.");
977
0
    SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
978
0
    if (res != lhs) {
979
0
      return {res, 0};
980
0
    }
981
0
  }
982
983
0
  return {scev_->CreateCantComputeNode(), 0};
984
0
}
985
986
}  // namespace opt
987
}  // namespace spvtools