Coverage Report

Created: 2026-05-27 07:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/proc/self/cwd/runtime/function_registry.cc
Line
Count
Source
1
// Copyright 2023 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "runtime/function_registry.h"
16
17
#include <cstddef>
18
#include <memory>
19
#include <string>
20
#include <utility>
21
#include <vector>
22
23
#include "absl/container/flat_hash_map.h"
24
#include "absl/container/node_hash_map.h"
25
#include "absl/status/status.h"
26
#include "absl/status/statusor.h"
27
#include "absl/strings/string_view.h"
28
#include "absl/types/optional.h"
29
#include "absl/types/span.h"
30
#include "common/function_descriptor.h"
31
#include "common/kind.h"
32
#include "runtime/activation_interface.h"
33
#include "runtime/function.h"
34
#include "runtime/function_overload_reference.h"
35
#include "runtime/function_provider.h"
36
37
namespace cel {
38
namespace {
39
40
// Impl for simple provider that looks up functions in an activation function
41
// registry.
42
class ActivationFunctionProviderImpl
43
    : public cel::runtime_internal::FunctionProvider {
44
 public:
45
0
  ActivationFunctionProviderImpl() = default;
46
47
  absl::StatusOr<std::optional<FunctionOverloadReference>> GetFunction(
48
      const cel::FunctionDescriptor& descriptor,
49
0
      const cel::ActivationInterface& activation) const override {
50
0
    std::vector<cel::FunctionOverloadReference> overloads =
51
0
        activation.FindFunctionOverloads(descriptor.name());
52
53
0
    std::optional<FunctionOverloadReference> matching_overload = absl::nullopt;
54
55
0
    for (const auto& overload : overloads) {
56
0
      if (overload.descriptor.ShapeMatches(descriptor)) {
57
0
        if (matching_overload.has_value()) {
58
0
          return absl::Status(absl::StatusCode::kInvalidArgument,
59
0
                              "Couldn't resolve function.");
60
0
        }
61
0
        matching_overload.emplace(overload);
62
0
      }
63
0
    }
64
65
0
    return matching_overload;
66
0
  }
67
};
68
69
// Create a CelFunctionProvider that just looks up the functions inserted in the
70
// Activation. This is a convenience implementation for a simple, common
71
// use-case.
72
std::unique_ptr<cel::runtime_internal::FunctionProvider>
73
0
CreateActivationFunctionProvider() {
74
0
  return std::make_unique<ActivationFunctionProviderImpl>();
75
0
}
76
77
}  // namespace
78
79
absl::Status FunctionRegistry::Register(
80
    const cel::FunctionDescriptor& descriptor,
81
2.55M
    std::unique_ptr<cel::Function> implementation) {
82
2.55M
  if (DescriptorRegistered(descriptor)) {
83
0
    return absl::Status(
84
0
        absl::StatusCode::kAlreadyExists,
85
0
        "CelFunction with specified parameters already registered");
86
0
  }
87
2.55M
  if (!ValidateNonStrictOverload(descriptor)) {
88
0
    return absl::Status(absl::StatusCode::kAlreadyExists,
89
0
                        "Only one overload is allowed for non-strict function");
90
0
  }
91
92
2.55M
  auto& overloads = functions_[descriptor.name()];
93
2.55M
  overloads.static_overloads.push_back(
94
2.55M
      StaticFunctionEntry(descriptor, std::move(implementation)));
95
2.55M
  return absl::OkStatus();
96
2.55M
}
97
98
absl::Status FunctionRegistry::RegisterLazyFunction(
99
0
    const cel::FunctionDescriptor& descriptor) {
100
0
  if (DescriptorRegistered(descriptor)) {
101
0
    return absl::Status(
102
0
        absl::StatusCode::kAlreadyExists,
103
0
        "CelFunction with specified parameters already registered");
104
0
  }
105
0
  if (!ValidateNonStrictOverload(descriptor)) {
106
0
    return absl::Status(absl::StatusCode::kAlreadyExists,
107
0
                        "Only one overload is allowed for non-strict function");
108
0
  }
109
0
  auto& overloads = functions_[descriptor.name()];
110
111
0
  overloads.lazy_overloads.push_back(
112
0
      LazyFunctionEntry(descriptor, CreateActivationFunctionProvider()));
113
114
0
  return absl::OkStatus();
115
0
}
116
117
std::vector<cel::FunctionOverloadReference>
118
FunctionRegistry::FindStaticOverloads(absl::string_view name,
119
                                      bool receiver_style,
120
10.3k
                                      absl::Span<const cel::Kind> types) const {
121
10.3k
  std::vector<cel::FunctionOverloadReference> matched_funcs;
122
123
10.3k
  auto overloads = functions_.find(name);
124
10.3k
  if (overloads == functions_.end()) {
125
10.3k
    return matched_funcs;
126
10.3k
  }
127
128
0
  for (const auto& overload : overloads->second.static_overloads) {
129
0
    if (overload.descriptor->ShapeMatches(receiver_style, types)) {
130
0
      matched_funcs.push_back({*overload.descriptor, *overload.implementation});
131
0
    }
132
0
  }
133
134
0
  return matched_funcs;
135
10.3k
}
136
137
std::vector<cel::FunctionOverloadReference>
138
FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name,
139
                                             bool receiver_style,
140
138k
                                             size_t arity) const {
141
138k
  std::vector<cel::FunctionOverloadReference> matched_funcs;
142
143
138k
  auto overloads = functions_.find(name);
144
138k
  if (overloads == functions_.end()) {
145
115
    return matched_funcs;
146
115
  }
147
148
1.05M
  for (const auto& overload : overloads->second.static_overloads) {
149
1.05M
    if (overload.descriptor->receiver_style() == receiver_style &&
150
1.05M
        overload.descriptor->types().size() == arity) {
151
1.05M
      matched_funcs.push_back({*overload.descriptor, *overload.implementation});
152
1.05M
    }
153
1.05M
  }
154
155
138k
  return matched_funcs;
156
138k
}
157
158
std::vector<FunctionRegistry::LazyOverload> FunctionRegistry::FindLazyOverloads(
159
    absl::string_view name, bool receiver_style,
160
0
    absl::Span<const cel::Kind> types) const {
161
0
  std::vector<FunctionRegistry::LazyOverload> matched_funcs;
162
163
0
  auto overloads = functions_.find(name);
164
0
  if (overloads == functions_.end()) {
165
0
    return matched_funcs;
166
0
  }
167
168
0
  for (const auto& entry : overloads->second.lazy_overloads) {
169
0
    if (entry.descriptor->ShapeMatches(receiver_style, types)) {
170
0
      matched_funcs.push_back({*entry.descriptor, *entry.function_provider});
171
0
    }
172
0
  }
173
174
0
  return matched_funcs;
175
0
}
176
177
std::vector<FunctionRegistry::LazyOverload>
178
FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name,
179
                                           bool receiver_style,
