Coverage Report

Created: 2026-05-27 07:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/proc/self/cwd/extensions/select_optimization.cc
Line
Count
Source
1
// Copyright 2023 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "extensions/select_optimization.h"
16
17
#include <cstddef>
18
#include <cstdint>
19
#include <iterator>
20
#include <memory>
21
#include <string>
22
#include <utility>
23
#include <vector>
24
25
#include "absl/algorithm/container.h"
26
#include "absl/base/nullability.h"
27
#include "absl/container/flat_hash_map.h"
28
#include "absl/functional/overload.h"
29
#include "absl/log/absl_check.h"
30
#include "absl/status/status.h"
31
#include "absl/status/statusor.h"
32
#include "absl/strings/match.h"
33
#include "absl/strings/string_view.h"
34
#include "absl/types/optional.h"
35
#include "absl/types/span.h"
36
#include "absl/types/variant.h"
37
#include "base/attribute.h"
38
#include "base/builtins.h"
39
#include "common/ast.h"
40
#include "common/ast_rewrite.h"
41
#include "common/casting.h"
42
#include "common/constant.h"
43
#include "common/expr.h"
44
#include "common/function_descriptor.h"
45
#include "common/kind.h"
46
#include "common/native_type.h"
47
#include "common/type.h"
48
#include "common/value.h"
49
#include "eval/compiler/flat_expr_builder.h"
50
#include "eval/compiler/flat_expr_builder_extensions.h"
51
#include "eval/eval/attribute_trail.h"
52
#include "eval/eval/direct_expression_step.h"
53
#include "eval/eval/evaluator_core.h"
54
#include "eval/eval/expression_step_base.h"
55
#include "internal/casts.h"
56
#include "internal/number.h"
57
#include "internal/status_macros.h"
58
#include "runtime/internal/errors.h"
59
#include "runtime/internal/runtime_friend_access.h"
60
#include "runtime/internal/runtime_impl.h"
61
#include "runtime/runtime_builder.h"
62
#include "google/protobuf/arena.h"
63
#include "google/protobuf/descriptor.h"
64
#include "google/protobuf/message.h"
65
66
namespace cel::extensions {
67
namespace {
68
69
using ::cel::Ast;
70
using ::cel::AstRewriterBase;
71
using ::cel::CallExpr;
72
using ::cel::ConstantKind;
73
using ::cel::Expr;
74
using ::cel::ExprKind;
75
using ::cel::SelectExpr;
76
using ::google::api::expr::runtime::AttributeTrail;
77
using ::google::api::expr::runtime::DirectExpressionStep;
78
using ::google::api::expr::runtime::ExecutionFrame;
79
using ::google::api::expr::runtime::ExecutionFrameBase;
80
using ::google::api::expr::runtime::ExpressionStepBase;
81
using ::google::api::expr::runtime::PlannerContext;
82
using ::google::api::expr::runtime::ProgramOptimizer;
83
84
// Represents a single select operation (field access or indexing).
85
// For struct-typed field accesses, includes the field name and the field
86
// number.
87
struct SelectInstruction {
88
  int64_t number;
89
  std::string name;
90
};
91
92
// Represents a single qualifier in a traversal path.
93
// TODO(uncreated-issue/51): support variable indexes.
94
using QualifierInstruction =
95
    std::variant<SelectInstruction, std::string, int64_t, uint64_t, bool>;
96
97
struct SelectPath {
98
  Expr* operand;
99
  std::vector<QualifierInstruction> select_instructions;
100
  bool test_only;
101
  // TODO(uncreated-issue/54): support for optionals.
102
};
103
104
// Generates the AST representation of the qualification path for the optimized
105
// select branch. I.e., the list-typed second argument of the cel.@attribute
106
// call.
107
Expr MakeSelectPathExpr(
108
0
    const std::vector<QualifierInstruction>& select_instructions) {
109
0
  Expr result;
110
0
  auto& ast_list = result.mutable_list_expr().mutable_elements();
111
0
  ast_list.reserve(select_instructions.size());
112
0
  auto visitor = absl::Overload(
113
0
      [&](const SelectInstruction& instruction) {
114
0
        Expr ast_instruction;
115
0
        Expr field_number;
116
0
        field_number.mutable_const_expr().set_int64_value(instruction.number);
117
0
        Expr field_name;
118
0
        field_name.mutable_const_expr().set_string_value(instruction.name);
119
0
        auto& field_specifier =
120
0
            ast_instruction.mutable_list_expr().mutable_elements();
121
0
        field_specifier.emplace_back().set_expr(std::move(field_number));
122
0
        field_specifier.emplace_back().set_expr(std::move(field_name));
123
124
0
        ast_list.emplace_back().set_expr(std::move(ast_instruction));
125
0
      },
126
0
      [&](absl::string_view instruction) {
127
0
        Expr const_expr;
128
0
        const_expr.mutable_const_expr().set_string_value(instruction);
129
0
        ast_list.emplace_back().set_expr(std::move(const_expr));
130
0
      },
131
0
      [&](int64_t instruction) {
132
0
        Expr const_expr;
133
0
        const_expr.mutable_const_expr().set_int64_value(instruction);
134
0
        ast_list.emplace_back().set_expr(std::move(const_expr));
135
0
      },
136
0
      [&](uint64_t instruction) {
137
0
        Expr const_expr;
138
0
        const_expr.mutable_const_expr().set_uint64_value(instruction);
139
0
        ast_list.emplace_back().set_expr(std::move(const_expr));
140
0
      },
141
0
      [&](bool instruction) {
142
0
        Expr const_expr;
143
0
        const_expr.mutable_const_expr().set_bool_value(instruction);
144
0
        ast_list.emplace_back().set_expr(std::move(const_expr));
145
0
      });
146
147
0
  for (const auto& instruction : select_instructions) {
148
0
    absl::visit(visitor, instruction);
149
0
  }
150
0
  return result;
151
0
}
152
153
// Returns a single select operation based on the inferred type of the operand
154
// and the field name. If the operand type doesn't define the field, returns
155
// nullopt.
156
std::optional<SelectInstruction> GetSelectInstruction(
157
    const StructType& runtime_type, PlannerContext& planner_context,
158
0
    absl::string_view field_name) {
159
0
  auto field_or = planner_context.type_reflector()
160
0
                      .FindStructTypeFieldByName(runtime_type, field_name)
161
0
                      .value_or(absl::nullopt);
162
0
  if (field_or.has_value()) {
163
0
    return SelectInstruction{field_or->number(), std::string(field_or->name())};
164
0
  }
165
0
  return absl::nullopt;
166
0
}
167
168
0
absl::StatusOr<SelectQualifier> SelectQualifierFromList(const ListExpr& list) {
169
0
  if (list.elements().size() != 2) {
170
0
    return absl::InvalidArgumentError("Invalid cel.attribute select list");
171
0
  }
172
173
0
  const Expr& field_number = list.elements()[0].expr();
174
0
  const Expr& field_name = list.elements()[1].expr();
175
176
0
  if (!field_number.has_const_expr() ||
177
0
      !field_number.const_expr().has_int64_value()) {
178
0
    return absl::InvalidArgumentError(
179
0
        "Invalid cel.attribute field select number");
180
0
  }
181
182
0
  if (!field_name.has_const_expr() ||
183
0
      !field_name.const_expr().has_string_value()) {
184
0
    return absl::InvalidArgumentError(
185
0
        "Invalid cel.attribute field select name");
186
0
  }
187
188
0
  return FieldSpecifier{field_number.const_expr().int64_value(),
189
0
                        field_name.const_expr().string_value()};
190
0
}
191
192
// Returns a qualifier instruction derived from a unoptimized ast.
193
absl::StatusOr<QualifierInstruction> SelectInstructionFromConstant(
194
0
    const Constant& constant) {
195
0
  if (constant.has_int_value()) {
196
0
    return QualifierInstruction(constant.int_value());
197
0
  } else if (constant.has_uint_value()) {
198
0
    return QualifierInstruction(constant.uint_value());
199
0
  } else if (constant.has_bool_value()) {
200
0
    return QualifierInstruction(constant.bool_value());
201
0
  } else if (constant.has_string_value()) {
202
0
    return QualifierInstruction(constant.string_value());
203
0
  } else if (constant.has_double_value()) {
204
0
    cel::internal::Number number(constant.double_value());
205
0
    if (number.LosslessConvertibleToInt()) {
206
0
      return QualifierInstruction(number.AsInt());
207
0
    } else if (number.LosslessConvertibleToUint()) {
208
0
      return QualifierInstruction(number.AsUint());
209
0
    }
210
0
  }
211
212
0
  return absl::InvalidArgumentError("invalid index constant for cel.attribute");
213
0
}
214
215
absl::StatusOr<SelectQualifier> SelectQualifierFromConstant(
216
0
    const Constant& constant) {
217
0
  if (constant.has_int_value()) {
218
0
    return AttributeQualifier::OfInt(constant.int_value());
219
0
  } else if (constant.has_uint_value()) {
220
0
    return AttributeQualifier::OfUint(constant.uint_value());
221
0
  } else if (constant.has_bool_value()) {
222
0
    return AttributeQualifier::OfBool(constant.bool_value());
223
0
  } else if (constant.has_string_value()) {
224
0
    return AttributeQualifier::OfString(constant.string_value());
225
0
  }
226
  // TODO(uncreated-issue/51): double keys could possibly be valid selectors, but
227
  // the other stacks don't implement the optimization yet and we normalize the
228
  // key to a uint or int if we do the late AST rewrite during planning.
229
230
0
  return absl::InvalidArgumentError("invalid cel.attribute constant");
231
0
}
232
233
0
absl::StatusOr<size_t> ListIndexFromQualifier(const AttributeQualifier& qual) {
234
0
  int64_t value = -1;
235
0
  switch (qual.kind()) {
236
0
    case Kind::kInt:
237
0
      value = *qual.GetInt64Key();
238
0
      break;
239
0
    default:
240
      // TODO(uncreated-issue/51): type-checker will reject an unsigned literal, but
241
      // should be supported as a dyn / variable.
242
0
      return runtime_internal::CreateNoMatchingOverloadError(
243
0
          cel::builtin::kIndex);
244
0
  }
245
246
0
  if (value < 0) {
247
0
    return absl::InvalidArgumentError("list index less than 0");
248
0
  }
249
250
0
  return static_cast<size_t>(value);
251
0
}
252
253
absl::StatusOr<Value> MapKeyFromQualifier(const AttributeQualifier& qual,
254
0
                                          google::protobuf::Arena* absl_nonnull arena) {
255
0
  switch (qual.kind()) {
256
0
    case Kind::kInt:
257
0
      return cel::IntValue(*qual.GetInt64Key());
258
0
    case Kind::kUint:
259
0
      return cel::UintValue(*qual.GetUint64Key());
260
0
    case Kind::kBool:
261
0
      return cel::BoolValue(*qual.GetBoolKey());
262
0
    case Kind::kString:
263
0
      return StringValue::From(*qual.GetStringKey(), arena);
264
0
    default:
265
0
      return runtime_internal::CreateNoMatchingOverloadError(
266
0
          cel::builtin::kIndex);
267
0
  }
268
0
}
269
270
absl::StatusOr<Value> ApplyQualifier(
271
    const Value& operand, const SelectQualifier& qualifier,
272
    const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
273
    google::protobuf::MessageFactory* absl_nonnull message_factory,
274
0
    google::protobuf::Arena* absl_nonnull arena) {
275
0
  return absl::visit(
276
0
      absl::Overload(
277
0
          [&](const FieldSpecifier& field_specifier) -> absl::StatusOr<Value> {
278
0
            if (!operand.Is<StructValue>()) {
279
0
              return cel::ErrorValue(
280
0
                  cel::runtime_internal::CreateNoMatchingOverloadError(
281
0
                      "<select>"));
282
0
            }
283
0
            return operand.GetStruct().GetFieldByName(
284
0
                field_specifier.name, descriptor_pool, message_factory, arena);
285
0
          },
286
0
          [&](const AttributeQualifier& qualifier) -> absl::StatusOr<Value> {
287
0
            if (operand.Is<ListValue>()) {
288
0
              auto index_or = ListIndexFromQualifier(qualifier);
289
0
              if (!index_or.ok()) {
290
0
                return cel::ErrorValue(index_or.status());
291
0
              }
292
0
              return operand.GetList().Get(*index_or, descriptor_pool,
293
0
                                           message_factory, arena);
294
0
            } else if (operand.Is<MapValue>()) {
295
0
              auto key_or = MapKeyFromQualifier(qualifier, arena);
296
0
              if (!key_or.ok()) {
297
0
                return cel::ErrorValue(key_or.status());
298
0
              }
299
0
              return operand.GetMap().Get(*key_or, descriptor_pool,
300
0
                                          message_factory, arena);
301
0
            }
302
0
            return cel::ErrorValue(
303
0
                cel::runtime_internal::CreateNoMatchingOverloadError(
304
0
                    cel::builtin::kIndex));
305
0
          }),
306
0
      qualifier);
307
0
}
308
309
absl::StatusOr<Value> FallbackSelect(
310
    const Value& root, absl::Span<const SelectQualifier> select_path,
311
    bool presence_test,
312
    const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
313
    google::protobuf::MessageFactory* absl_nonnull message_factory,
314
0
    google::protobuf::Arena* absl_nonnull arena) {
315
0
  const Value* elem = &root;
316
0
  Value result;
317
318
0
  for (const auto& instruction :
319
0
       select_path.subspan(0, select_path.size() - 1)) {
320
0
    CEL_ASSIGN_OR_RETURN(result,
321
0
                         ApplyQualifier(*elem, instruction, descriptor_pool,
322
0
                                        message_factory, arena));
323
0
    if (result->Is<ErrorValue>()) {
324
0
      return result;
325
0
    }
326
0
    elem = &result;
327
0
  }
328
329
0
  const auto& last_instruction = select_path.back();
330
0
  if (presence_test) {
331
0
    return absl::visit(
332
0
        absl::Overload(
333
0
            [&](const FieldSpecifier& field_specifier)
334
0
                -> absl::StatusOr<Value> {
335
0
              if (!elem->Is<StructValue>()) {
336
0
                return cel::ErrorValue(
337
0
                    cel::runtime_internal::CreateNoMatchingOverloadError(
338
0
                        "<select>"));
339
0
              }
340
0
              CEL_ASSIGN_OR_RETURN(
341
0
                  bool present,
342
0
                  elem->GetStruct().HasFieldByName(field_specifier.name));
343
0
              return cel::BoolValue(present);
344
0
            },
345
0
            [&](const AttributeQualifier& qualifier) -> absl::StatusOr<Value> {
346
0
              if (!elem->Is<MapValue>() || qualifier.kind() != Kind::kString) {
347
0
                return cel::ErrorValue(
348
0
                    cel::runtime_internal::CreateNoMatchingOverloadError(
349
0
                        "has"));
350
0
              }
351
352
0
              return elem->GetMap().Has(
353
0
                  StringValue(arena, *qualifier.GetStringKey()),
354
0
                  descriptor_pool, message_factory, arena);
355
0
            }),
356
0
        last_instruction);
357
0
  }
358
359
0
  return ApplyQualifier(*elem, last_instruction, descriptor_pool,
360
0
                        message_factory, arena);
361
0
}
362
363
absl::StatusOr<std::vector<SelectQualifier>> SelectInstructionsFromCall(
364
0
    const CallExpr& call) {
365
0
  if (call.args().size() < 2 || !call.args()[1].has_list_expr()) {
366
0
    return absl::InvalidArgumentError("Invalid cel.attribute call");
367
0
  }
368
0
  std::vector<SelectQualifier> instructions;
369
0
  const auto& ast_path = call.args()[1].list_expr().elements();
370
0
  instructions.reserve(ast_path.size());
371
372
0
  for (const ListExprElement& element : ast_path) {
373
    // Optimized field select.
374
0
    if (element.has_expr()) {
375
0
      const auto& element_expr = element.expr();
376
0
      if (element_expr.has_list_expr()) {
377
0
        CEL_ASSIGN_OR_RETURN(instructions.emplace_back(),
378
0
                             SelectQualifierFromList(element_expr.list_expr()));
379
0
      } else if (element_expr.has_const_expr()) {
380
0
        CEL_ASSIGN_OR_RETURN(
381
0
            instructions.emplace_back(),
382
0
            SelectQualifierFromConstant(element_expr.const_expr()));
383
0
      } else {
384
0
        return absl::InvalidArgumentError("Invalid cel.attribute call");
385
0
      }
386
0
    } else {
387
0
      return absl::InvalidArgumentError("Invalid cel.attribute call");
388
0
    }
389
0
  }
390
391
  // TODO(uncreated-issue/54): support for optionals.
392
393
0
  return instructions;
394
0
}
395
396
class RewriterImpl : public AstRewriterBase {
397
 public:
398
  RewriterImpl(const Ast& ast, PlannerContext& planner_context)
399
0
      : ast_(ast), planner_context_(planner_context) {}
400
401
0
  void PreVisitExpr(const Expr& expr) override { path_.push_back(&expr); }
402
403
0
  void PreVisitSelect(const Expr& expr, const SelectExpr& select) override {
404
0
    const Expr& operand = select.operand();
405
0
    const std::string& field_name = select.field();
406
    // Select optimization can generalize to lists and maps, but for now only
407
    // support message traversal.
408
0
    const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id());
409
410
0
    std::optional<Type> rt_type =
411
0
        (checker_type.has_message_type())
412
0
            ? GetRuntimeType(checker_type.message_type().type())
413
0
            : absl::nullopt;
414
0
    if (rt_type.has_value() && (*rt_type).Is<StructType>()) {
415
0
      const StructType& runtime_type = rt_type->GetStruct();
416
0
      std::optional<SelectInstruction> field_or =
417
0
          GetSelectInstruction(runtime_type, planner_context_, field_name);
418
0
      if (field_or.has_value()) {
419
0
        candidates_[&expr] = std::move(field_or).value();
420
0
      }
421
0
    } else if (checker_type.has_map_type()) {
422
0
      candidates_[&expr] = QualifierInstruction(field_name);
423
0
    }
424
    // else
425
    // TODO(uncreated-issue/54): add support for either dyn or any. Excluded to
426
    // simplify program plan.
427
0
  }
428
429
0
  void PreVisitCall(const Expr& expr, const CallExpr& call) override {
430
0
    if (call.args().size() != 2 || call.function() != ::cel::builtin::kIndex) {
431
0
      return;
432
0
    }
433
434
0
    const auto& qualifier_expr = call.args()[1];
435
0
    if (qualifier_expr.has_const_expr()) {
436
0
      auto qualifier_or =
437
0
          SelectInstructionFromConstant(qualifier_expr.const_expr());
438
0
      if (!qualifier_or.ok()) {
439
        // TODO(uncreated-issue/54): should warn, but by default warnings fail overall
440
        // program planning.
441
0
        return;
442
0
      }
443
0
      candidates_[&expr] = std::move(qualifier_or).value();
444
0
    }
445
    // TODO(uncreated-issue/54): support variable indexes
446
0
  }
447
448
0
  bool PostVisitRewrite(Expr& expr) override {
449
0
    if (!progress_status_.ok()) {
450
0
      return false;
451
0
    }
452
0
    path_.pop_back();
453
0
    auto candidate_iter = candidates_.find(&expr);
454
0
    if (candidate_iter == candidates_.end()) {
455
0
      return false;
456
0
    }
457
458
    // On post visit, filter candidates that aren't rooted on a message or a
459
    // select chain.
460
0
    const QualifierInstruction& candidate = candidate_iter->second;
461
0
    if (!HasOptimizeableRoot(&expr, candidate)) {
462
0
      candidates_.erase(candidate_iter);
463
0
      return false;
464
0
    }
465
466
0
    if (!path_.empty() && candidates_.find(path_.back()) != candidates_.end()) {
467
      // parent is optimizeable, defer rewriting until we consider the parent.
468
0
      return false;
469
0
    }
470
471
0
    SelectPath path = GetSelectPath(&expr);
472
473
    // generate the new cel.attribute call.
474
0
    absl::string_view fn = path.test_only ? kCelHasField : kCelAttribute;
475
476
0
    Expr operand(std::move(*path.operand));
477
0
    Expr call;
478
0
    call.set_id(expr.id());
479
0
    call.mutable_call_expr().set_function(std::string(fn));
480
0
    call.mutable_call_expr().mutable_args().reserve(2);
481
482
0
    call.mutable_call_expr().mutable_args().push_back(std::move(operand));
483
0
    call.mutable_call_expr().mutable_args().push_back(
484
0
        MakeSelectPathExpr(path.select_instructions));
485
486
    // TODO(uncreated-issue/54): support for optionals.
487
0
    expr = std::move(call);
488
489
0
    return true;
490
0
  }
491
492
0
  absl::Status GetProgressStatus() const { return progress_status_; }
493
494
 private:
495
0
  SelectPath GetSelectPath(Expr* expr) {
496
0
    SelectPath result;
497
0
    result.test_only = false;
498
0
    Expr* operand = expr;
499
0
    auto candidate_iter = candidates_.find(operand);
500
0
    while (candidate_iter != candidates_.end()) {
501
0
      result.select_instructions.push_back(candidate_iter->second);
502
0
      if (operand->has_select_expr()) {
503
0
        if (operand->select_expr().test_only()) {
504
0
          result.test_only = true;
505
0
        }
506
0
        operand = &(operand->mutable_select_expr().mutable_operand());
507
0
      } else {
508
0
        ABSL_DCHECK(operand->has_call_expr());
509
0
        operand = &(operand->mutable_call_expr().mutable_args()[0]);
510
0
      }
511
0
      candidate_iter = candidates_.find(operand);
512
0
    }
513
0
    absl::c_reverse(result.select_instructions);
514
0
    result.operand = operand;
515
0
    return result;
516
0
  }
517
518
  // Check whether the candidate has a message type as a root (the operand for
519
  // the batched select operation).
520
  // Called on post visit.
521
  bool HasOptimizeableRoot(const Expr* expr,
522
0
                           const QualifierInstruction& candidate) {
523
0
    if (absl::holds_alternative<SelectInstruction>(candidate)) {
524
0
      return true;
525
0
    }
526
0
    const Expr* operand = nullptr;
527
0
    if (expr->has_call_expr() && expr->call_expr().args().size() == 2 &&
528
0
        expr->call_expr().function() == ::cel::builtin::kIndex) {
529
0
      operand = &expr->call_expr().args()[0];
530
0
    } else if (expr->has_select_expr()) {
531
0
      operand = &expr->select_expr().operand();
532
0
    }
533
534
0
    if (operand == nullptr) {
535
0
      return false;
536
0
    }
537
538
0
    return candidates_.find(operand) != candidates_.end();
539
0
  }
540
541
0
  std::optional<Type> GetRuntimeType(absl::string_view type_name) {
542
0
    return planner_context_.type_reflector().FindType(type_name).value_or(
543
0
        absl::nullopt);
544
0
  }
545
546
0
  void SetProgressStatus(const absl::Status& status) {
547
0
    if (progress_status_.ok() && !status.ok()) {
548
0
      progress_status_ = status;
549
0
    }
550
0
  }
551
552
  const Ast& ast_;
553
  PlannerContext& planner_context_;
554
  // ids of potentially optimizeable expr nodes.
555
  absl::flat_hash_map<const Expr*, QualifierInstruction> candidates_;
556
  std::vector<const Expr*> path_;
557
  absl::Status progress_status_;
558
};
559
560
class OptimizedSelectImpl {
561
 public:
562
  OptimizedSelectImpl(std::vector<SelectQualifier> select_path,
563
                      std::vector<AttributeQualifier> qualifiers,
564
                      bool presence_test, SelectOptimizationOptions options)
565
0
      : select_path_(std::move(select_path)),
566
0
        qualifiers_(std::move(qualifiers)),
567
0
        presence_test_(presence_test),
568
0
        options_(options)
569
570
0
  {
571
0
    ABSL_DCHECK(!select_path_.empty());
572
0
  }
573
574
  // Move constructible.
575
  OptimizedSelectImpl(const OptimizedSelectImpl&) = delete;
576
  OptimizedSelectImpl& operator=(const OptimizedSelectImpl&) = delete;
577
0
  OptimizedSelectImpl(OptimizedSelectImpl&&) = default;
578
  OptimizedSelectImpl& operator=(OptimizedSelectImpl&&) = delete;
579
580
  absl::StatusOr<Value> ApplySelect(ExecutionFrameBase& frame,
581
                                    const StructValue& struct_value) const;
582
583
  AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const;
584
585
0
  std::optional<Attribute> attribute() const { return attribute_; }
586
587
0
  const std::vector<AttributeQualifier>& qualifiers() const {
588
0
    return qualifiers_;
589
0
  }
590
591
 private:
592
  std::optional<Attribute> attribute_;
593
  std::vector<SelectQualifier> select_path_;
594
  std::vector<AttributeQualifier> qualifiers_;
595
  bool presence_test_;
596
  SelectOptimizationOptions options_;
597
};
598
599
// Check for unknowns or missing attributes.
600
absl::StatusOr<std::optional<Value>> CheckForMarkedAttributes(
601
0
    ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) {
602
0
  if (attribute_trail.empty()) {
603
0
    return absl::nullopt;
604
0
  }
605
606
0
  if (frame.unknown_processing_enabled() &&
607
0
      frame.attribute_utility().CheckForUnknownExact(attribute_trail)) {
608
    // Check if the inferred attribute is marked. Only matches if this attribute
609
    // or a parent is marked unknown (use_partial = false).
610
    // Partial matches (i.e. descendant of this attribute is marked) aren't
611
    // considered yet in case another operation would select an unmarked
612
    // descended attribute.
613
    //
614
    // TODO(uncreated-issue/51): this may return a more specific attribute than the
615
    // declared pattern. Follow up will truncate the returned attribute to match
616
    // the pattern.
617
0
    return frame.attribute_utility().CreateUnknownSet(
618
0
        attribute_trail.attribute());
619
0
  }
620
621
0
  if (frame.missing_attribute_errors_enabled() &&
622
0
      frame.attribute_utility().CheckForMissingAttribute(attribute_trail)) {
623
0
    return frame.attribute_utility().CreateMissingAttributeError(
624
0
        attribute_trail.attribute());
625
0
  }
626
627
0
  return absl::nullopt;
628
0
}
629
630
absl::StatusOr<Value> OptimizedSelectImpl::ApplySelect(
631
0
    ExecutionFrameBase& frame, const StructValue& struct_value) const {
632
0
  auto value_or =
633
0
      (options_.force_fallback_implementation)
634
0
          ? absl::UnimplementedError("Forced fallback impl")
635
0
          : struct_value.Qualify(select_path_, presence_test_,
636
0
                                 frame.descriptor_pool(),
637
0
                                 frame.message_factory(), frame.arena());
638
639
0
  if (!value_or.ok()) {
640
0
    if (value_or.status().code() == absl::StatusCode::kUnimplemented) {
641
0
      return FallbackSelect(struct_value, select_path_, presence_test_,
642
0
                            frame.descriptor_pool(), frame.message_factory(),
643
0
                            frame.arena());
644
0
    }
645
646
0
    return value_or.status();
647
0
  }
648
649
0
  if (value_or->second < 0 || value_or->second >= select_path_.size()) {
650
0
    return std::move(value_or->first);
651
0
  }
652
653
0
  return FallbackSelect(
654
0
      value_or->first,
655
0
      absl::MakeConstSpan(select_path_).subspan(value_or->second),
656
0
      presence_test_, frame.descriptor_pool(), frame.message_factory(),
657
0
      frame.arena());
658
0
}
659
660
AttributeTrail OptimizedSelectImpl::GetAttributeTrail(
661
0
    const AttributeTrail& operand_trail) const {
662
0
  if (operand_trail.empty()) {
663
0
    return AttributeTrail();
664
0
  }
665
0
  std::vector<AttributeQualifier> qualifiers = std::vector<AttributeQualifier>(
666
0
      operand_trail.attribute().qualifier_path().begin(),
667
0
      operand_trail.attribute().qualifier_path().end());
668
0
  qualifiers.reserve(qualifiers_.size() + qualifiers.size());
669
0
  absl::c_copy(qualifiers_, std::back_inserter(qualifiers));
670
0
  return AttributeTrail(
671
0
      Attribute(std::string(operand_trail.attribute().variable_name()),
672
0
                std::move(qualifiers)));
673
0
}
674
675
class StackMachineImpl : public ExpressionStepBase {
676
 public:
677
  StackMachineImpl(int expr_id, OptimizedSelectImpl impl)
678
0
      : ExpressionStepBase(expr_id), impl_(std::move(impl)) {}
679
680
  absl::Status Evaluate(ExecutionFrame* frame) const override;
681
682
 private:
683
  // Get the effective attribute for the optimized select expression.
684
  // Assumes the operand is the top of stack if the attribute wasn't known at
685
  // plan time.
686
  AttributeTrail GetAttributeTrail(ExecutionFrame* frame) const;
687
688
  OptimizedSelectImpl impl_;
689
};
690
691
AttributeTrail StackMachineImpl::GetAttributeTrail(
692
0
    ExecutionFrame* frame) const {
693
0
  const auto& attr = frame->value_stack().PeekAttribute();
694
0
  return impl_.GetAttributeTrail(attr);
695
0
}
696
697
0
absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const {
698
  // Default empty.
699
0
  AttributeTrail attribute_trail;
700
  // TODO(uncreated-issue/51): add support for variable qualifiers and string literal
701
  // variable names.
702
0
  constexpr size_t kStackInputs = 1;
703
704
  // For now, we expect the operand to be top of stack.
705
0
  const Value& operand = frame->value_stack().Peek();
706
707
0
  if (operand->Is<ErrorValue>() || operand->Is<UnknownValue>()) {
708
    // Just forward the error which is already top of stack.
709
0
    return absl::OkStatus();
710
0
  }
711
712
0
  if (frame->enable_attribute_tracking()) {
713
    // Compute the attribute trail then check for any marked values.
714
    // When possible, this is computed at plan time based on the optimized
715
    // select arguments.
716
    // TODO(uncreated-issue/51): add support variable qualifiers
717
0
    attribute_trail = GetAttributeTrail(frame);
718
0
    CEL_ASSIGN_OR_RETURN(std::optional<Value> value,
719
0
                         CheckForMarkedAttributes(*frame, attribute_trail));
720
0
    if (value.has_value()) {
721
0
      frame->value_stack().Pop(kStackInputs);
722
0
      frame->value_stack().Push(std::move(value).value(),
723
0
                                std::move(attribute_trail));
724
0
      return absl::OkStatus();
725
0
    }
726
0
  }
727
728
0
  if (!operand->Is<StructValue>()) {
729
0
    return absl::InvalidArgumentError(
730
0
        "Expected struct type for select optimization.");
731
0
  }
732
733
0
  CEL_ASSIGN_OR_RETURN(Value result,
734
0
                       impl_.ApplySelect(*frame, operand.GetStruct()));
735
736
0
  frame->value_stack().Pop(kStackInputs);
737
0
  frame->value_stack().Push(std::move(result), std::move(attribute_trail));
738
0
  return absl::OkStatus();
739
0
}
740
741
class RecursiveImpl : public DirectExpressionStep {
742
 public:
743
  RecursiveImpl(int64_t expr_id, std::unique_ptr<DirectExpressionStep> operand,
744
                OptimizedSelectImpl impl)
745
0
      : DirectExpressionStep(expr_id),
746
0
        operand_(std::move(operand)),
747
0
        impl_(std::move(impl)) {}
748
749
  absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
750
                        AttributeTrail& attribute) const override;
751
752
 private:
753
  // Get the effective attribute for the optimized select expression.
754
  // Assumes the operand is the top of stack if the attribute wasn't known at
755
  // plan time.
756
  AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const;
757
  std::unique_ptr<DirectExpressionStep> operand_;
758
  OptimizedSelectImpl impl_;
759
};
760
761
AttributeTrail RecursiveImpl::GetAttributeTrail(
762
0
    const AttributeTrail& operand_trail) const {
763
0
  return impl_.GetAttributeTrail(operand_trail);
764
0
}
765
766
absl::Status RecursiveImpl::Evaluate(ExecutionFrameBase& frame, Value& result,
767
0
                                     AttributeTrail& attribute) const {
768
0
  CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute));
