Coverage Report

Created: 2026-05-27 07:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/proc/self/cwd/eval/compiler/qualified_reference_resolver.cc
Line
Count
Source
1
// Copyright 2020 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "eval/compiler/qualified_reference_resolver.h"
16
17
#include <cstdint>
18
#include <memory>
19
#include <string>
20
#include <utility>
21
#include <vector>
22
23
#include "absl/container/flat_hash_map.h"
24
#include "absl/container/flat_hash_set.h"
25
#include "absl/status/status.h"
26
#include "absl/status/statusor.h"
27
#include "absl/strings/str_cat.h"
28
#include "absl/strings/string_view.h"
29
#include "absl/types/optional.h"
30
#include "base/ast.h"
31
#include "base/builtins.h"
32
#include "common/ast.h"
33
#include "common/ast_rewrite.h"
34
#include "common/expr.h"
35
#include "common/kind.h"
36
#include "eval/compiler/flat_expr_builder_extensions.h"
37
#include "eval/compiler/resolver.h"
38
#include "runtime/internal/issue_collector.h"
39
#include "runtime/runtime_issue.h"
40
41
namespace google::api::expr::runtime {
42
43
namespace {
44
45
using ::cel::Expr;
46
using ::cel::Reference;
47
using ::cel::RuntimeIssue;
48
using ::cel::runtime_internal::IssueCollector;
49
50
// Optional types are opt-in but require special handling in the evaluator.
51
constexpr absl::string_view kOptionalOr = "or";
52
constexpr absl::string_view kOptionalOrValue = "orValue";
53
54
// Determines if function is implemented with custom evaluation step instead of
55
// registered.
56
0
bool IsSpecialFunction(absl::string_view function_name) {
57
0
  return function_name == cel::builtin::kAnd ||
58
0
         function_name == cel::builtin::kOr ||
59
0
         function_name == cel::builtin::kIndex ||
60
0
         function_name == cel::builtin::kTernary ||
61
0
         function_name == kOptionalOr || function_name == kOptionalOrValue ||
62
0
         function_name == cel::builtin::kEqual ||
63
0
         function_name == cel::builtin::kInequal ||
64
0
         function_name == cel::builtin::kNot ||
65
0
         function_name == cel::builtin::kNotStrictlyFalse ||
66
0
         function_name == cel::builtin::kNotStrictlyFalseDeprecated ||
67
0
         function_name == cel::builtin::kIn ||
68
0
         function_name == cel::builtin::kInDeprecated ||
69
0
         function_name == cel::builtin::kInFunction ||
70
0
         function_name == "cel.@block";
71
0
}
72
73
bool OverloadExists(const Resolver& resolver, absl::string_view name,
74
                    const std::vector<cel::Kind>& arguments_matcher,
75
0
                    bool receiver_style = false) {
76
0
  return !resolver.FindOverloads(name, receiver_style, arguments_matcher)
77
0
              .empty() ||
78
0
         !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher)
79
0
              .empty();
80
0
}
81
82
// Return the qualified name of the most qualified matching overload, or
83
// nullopt if no matches are found.
84
std::optional<std::string> BestOverloadMatch(const Resolver& resolver,
85
                                             absl::string_view base_name,
86
0
                                             int argument_count) {
87
0
  if (IsSpecialFunction(base_name)) {
88
0
    return std::string(base_name);
89
0
  }
90
0
  auto arguments_matcher = ArgumentsMatcher(argument_count);
91
  // Check from most qualified to least qualified for a matching overload.
92
0
  auto names = resolver.FullyQualifiedNames(base_name);
93
0
  for (auto name = names.begin(); name != names.end(); ++name) {
94
0
    if (OverloadExists(resolver, *name, arguments_matcher)) {
95
0
      if (base_name[0] == '.') {
96
        // Preserve leading '.' to prevent re-resolving at plan time.
97
0
        return std::string(base_name);
98
0
      }
99
0
      return *name;
100
0
    }
101
0
  }
102
0
  return absl::nullopt;
103
0
}
104
105
// Rewriter visitor for resolving references.
106
//
107
// On previsit pass, replace (possibly qualified) identifier branches with the
108
// canonical name in the reference map (most qualified references considered
109
// first).
110
//
111
// On post visit pass, update function calls to determine whether the function
112
// target is a namespace for the function or a receiver for the call.
113
class ReferenceResolver : public cel::AstRewriterBase {
114
 public:
115
  ReferenceResolver(
116
      const absl::flat_hash_map<int64_t, Reference>& reference_map,
117
      const Resolver& resolver, IssueCollector& issue_collector)
118
0
      : reference_map_(reference_map),
119
0
        resolver_(resolver),
120
0
        issues_(issue_collector),
121
0
        progress_status_(absl::OkStatus()) {}
122
123
  // Attempt to resolve references in expr. Return true if part of the
124
  // expression was rewritten.
125
  // TODO(issues/95): If possible, it would be nice to write a general utility
126
  // for running the preprocess steps when traversing the AST instead of having
127
  // one pass per transform.
128
0
  bool PreVisitRewrite(Expr& expr) override {
129
0
    const Reference* reference = GetReferenceForId(expr.id());
130
131
    // Fold compile time constant (e.g. enum values)
132
0
    if (reference != nullptr && reference->has_value()) {
133
0
      if (reference->value().has_int64_value()) {
134
        // Replace enum idents with const reference value.
135
0
        expr.mutable_const_expr().set_int64_value(
136
0
            reference->value().int64_value());
137
0
        return true;
138
0
      } else if (expr.has_ident_expr()) {
139
        // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes
140
        // it is interpreted as null value and sometimes as an enum constant.
141
0
        if (reference->value().has_null_value() &&
142
0
            expr.ident_expr().name() ==
143
0
                "google.protobuf.NullValue.NULL_VALUE") {
144
0
          return false;
145
0
        }
146
0
        expr.set_const_expr(reference->value());
147
0
        return true;
148
0
      } else {
149
0
        return false;
150
0
      }
151
0
    }
152
153
0
    if (reference != nullptr) {
154
0
      if (expr.has_ident_expr()) {
155
0
        return MaybeUpdateIdentNode(&expr, *reference);
156
0
      } else if (expr.has_select_expr()) {
157
0
        return MaybeUpdateSelectNode(&expr, *reference);
158
0
      } else {
159
        // Call nodes are updated on post visit so they will see any select
160
        // path rewrites.
161
0
        return false;
162
0
      }
163
0
    }
164
0
    return false;
165
0
  }
166
167
0
  bool PostVisitRewrite(Expr& expr) override {
168
0
    const Reference* reference = GetReferenceForId(expr.id());
169
0
    if (expr.has_call_expr()) {
170
0
      return MaybeUpdateCallNode(&expr, reference);
171
0
    }
172
0
    return false;
173
0
  }
174
175
0
  const absl::Status& GetProgressStatus() const { return progress_status_; }
176
177
 private:
178
  // Attempt to update a function call node. This disambiguates
179
  // receiver call verses namespaced names in parse if possible.
180
  //
181
  // TODO(issues/95): This duplicates some of the overload matching behavior
182
  // for parsed expressions. We should refactor to consolidate the code.
183
0
  bool MaybeUpdateCallNode(Expr* out, const Reference* reference) {
184
0
    auto& call_expr = out->mutable_call_expr();
185
0
    const std::string& function = call_expr.function();
186
0
    if (reference != nullptr && reference->overload_id().empty()) {
187
0
      UpdateStatus(issues_.AddIssue(
188
0
          RuntimeIssue::CreateWarning(absl::InvalidArgumentError(
189
0
              absl::StrCat("Reference map doesn't provide overloads for ",
190
0
                           out->call_expr().function())))));
191
0
    }
192
0
    bool receiver_style = call_expr.has_target();
193
0
    int arg_num = call_expr.args().size();
194
0
    if (receiver_style) {
195
0
      auto maybe_namespace = ToNamespace(call_expr.target());
196
0
      if (maybe_namespace.has_value()) {
197
0
        std::string resolved_name =
198
0
            absl::StrCat(*maybe_namespace, ".", function);
199
0
        auto resolved_function =
200
0
            BestOverloadMatch(resolver_, resolved_name, arg_num);
201
0
        if (resolved_function.has_value()) {
202
0
          call_expr.set_function(*resolved_function);
203
0
          call_expr.set_target(nullptr);
204
0
          return true;
205
0
        }
206
0
      }
207
0
    } else {
208
      // Not a receiver style function call. Check to see if it is a namespaced
209
      // function using a shorthand inside the expression container.
210
0
      auto maybe_resolved_function =
211
0
          BestOverloadMatch(resolver_, function, arg_num);
212
0
      if (!maybe_resolved_function.has_value()) {
213
0
        UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning(
214
0
            absl::InvalidArgumentError(absl::StrCat(
215
0
                "No overload found in reference resolve step for ", function)),
216
0
            RuntimeIssue::ErrorCode::kNoMatchingOverload)));
217
0
      } else if (maybe_resolved_function.value() != function) {
218
0
        call_expr.set_function(maybe_resolved_function.value());
219
0
        return true;
220
0
      }
221
0
    }
222
    // For parity, if we didn't rewrite the receiver call style function,
223
    // check that an overload is provided in the builder.
224
0
    if (call_expr.has_target() && !IsSpecialFunction(function) &&
225
0
        !OverloadExists(resolver_, function, ArgumentsMatcher(arg_num + 1),
226
0
                        /* receiver_style= */ true)) {
227
0
      UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning(
228
0
          absl::InvalidArgumentError(absl::StrCat(
229
0
              "No overload found in reference resolve step for ", function)),
230
0
          RuntimeIssue::ErrorCode::kNoMatchingOverload)));
