Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/plugins/pr_curve/metadata.py: 59%

22 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Internal information about the pr_curves plugin.""" 

16 

17 

18from tensorboard.compat.proto import summary_pb2 

19from tensorboard.plugins.pr_curve import plugin_data_pb2 

20 

21PLUGIN_NAME = "pr_curves" 

22 

23# Indices for obtaining various values from the tensor stored in a summary. 

24TRUE_POSITIVES_INDEX = 0 

25FALSE_POSITIVES_INDEX = 1 

26TRUE_NEGATIVES_INDEX = 2 

27FALSE_NEGATIVES_INDEX = 3 

28PRECISION_INDEX = 4 

29RECALL_INDEX = 5 

30 

31# The most recent value for the `version` field of the 

32# `PrCurvePluginData` proto. 

33PROTO_VERSION = 0 

34 

35 

36def create_summary_metadata(display_name, description, num_thresholds): 

37 """Create a `summary_pb2.SummaryMetadata` proto for pr_curves plugin data. 

38 

39 Arguments: 

40 display_name: The display name used in TensorBoard. 

41 description: The description to show in TensorBoard. 

42 num_thresholds: The number of thresholds to use for PR curves. 

43 

44 Returns: 

45 A `summary_pb2.SummaryMetadata` protobuf object. 

46 """ 

47 pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData( 

48 version=PROTO_VERSION, num_thresholds=num_thresholds 

49 ) 

50 content = pr_curve_plugin_data.SerializeToString() 

51 return summary_pb2.SummaryMetadata( 

52 display_name=display_name, 

53 summary_description=description, 

54 plugin_data=summary_pb2.SummaryMetadata.PluginData( 

55 plugin_name=PLUGIN_NAME, content=content 

56 ), 

57 ) 

58 

59 

60def parse_plugin_metadata(content): 

61 """Parse summary metadata to a Python object. 

62 

63 Arguments: 

64 content: The `content` field of a `SummaryMetadata` proto 

65 corresponding to the pr_curves plugin. 

66 

67 Returns: 

68 A `PrCurvesPlugin` protobuf object. 

69 """ 

70 if not isinstance(content, bytes): 

71 raise TypeError("Content type must be bytes") 

72 result = plugin_data_pb2.PrCurvePluginData.FromString(content) 

73 if result.version == 0: 

74 return result 

75 # No other versions known at this time, so no migrations to do. 

76 return result