Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/protobuf/compare.py: 16%

116 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utility functions for comparing proto2 messages in Python. 

17 

18ProtoEq() compares two proto2 messages for equality. 

19 

20ClearDefaultValuedFields() recursively clears the fields that are set to their 

21default values. This is useful for comparing protocol buffers where the 

22semantics of unset fields and default valued fields are the same. 

23 

24assertProtoEqual() is useful for unit tests. It produces much more helpful 

25output than assertEqual() for proto2 messages, e.g. this: 

26 

27 outer { 

28 inner { 

29- strings: "x" 

30? ^ 

31+ strings: "y" 

32? ^ 

33 } 

34 } 

35 

36...compared to the default output from assertEqual() that looks like this: 

37 

38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc> 

39 

40Call it inside your unit test's googletest.TestCase subclasses like this: 

41 

42 from tensorflow.python.util.protobuf import compare 

43 

44 class MyTest(googletest.TestCase): 

45 ... 

46 def testXXX(self): 

47 ... 

48 compare.assertProtoEqual(self, a, b) 

49 

50Alternatively: 

51 

52 from tensorflow.python.util.protobuf import compare 

53 

54 class MyTest(compare.ProtoAssertions, googletest.TestCase): 

55 ... 

56 def testXXX(self): 

57 ... 

58 self.assertProtoEqual(a, b) 

59""" 

60 

61import difflib 

62import math 

63 

64from ..compat import collections_abc 

65import six 

66 

67from google.protobuf import descriptor 

68from google.protobuf import descriptor_pool 

69from google.protobuf import message 

70from google.protobuf import text_format 

71 

72 

73# TODO(alankelly): Distinguish between signalling and quiet NaNs. 

74def isClose(x, y, relative_tolerance): # pylint: disable=invalid-name 

75 """Returns True if x is close to y given the relative tolerance or if x and y are both inf, both -inf, or both NaNs. 

76 

77 This function does not distinguish between signalling and non-signalling NaN. 

78 

79 Args: 

80 x: float value to be compared 

81 y: float value to be compared 

82 relative_tolerance: float. The allowable difference between the two values 

83 being compared is determined by multiplying the relative tolerance by the 

84 maximum of the two values. If this is not provided, then all floats are 

85 compared using string comparison. 

86 """ 

87 # NaNs are considered equal. 

88 if math.isnan(x) or math.isnan(y): 

89 return math.isnan(x) == math.isnan(y) 

90 

91 if math.isinf(x) or math.isinf(y): 

92 return x == y 

93 

94 return abs(x - y) <= relative_tolerance * max(abs(x), abs(y)) 

95 

96 

97def checkFloatEqAndReplace(self, expected, actual, relative_tolerance): # pylint: disable=invalid-name 

98 """Recursively replaces the floats in actual with those in expected iff they are approximately equal. 

99 

100 This is done because string equality will consider values such as 5.0999999999 

101 and 5.1 as not being equal, despite being extremely close. 

102 

103 Args: 

104 self: googletest.TestCase 

105 expected: expected values 

106 actual: actual values 

107 relative_tolerance: float, relative tolerance. 

108 """ 

109 

110 for expected_fields, actual_fields in zip( 

111 expected.ListFields(), actual.ListFields() 

112 ): 

113 is_repeated = True 

114 expected_desc, expected_values = expected_fields 

115 actual_values = actual_fields[1] 

116 if expected_desc.label != descriptor.FieldDescriptor.LABEL_REPEATED: 

117 is_repeated = False 

118 expected_values = [expected_values] 

119 actual_values = [actual_values] 

120 

121 if ( 

122 expected_desc.type == descriptor.FieldDescriptor.TYPE_FLOAT 

123 or expected_desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE 

124 ): 

125 for i, (x, y) in enumerate(zip(expected_values, actual_values)): 

126 # Replace the actual value with the expected value if the test passes, 

127 # otherwise leave it and let it fail in the next test so that the error 

128 # message is nicely formatted 

129 if isClose(x, y, relative_tolerance): 

130 if is_repeated: 

131 getattr(actual, actual_fields[0].name)[i] = x 

132 else: 

133 setattr(actual, actual_fields[0].name, x) 

134 

135 if ( 

136 expected_desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE 

137 or expected_desc.type == descriptor.FieldDescriptor.TYPE_GROUP 

138 ): 

139 if ( 

140 expected_desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE 

141 and expected_desc.message_type.has_options 

142 and expected_desc.message_type.GetOptions().map_entry 

143 ): 

144 # This is a map, only recurse if it has type message type. 

145 if ( 

146 expected_desc.message_type.fields_by_number[2].type 

147 == descriptor.FieldDescriptor.TYPE_MESSAGE 

148 ): 

149 for e_v, a_v in zip( 

150 six.itervalues(expected_values), six.itervalues(actual_values) 

151 ): 

152 checkFloatEqAndReplace( 

153 self, 

154 expected=e_v, 

155 actual=a_v, 

156 relative_tolerance=relative_tolerance, 

157 ) 

158 else: 

159 for v, a in zip(expected_values, actual_values): 

160 # recursive step 

161 checkFloatEqAndReplace( 

162 self, expected=v, actual=a, relative_tolerance=relative_tolerance 

163 ) 

164 

165 

166def assertProtoEqual( 

167 self, 

168 a, 

169 b, 

170 check_initialized=True, 

171 normalize_numbers=False, 

172 msg=None, 

173 relative_tolerance=None, 

174): # pylint: disable=invalid-name( 

175 """Fails with a useful error if a and b aren't equal. 

