/proc/self/cwd/eval/compiler/regex_precompilation_optimization.cc
Line | Count | Source |
1 | | // Copyright 2023 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "eval/compiler/regex_precompilation_optimization.h" |
16 | | |
17 | | #include <cstddef> |
18 | | #include <cstdint> |
19 | | #include <memory> |
20 | | #include <string> |
21 | | #include <utility> |
22 | | #include <vector> |
23 | | |
24 | | #include "absl/base/nullability.h" |
25 | | #include "absl/container/flat_hash_map.h" |
26 | | #include "absl/status/status.h" |
27 | | #include "absl/status/statusor.h" |
28 | | #include "absl/strings/string_view.h" |
29 | | #include "absl/types/optional.h" |
30 | | #include "base/builtins.h" |
31 | | #include "common/ast.h" |
32 | | #include "common/casting.h" |
33 | | #include "common/expr.h" |
34 | | #include "common/native_type.h" |
35 | | #include "common/value.h" |
36 | | #include "eval/compiler/flat_expr_builder_extensions.h" |
37 | | #include "eval/eval/compiler_constant_step.h" |
38 | | #include "eval/eval/direct_expression_step.h" |
39 | | #include "eval/eval/evaluator_core.h" |
40 | | #include "eval/eval/regex_match_step.h" |
41 | | #include "internal/casts.h" |
42 | | #include "internal/re2_options.h" |
43 | | #include "internal/status_macros.h" |
44 | | #include "re2/re2.h" |
45 | | |
46 | | namespace google::api::expr::runtime { |
47 | | namespace { |
48 | | |
49 | | using ::cel::Ast; |
50 | | using ::cel::CallExpr; |
51 | | using ::cel::Cast; |
52 | | using ::cel::Expr; |
53 | | using ::cel::InstanceOf; |
54 | | using ::cel::NativeTypeId; |
55 | | using ::cel::Reference; |
56 | | using ::cel::StringValue; |
57 | | using ::cel::Value; |
58 | | using ::cel::internal::down_cast; |
59 | | |
60 | | using ReferenceMap = absl::flat_hash_map<int64_t, Reference>; |
61 | | |
62 | | bool IsFunctionOverload(const Expr& expr, absl::string_view function, |
63 | | absl::string_view overload, size_t arity, |
64 | 0 | const ReferenceMap& reference_map) { |
65 | 0 | if (!expr.has_call_expr()) { |
66 | 0 | return false; |
67 | 0 | } |
68 | 0 | const auto& call_expr = expr.call_expr(); |
69 | 0 | if (call_expr.function() != function) { |
70 | 0 | return false; |
71 | 0 | } |
72 | 0 | if (call_expr.args().size() + (call_expr.has_target() ? 1 : 0) != arity) { |
73 | 0 | return false; |
74 | 0 | } |
75 | | |
76 | | // If parse-only and opted in to the optimization, assume this is the intended |
77 | | // overload. This will still only change the evaluation plan if the second arg |
78 | | // is a constant string. |
79 | 0 | if (reference_map.empty()) { |
80 | 0 | return true; |
81 | 0 | } |
82 | | |
83 | 0 | auto reference = reference_map.find(expr.id()); |
84 | 0 | if (reference != reference_map.end() && |
85 | 0 | reference->second.overload_id().size() == 1 && |
86 | 0 | reference->second.overload_id().front() == overload) { |
87 | 0 | return true; |
88 | 0 | } |
89 | 0 | return false; |
90 | 0 | } |
91 | | |
92 | | // Abstraction for deduplicating regular expressions over the course of a single |
93 | | // create expression call. Should not be used during evaluation. Uses |
94 | | // std::shared_ptr and std::weak_ptr. |
95 | | class RegexProgramBuilder final { |
96 | | public: |
97 | | explicit RegexProgramBuilder(int max_program_size) |
98 | 0 | : max_program_size_(max_program_size) {} |
99 | | |
100 | | absl::StatusOr<std::shared_ptr<const RE2>> BuildRegexProgram( |
101 | 0 | std::string pattern) { |
102 | 0 | auto existing = programs_.find(pattern); |
103 | 0 | if (existing != programs_.end()) { |
104 | 0 | if (auto program = existing->second.lock(); program) { |
105 | 0 | return program; |
106 | 0 | } |
107 | 0 | programs_.erase(existing); |
108 | 0 | } |
109 | 0 | auto program = |
110 | 0 | std::make_shared<RE2>(pattern, cel::internal::MakeRE2Options()); |
111 | 0 | CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(*program, max_program_size_)); |
112 | 0 | programs_.insert({std::move(pattern), program}); |
113 | 0 | return program; |
114 | 0 | } |
115 | | |
116 | | private: |
117 | | const int max_program_size_; |
118 | | absl::flat_hash_map<std::string, std::weak_ptr<const RE2>> programs_; |
119 | | }; |
120 | | |
121 | | class RegexPrecompilationOptimization : public ProgramOptimizer { |
122 | | public: |
123 | | explicit RegexPrecompilationOptimization(const ReferenceMap& reference_map, |
124 | | int regex_max_program_size) |
125 | 0 | : reference_map_(reference_map), |
126 | 0 | regex_program_builder_(regex_max_program_size) {} |
127 | | |
128 | 0 | absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { |
129 | 0 | return absl::OkStatus(); |
130 | 0 | } |
131 | | |
132 | 0 | absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override { |
133 | | // Check that this is the correct matches overload instead of a user defined |
134 | | // overload. |
135 | 0 | if (!IsFunctionOverload(node, cel::builtin::kRegexMatch, "matches_string", |
136 | 0 | 2, reference_map_)) { |
137 | 0 | return absl::OkStatus(); |
138 | 0 | } |
139 | | |
140 | 0 | ProgramBuilder::Subexpression* subexpression = |
141 | 0 | context.program_builder().GetSubexpression(&node); |
142 | |
|
143 | 0 | const CallExpr& call_expr = node.call_expr(); |
144 | 0 | const Expr& pattern_expr = call_expr.args().back(); |
145 | | |
146 | | // Try to check if the regex is valid, whether or not we can actually update |
147 | | // the plan. |
148 | 0 | std::optional<std::string> pattern = |
149 | 0 | GetConstantString(context, subexpression, node, pattern_expr); |
150 | 0 | if (!pattern.has_value()) { |
151 | 0 | return absl::OkStatus(); |
152 | 0 | } |
153 | | |
154 | 0 | CEL_ASSIGN_OR_RETURN( |
155 | 0 | std::shared_ptr<const RE2> regex_program, |
156 | 0 | regex_program_builder_.BuildRegexProgram(std::move(pattern).value())); |
157 | |
|
158 | 0 | if (subexpression == nullptr || subexpression->IsFlattened()) { |
159 | | // Already modified, can't update further. |
160 | 0 | return absl::OkStatus(); |
161 | 0 | } |
162 | | |
163 | 0 | const Expr& subject_expr = |
164 | 0 | call_expr.has_target() ? call_expr.target() : call_expr.args().front(); |
165 | |
|
166 | 0 | return RewritePlan(context, subexpression, node, subject_expr, |
167 | 0 | std::move(regex_program)); |
168 | 0 | } |
169 | | |
170 | | private: |
171 | | std::optional<std::string> GetConstantString( |
172 | | PlannerContext& context, |
173 | | ProgramBuilder::Subexpression* absl_nullable subexpression, |
174 | 0 | const Expr& call_expr, const Expr& re_expr) const { |
175 | 0 | if (re_expr.has_const_expr() && re_expr.const_expr().has_string_value()) { |
176 | 0 | return re_expr.const_expr().string_value(); |
177 | 0 | } |
178 | | |
179 | 0 | if (subexpression == nullptr || subexpression->IsFlattened()) { |
180 | | // Already modified, can't recover the input pattern. |
181 | 0 | return absl::nullopt; |
182 | 0 | } |
183 | 0 | std::optional<Value> constant; |
184 | 0 | if (subexpression->IsRecursive()) { |
185 | 0 | const auto& program = subexpression->recursive_program(); |
186 | 0 | auto deps = program.step->GetDependencies(); |
187 | 0 | if (deps.has_value() && deps->size() == 2) { |
188 | 0 | const auto* re_plan = |
189 | 0 | TryDowncastDirectStep<DirectCompilerConstantStep>(deps->at(1)); |
190 | 0 | if (re_plan != nullptr) { |
191 | 0 | constant = re_plan->value(); |
192 | 0 | } |
193 | 0 | } |
194 | 0 | } else { |
195 | | // otherwise stack-machine program. |
196 | 0 | ExecutionPathView re_plan = context.GetSubplan(re_expr); |
197 | 0 | if (re_plan.size() == 1 && |
198 | 0 | re_plan[0]->GetNativeTypeId() == |
199 | 0 | NativeTypeId::For<CompilerConstantStep>()) { |
200 | 0 | constant = |
201 | 0 | down_cast<const CompilerConstantStep*>(re_plan[0].get())->value(); |
202 | 0 | } |
203 | 0 | } |
204 | |
|
205 | 0 | if (constant.has_value() && InstanceOf<StringValue>(*constant)) { |
206 | 0 | return Cast<StringValue>(*constant).ToString(); |
207 | 0 | } |
208 | | |
209 | 0 | return absl::nullopt; |
210 | 0 | } |
211 | | |
212 | | absl::Status RewritePlan( |
213 | | PlannerContext& context, |
214 | | ProgramBuilder::Subexpression* absl_nonnull subexpression, |
215 | | const Expr& call, const Expr& subject, |
216 | 0 | std::shared_ptr<const RE2> regex_program) { |
217 | 0 | if (subexpression->IsRecursive()) { |
218 | 0 | return RewriteRecursivePlan(subexpression, call, subject, |
219 | 0 | std::move(regex_program)); |
220 | 0 | } |
221 | 0 | return RewriteStackMachinePlan(context, call, subject, |
222 | 0 | std::move(regex_program)); |
223 | 0 | } |
224 | | |
225 | | absl::Status RewriteRecursivePlan( |
226 | | ProgramBuilder::Subexpression* absl_nonnull subexpression, |
227 | | const Expr& call, const Expr& subject, |
228 | 0 | std::shared_ptr<const RE2> regex_program) { |
229 | 0 | auto program = subexpression->ExtractRecursiveProgram(); |
230 | 0 | auto deps = program.step->ExtractDependencies(); |
231 | 0 | if (!deps.has_value() || deps->size() != 2) { |
232 | | // Possibly already const-folded, put the plan back. |
233 | 0 | subexpression->set_recursive_program(std::move(program.step), |
234 | 0 | program.depth); |
235 | 0 | return absl::OkStatus(); |
236 | 0 | } |
237 | 0 | subexpression->set_recursive_program( |
238 | 0 | CreateDirectRegexMatchStep(call.id(), std::move(deps->at(0)), |
239 | 0 | std::move(regex_program)), |
240 | 0 | program.depth); |
241 | 0 | return absl::OkStatus(); |
242 | 0 | } |
243 | | |
244 | | absl::Status RewriteStackMachinePlan( |
245 | | PlannerContext& context, const Expr& call, const Expr& subject, |
246 | 0 | std::shared_ptr<const RE2> regex_program) { |
247 | 0 | if (context.GetSubplan(subject).empty()) { |
248 | | // This subexpression was already optimized, nothing to do. |
249 | 0 | return absl::OkStatus(); |
250 | 0 | } |
251 | | |
252 | 0 | CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, |
253 | 0 | context.ExtractSubplan(subject)); |
254 | 0 | CEL_ASSIGN_OR_RETURN( |
255 | 0 | new_plan.emplace_back(), |
256 | 0 | CreateRegexMatchStep(std::move(regex_program), call.id())); |
257 | |
|
258 | 0 | return context.ReplaceSubplan(call, std::move(new_plan)); |
259 | 0 | } |
260 | | |
261 | | const ReferenceMap& reference_map_; |
262 | | RegexProgramBuilder regex_program_builder_; |
263 | | }; |
264 | | |
265 | | } // namespace |
266 | | |
267 | | ProgramOptimizerFactory CreateRegexPrecompilationExtension( |
268 | 0 | int regex_max_program_size) { |
269 | 0 | return [=](PlannerContext& context, const Ast& ast) { |
270 | 0 | return std::make_unique<RegexPrecompilationOptimization>( |
271 | 0 | ast.reference_map(), regex_max_program_size); |
272 | 0 | }; |
273 | 0 | } |
274 | | } // namespace google::api::expr::runtime |