769
770
0
  if (InstanceOf<ErrorValue>(result) || InstanceOf<UnknownValue>(result)) {
771
    // Just forward.
772
0
    return absl::OkStatus();
773
0
  }
774
775
0
  if (frame.attribute_tracking_enabled()) {
776
0
    attribute = impl_.GetAttributeTrail(attribute);
777
0
    CEL_ASSIGN_OR_RETURN(auto value,
778
0
                         CheckForMarkedAttributes(frame, attribute));
779
0
    if (value.has_value()) {
780
0
      result = std::move(value).value();
781
0
      return absl::OkStatus();
782
0
    }
783
0
  }
784
785
0
  if (!InstanceOf<StructValue>(result)) {
786
0
    return absl::InvalidArgumentError(
787
0
        "Expected struct type for select optimization");
788
0
  }
789
0
  CEL_ASSIGN_OR_RETURN(result,
790
0
                       impl_.ApplySelect(frame, Cast<StructValue>(result)));
791
0
  return absl::OkStatus();
792
0
}
793
794
class SelectOptimizer : public ProgramOptimizer {
795
 public:
796
  explicit SelectOptimizer(const SelectOptimizationOptions& options)
797
0
      : options_(options) {}
798
799
0
  absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override {
800
0
    return absl::OkStatus();
801
0
  }
802
803
  absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override;
804
805
 private:
806
  SelectOptimizationOptions options_;
807
};
808
809
absl::Status SelectOptimizer::OnPostVisit(PlannerContext& context,
810
0
                                          const Expr& node) {
811
0
  if (!node.has_call_expr()) {
812
0
    return absl::OkStatus();
813
0
  }
814
815
0
  absl::string_view fn = node.call_expr().function();
816
0
  if (fn != kCelHasField && fn != kCelAttribute) {
817
0
    return absl::OkStatus();
818
0
  }
819
820
0
  if (node.call_expr().args().size() < 2 ||
821
0
      node.call_expr().args().size() > 3) {
822
0
    return absl::InvalidArgumentError("Invalid cel.attribute call");
823
0
  }
824
825
0
  if (node.call_expr().args().size() == 3) {
826
0
    return absl::UnimplementedError("Optionals not yet supported");
827
0
  }
828
829
0
  CEL_ASSIGN_OR_RETURN(std::vector<SelectQualifier> instructions,
830
0
                       SelectInstructionsFromCall(node.call_expr()));
831
832
0
  if (instructions.empty()) {
833
0
    return absl::InvalidArgumentError("Invalid cel.attribute no select steps.");
834
0
  }
835
836
0
  bool presence_test = false;
837
838
0
  if (fn == kCelHasField) {
839
0
    presence_test = true;
840
0
  }
841
842
0
  const Expr& operand = node.call_expr().args()[0];
843
0
  absl::string_view identifier;
844
0
  if (operand.has_ident_expr()) {
845
0
    identifier = operand.ident_expr().name();
846
0
  }
847
848
0
  if (absl::StrContains(identifier, ".")) {
849
0
    return absl::UnimplementedError("qualified identifiers not supported.");
850
0
  }
851
852
0
  std::vector<AttributeQualifier> qualifiers;
853
0
  qualifiers.reserve(instructions.size());
854
0
  for (const auto& instruction : instructions) {
855
0
    qualifiers.push_back(
856
0
        absl::visit(absl::Overload(
857
0
                        [](const FieldSpecifier& field) {
858
0
                          return AttributeQualifier::OfString(field.name);
859
0
                        },
860
0
                        [](const AttributeQualifier& q) { return q; }),
861
0
                    instruction));
862
0
  }
863
864
  // TODO(uncreated-issue/51): If the first argument is a string literal, the custom
865
  // step needs to handle variable lookup.
866
0
  auto* subexpression = context.program_builder().GetSubexpression(&node);
867
0
  if (subexpression == nullptr || subexpression->IsFlattened()) {
868
    // No information on the subprogram, can't optimize.
869
0
    return absl::OkStatus();
870
0
  }
871
872
0
  OptimizedSelectImpl impl(std::move(instructions), std::move(qualifiers),
873
0
                           presence_test, options_);
874
875
0
  if (subexpression->IsRecursive()) {
876
0
    auto program = subexpression->ExtractRecursiveProgram();
877
0
    auto deps = program.step->ExtractDependencies();
878
0
    if (!deps.has_value() || deps->empty()) {
879
0
      return absl::InvalidArgumentError("Unexpected cel.@attribute call");
880
0
    }
881
0
    subexpression->set_recursive_program(
882
0
        std::make_unique<RecursiveImpl>(node.id(), std::move(deps->at(0)),
883
0
                                        std::move(impl)),
884
0
        program.depth);
885
0
    return absl::OkStatus();
886
0
  }
887
888
0
  google::api::expr::runtime::ExecutionPath path;
889
890
  // else, we need to preserve the original plan for the first argument.
891
0
  if (context.GetSubplan(operand).empty()) {
892
    // Indicates another extension modified the step. Nothing to do here.
893
0
    return absl::OkStatus();
894
0
  }
895
0
  CEL_ASSIGN_OR_RETURN(auto operand_subplan, context.ExtractSubplan(operand));
896
0
  absl::c_move(operand_subplan, std::back_inserter(path));
897
898
0
  path.push_back(
899
0
      std::make_unique<StackMachineImpl>(node.id(), std::move(impl)));
900
901
0
  return context.ReplaceSubplan(node, std::move(path));
902
0
}
903
904
google::api::expr::runtime::FlatExprBuilder* GetFlatExprBuilder(
905
0
    RuntimeBuilder& builder) {
906
0
  auto& runtime =
907
0
      runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder);