176 

177 Comparison of repeated fields matches the semantics of 

178 unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. 

179 

180 Args: 

181 self: googletest.TestCase 

182 a: proto2 PB instance, or text string representing one. 

183 b: proto2 PB instance -- message.Message or subclass thereof. 

184 check_initialized: boolean, whether to fail if either a or b isn't 

185 initialized. 

186 normalize_numbers: boolean, whether to normalize types and precision of 

187 numbers before comparison. 

188 msg: if specified, is used as the error message on failure. 

189 relative_tolerance: float, relative tolerance. If this is not provided, then 

190 all floats are compared using string comparison otherwise, floating point 

191 comparisons are done using the relative tolerance provided. 

192 """ 

193 pool = descriptor_pool.Default() 

194 if isinstance(a, six.string_types): 

195 a = text_format.Parse(a, b.__class__(), descriptor_pool=pool) 

196 

197 for pb in a, b: 

198 if check_initialized: 

199 errors = pb.FindInitializationErrors() 

200 if errors: 

201 self.fail('Initialization errors: %s\n%s' % (errors, pb)) 

202 if normalize_numbers: 

203 NormalizeNumberFields(pb) 

204 

205 if relative_tolerance is not None: 

206 checkFloatEqAndReplace( 

207 self, expected=b, actual=a, relative_tolerance=relative_tolerance 

208 ) 

209 

210 a_str = text_format.MessageToString(a, descriptor_pool=pool) 

211 b_str = text_format.MessageToString(b, descriptor_pool=pool) 

212 

213 # Some Python versions would perform regular diff instead of multi-line 

214 # diff if string is longer than 2**16. We substitute this behavior 

215 # with a call to unified_diff instead to have easier-to-read diffs. 

216 # For context, see: https://bugs.python.org/issue11763. 

217 if len(a_str) < 2**16 and len(b_str) < 2**16: 

218 self.assertMultiLineEqual(a_str, b_str, msg=msg) 

219 else: 

220 diff = ''.join( 

221 difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) 

222 if diff: 

223 self.fail('%s :\n%s' % (msg, diff)) 

224 

225 

226def NormalizeNumberFields(pb): 

227 """Normalizes types and precisions of number fields in a protocol buffer. 

228 

229 Due to subtleties in the python protocol buffer implementation, it is possible 

230 for values to have different types and precision depending on whether they 

231 were set and retrieved directly or deserialized from a protobuf. This function 

232 normalizes integer values to ints and longs based on width, 32-bit floats to 

233 five digits of precision to account for python always storing them as 64-bit, 

234 and ensures doubles are floating point for when they're set to integers. 

235 

236 Modifies pb in place. Recurses into nested objects. 

237 

238 Args: 

239 pb: proto2 message. 

240 

241 Returns: 

242 the given pb, modified in place. 