180
138k
                                           size_t arity) const {
181
138k
  std::vector<FunctionRegistry::LazyOverload> matched_funcs;
182
183
138k
  auto overloads = functions_.find(name);
184
138k
  if (overloads == functions_.end()) {
185
115
    return matched_funcs;
186
115
  }
187
188
138k
  for (const auto& entry : overloads->second.lazy_overloads) {
189
0
    if (entry.descriptor->receiver_style() == receiver_style &&
190
0
        entry.descriptor->types().size() == arity) {
191
0
      matched_funcs.push_back({*entry.descriptor, *entry.function_provider});
192
0
    }
193
0
  }
194
195
138k
  return matched_funcs;
196
138k
}
197
198
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
199
0
FunctionRegistry::ListFunctions() const {
200
0
  absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
201
0
      descriptor_map;
202
203
0
  for (const auto& entry : functions_) {
204
0
    std::vector<const cel::FunctionDescriptor*> descriptors;
205
0
    const RegistryEntry& function_entry = entry.second;
206
0
    descriptors.reserve(function_entry.static_overloads.size() +
207
0
                        function_entry.lazy_overloads.size());
208
0
    for (const auto& entry : function_entry.static_overloads) {
209
0
      descriptors.push_back(entry.descriptor.get());
210
0
    }
211
0
    for (const auto& entry : function_entry.lazy_overloads) {
212
0
      descriptors.push_back(entry.descriptor.get());
213
0
    }
214
0
    descriptor_map[entry.first] = std::move(descriptors);
215
0
  }
216
217
0
  return descriptor_map;
218
0
}
219
220
bool FunctionRegistry::DescriptorRegistered(
221
2.55M
    const cel::FunctionDescriptor& descriptor) const {
222
2.55M
  auto overloads = functions_.find(descriptor.name());
223
2.55M
  if (overloads == functions_.end()) {
224
609k
    return false;
225
609k
  }
226
1.94M
  const RegistryEntry& entry = overloads->second;
227
8.40M
  for (const auto& static_ovl : entry.static_overloads) {
228
8.40M
    if (static_ovl.descriptor->ShapeMatches(descriptor)) {
229
0
      return true;
230
0
    }
231
8.40M
  }
232
1.94M
  for (const auto& lazy_ovl : entry.lazy_overloads) {
233
0
    if (lazy_ovl.descriptor->ShapeMatches(descriptor)) {
234
0
      return true;
235
0
    }
236
0
  }
237
1.94M
  return false;
238
1.94M
}
239
240
bool FunctionRegistry::ValidateNonStrictOverload(
241
2.55M
    const cel::FunctionDescriptor& descriptor) const {
242
2.55M
  auto overloads = functions_.find(descriptor.name());
243
2.55M
  if (overloads == functions_.end()) {
244
609k
    return true;
245
609k
  }
246
1.94M
  const RegistryEntry& entry = overloads->second;
247
1.94M
  if (!descriptor.is_strict()) {
248
    // If the newly added overload is a non-strict function, we require that
249
    // there are no other overloads, which is not possible here.
250
0
    return false;
251
0
  }
252
  // If the newly added overload is a strict function, we need to make sure
253
  // that no previous overloads are registered non-strict. If the list of
254
  // overload is not empty, we only need to check the first overload. This is
255
  // because if the first overload is strict, other overloads must also be
256
  // strict by the rule.
257
1.94M
  return (entry.static_overloads.empty() ||
258
1.94M
          entry.static_overloads[0].descriptor->is_strict()) &&
259
1.94M
         (entry.lazy_overloads.empty() ||
260
0
          entry.lazy_overloads[0].descriptor->is_strict());
261
1.94M
}
262
263
}  // namespace cel