Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/protobuf/compare.py: 16%
116 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utility functions for comparing proto2 messages in Python.
18ProtoEq() compares two proto2 messages for equality.
20ClearDefaultValuedFields() recursively clears the fields that are set to their
21default values. This is useful for comparing protocol buffers where the
22semantics of unset fields and default valued fields are the same.
24assertProtoEqual() is useful for unit tests. It produces much more helpful
25output than assertEqual() for proto2 messages, e.g. this:
27 outer {
28 inner {
29- strings: "x"
30? ^
31+ strings: "y"
32? ^
33 }
34 }
36...compared to the default output from assertEqual() that looks like this:
38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
40Call it inside your unit test's googletest.TestCase subclasses like this:
42 from tensorflow.python.util.protobuf import compare
44 class MyTest(googletest.TestCase):
45 ...
46 def testXXX(self):
47 ...
48 compare.assertProtoEqual(self, a, b)
50Alternatively:
52 from tensorflow.python.util.protobuf import compare
54 class MyTest(compare.ProtoAssertions, googletest.TestCase):
55 ...
56 def testXXX(self):
57 ...
58 self.assertProtoEqual(a, b)
59"""
61import difflib
62import math
64from ..compat import collections_abc
65import six
67from google.protobuf import descriptor
68from google.protobuf import descriptor_pool
69from google.protobuf import message
70from google.protobuf import text_format
73# TODO(alankelly): Distinguish between signalling and quiet NaNs.
74def isClose(x, y, relative_tolerance): # pylint: disable=invalid-name
75 """Returns True if x is close to y given the relative tolerance or if x and y are both inf, both -inf, or both NaNs.
77 This function does not distinguish between signalling and non-signalling NaN.
79 Args:
80 x: float value to be compared
81 y: float value to be compared
82 relative_tolerance: float. The allowable difference between the two values
83 being compared is determined by multiplying the relative tolerance by the
84 maximum of the two values. If this is not provided, then all floats are
85 compared using string comparison.
86 """
87 # NaNs are considered equal.
88 if math.isnan(x) or math.isnan(y):
89 return math.isnan(x) == math.isnan(y)
91 if math.isinf(x) or math.isinf(y):
92 return x == y
94 return abs(x - y) <= relative_tolerance * max(abs(x), abs(y))
97def checkFloatEqAndReplace(self, expected, actual, relative_tolerance): # pylint: disable=invalid-name
98 """Recursively replaces the floats in actual with those in expected iff they are approximately equal.
100 This is done because string equality will consider values such as 5.0999999999
101 and 5.1 as not being equal, despite being extremely close.
103 Args:
104 self: googletest.TestCase
105 expected: expected values
106 actual: actual values
107 relative_tolerance: float, relative tolerance.
108 """
110 for expected_fields, actual_fields in zip(
111 expected.ListFields(), actual.ListFields()
112 ):
113 is_repeated = True
114 expected_desc, expected_values = expected_fields
115 actual_values = actual_fields[1]
116 if expected_desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
117 is_repeated = False
118 expected_values = [expected_values]
119 actual_values = [actual_values]
121 if (
122 expected_desc.type == descriptor.FieldDescriptor.TYPE_FLOAT
123 or expected_desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE
124 ):
125 for i, (x, y) in enumerate(zip(expected_values, actual_values)):
126 # Replace the actual value with the expected value if the test passes,
127 # otherwise leave it and let it fail in the next test so that the error
128 # message is nicely formatted
129 if isClose(x, y, relative_tolerance):
130 if is_repeated:
131 getattr(actual, actual_fields[0].name)[i] = x
132 else:
133 setattr(actual, actual_fields[0].name, x)
135 if (
136 expected_desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
137 or expected_desc.type == descriptor.FieldDescriptor.TYPE_GROUP
138 ):
139 if (
140 expected_desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
141 and expected_desc.message_type.has_options
142 and expected_desc.message_type.GetOptions().map_entry
143 ):
144 # This is a map, only recurse if it has type message type.
145 if (
146 expected_desc.message_type.fields_by_number[2].type
147 == descriptor.FieldDescriptor.TYPE_MESSAGE
148 ):
149 for e_v, a_v in zip(
150 six.itervalues(expected_values), six.itervalues(actual_values)
151 ):
152 checkFloatEqAndReplace(
153 self,
154 expected=e_v,
155 actual=a_v,
156 relative_tolerance=relative_tolerance,
157 )
158 else:
159 for v, a in zip(expected_values, actual_values):
160 # recursive step
161 checkFloatEqAndReplace(
162 self, expected=v, actual=a, relative_tolerance=relative_tolerance
163 )
166def assertProtoEqual(
167 self,
168 a,
169 b,
170 check_initialized=True,
171 normalize_numbers=False,
172 msg=None,
173 relative_tolerance=None,
174): # pylint: disable=invalid-name(
175 """Fails with a useful error if a and b aren't equal.
177 Comparison of repeated fields matches the semantics of
178 unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
180 Args:
181 self: googletest.TestCase
182 a: proto2 PB instance, or text string representing one.
183 b: proto2 PB instance -- message.Message or subclass thereof.
184 check_initialized: boolean, whether to fail if either a or b isn't
185 initialized.
186 normalize_numbers: boolean, whether to normalize types and precision of
187 numbers before comparison.
188 msg: if specified, is used as the error message on failure.
189 relative_tolerance: float, relative tolerance. If this is not provided, then
190 all floats are compared using string comparison otherwise, floating point
191 comparisons are done using the relative tolerance provided.
192 """
193 pool = descriptor_pool.Default()
194 if isinstance(a, six.string_types):
195 a = text_format.Parse(a, b.__class__(), descriptor_pool=pool)
197 for pb in a, b:
198 if check_initialized:
199 errors = pb.FindInitializationErrors()
200 if errors:
201 self.fail('Initialization errors: %s\n%s' % (errors, pb))
202 if normalize_numbers:
203 NormalizeNumberFields(pb)
205 if relative_tolerance is not None:
206 checkFloatEqAndReplace(
207 self, expected=b, actual=a, relative_tolerance=relative_tolerance
208 )
210 a_str = text_format.MessageToString(a, descriptor_pool=pool)
211 b_str = text_format.MessageToString(b, descriptor_pool=pool)
213 # Some Python versions would perform regular diff instead of multi-line
214 # diff if string is longer than 2**16. We substitute this behavior
215 # with a call to unified_diff instead to have easier-to-read diffs.
216 # For context, see: https://bugs.python.org/issue11763.
217 if len(a_str) < 2**16 and len(b_str) < 2**16:
218 self.assertMultiLineEqual(a_str, b_str, msg=msg)
219 else:
220 diff = ''.join(
221 difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
222 if diff:
223 self.fail('%s :\n%s' % (msg, diff))
226def NormalizeNumberFields(pb):
227 """Normalizes types and precisions of number fields in a protocol buffer.
229 Due to subtleties in the python protocol buffer implementation, it is possible
230 for values to have different types and precision depending on whether they
231 were set and retrieved directly or deserialized from a protobuf. This function
232 normalizes integer values to ints and longs based on width, 32-bit floats to
233 five digits of precision to account for python always storing them as 64-bit,
234 and ensures doubles are floating point for when they're set to integers.
236 Modifies pb in place. Recurses into nested objects.
238 Args:
239 pb: proto2 message.
241 Returns:
242 the given pb, modified in place.
243 """
244 for desc, values in pb.ListFields():
245 is_repeated = True
246 if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
247 is_repeated = False
248 values = [values]
250 normalized_values = None
252 # We force 32-bit values to int and 64-bit values to long to make
253 # alternate implementations where the distinction is more significant
254 # (e.g. the C++ implementation) simpler.
255 if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
256 descriptor.FieldDescriptor.TYPE_UINT64,
257 descriptor.FieldDescriptor.TYPE_SINT64):
258 normalized_values = [int(x) for x in values]
259 elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
260 descriptor.FieldDescriptor.TYPE_UINT32,
261 descriptor.FieldDescriptor.TYPE_SINT32,
262 descriptor.FieldDescriptor.TYPE_ENUM):
263 normalized_values = [int(x) for x in values]
264 elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
265 normalized_values = [round(x, 6) for x in values]
266 elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
267 normalized_values = [round(float(x), 7) for x in values]
269 if normalized_values is not None:
270 if is_repeated:
271 pb.ClearField(desc.name)
272 getattr(pb, desc.name).extend(normalized_values)
273 else:
274 setattr(pb, desc.name, normalized_values[0])
276 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
277 desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
278 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
279 desc.message_type.has_options and
280 desc.message_type.GetOptions().map_entry):
281 # This is a map, only recurse if the values have a message type.
282 if (desc.message_type.fields_by_number[2].type ==
283 descriptor.FieldDescriptor.TYPE_MESSAGE):
284 for v in six.itervalues(values):
285 NormalizeNumberFields(v)
286 else:
287 for v in values:
288 # recursive step
289 NormalizeNumberFields(v)
291 return pb
294def _IsMap(value):
295 return isinstance(value, collections_abc.Mapping)
298def _IsRepeatedContainer(value):
299 if isinstance(value, six.string_types):
300 return False
301 try:
302 iter(value)
303 return True
304 except TypeError:
305 return False
308def ProtoEq(a, b):
309 """Compares two proto2 objects for equality.
311 Recurses into nested messages. Uses list (not set) semantics for comparing
312 repeated fields, ie duplicates and order matter.
314 Args:
315 a: A proto2 message or a primitive.
316 b: A proto2 message or a primitive.
318 Returns:
319 `True` if the messages are equal.
320 """
321 def Format(pb):
322 """Returns a dictionary or unchanged pb bases on its type.
324 Specifically, this function returns a dictionary that maps tag
325 number (for messages) or element index (for repeated fields) to
326 value, or just pb unchanged if it's neither.
328 Args:
329 pb: A proto2 message or a primitive.
330 Returns:
331 A dict or unchanged pb.
332 """
333 if isinstance(pb, message.Message):
334 return dict((desc.number, value) for desc, value in pb.ListFields())
335 elif _IsMap(pb):
336 return dict(pb.items())
337 elif _IsRepeatedContainer(pb):
338 return dict(enumerate(list(pb)))
339 else:
340 return pb
342 a, b = Format(a), Format(b)
344 # Base case
345 if not isinstance(a, dict) or not isinstance(b, dict):
346 return a == b
348 # This list performs double duty: it compares two messages by tag value *or*
349 # two repeated fields by element, in order. the magic is in the format()
350 # function, which converts them both to the same easily comparable format.
351 for tag in sorted(set(a.keys()) | set(b.keys())):
352 if tag not in a or tag not in b:
353 return False
354 else:
355 # Recursive step
356 if not ProtoEq(a[tag], b[tag]):
357 return False
359 # Didn't find any values that differed, so they're equal!
360 return True
363class ProtoAssertions(object):
364 """Mix this into a googletest.TestCase class to get proto2 assertions.
366 Usage:
368 class SomeTestCase(compare.ProtoAssertions, googletest.TestCase):
369 ...
370 def testSomething(self):
371 ...
372 self.assertProtoEqual(a, b)
374 See module-level definitions for method documentation.
375 """
377 # pylint: disable=invalid-name
378 def assertProtoEqual(self, *args, **kwargs):
379 return assertProtoEqual(self, *args, **kwargs)