Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py: 13%
448 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Class to hold a library of OpDefs and use it to create Brain operations."""
18from google.protobuf import text_format
19from tensorflow.core.config import flags
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.core.framework import tensor_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import types_pb2
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import op_callbacks
26from tensorflow.python.framework import op_def_library_pybind
27from tensorflow.python.framework import op_def_registry
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import _pywrap_utils
32from tensorflow.python.util import compat
33from tensorflow.python.util import tf_contextlib
36def _Attr(op_def, name):
37 for attr in op_def.attr:
38 if attr.name == name:
39 return attr
40 raise TypeError(f"Inconsistent OpDef for '{op_def.name}', missing attr "
41 f"'{name}'")
44def _AttrValue(attr_protos, name, op_type_name):
45 if name in attr_protos:
46 return attr_protos[name]
47 raise TypeError(f"Inconsistent OpDef for '{op_type_name}', missing attr "
48 f"'{name}' from '{attr_protos}'.")
51def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
52 if attr_def.HasField("allowed_values"):
53 allowed_list = attr_def.allowed_values.list.type
54 allowed_values = ", ".join(dtypes.as_dtype(x).name for x in allowed_list)
55 if dtype not in allowed_list:
56 raise TypeError(
57 f"Value passed to parameter '{param_name}' has DataType "
58 f"{dtypes.as_dtype(dtype).name} not in list of allowed values: "
59 f"{allowed_values}")
62def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
63 if attr_def.has_minimum and length < attr_def.minimum:
64 raise ValueError(f"Attr '{param_name}' of '{op_type_name}' Op passed list "
65 f"of length {length} less than minimum "
66 f"{attr_def.minimum}.")
69def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
70 if value not in attr_def.allowed_values.list.s:
71 allowed_values = '", "'.join(
72 map(compat.as_text, attr_def.allowed_values.list.s))
73 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed string "
74 f"'{compat.as_text(value)}' not in: \"{allowed_values}\".")
77def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
78 if value < attr_def.minimum:
79 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed {value} "
80 f"less than minimum {attr_def.minimum}.")
83def _IsListParameter(arg):
84 if arg.number_attr:
85 return True
86 elif arg.type_list_attr:
87 return True
88 return False
91def _NumTypeFields(arg):
92 num = 0
93 if arg.type != types_pb2.DT_INVALID: num += 1
94 if arg.type_attr: num += 1
95 if arg.type_list_attr: num += 1
96 return num
99def _IsListValue(v):
100 return isinstance(v, (list, tuple))
103def _Flatten(l):
104 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
105 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
106 l_of_l = [x if _IsListValue(x) else [x] for x in l]
107 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
108 return [item for sublist in l_of_l for item in sublist]
111def _Restructure(l, structure):
112 """Returns the elements of list l structured according to the given structure.
114 A structure is represented by a list whose elements are either
115 `None` or a non-negative integer. `None` corresponds to a single
116 element in the output list, and an integer N corresponds to a nested
117 list of length N.
119 The function returns a data structure whose shape is given by
120 `structure`, and whose elements are taken from `l`. If `structure`
121 is a singleton, the function returns the single data structure
122 implied by the 0th element of `structure`. For example:
124 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
125 -> ["foo", ["bar", "baz"], "qux"]
127 _Restructure(["foo"], [None]) -> "foo"
129 _Restructure(["foo"], [1]) -> ["foo"]
131 _Restructure([], [0]) -> []
133 Args:
134 l: A list.
135 structure: A list whose elements are either `None` or a non-negative
136 integer.
138 Returns:
139 The elements of `l`, restructured according to `structure`. If
140 `structure` is a list of length 1, this function returns the
141 single data structure implied by `structure[0]`.
143 """
144 result = []
145 current_index = 0
146 for element in structure:
147 if element is None:
148 result.append(l[current_index])
149 current_index += 1
150 else:
151 result.append(l[current_index:current_index+element])
152 current_index += element
154 if len(result) == 1:
155 return result[0]
156 else:
157 return tuple(result)
160def _MakeFloat(v, arg_name):
161 if not isinstance(v, compat.real_types):
162 raise TypeError(f"Expected float for argument '{arg_name}' not {repr(v)}.")
163 return float(v)
166def _MakeInt(v, arg_name):
167 if isinstance(v, str):
168 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.")
169 try:
170 return int(v)
171 except (ValueError, TypeError):
172 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.")
175def _MakeStr(v, arg_name):
176 if not isinstance(v, compat.bytes_or_text_types):
177 raise TypeError(f"Expected string for argument '{arg_name}' not {repr(v)}.")
178 return compat.as_bytes(v) # Convert unicode strings to bytes.
181def _MakeBool(v, arg_name):
182 if not isinstance(v, bool):
183 raise TypeError(f"Expected bool for argument '{arg_name}' not {repr(v)}.")
184 return v
187def _MakeType(v, arg_name):
188 try:
189 v = dtypes.as_dtype(v).base_dtype
190 except TypeError:
191 raise TypeError(f"Expected DataType for argument '{arg_name}' not "
192 f"{repr(v)}.")
193 return v.as_datatype_enum
196def _MakeShape(v, arg_name):
197 """Convert v into a TensorShapeProto."""
198 # Args:
199 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
200 # arg_name: String, for error messages.
202 # Returns:
203 # A TensorShapeProto.
204 if isinstance(v, tensor_shape_pb2.TensorShapeProto):
205 for d in v.dim:
206 if d.name:
207 logging.warning("Warning: TensorShapeProto with a named dimension: %s",
208 str(v))
209 break
210 return v
211 try:
212 return tensor_shape.as_shape(v).as_proto()
213 except TypeError as e:
214 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a "
215 f"TensorShape: {e}")
216 except ValueError as e:
217 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a "
218 f"TensorShape: {e}")
221def _MakeTensor(v, arg_name):
222 """Ensure v is a TensorProto."""
223 if isinstance(v, tensor_pb2.TensorProto):
224 return v
225 raise TypeError(
226 f"Don't know how to convert {repr(v)} to a TensorProto for argument "
227 f"'{arg_name}'")
230def _MakeFunc(v, arg_name):
231 """Ensure v is a func."""
232 if isinstance(v, attr_value_pb2.NameAttrList):
233 return v
234 if isinstance(v, compat.bytes_or_text_types):
235 fn_attr = attr_value_pb2.NameAttrList(name=v)
236 elif hasattr(v, "add_to_graph"):
237 v.add_to_graph(ops.get_default_graph())
238 if hasattr(v, "_as_name_attr_list"):
239 fn_attr = v._as_name_attr_list # pylint: disable=protected-access
240 else:
241 fn_attr = attr_value_pb2.NameAttrList(name=v.name)
242 else:
243 raise TypeError(f"Don't know how to convert {repr(v)} to a func for "
244 f"argument {arg_name}")
245 return fn_attr
248# pylint: disable=g-doc-return-or-yield
249@tf_contextlib.contextmanager
250def _MaybeColocateWith(inputs):
251 """A context manager for (maybe) colocating with a list of input tensors.
253 Args:
254 inputs: A list of `Tensor` or `Operation` objects.
256 Returns:
257 A context manager.
258 """
259 if not inputs:
260 yield
261 else:
262 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
263 # op or tensor, so we create one context manager per element in the list.
264 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
265 yield
266# pylint: enable=g-doc-return-or-yield
269def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
270 """Add a node invoking a registered Op to a graph.
272 Example usage:
273 # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
274 # will convert to a Tensor.
275 op_def_library.apply_op("op", input1=input1, input2=input2)
276 # Can specify a node name.
277 op_def_library.apply_op("op", input1=input1, name="node_name")
278 # Must use keyword arguments, with the names specified in the OpDef.
279 op_def_library.apply_op("op", input_name=input, attr_name=attr)
281 All attrs must either be inferred from an input or specified.
282 (If inferred, the attr must not be specified.) If an attr has a default
283 value specified in the Op's OpDef, then you may pass None as the value
284 of that attr to get the default.
286 Args:
287 op_type_name: string. Must match the name field of a registered Op.
288 name: string. Optional name of the created op.
289 **keywords: input Tensor and attr arguments specified by name, and optional
290 parameters to pass when constructing the Operation.
292 Returns:
293 The Tensor(s) representing the output of the operation, or the Operation
294 itself if there are no outputs.
296 Raises:
297 RuntimeError: On some errors.
298 TypeError: On some errors.
299 ValueError: On some errors.
300 """
301 output_structure, is_stateful, op, outputs = _apply_op_helper(
302 op_type_name, name, **keywords)
303 if output_structure:
304 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
305 if isinstance(res, list) and not res and is_stateful:
306 return op
307 else:
308 return res
309 else:
310 return op
313# This is temporary Python/C++ code duplication until all of it can be ported
314# over to C++.
315# LINT.IfChange
316def _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos):
317 """Extracts `attr_protos`. For use in _apply_op_helper."""
318 for attr_def in op_def.attr:
319 key = attr_def.name
320 value = attrs[key]
322 if attr_def.HasField("default_value") and value is None:
323 attr_value = attr_value_pb2.AttrValue()
324 attr_value.CopyFrom(attr_def.default_value)
325 attr_protos[key] = attr_value
326 continue
328 attr_value = value_to_attr_value(value, attr_def.type, key)
329 if attr_def.type.startswith("list("):
330 _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
331 if attr_def.HasField("allowed_values"):
332 if attr_def.type == "string":
333 _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
334 op_type_name)
335 elif attr_def.type == "list(string)":
336 for value in attr_value.list.s:
337 _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name)
338 if attr_def.has_minimum and attr_def.type == "int":
339 _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name)
340 if attr_def.type == "type":
341 _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
342 if attr_def.type == "list(type)":
343 for value in attr_value.list.type:
344 _SatisfiesTypeConstraint(value, attr_def, key)
346 attr_protos[key] = attr_value
349def _ExtractOutputStructure(op_type_name, op_def, attr_protos,
350 output_structure):
351 """Extracts `output_structure`. For use in _apply_op_helper."""
352 for arg in op_def.output_arg:
353 if arg.number_attr:
354 n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
355 output_structure.append(n)
356 elif arg.type_attr:
357 t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
358 output_structure.append(None)
359 elif arg.type_list_attr:
360 t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)
361 output_structure.append(len(t.list.type))
362 else:
363 output_structure.append(None)
366def _CanExtractAttrsFastPath(op_def, keywords):
367 """Check if the fast path for _apply_op_helper is applicable."""
368 # Check if all inputs are already tf.Tensor
369 for input_arg in op_def.input_arg:
370 value = keywords.get(input_arg.name, None)
371 if not isinstance(value, ops.Tensor):
372 return False
374 # Check that attrs are not `func` or `list(func)` type.
375 for attr_def in op_def.attr:
376 if attr_def.type == "func" or attr_def.type == "list(func)":
377 return False
379 return True
382def _CheckOpDeprecation(op_type_name, op_def, producer):
383 """Checks if the op is deprecated."""
384 deprecation_version = op_def.deprecation.version
385 if deprecation_version and producer >= deprecation_version:
386 raise NotImplementedError(
387 f"Op {op_type_name} is not available in GraphDef version {producer}. "
388 f"It has been removed in version {deprecation_version}. "
389 f"{op_def.deprecation.explanation}.")
392def _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map,
393 allowed_list_attr_map):
394 """Extracts the `default_type_attr_map` and `allowed_list_attr_map`."""
395 # TODO(b/31302892): Currently the defaults don't work in the right
396 # way if you have two inputs, one of whose type resolution depends
397 # on the other. Handling this will require restructuring this code
398 # significantly.
399 for attr_def in op_def.attr:
400 if attr_def.type != "type":
401 continue
402 key = attr_def.name
403 if attr_def.HasField("default_value"):
404 default_type_attr_map[key] = dtypes.as_dtype(
405 attr_def.default_value.type)
406 if attr_def.HasField("allowed_values"):
407 allowed_list_attr_map[key] = attr_def.allowed_values.list.type
410def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
411 keywords, default_type_attr_map, attrs, inputs,
412 input_types):
413 """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper."""
414 inferred_from = {}
415 for input_arg in op_def.input_arg:
416 input_name = input_arg.name
417 if input_name in keywords:
418 values = keywords.pop(input_name)
419 elif input_name + "_" in keywords:
420 # Handle the case where the name is a keyword or built-in
421 # for Python so we use the name + _ instead.
422 input_name += "_"
423 values = keywords.pop(input_name)
424 else:
425 raise TypeError(f"No argument for input {input_name} found in {op_def}")
427 # Goals:
428 # * Convert values to Tensors if it contains constants.
429 # * Verify that values is a list if that matches the input_arg's
430 # type.
431 # * If the input_arg's type is determined by attrs, either set
432 # those attrs and validate those attr values are legal (if
433 # they have not yet been set) or validate the input matches
434 # the type indicated by the attrs (if they have already been
435 # inferred via an earlier input).
436 # * If the input_arg has an explicit type, make sure the input
437 # conforms.
439 if _IsListParameter(input_arg):
440 if not _IsListValue(values):
441 raise TypeError(
442 f"Expected list for '{input_name}' argument to '{op_type_name}' "
443 f"Op, not {values}.")
444 # In cases where we expect all elements of the list to have the
445 # same dtype, try to cast non-Tensor elements to that type.
446 dtype = None
447 default_dtype = None
448 if input_arg.type != types_pb2.DT_INVALID:
449 dtype = input_arg.type
450 elif input_arg.number_attr:
451 if input_arg.type_attr in attrs:
452 dtype = attrs[input_arg.type_attr]
453 else:
454 for t in values:
455 if isinstance(t, ops.Tensor):
456 dtype = t.dtype
457 break
459 # dtype still not found, prefer using the default dtype
460 # from the attr.
461 if dtype is None and input_arg.type_attr in default_type_attr_map:
462 default_dtype = default_type_attr_map[input_arg.type_attr]
464 try:
465 if not input_arg.is_ref and dtype:
466 dtype = dtypes.as_dtype(dtype).base_dtype
467 values = ops.internal_convert_n_to_tensor(
468 values,
469 name=input_arg.name,
470 dtype=dtype if dtype else None,
471 preferred_dtype=default_dtype,
472 as_ref=input_arg.is_ref)
473 all_types = set(v.dtype.base_dtype for v in values)
474 if input_arg.number_attr and len(all_types) > 1:
475 # All types should match.
476 raise TypeError(f"Not all types matched for {input_arg.name} for "
477 f"{op_type_name}. Got {all_types}")
478 except (TypeError, ValueError):
479 # What types does the conversion function think values have?
480 observed_types = []
481 for value in values:
482 try:
483 converted_value = ops.convert_to_tensor(
484 value, as_ref=input_arg.is_ref)
485 observed_types.append(converted_value.dtype.base_dtype.name)
486 except (TypeError, ValueError):
487 observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
488 observed = ", ".join(observed_types)
490 prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" %
491 (input_name, op_type_name, observed))
492 if input_arg.number_attr:
493 if input_arg.type != types_pb2.DT_INVALID:
494 raise TypeError(f"{prefix} that do not match expected type "
495 f"{dtype.name}.")
496 elif input_arg.type_attr in attrs:
497 raise TypeError(f"{prefix} that do not match type {dtype.name} "
498 "inferred from earlier arguments.")
499 else:
500 raise TypeError(f"{prefix} that don't all match.")
501 else:
502 raise TypeError(f"{prefix} that are invalid. Tensors: {values}")
504 types = [x.dtype for x in values]
505 inputs.extend(values)
506 else:
507 # In cases where we have an expected type, try to convert non-Tensor
508 # arguments to that type.
509 dtype = None
510 default_dtype = None
511 allowed_list = None
512 if input_arg.type != types_pb2.DT_INVALID:
513 dtype = input_arg.type
514 elif input_arg.type_attr in attrs:
515 dtype = attrs[input_arg.type_attr]
516 elif input_arg.type_attr in default_type_attr_map:
517 # The dtype could not be inferred solely from the inputs,
518 # so we prefer the attr's default, so code that adds a new attr
519 # with a default is backwards compatible.
520 default_dtype = default_type_attr_map[input_arg.type_attr]
521 allowed_list = allowed_list_attr_map.get(input_arg.type_attr)
523 try:
524 # First see if we can get a valid dtype with the default conversion
525 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
526 # not list allowed dtypes, in which case we should skip this.
527 if dtype is None and allowed_list:
528 inferred = None
529 try:
530 inferred = ops.convert_to_tensor(
531 values, name=input_arg.name, as_ref=input_arg.is_ref)
532 except TypeError as err:
533 # When converting a python object such as a list of Dimensions, we
534 # need a dtype to be specified, thus tensor conversion may throw
535 # an exception which we will ignore and try again below.
536 pass
538 # If we did not match an allowed dtype, try again with the default
539 # dtype. This could be because we have an empty tensor and thus we
540 # picked the wrong type.
541 if inferred is not None and inferred.dtype in allowed_list:
542 values = inferred
543 else:
544 values = ops.convert_to_tensor(
545 values,
546 name=input_arg.name,
547 as_ref=input_arg.is_ref,
548 preferred_dtype=default_dtype)
549 else:
550 values = ops.convert_to_tensor(
551 values,
552 name=input_arg.name,
553 dtype=dtype,
554 as_ref=input_arg.is_ref,
555 preferred_dtype=default_dtype)
556 except TypeError as err:
557 if dtype is None:
558 raise err
559 else:
560 raise TypeError(
561 f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
562 f"'{input_arg.name}' of op '{op_type_name}', got "
563 f"{repr(values)} of type '{type(values).__name__}' instead. "
564 f"Error: {err}")
565 except ValueError:
566 # What type does convert_to_tensor think it has?
567 try:
568 observed = ops.convert_to_tensor(
569 values, as_ref=input_arg.is_ref).dtype.name
570 except ValueError as err:
571 raise ValueError(
572 f"Tried to convert '{input_name}' to a tensor and failed. "
573 f"Error: {err}")
574 prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
575 (input_name, op_type_name, observed))
576 if input_arg.type != types_pb2.DT_INVALID:
577 raise TypeError(f"{prefix} expected type of "
578 f"{dtypes.as_dtype(input_arg.type).name}.")
579 else:
580 # Update the maps with the default, if needed.
581 k = input_arg.type_attr
582 if k in default_type_attr_map:
583 if k not in attrs:
584 attrs[k] = default_type_attr_map[k]
585 if k not in inferred_from:
586 inferred_from[k] = "Default in OpDef"
588 raise TypeError(
589 f"{prefix} type "
590 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
591 f"argument '{inferred_from[input_arg.type_attr]}'.")
593 types = [values.dtype]
594 inputs.append(values)
595 base_types = [x.base_dtype for x in types]
597 if input_arg.number_attr:
598 # <number-attr> * <type> or <number-attr> * <type-attr>
599 if input_arg.number_attr in attrs:
600 if len(values) != attrs[input_arg.number_attr]:
601 raise ValueError(
602 f"List argument '{input_name}' to '{op_type_name}' Op with "
603 f"length {len(values)} must match length "
604 f"{attrs[input_arg.number_attr]} of argument "
605 f"'{inferred_from[input_arg.number_attr]}'.")
606 else:
607 attrs[input_arg.number_attr] = len(values)
608 inferred_from[input_arg.number_attr] = input_name
609 num_attr = _Attr(op_def, input_arg.number_attr)
610 if num_attr.has_minimum and len(values) < num_attr.minimum:
611 raise ValueError(
612 f"List argument '{input_name}' to '{op_type_name}' Op with "
613 f"length {len(values)} shorter than minimum length "
614 f"{num_attr.minimum}.")
615 # All tensors must have the same base type.
616 if any(bt != base_types[0] for bt in base_types):
617 raise TypeError(
618 f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
619 f"must have the same type. Got {base_types} instead.")
620 if input_arg.type != types_pb2.DT_INVALID:
621 # <number-attr> * <type> case
622 if base_types and base_types[0] != input_arg.type:
623 assert False, "Unreachable"
624 elif input_arg.type_attr in attrs:
625 # <number-attr> * <type-attr> case, where <type-attr> already
626 # has an inferred value.
627 if base_types and base_types[0] != attrs[input_arg.type_attr]:
628 assert False, "Unreachable"
629 else:
630 # <number-attr> * <type-attr> case, where we are now setting
631 # the <type-attr> based on this input
632 if not base_types:
633 # If it's in default_type_attr_map, then wait to set it
634 # (in "process remaining attrs", below).
635 if input_arg.type_attr not in default_type_attr_map:
636 raise TypeError(
637 "Don't know how to infer type variable from empty input "
638 f"list passed to input '{input_name}' of '{op_type_name}' "
639 "Op.")
640 else:
641 attrs[input_arg.type_attr] = base_types[0]
642 inferred_from[input_arg.type_attr] = input_name
643 type_attr = _Attr(op_def, input_arg.type_attr)
644 _SatisfiesTypeConstraint(
645 base_types[0], type_attr, param_name=input_name)
646 elif input_arg.type_attr:
647 # <type-attr>
648 attr_value = base_types[0]
649 if input_arg.type_attr in attrs:
650 if attrs[input_arg.type_attr] != attr_value:
651 raise TypeError(
652 f"Input '{input_name}' of '{op_type_name}' Op has type "
653 f"{dtypes.as_dtype(attr_value).name} that does not match type "
654 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
655 f"argument '{inferred_from[input_arg.type_attr]}'.")
656 else:
657 for base_type in base_types:
658 _SatisfiesTypeConstraint(
659 base_type,
660 _Attr(op_def, input_arg.type_attr),
661 param_name=input_name)
662 attrs[input_arg.type_attr] = attr_value
663 inferred_from[input_arg.type_attr] = input_name
664 elif input_arg.type_list_attr:
665 # <type-list-attr>
666 attr_value = base_types
667 if input_arg.type_list_attr in attrs:
668 if attrs[input_arg.type_list_attr] != attr_value:
669 actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value)
670 expected_types = ", ".join(
671 dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr])
672 raise TypeError(
673 f"Input '{input_name}' of '{op_type_name}' Op has type list of "
674 f"{actual_types} that does not match type list {expected_types}"
675 f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
676 else:
677 for base_type in base_types:
678 _SatisfiesTypeConstraint(
679 base_type,
680 _Attr(op_def, input_arg.type_list_attr),
681 param_name=input_name)
682 attrs[input_arg.type_list_attr] = attr_value
683 inferred_from[input_arg.type_list_attr] = input_name
684 else:
685 # single Tensor with specified type
686 if base_types[0] != input_arg.type:
687 assert False, "Unreachable"
689 if input_arg.is_ref:
690 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
691 raise TypeError(
692 f"'{op_type_name}' Op requires that input '{input_name}' be a "
693 "mutable tensor (e.g.: a tf.Variable)")
694 input_types.extend(types)
695 else:
696 input_types.extend(base_types)
699def _ExtractRemainingAttrs(op_type_name, op_def, keywords,
700 default_type_attr_map, attrs):
701 """Extracts the remaining attributes into `attrs` in _apply_op_helper."""
702 for attr in op_def.attr:
703 # Skip attrs that have already had their values inferred
704 if attr.name in attrs:
705 if attr.name in keywords:
706 raise TypeError(
707 f"Should not specify value for inferred attr '{attr.name}' for "
708 f"{op_type_name}.")
709 continue
710 if attr.name in keywords:
711 attrs[attr.name] = keywords.pop(attr.name)
712 elif attr.name + "_" in keywords:
713 # Attrs whose names match Python keywords have an extra '_'
714 # appended, so we must check for that as well.
715 attrs[attr.name] = keywords.pop(attr.name + "_")
716 elif attr.name in default_type_attr_map:
717 attrs[attr.name] = default_type_attr_map[attr.name]
718 else:
719 raise TypeError(f"No argument found for attr {attr.name} for "
720 f"{op_type_name}")
723def _GetOpDef(op_type_name, keywords):
724 """Returns the OpDef, Graph and Producer. For use in _apply_op_helper."""
725 op_def = op_def_registry.get(op_type_name)
726 if op_def is None:
727 raise RuntimeError(f"Unrecognized Op name {op_type_name}")
729 # Determine the graph context.
730 try:
731 # Need to flatten all the arguments into a list.
732 # pylint: disable=protected-access
733 g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
734 producer = g.graph_def_versions.producer
735 # pylint: enable=protected-access
736 except AssertionError as e:
737 raise RuntimeError(
738 f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")
740 return op_def, g, producer
743def _CheckAllInputsUsed(op_type_name, keywords):
744 """Ensures all inputs passed into _apply_op_helper were used."""
745 if keywords:
746 all_keywords = ", ".join(sorted(keywords.keys()))
747 raise TypeError(f"{op_type_name} got unexpected keyword arguments: "
748 f"{all_keywords}.")
751def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
752 """Implementation of apply_op that returns output_structure, op."""
754 op_def, g, producer = _GetOpDef(op_type_name, keywords)
755 name = name if name else op_type_name
757 attrs, attr_protos = {}, {}
758 default_type_attr_map, allowed_list_attr_map = {}, {}
759 inputs, input_types, output_structure = [], [], []
760 fallback = True
762 if (_CanExtractAttrsFastPath(op_def, keywords) and
763 flags.config().graph_building_optimization.value()):
764 fallback = False
765 attr_protos, inputs, input_types, output_structure = (
766 op_def_library_pybind.process_inputs(op_type_name, producer, keywords))
768 if fallback:
769 _CheckOpDeprecation(op_type_name, op_def, producer)
770 _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map,
771 allowed_list_attr_map)
773 # Requires that op_def has passed validation (using the C++
774 # ValidateOpDef() from ../framework/op_def_util.h).
775 with g.as_default(), ops.name_scope(name) as scope:
776 if fallback:
777 _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
778 keywords, default_type_attr_map, attrs, inputs,
779 input_types)
780 _ExtractRemainingAttrs(op_type_name, op_def, keywords,
781 default_type_attr_map, attrs)
782 _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos)
783 del attrs # attrs is no longer authoritative, use attr_protos instead
784 _ExtractOutputStructure(op_type_name, op_def, attr_protos,
785 output_structure)
786 _CheckAllInputsUsed(op_type_name, keywords)
788 # NOTE(mrry): We add an explicit colocation constraint between
789 # the newly created op and any of its reference-typed inputs.
790 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
791 if arg.is_ref]
792 with _MaybeColocateWith(must_colocate_inputs):
793 # Add Op to graph
794 # pylint: disable=protected-access
795 op = g._create_op_internal(op_type_name, inputs, dtypes=None,
796 name=scope, input_types=input_types,
797 attrs=attr_protos, op_def=op_def)
799 # `outputs` is returned as a separate return value so that the output
800 # tensors can the `op` per se can be decoupled so that the
801 # `op_callbacks` can function properly. See framework/op_callbacks.py
802 # for more details.
803 outputs = op.outputs
804 # Conditionally invoke tfdbg v2's op callback(s).
805 if op_callbacks.should_invoke_op_callbacks():
806 callback_outputs = op_callbacks.invoke_op_callbacks(
807 op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
808 op_name=op.name, graph=g)
809 if callback_outputs is not None:
810 outputs = callback_outputs
812 return output_structure, op_def.is_stateful, op, outputs
815def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name
816 """Encodes a Python value as an `AttrValue` proto message.
818 Args:
819 value: The value to convert.
820 attr_type: The value type (string) -- see the AttrValue proto definition for
821 valid strings.
822 arg_name: Argument name (for error messages).
824 Returns:
825 An AttrValue proto message that encodes `value`.
826 """
827 attr_value = attr_value_pb2.AttrValue()
829 if attr_type.startswith("list("):
830 if not _IsListValue(value):
831 raise TypeError(f"Expected list for attr {arg_name}, obtained "
832 f"{type(value).__name__} instead.")
834 if attr_type == "string":
835 attr_value.s = _MakeStr(value, arg_name)
836 elif attr_type == "list(string)":
837 attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
838 elif attr_type == "int":
839 attr_value.i = _MakeInt(value, arg_name)
840 elif attr_type == "list(int)":
841 attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
842 elif attr_type == "float":
843 attr_value.f = _MakeFloat(value, arg_name)
844 elif attr_type == "list(float)":
845 attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
846 elif attr_type == "bool":
847 attr_value.b = _MakeBool(value, arg_name)
848 elif attr_type == "list(bool)":
849 attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
850 elif attr_type == "type":
851 attr_value.type = _MakeType(value, arg_name)
852 elif attr_type == "list(type)":
853 attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
854 elif attr_type == "shape":
855 attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
856 elif attr_type == "list(shape)":
857 attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
858 elif attr_type == "tensor":
859 attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
860 elif attr_type == "list(tensor)":
861 attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
862 elif attr_type == "func":
863 attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
864 elif attr_type == "list(func)":
865 attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
866 else:
867 raise TypeError(f"Unrecognized Attr type {attr_type} for {arg_name}.")
868 return attr_value
869# LINT.ThenChange(//tensorflow/python/framework/op_def_library_pybind.cc)
872# The following symbols are used by op_def_util.cc.
873_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
874_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
875_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape)
876_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape)
877_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto)
878_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse)
879_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)