/src/spirv-tools/source/opt/scalar_analysis.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2018 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_ |
16 | | #define SOURCE_OPT_SCALAR_ANALYSIS_H_ |
17 | | |
18 | | #include <algorithm> |
19 | | #include <cstdint> |
20 | | #include <map> |
21 | | #include <memory> |
22 | | #include <unordered_set> |
23 | | #include <utility> |
24 | | #include <vector> |
25 | | |
26 | | #include "source/opt/basic_block.h" |
27 | | #include "source/opt/instruction.h" |
28 | | #include "source/opt/scalar_analysis_nodes.h" |
29 | | |
30 | | namespace spvtools { |
31 | | namespace opt { |
32 | | |
33 | | class IRContext; |
34 | | class Loop; |
35 | | |
36 | | // Manager for the Scalar Evolution analysis. Creates and maintains a DAG of |
37 | | // scalar operations generated from analysing the use def graph from incoming |
38 | | // instructions. Each node is hashed as it is added so like node (for instance, |
39 | | // two induction variables i=0,i++ and j=0,j++) become the same node. After |
40 | | // creating a DAG with AnalyzeInstruction it can the be simplified into a more |
41 | | // usable form with SimplifyExpression. |
42 | | class ScalarEvolutionAnalysis { |
43 | | public: |
44 | | explicit ScalarEvolutionAnalysis(IRContext* context); |
45 | | |
46 | | // Create a unary negative node on |operand|. |
47 | | SENode* CreateNegation(SENode* operand); |
48 | | |
49 | | // Creates a subtraction between the two operands by adding |operand_1| to the |
50 | | // negation of |operand_2|. |
51 | | SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2); |
52 | | |
53 | | // Create an addition node between two operands. The |simplify| when set will |
54 | | // allow the function to return an SEConstant instead of an addition if the |
55 | | // two input operands are also constant. |
56 | | SENode* CreateAddNode(SENode* operand_1, SENode* operand_2); |
57 | | |
58 | | // Create a multiply node between two operands. |
59 | | SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2); |
60 | | |
61 | | // Create a node representing a constant integer. |
62 | | SENode* CreateConstant(int64_t integer); |
63 | | |
64 | | // Create a value unknown node, such as a load. |
65 | | SENode* CreateValueUnknownNode(const Instruction* inst); |
66 | | |
67 | | // Create a CantComputeNode. Used to exit out of analysis. |
68 | | SENode* CreateCantComputeNode(); |
69 | | |
70 | | // Create a new recurrent node with |offset| and |coefficient|, with respect |
71 | | // to |loop|. |
72 | | SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset, |
73 | | SENode* coefficient); |
74 | | |
75 | | // Construct the DAG by traversing use def chain of |inst|. |
76 | | SENode* AnalyzeInstruction(const Instruction* inst); |
77 | | |
78 | | // Simplify the |node| by grouping like terms or if contains a recurrent |
79 | | // expression, rewrite the graph so the whole DAG (from |node| down) is in |
80 | | // terms of that recurrent expression. |
81 | | // |
82 | | // For example. |
83 | | // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be |
84 | | // transformed into Rec(1,1). |
85 | | // |
86 | | // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are |
87 | | // ValueUnknown nodes (such as a load instruction). |
88 | | SENode* SimplifyExpression(SENode* node); |
89 | | |
90 | | // Add |prospective_node| into the cache and return a raw pointer to it. If |
91 | | // |prospective_node| is already in the cache just return the raw pointer. |
92 | | SENode* GetCachedOrAdd(std::unique_ptr<SENode> prospective_node); |
93 | | |
94 | | // Checks that the graph starting from |node| is invariant to the |loop|. |
95 | | bool IsLoopInvariant(const Loop* loop, const SENode* node) const; |
96 | | |
97 | | // Sets |is_gt_zero| to true if |node| represent a value always strictly |
98 | | // greater than 0. The result of |is_gt_zero| is valid only if the function |
99 | | // returns true. |
100 | | bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const; |
101 | | |
102 | | // Sets |is_ge_zero| to true if |node| represent a value greater or equals to |
103 | | // 0. The result of |is_ge_zero| is valid only if the function returns true. |
104 | | bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const; |
105 | | |
106 | | // Find the recurrent term belonging to |loop| in the graph starting from |
107 | | // |node| and return the coefficient of that recurrent term. Constant zero |
108 | | // will be returned if no recurrent could be found. |node| should be in |
109 | | // simplest form. |
110 | | SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop); |
111 | | |
112 | | // Return a rebuilt graph starting from |node| with the recurrent expression |
113 | | // belonging to |loop| being zeroed out. Returned node will be simplified. |
114 | | SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop); |
115 | | |
116 | | // Return the recurrent term belonging to |loop| if it appears in the graph |
117 | | // starting at |node| or null if it doesn't. |
118 | | SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop); |
119 | | |
120 | | SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child); |
121 | | |
122 | | // The loops in |loop_pair| will be considered the same when constructing |
123 | | // SERecurrentNode objects. This enables analysing dependencies that will be |
124 | | // created during loop fusion. |
125 | | void AddLoopsToPretendAreTheSame( |
126 | 0 | const std::pair<const Loop*, const Loop*>& loop_pair) { |
127 | 0 | pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair); |
128 | 0 | } |
129 | | |
130 | | private: |
131 | | SENode* AnalyzeConstant(const Instruction* inst); |
132 | | |
133 | | // Handles both addition and subtraction. If the |instruction| is OpISub |
134 | | // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then |
135 | | // the result will be op1+op2. |instruction| must be OpIAdd or OpISub. |
136 | | SENode* AnalyzeAddOp(const Instruction* instruction); |
137 | | |
138 | | SENode* AnalyzeMultiplyOp(const Instruction* multiply); |
139 | | |
140 | | SENode* AnalyzePhiInstruction(const Instruction* phi); |
141 | | |
142 | | IRContext* context_; |
143 | | |
144 | | // A map of instructions to SENodes. This is used to track recurrent |
145 | | // expressions as they are added when analyzing instructions. Recurrent |
146 | | // expressions come from phi nodes which by nature can include recursion so we |
147 | | // check if nodes have already been built when analyzing instructions. |
148 | | std::map<const Instruction*, SENode*> recurrent_node_map_; |
149 | | |
150 | | // On creation we create and cache the CantCompute node so we not need to |
151 | | // perform a needless create step. |
152 | | SENode* cached_cant_compute_; |
153 | | |
154 | | // Helper functor to allow two unique_ptr to nodes to be compare. Only |
155 | | // needed |
156 | | // for the unordered_set implementation. |
157 | | struct NodePointersEquality { |
158 | | bool operator()(const std::unique_ptr<SENode>& lhs, |
159 | 0 | const std::unique_ptr<SENode>& rhs) const { |
160 | 0 | return *lhs == *rhs; |
161 | 0 | } |
162 | | }; |
163 | | |
164 | | // Cache of nodes. All pointers to the nodes are references to the memory |
165 | | // managed by they set. |
166 | | std::unordered_set<std::unique_ptr<SENode>, SENodeHash, NodePointersEquality> |
167 | | node_cache_; |
168 | | |
169 | | // Loops that should be considered the same for performing analysis for loop |
170 | | // fusion. |
171 | | std::map<const Loop*, const Loop*> pretend_equal_; |
172 | | }; |
173 | | |
174 | | // Wrapping class to manipulate SENode pointer using + - * / operators. |
175 | | class SExpression { |
176 | | public: |
177 | | // Implicit on purpose ! |
178 | | SExpression(SENode* node) |
179 | 0 | : node_(node->GetParentAnalysis()->SimplifyExpression(node)), |
180 | 0 | scev_(node->GetParentAnalysis()) {} |
181 | | |
182 | 0 | inline operator SENode*() const { return node_; } |
183 | 0 | inline SENode* operator->() const { return node_; } |
184 | 0 | const SENode& operator*() const { return *node_; } |
185 | | |
186 | 0 | inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const { |
187 | 0 | return scev_; |
188 | 0 | } |
189 | | |
190 | | inline SExpression operator+(SENode* rhs) const; |
191 | | template <typename T, |
192 | | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
193 | | inline SExpression operator+(T integer) const; |
194 | | inline SExpression operator+(SExpression rhs) const; |
195 | | |
196 | | inline SExpression operator-() const; |
197 | | inline SExpression operator-(SENode* rhs) const; |
198 | | template <typename T, |
199 | | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
200 | | inline SExpression operator-(T integer) const; |
201 | | inline SExpression operator-(SExpression rhs) const; |
202 | | |
203 | | inline SExpression operator*(SENode* rhs) const; |
204 | | template <typename T, |
205 | | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
206 | | inline SExpression operator*(T integer) const; |
207 | | inline SExpression operator*(SExpression rhs) const; |
208 | | |
209 | | template <typename T, |
210 | | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
211 | | inline std::pair<SExpression, int64_t> operator/(T integer) const; |
212 | | // Try to perform a division. Returns the pair <this.node_ / rhs, division |
213 | | // remainder>. If it fails to simplify it, the function returns a |
214 | | // CanNotCompute node. |
215 | | std::pair<SExpression, int64_t> operator/(SExpression rhs) const; |
216 | | |
217 | | private: |
218 | | SENode* node_; |
219 | | ScalarEvolutionAnalysis* scev_; |
220 | | }; |
221 | | |
222 | 0 | inline SExpression SExpression::operator+(SENode* rhs) const { |
223 | 0 | return scev_->CreateAddNode(node_, rhs); |
224 | 0 | } |
225 | | |
226 | | template <typename T, |
227 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
228 | | inline SExpression SExpression::operator+(T integer) const { |
229 | | return *this + scev_->CreateConstant(integer); |
230 | | } |
231 | | |
232 | 0 | inline SExpression SExpression::operator+(SExpression rhs) const { |
233 | 0 | return *this + rhs.node_; |
234 | 0 | } |
235 | | |
236 | 0 | inline SExpression SExpression::operator-() const { |
237 | 0 | return scev_->CreateNegation(node_); |
238 | 0 | } |
239 | | |
240 | 0 | inline SExpression SExpression::operator-(SENode* rhs) const { |
241 | 0 | return *this + scev_->CreateNegation(rhs); |
242 | 0 | } |
243 | | |
244 | | template <typename T, |
245 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
246 | | inline SExpression SExpression::operator-(T integer) const { |
247 | | return *this - scev_->CreateConstant(integer); |
248 | | } |
249 | | |
250 | 0 | inline SExpression SExpression::operator-(SExpression rhs) const { |
251 | 0 | return *this - rhs.node_; |
252 | 0 | } |
253 | | |
254 | 0 | inline SExpression SExpression::operator*(SENode* rhs) const { |
255 | 0 | return scev_->CreateMultiplyNode(node_, rhs); |
256 | 0 | } |
257 | | |
258 | | template <typename T, |
259 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
260 | 0 | inline SExpression SExpression::operator*(T integer) const { |
261 | 0 | return *this * scev_->CreateConstant(integer); |
262 | 0 | } |
263 | | |
264 | 0 | inline SExpression SExpression::operator*(SExpression rhs) const { |
265 | 0 | return *this * rhs.node_; |
266 | 0 | } |
267 | | |
268 | | template <typename T, |
269 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
270 | | inline std::pair<SExpression, int64_t> SExpression::operator/(T integer) const { |
271 | | return *this / scev_->CreateConstant(integer); |
272 | | } |
273 | | |
274 | | template <typename T, |
275 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
276 | | inline SExpression operator+(T lhs, SExpression rhs) { |
277 | | return rhs + lhs; |
278 | | } |
279 | 0 | inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; } |
280 | | |
281 | | template <typename T, |
282 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
283 | | inline SExpression operator-(T lhs, SExpression rhs) { |
284 | | // NOLINTNEXTLINE(whitespace/braces) |
285 | | return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} - |
286 | | rhs; |
287 | | } |
288 | 0 | inline SExpression operator-(SENode* lhs, SExpression rhs) { |
289 | 0 | // NOLINTNEXTLINE(whitespace/braces) |
290 | 0 | return SExpression{lhs} - rhs; |
291 | 0 | } |
292 | | |
293 | | template <typename T, |
294 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
295 | | inline SExpression operator*(T lhs, SExpression rhs) { |
296 | | return rhs * lhs; |
297 | | } |
298 | 0 | inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; } |
299 | | |
300 | | template <typename T, |
301 | | typename std::enable_if<std::is_integral<T>::value, int>::type> |
302 | | inline std::pair<SExpression, int64_t> operator/(T lhs, SExpression rhs) { |
303 | | // NOLINTNEXTLINE(whitespace/braces) |
304 | | return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} / |
305 | | rhs; |
306 | | } |
307 | 0 | inline std::pair<SExpression, int64_t> operator/(SENode* lhs, SExpression rhs) { |
308 | 0 | // NOLINTNEXTLINE(whitespace/braces) |
309 | 0 | return SExpression{lhs} / rhs; |
310 | 0 | } |
311 | | |
312 | | } // namespace opt |
313 | | } // namespace spvtools |
314 | | #endif // SOURCE_OPT_SCALAR_ANALYSIS_H_ |