Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/lib/common.py: 42%
19 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common values and methods for TensorFlow Debugger."""
16import collections
17import json
19GRPC_URL_PREFIX = "grpc://"
21# A key for a Session.run() call.
22RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"])
25def get_graph_element_name(elem):
26 """Obtain the name or string representation of a graph element.
28 If the graph element has the attribute "name", return name. Otherwise, return
29 a __str__ representation of the graph element. Certain graph elements, such as
30 `SparseTensor`s, do not have the attribute "name".
32 Args:
33 elem: The graph element in question.
35 Returns:
36 If the attribute 'name' is available, return the name. Otherwise, return
37 str(fetch).
38 """
40 return elem.name if hasattr(elem, "name") else str(elem)
43def get_flattened_names(feeds_or_fetches):
44 """Get a flattened list of the names in run() call feeds or fetches.
46 Args:
47 feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe
48 a Tensor, an Operation or a Variable. It may also be nested lists, tuples
49 or dicts. See doc of `Session.run()` for more details.
51 Returns:
52 (list of str) A flattened list of fetch names from `feeds_or_fetches`.
53 """
55 lines = []
56 if isinstance(feeds_or_fetches, (list, tuple)):
57 for item in feeds_or_fetches:
58 lines.extend(get_flattened_names(item))
59 elif isinstance(feeds_or_fetches, dict):
60 for key in feeds_or_fetches:
61 lines.extend(get_flattened_names(feeds_or_fetches[key]))
62 else:
63 # This ought to be a Tensor, an Operation or a Variable, for which the name
64 # attribute should be available. (Bottom-out condition of the recursion.)
65 lines.append(get_graph_element_name(feeds_or_fetches))
67 return lines
70def get_run_key(feed_dict, fetches):
71 """Summarize the names of feeds and fetches as a RunKey JSON string.
73 Args:
74 feed_dict: The feed_dict given to the `Session.run()` call.
75 fetches: The fetches from the `Session.run()` call.
77 Returns:
78 A JSON Array consisting of two items. They first items is a flattened
79 Array of the names of the feeds. The second item is a flattened Array of
80 the names of the fetches.
81 """
82 return json.dumps(RunKey(get_flattened_names(feed_dict),
83 get_flattened_names(fetches)))