Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/plugins/pr_curve/metadata.py: 59%
22 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Internal information about the pr_curves plugin."""
18from tensorboard.compat.proto import summary_pb2
19from tensorboard.plugins.pr_curve import plugin_data_pb2
21PLUGIN_NAME = "pr_curves"
23# Indices for obtaining various values from the tensor stored in a summary.
24TRUE_POSITIVES_INDEX = 0
25FALSE_POSITIVES_INDEX = 1
26TRUE_NEGATIVES_INDEX = 2
27FALSE_NEGATIVES_INDEX = 3
28PRECISION_INDEX = 4
29RECALL_INDEX = 5
31# The most recent value for the `version` field of the
32# `PrCurvePluginData` proto.
33PROTO_VERSION = 0
36def create_summary_metadata(display_name, description, num_thresholds):
37 """Create a `summary_pb2.SummaryMetadata` proto for pr_curves plugin data.
39 Arguments:
40 display_name: The display name used in TensorBoard.
41 description: The description to show in TensorBoard.
42 num_thresholds: The number of thresholds to use for PR curves.
44 Returns:
45 A `summary_pb2.SummaryMetadata` protobuf object.
46 """
47 pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData(
48 version=PROTO_VERSION, num_thresholds=num_thresholds
49 )
50 content = pr_curve_plugin_data.SerializeToString()
51 return summary_pb2.SummaryMetadata(
52 display_name=display_name,
53 summary_description=description,
54 plugin_data=summary_pb2.SummaryMetadata.PluginData(
55 plugin_name=PLUGIN_NAME, content=content
56 ),
57 )
60def parse_plugin_metadata(content):
61 """Parse summary metadata to a Python object.
63 Arguments:
64 content: The `content` field of a `SummaryMetadata` proto
65 corresponding to the pr_curves plugin.
67 Returns:
68 A `PrCurvesPlugin` protobuf object.
69 """
70 if not isinstance(content, bytes):
71 raise TypeError("Content type must be bytes")
72 result = plugin_data_pb2.PrCurvePluginData.FromString(content)
73 if result.version == 0:
74 return result
75 # No other versions known at this time, so no migrations to do.
76 return result