231
0
    }
232
0
    return false;
233
0
  }
234
235
  // Attempt to resolve a select node. If reference is valid,
236
  // replace the select node with the fully qualified ident node.
237
0
  bool MaybeUpdateSelectNode(Expr* out, const Reference& reference) {
238
0
    if (out->select_expr().test_only()) {
239
0
      UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning(
240
0
          absl::InvalidArgumentError("Reference map points to a presence "
241
0
                                     "test -- has(container.attr)"))));
242
0
    } else if (!reference.name().empty()) {
243
0
      out->mutable_ident_expr().set_name(reference.name());
244
0
      rewritten_reference_.insert(out->id());
245
0
      return true;
246
0
    }
247
0
    return false;
248
0
  }
249
250
  // Attempt to resolve an ident node. If reference is valid,
251
  // replace the node with the fully qualified ident node.
252
0
  bool MaybeUpdateIdentNode(Expr* out, const Reference& reference) {
253
0
    if (!reference.name().empty() &&
254
0
        reference.name() != out->ident_expr().name()) {
255
0
      out->mutable_ident_expr().set_name(reference.name());
256
0
      rewritten_reference_.insert(out->id());
257
0
      return true;
258
0
    }
259
0
    return false;
260
0
  }
261
262
  // Convert a select expr sub tree into a namespace name if possible.