908
0
  if (runtime_internal::RuntimeFriendAccess::RuntimeTypeId(runtime) ==
909
0
      NativeTypeId::For<runtime_internal::RuntimeImpl>()) {
910
0
    auto& runtime_impl =
911
0
        cel::internal::down_cast<runtime_internal::RuntimeImpl&>(runtime);
912
0
    return &runtime_impl.expr_builder();
913
0
  }
914
0
  return nullptr;
915
0
}
916
917
}  // namespace
918
919
absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context,
920
0
                                                     Ast& ast) const {
921
0
  RewriterImpl rewriter(ast, context);
922
0
  AstRewrite(ast.mutable_root_expr(), rewriter);
923
0
  return rewriter.GetProgressStatus();
924
0
}
925
926
google::api::expr::runtime::ProgramOptimizerFactory
927
CreateSelectOptimizationProgramOptimizer(
928
0
    const SelectOptimizationOptions& options) {
929
0
  return [=](PlannerContext& context, const Ast& ast) {
930
0
    return std::make_unique<SelectOptimizer>(options);
931
0
  };
932
0
}
933
934
absl::Status EnableSelectOptimization(
935
0
    cel::RuntimeBuilder& builder, const SelectOptimizationOptions& options) {
936
0
  auto* flat_expr_builder = GetFlatExprBuilder(builder);
937
0
  if (flat_expr_builder == nullptr) {
938
0
    return absl::InvalidArgumentError(
939
0
        "SelectOptimization requires default runtime implementation");
940
0
  }
941
942
0
  flat_expr_builder->AddAstTransform(
943
0
      std::make_unique<SelectOptimizationAstUpdater>());
944
  // Add overloads for select optimization signature.
945
  // These are never bound, only used to prevent the builder from failing on
946
  // the overloads check.
947
0
  CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction(
948
0
      FunctionDescriptor(kCelAttribute, false, {Kind::kAny, Kind::kList})));
949
950
0
  CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction(
951
0
      FunctionDescriptor(kCelHasField, false, {Kind::kAny, Kind::kList})));
952
  // Add runtime implementation.
953
0
  flat_expr_builder->AddProgramOptimizer(
954
0
      CreateSelectOptimizationProgramOptimizer(options));
955
0
  return absl::OkStatus();
956
0
}
957
958
}  // namespace cel::extensions