243 """ 

244 for desc, values in pb.ListFields(): 

245 is_repeated = True 

246 if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED: 

247 is_repeated = False 

248 values = [values] 

249 

250 normalized_values = None 

251 

252 # We force 32-bit values to int and 64-bit values to long to make 

253 # alternate implementations where the distinction is more significant 

254 # (e.g. the C++ implementation) simpler. 

255 if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, 

256 descriptor.FieldDescriptor.TYPE_UINT64, 

257 descriptor.FieldDescriptor.TYPE_SINT64): 

258 normalized_values = [int(x) for x in values] 

259 elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, 

260 descriptor.FieldDescriptor.TYPE_UINT32, 

261 descriptor.FieldDescriptor.TYPE_SINT32, 

262 descriptor.FieldDescriptor.TYPE_ENUM): 

263 normalized_values = [int(x) for x in values] 

264 elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: 

265 normalized_values = [round(x, 6) for x in values] 

266 elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: 

267 normalized_values = [round(float(x), 7) for x in values] 

268 

269 if normalized_values is not None: 

270 if is_repeated: 

271 pb.ClearField(desc.name) 

272 getattr(pb, desc.name).extend(normalized_values) 

273 else: 

274 setattr(pb, desc.name, normalized_values[0]) 

275 

276 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or 

277 desc.type == descriptor.FieldDescriptor.TYPE_GROUP): 

278 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 

279 desc.message_type.has_options and 

280 desc.message_type.GetOptions().map_entry): 

281 # This is a map, only recurse if the values have a message type. 

282 if (desc.message_type.fields_by_number[2].type == 

283 descriptor.FieldDescriptor.TYPE_MESSAGE): 

284 for v in six.itervalues(values): 

285 NormalizeNumberFields(v) 

286 else: 

287 for v in values: 

288 # recursive step 

289 NormalizeNumberFields(v) 

290 

291 return pb 

292 

293 

294def _IsMap(value): 

295 return isinstance(value, collections_abc.Mapping) 

296 

297 

298def _IsRepeatedContainer(value): 

299 if isinstance(value, six.string_types): 

300 return False 

301 try: 

302 iter(value) 

303 return True 

304 except TypeError: 

305 return False 

306 

307 

308def ProtoEq(a, b): 

309 """Compares two proto2 objects for equality. 

310 

311 Recurses into nested messages. Uses list (not set) semantics for comparing 

312 repeated fields, ie duplicates and order matter. 

313 

314 Args: 

315 a: A proto2 message or a primitive. 

316 b: A proto2 message or a primitive. 

317 

318 Returns: 

319 `True` if the messages are equal. 

320 """ 

321 def Format(pb): 

322 """Returns a dictionary or unchanged pb bases on its type. 

323 

324 Specifically, this function returns a dictionary that maps tag 

325 number (for messages) or element index (for repeated fields) to 

326 value, or just pb unchanged if it's neither. 

327 

328 Args: 

329 pb: A proto2 message or a primitive. 

330 Returns: 

331 A dict or unchanged pb. 

332 """ 

333 if isinstance(pb, message.Message): 

334 return dict((desc.number, value) for desc, value in pb.ListFields()) 

335 elif _IsMap(pb): 

336 return dict(pb.items()) 

337 elif _IsRepeatedContainer(pb): 

338 return dict(enumerate(list(pb))) 

339 else: 

340 return pb 

341 

342 a, b = Format(a), Format(b) 

343 

344 # Base case 

345 if not isinstance(a, dict) or not isinstance(b, dict): 

346 return a == b 

347 

348 # This list performs double duty: it compares two messages by tag value *or* 

349 # two repeated fields by element, in order. the magic is in the format() 

350 # function, which converts them both to the same easily comparable format. 

351 for tag in sorted(set(a.keys()) | set(b.keys())): 

352 if tag not in a or tag not in b: 

353 return False 

354 else: 

355 # Recursive step 

356 if not ProtoEq(a[tag], b[tag]): 

357 return False 

358 

359 # Didn't find any values that differed, so they're equal! 

360 return True 

361 

362 

363class ProtoAssertions(object): 

364 """Mix this into a googletest.TestCase class to get proto2 assertions. 

365 

366 Usage: 

367 

368 class SomeTestCase(compare.ProtoAssertions, googletest.TestCase): 

369 ... 

370 def testSomething(self): 

371 ... 

372 self.assertProtoEqual(a, b) 

373 

374 See module-level definitions for method documentation. 

375 """ 

376 

377 # pylint: disable=invalid-name 

378 def assertProtoEqual(self, *args, **kwargs): 

379 return assertProtoEqual(self, *args, **kwargs)