/src/spirv-tools/source/opt/pass_manager.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2016 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #ifndef SOURCE_OPT_PASS_MANAGER_H_ |
16 | | #define SOURCE_OPT_PASS_MANAGER_H_ |
17 | | |
18 | | #include <memory> |
19 | | #include <ostream> |
20 | | #include <utility> |
21 | | #include <vector> |
22 | | |
23 | | #include "source/opt/log.h" |
24 | | #include "source/opt/module.h" |
25 | | #include "source/opt/pass.h" |
26 | | |
27 | | #include "source/opt/ir_context.h" |
28 | | #include "spirv-tools/libspirv.hpp" |
29 | | |
30 | | namespace spvtools { |
31 | | namespace opt { |
32 | | |
33 | | // The pass manager, responsible for tracking and running passes. |
34 | | // Clients should first call AddPass() to add passes and then call Run() |
35 | | // to run on a module. Passes are executed in the exact order of addition. |
36 | | class PassManager { |
37 | | public: |
38 | | // Constructs a pass manager. |
39 | | // |
40 | | // The constructed instance will have an empty message consumer, which just |
41 | | // ignores all messages from the library. Use SetMessageConsumer() to supply |
42 | | // one if messages are of concern. |
43 | | PassManager() |
44 | 31.4k | : consumer_(nullptr), |
45 | 31.4k | print_all_stream_(nullptr), |
46 | 31.4k | time_report_stream_(nullptr), |
47 | 31.4k | target_env_(SPV_ENV_UNIVERSAL_1_2), |
48 | 31.4k | val_options_(nullptr), |
49 | 31.4k | validate_after_all_(false) {} |
50 | | |
51 | | // Sets the message consumer to the given |consumer|. |
52 | 31.4k | void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } |
53 | | |
54 | | // Adds an externally constructed pass. |
55 | | void AddPass(std::unique_ptr<Pass> pass); |
56 | | // Uses the argument |args| to construct a pass instance of type |T|, and adds |
57 | | // the pass instance to this pass manager. The pass added will use this pass |
58 | | // manager's message consumer. |
59 | | template <typename T, typename... Args> |
60 | | void AddPass(Args&&... args); |
61 | | |
62 | | // Returns the number of passes added. |
63 | | uint32_t NumPasses() const; |
64 | | // Returns a pointer to the |index|th pass added. |
65 | | inline Pass* GetPass(uint32_t index) const; |
66 | | |
67 | | // Returns the message consumer. |
68 | | inline const MessageConsumer& consumer() const; |
69 | | |
70 | | // Runs all passes on the given |module|. Returns Status::Failure if errors |
71 | | // occur when processing using one of the registered passes. All passes |
72 | | // registered after the error-reporting pass will be skipped. Returns the |
73 | | // corresponding Status::Success if processing is successful to indicate |
74 | | // whether changes are made to the module. |
75 | | // |
76 | | // After running all the passes, they are removed from the list. |
77 | | Pass::Status Run(IRContext* context); |
78 | | |
79 | | // Sets the option to print the disassembly before each pass and after the |
80 | | // last pass. Output is written to |out| if that is not null. No output |
81 | | // is generated if |out| is null. |
82 | 0 | PassManager& SetPrintAll(std::ostream* out) { |
83 | 0 | print_all_stream_ = out; |
84 | 0 | return *this; |
85 | 0 | } |
86 | | |
87 | | // Sets the option to print the resource utilization of each pass. Output is |
88 | | // written to |out| if that is not null. No output is generated if |out| is |
89 | | // null. |
90 | 0 | PassManager& SetTimeReport(std::ostream* out) { |
91 | 0 | time_report_stream_ = out; |
92 | 0 | return *this; |
93 | 0 | } |
94 | | |
95 | | // Sets the target environment for validation. |
96 | 14.8k | PassManager& SetTargetEnv(spv_target_env env) { |
97 | 14.8k | target_env_ = env; |
98 | 14.8k | return *this; |
99 | 14.8k | } |
100 | | |
101 | | // Sets the validation options. |
102 | 14.8k | PassManager& SetValidatorOptions(spv_validator_options options) { |
103 | 14.8k | val_options_ = options; |
104 | 14.8k | return *this; |
105 | 14.8k | } |
106 | | |
107 | | // Sets the option to validate after each pass. |
108 | 0 | PassManager& SetValidateAfterAll(bool validate) { |
109 | 0 | validate_after_all_ = validate; |
110 | 0 | return *this; |
111 | 0 | } |
112 | | |
113 | | private: |
114 | | // Consumer for messages. |
115 | | MessageConsumer consumer_; |
116 | | // A vector of passes. Order matters. |
117 | | std::vector<std::unique_ptr<Pass>> passes_; |
118 | | // The output stream to write disassembly to before each pass, and after |
119 | | // the last pass. If this is null, no output is generated. |
120 | | std::ostream* print_all_stream_; |
121 | | // The output stream to write the resource utilization of each pass. If this |
122 | | // is null, no output is generated. |
123 | | std::ostream* time_report_stream_; |
124 | | // The target environment. |
125 | | spv_target_env target_env_; |
126 | | // The validator options (used when validating each pass). |
127 | | spv_validator_options val_options_; |
128 | | // Controls whether validation occurs after every pass. |
129 | | bool validate_after_all_; |
130 | | }; |
131 | | |
132 | 906k | inline void PassManager::AddPass(std::unique_ptr<Pass> pass) { |
133 | 906k | passes_.push_back(std::move(pass)); |
134 | 906k | } |
135 | | |
136 | | template <typename T, typename... Args> |
137 | | inline void PassManager::AddPass(Args&&... args) { |
138 | | passes_.emplace_back(new T(std::forward<Args>(args)...)); |
139 | | passes_.back()->SetMessageConsumer(consumer_); |
140 | | } |
141 | | |
142 | 31.4k | inline uint32_t PassManager::NumPasses() const { |
143 | 31.4k | return static_cast<uint32_t>(passes_.size()); |
144 | 31.4k | } |
145 | | |
146 | 0 | inline Pass* PassManager::GetPass(uint32_t index) const { |
147 | 0 | SPIRV_ASSERT(consumer_, index < passes_.size(), "index out of bound"); |
148 | 0 | return passes_[index].get(); |
149 | 0 | } |
150 | | |
151 | 949k | inline const MessageConsumer& PassManager::consumer() const { |
152 | 949k | return consumer_; |
153 | 949k | } |
154 | | |
155 | | } // namespace opt |
156 | | } // namespace spvtools |
157 | | |
158 | | #endif // SOURCE_OPT_PASS_MANAGER_H_ |