/proc/self/cwd/eval/compiler/qualified_reference_resolver.cc
Line | Count | Source |
1 | | // Copyright 2020 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "eval/compiler/qualified_reference_resolver.h" |
16 | | |
17 | | #include <cstdint> |
18 | | #include <memory> |
19 | | #include <string> |
20 | | #include <utility> |
21 | | #include <vector> |
22 | | |
23 | | #include "absl/container/flat_hash_map.h" |
24 | | #include "absl/container/flat_hash_set.h" |
25 | | #include "absl/status/status.h" |
26 | | #include "absl/status/statusor.h" |
27 | | #include "absl/strings/str_cat.h" |
28 | | #include "absl/strings/string_view.h" |
29 | | #include "absl/types/optional.h" |
30 | | #include "base/ast.h" |
31 | | #include "base/builtins.h" |
32 | | #include "common/ast.h" |
33 | | #include "common/ast_rewrite.h" |
34 | | #include "common/expr.h" |
35 | | #include "common/kind.h" |
36 | | #include "eval/compiler/flat_expr_builder_extensions.h" |
37 | | #include "eval/compiler/resolver.h" |
38 | | #include "runtime/internal/issue_collector.h" |
39 | | #include "runtime/runtime_issue.h" |
40 | | |
41 | | namespace google::api::expr::runtime { |
42 | | |
43 | | namespace { |
44 | | |
45 | | using ::cel::Expr; |
46 | | using ::cel::Reference; |
47 | | using ::cel::RuntimeIssue; |
48 | | using ::cel::runtime_internal::IssueCollector; |
49 | | |
50 | | // Optional types are opt-in but require special handling in the evaluator. |
51 | | constexpr absl::string_view kOptionalOr = "or"; |
52 | | constexpr absl::string_view kOptionalOrValue = "orValue"; |
53 | | |
54 | | // Determines if function is implemented with custom evaluation step instead of |
55 | | // registered. |
56 | 0 | bool IsSpecialFunction(absl::string_view function_name) { |
57 | 0 | return function_name == cel::builtin::kAnd || |
58 | 0 | function_name == cel::builtin::kOr || |
59 | 0 | function_name == cel::builtin::kIndex || |
60 | 0 | function_name == cel::builtin::kTernary || |
61 | 0 | function_name == kOptionalOr || function_name == kOptionalOrValue || |
62 | 0 | function_name == cel::builtin::kEqual || |
63 | 0 | function_name == cel::builtin::kInequal || |
64 | 0 | function_name == cel::builtin::kNot || |
65 | 0 | function_name == cel::builtin::kNotStrictlyFalse || |
66 | 0 | function_name == cel::builtin::kNotStrictlyFalseDeprecated || |
67 | 0 | function_name == cel::builtin::kIn || |
68 | 0 | function_name == cel::builtin::kInDeprecated || |
69 | 0 | function_name == cel::builtin::kInFunction || |
70 | 0 | function_name == "cel.@block"; |
71 | 0 | } |
72 | | |
73 | | bool OverloadExists(const Resolver& resolver, absl::string_view name, |
74 | | const std::vector<cel::Kind>& arguments_matcher, |
75 | 0 | bool receiver_style = false) { |
76 | 0 | return !resolver.FindOverloads(name, receiver_style, arguments_matcher) |
77 | 0 | .empty() || |
78 | 0 | !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher) |
79 | 0 | .empty(); |
80 | 0 | } |
81 | | |
82 | | // Return the qualified name of the most qualified matching overload, or |
83 | | // nullopt if no matches are found. |
84 | | std::optional<std::string> BestOverloadMatch(const Resolver& resolver, |
85 | | absl::string_view base_name, |
86 | 0 | int argument_count) { |
87 | 0 | if (IsSpecialFunction(base_name)) { |
88 | 0 | return std::string(base_name); |
89 | 0 | } |
90 | 0 | auto arguments_matcher = ArgumentsMatcher(argument_count); |
91 | | // Check from most qualified to least qualified for a matching overload. |
92 | 0 | auto names = resolver.FullyQualifiedNames(base_name); |
93 | 0 | for (auto name = names.begin(); name != names.end(); ++name) { |
94 | 0 | if (OverloadExists(resolver, *name, arguments_matcher)) { |
95 | 0 | if (base_name[0] == '.') { |
96 | | // Preserve leading '.' to prevent re-resolving at plan time. |
97 | 0 | return std::string(base_name); |
98 | 0 | } |
99 | 0 | return *name; |
100 | 0 | } |
101 | 0 | } |
102 | 0 | return absl::nullopt; |
103 | 0 | } |
104 | | |
105 | | // Rewriter visitor for resolving references. |
106 | | // |
107 | | // On previsit pass, replace (possibly qualified) identifier branches with the |
108 | | // canonical name in the reference map (most qualified references considered |
109 | | // first). |
110 | | // |
111 | | // On post visit pass, update function calls to determine whether the function |
112 | | // target is a namespace for the function or a receiver for the call. |
113 | | class ReferenceResolver : public cel::AstRewriterBase { |
114 | | public: |
115 | | ReferenceResolver( |
116 | | const absl::flat_hash_map<int64_t, Reference>& reference_map, |
117 | | const Resolver& resolver, IssueCollector& issue_collector) |
118 | 0 | : reference_map_(reference_map), |
119 | 0 | resolver_(resolver), |
120 | 0 | issues_(issue_collector), |
121 | 0 | progress_status_(absl::OkStatus()) {} |
122 | | |
123 | | // Attempt to resolve references in expr. Return true if part of the |
124 | | // expression was rewritten. |
125 | | // TODO(issues/95): If possible, it would be nice to write a general utility |
126 | | // for running the preprocess steps when traversing the AST instead of having |
127 | | // one pass per transform. |
128 | 0 | bool PreVisitRewrite(Expr& expr) override { |
129 | 0 | const Reference* reference = GetReferenceForId(expr.id()); |
130 | | |
131 | | // Fold compile time constant (e.g. enum values) |
132 | 0 | if (reference != nullptr && reference->has_value()) { |
133 | 0 | if (reference->value().has_int64_value()) { |
134 | | // Replace enum idents with const reference value. |
135 | 0 | expr.mutable_const_expr().set_int64_value( |
136 | 0 | reference->value().int64_value()); |
137 | 0 | return true; |
138 | 0 | } else if (expr.has_ident_expr()) { |
139 | | // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes |
140 | | // it is interpreted as null value and sometimes as an enum constant. |
141 | 0 | if (reference->value().has_null_value() && |
142 | 0 | expr.ident_expr().name() == |
143 | 0 | "google.protobuf.NullValue.NULL_VALUE") { |
144 | 0 | return false; |
145 | 0 | } |
146 | 0 | expr.set_const_expr(reference->value()); |
147 | 0 | return true; |
148 | 0 | } else { |
149 | 0 | return false; |
150 | 0 | } |
151 | 0 | } |
152 | | |
153 | 0 | if (reference != nullptr) { |
154 | 0 | if (expr.has_ident_expr()) { |
155 | 0 | return MaybeUpdateIdentNode(&expr, *reference); |
156 | 0 | } else if (expr.has_select_expr()) { |
157 | 0 | return MaybeUpdateSelectNode(&expr, *reference); |
158 | 0 | } else { |
159 | | // Call nodes are updated on post visit so they will see any select |
160 | | // path rewrites. |
161 | 0 | return false; |
162 | 0 | } |
163 | 0 | } |
164 | 0 | return false; |
165 | 0 | } |
166 | | |
167 | 0 | bool PostVisitRewrite(Expr& expr) override { |
168 | 0 | const Reference* reference = GetReferenceForId(expr.id()); |
169 | 0 | if (expr.has_call_expr()) { |
170 | 0 | return MaybeUpdateCallNode(&expr, reference); |
171 | 0 | } |
172 | 0 | return false; |
173 | 0 | } |
174 | | |
175 | 0 | const absl::Status& GetProgressStatus() const { return progress_status_; } |
176 | | |
177 | | private: |
178 | | // Attempt to update a function call node. This disambiguates |
179 | | // receiver call verses namespaced names in parse if possible. |
180 | | // |
181 | | // TODO(issues/95): This duplicates some of the overload matching behavior |
182 | | // for parsed expressions. We should refactor to consolidate the code. |
183 | 0 | bool MaybeUpdateCallNode(Expr* out, const Reference* reference) { |
184 | 0 | auto& call_expr = out->mutable_call_expr(); |
185 | 0 | const std::string& function = call_expr.function(); |
186 | 0 | if (reference != nullptr && reference->overload_id().empty()) { |
187 | 0 | UpdateStatus(issues_.AddIssue( |
188 | 0 | RuntimeIssue::CreateWarning(absl::InvalidArgumentError( |
189 | 0 | absl::StrCat("Reference map doesn't provide overloads for ", |
190 | 0 | out->call_expr().function()))))); |
191 | 0 | } |
192 | 0 | bool receiver_style = call_expr.has_target(); |
193 | 0 | int arg_num = call_expr.args().size(); |
194 | 0 | if (receiver_style) { |
195 | 0 | auto maybe_namespace = ToNamespace(call_expr.target()); |
196 | 0 | if (maybe_namespace.has_value()) { |
197 | 0 | std::string resolved_name = |
198 | 0 | absl::StrCat(*maybe_namespace, ".", function); |
199 | 0 | auto resolved_function = |
200 | 0 | BestOverloadMatch(resolver_, resolved_name, arg_num); |
201 | 0 | if (resolved_function.has_value()) { |
202 | 0 | call_expr.set_function(*resolved_function); |
203 | 0 | call_expr.set_target(nullptr); |
204 | 0 | return true; |
205 | 0 | } |
206 | 0 | } |
207 | 0 | } else { |
208 | | // Not a receiver style function call. Check to see if it is a namespaced |
209 | | // function using a shorthand inside the expression container. |
210 | 0 | auto maybe_resolved_function = |
211 | 0 | BestOverloadMatch(resolver_, function, arg_num); |
212 | 0 | if (!maybe_resolved_function.has_value()) { |
213 | 0 | UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( |
214 | 0 | absl::InvalidArgumentError(absl::StrCat( |
215 | 0 | "No overload found in reference resolve step for ", function)), |
216 | 0 | RuntimeIssue::ErrorCode::kNoMatchingOverload))); |
217 | 0 | } else if (maybe_resolved_function.value() != function) { |
218 | 0 | call_expr.set_function(maybe_resolved_function.value()); |
219 | 0 | return true; |
220 | 0 | } |
221 | 0 | } |
222 | | // For parity, if we didn't rewrite the receiver call style function, |
223 | | // check that an overload is provided in the builder. |
224 | 0 | if (call_expr.has_target() && !IsSpecialFunction(function) && |
225 | 0 | !OverloadExists(resolver_, function, ArgumentsMatcher(arg_num + 1), |
226 | 0 | /* receiver_style= */ true)) { |
227 | 0 | UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( |
228 | 0 | absl::InvalidArgumentError(absl::StrCat( |
229 | 0 | "No overload found in reference resolve step for ", function)), |
230 | 0 | RuntimeIssue::ErrorCode::kNoMatchingOverload))); |
231 | 0 | } |
232 | 0 | return false; |
233 | 0 | } |
234 | | |
235 | | // Attempt to resolve a select node. If reference is valid, |
236 | | // replace the select node with the fully qualified ident node. |
237 | 0 | bool MaybeUpdateSelectNode(Expr* out, const Reference& reference) { |
238 | 0 | if (out->select_expr().test_only()) { |
239 | 0 | UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( |
240 | 0 | absl::InvalidArgumentError("Reference map points to a presence " |
241 | 0 | "test -- has(container.attr)")))); |
242 | 0 | } else if (!reference.name().empty()) { |
243 | 0 | out->mutable_ident_expr().set_name(reference.name()); |
244 | 0 | rewritten_reference_.insert(out->id()); |
245 | 0 | return true; |
246 | 0 | } |
247 | 0 | return false; |
248 | 0 | } |
249 | | |
250 | | // Attempt to resolve an ident node. If reference is valid, |
251 | | // replace the node with the fully qualified ident node. |
252 | 0 | bool MaybeUpdateIdentNode(Expr* out, const Reference& reference) { |
253 | 0 | if (!reference.name().empty() && |
254 | 0 | reference.name() != out->ident_expr().name()) { |
255 | 0 | out->mutable_ident_expr().set_name(reference.name()); |
256 | 0 | rewritten_reference_.insert(out->id()); |
257 | 0 | return true; |
258 | 0 | } |
259 | 0 | return false; |
260 | 0 | } |
261 | | |
262 | | // Convert a select expr sub tree into a namespace name if possible. |
263 | | // If any operand of the top element is a not a select or an ident node, |
264 | | // return nullopt. |
265 | 0 | std::optional<std::string> ToNamespace(const Expr& expr) { |
266 | 0 | std::optional<std::string> maybe_parent_namespace; |
267 | 0 | if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { |
268 | | // The target expr matches a reference (resolved to an ident decl). |
269 | | // This should not be treated as a function qualifier. |
270 | 0 | return absl::nullopt; |
271 | 0 | } |
272 | 0 | if (expr.has_ident_expr()) { |
273 | 0 | return expr.ident_expr().name(); |
274 | 0 | } else if (expr.has_select_expr()) { |
275 | 0 | if (expr.select_expr().test_only()) { |
276 | 0 | return absl::nullopt; |
277 | 0 | } |
278 | 0 | maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); |
279 | 0 | if (!maybe_parent_namespace.has_value()) { |
280 | 0 | return absl::nullopt; |
281 | 0 | } |
282 | 0 | return absl::StrCat(*maybe_parent_namespace, ".", |
283 | 0 | expr.select_expr().field()); |
284 | 0 | } else { |
285 | 0 | return absl::nullopt; |
286 | 0 | } |
287 | 0 | } |
288 | | |
289 | | // Find a reference for the given expr id. |
290 | | // |
291 | | // Returns nullptr if no reference is available. |
292 | 0 | const Reference* GetReferenceForId(int64_t expr_id) { |
293 | 0 | auto iter = reference_map_.find(expr_id); |
294 | 0 | if (iter == reference_map_.end()) { |
295 | 0 | return nullptr; |
296 | 0 | } |
297 | 0 | if (expr_id == 0) { |
298 | 0 | UpdateStatus(issues_.AddIssue( |
299 | 0 | RuntimeIssue::CreateWarning(absl::InvalidArgumentError( |
300 | 0 | "reference map entries for expression id 0 are not supported")))); |
301 | 0 | return nullptr; |
302 | 0 | } |
303 | 0 | return &iter->second; |
304 | 0 | } |
305 | | |
306 | 0 | void UpdateStatus(absl::Status status) { |
307 | 0 | if (progress_status_.ok() && !status.ok()) { |
308 | 0 | progress_status_ = std::move(status); |
309 | 0 | return; |
310 | 0 | } |
311 | 0 | status.IgnoreError(); |
312 | 0 | } |
313 | | |
314 | | const absl::flat_hash_map<int64_t, Reference>& reference_map_; |
315 | | const Resolver& resolver_; |
316 | | IssueCollector& issues_; |
317 | | absl::Status progress_status_; |
318 | | absl::flat_hash_set<int64_t> rewritten_reference_; |
319 | | }; |
320 | | |
321 | | class ReferenceResolverExtension : public AstTransform { |
322 | | public: |
323 | | explicit ReferenceResolverExtension(ReferenceResolverOption opt) |
324 | 14.5k | : opt_(opt) {} |
325 | | absl::Status UpdateAst(PlannerContext& context, |
326 | 10.3k | cel::Ast& ast) const override { |
327 | 10.3k | if (opt_ == ReferenceResolverOption::kCheckedOnly && |
328 | 10.3k | ast.reference_map().empty()) { |
329 | 10.3k | return absl::OkStatus(); |
330 | 10.3k | } |
331 | 0 | return ResolveReferences(context.resolver(), context.issue_collector(), ast) |
332 | 0 | .status(); |
333 | 10.3k | } |
334 | | |
335 | | private: |
336 | | ReferenceResolverOption opt_; |
337 | | }; |
338 | | |
339 | | } // namespace |
340 | | |
341 | | absl::StatusOr<bool> ResolveReferences(const Resolver& resolver, |
342 | 0 | IssueCollector& issues, cel::Ast& ast) { |
343 | 0 | ReferenceResolver ref_resolver(ast.reference_map(), resolver, issues); |
344 | | |
345 | | // Rewriting interface doesn't support failing mid traverse propagate first |
346 | | // error encountered if fail fast enabled. |
347 | 0 | bool was_rewritten = cel::AstRewrite(ast.mutable_root_expr(), ref_resolver); |
348 | 0 | if (!ref_resolver.GetProgressStatus().ok()) { |
349 | 0 | return ref_resolver.GetProgressStatus(); |
350 | 0 | } |
351 | 0 | return was_rewritten; |
352 | 0 | } |
353 | | |
354 | | std::unique_ptr<AstTransform> NewReferenceResolverExtension( |
355 | 14.5k | ReferenceResolverOption option) { |
356 | 14.5k | return std::make_unique<ReferenceResolverExtension>(option); |
357 | 14.5k | } |
358 | | |
359 | | } // namespace google::api::expr::runtime |