263
  // If any operand of the top element is a not a select or an ident node,
264
  // return nullopt.
265
0
  std::optional<std::string> ToNamespace(const Expr& expr) {
266
0
    std::optional<std::string> maybe_parent_namespace;
267
0
    if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) {
268
      // The target expr matches a reference (resolved to an ident decl).
269
      // This should not be treated as a function qualifier.
270
0
      return absl::nullopt;
271
0
    }
272
0
    if (expr.has_ident_expr()) {
273
0
      return expr.ident_expr().name();
274
0
    } else if (expr.has_select_expr()) {
275
0
      if (expr.select_expr().test_only()) {
276
0
        return absl::nullopt;
277
0
      }
278
0
      maybe_parent_namespace = ToNamespace(expr.select_expr().operand());
279
0
      if (!maybe_parent_namespace.has_value()) {
280
0
        return absl::nullopt;
281
0
      }
282
0
      return absl::StrCat(*maybe_parent_namespace, ".",
283
0
                          expr.select_expr().field());
284
0
    } else {
285
0
      return absl::nullopt;
286
0
    }
287
0
  }
288
289
  // Find a reference for the given expr id.
290
  //
291
  // Returns nullptr if no reference is available.
292
0
  const Reference* GetReferenceForId(int64_t expr_id) {
293
0
    auto iter = reference_map_.find(expr_id);
294
0
    if (iter == reference_map_.end()) {
295
0
      return nullptr;
296
0
    }
297
0
    if (expr_id == 0) {
298
0
      UpdateStatus(issues_.AddIssue(
299
0
          RuntimeIssue::CreateWarning(absl::InvalidArgumentError(
300
0
              "reference map entries for expression id 0 are not supported"))));
301
0
      return nullptr;
302
0
    }
303
0
    return &iter->second;
304
0
  }
305
306
0
  void UpdateStatus(absl::Status status) {
307
0
    if (progress_status_.ok() && !status.ok()) {
308
0
      progress_status_ = std::move(status);
309
0
      return;
310
0
    }
311
0
    status.IgnoreError();
312
0
  }
313
314
  const absl::flat_hash_map<int64_t, Reference>& reference_map_;
315
  const Resolver& resolver_;
316
  IssueCollector& issues_;
317
  absl::Status progress_status_;
318
  absl::flat_hash_set<int64_t> rewritten_reference_;
319
};
320
321
class ReferenceResolverExtension : public AstTransform {
322
 public:
323
  explicit ReferenceResolverExtension(ReferenceResolverOption opt)
324
14.5k
      : opt_(opt) {}
325
  absl::Status UpdateAst(PlannerContext& context,
326
10.3k
                         cel::Ast& ast) const override {
327
10.3k
    if (opt_ == ReferenceResolverOption::kCheckedOnly &&
328
10.3k
        ast.reference_map().empty()) {
329
10.3k
      return absl::OkStatus();
330
10.3k
    }
331
0
    return ResolveReferences(context.resolver(), context.issue_collector(), ast)
332
0
        .status();
333
10.3k
  }
334
335
 private:
336
  ReferenceResolverOption opt_;
337
};
338
339
}  // namespace
340
341
absl::StatusOr<bool> ResolveReferences(const Resolver& resolver,
342
0
                                       IssueCollector& issues, cel::Ast& ast) {
343
0
  ReferenceResolver ref_resolver(ast.reference_map(), resolver, issues);
344
345
  // Rewriting interface doesn't support failing mid traverse propagate first
346
  // error encountered if fail fast enabled.
347
0
  bool was_rewritten = cel::AstRewrite(ast.mutable_root_expr(), ref_resolver);
348
0
  if (!ref_resolver.GetProgressStatus().ok()) {
349
0
    return ref_resolver.GetProgressStatus();
350
0
  }
351
0
  return was_rewritten;
352
0
}
353
354
std::unique_ptr<AstTransform> NewReferenceResolverExtension(
355
14.5k
    ReferenceResolverOption option) {
356
14.5k
  return std::make_unique<ReferenceResolverExtension>(option);
357
14.5k
}
358
359
}  // namespace google::api::expr::runtime