Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/compiler_ir.py: 28%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implmentation for defining get_compiler_ir."""
16from typing import List, Optional
18from tensorflow.core.function import trace_type
19from tensorflow.python.eager import context
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import tensor_spec
22from tensorflow.python.ops import random_ops
24from tensorflow.python.util import nest
27def maybe_get_device_name(device_name):
28 # TODO(cheshire): This is a hack to get the current "preferred" device,
29 # there is no current API to get it otherwise.
30 if device_name is None:
31 device_name = random_ops.random_normal([]).device
32 return device_name
35# TODO(fmuham): Use trace_type._flatten here instead when available
36def make_handledata_tensor_specs(resource_vars):
37 """Convert tf.Variable list to its corresponding TensorSpec list."""
38 if not all(x.dtype is dtypes.resource for x in resource_vars):
39 raise RuntimeError("Resource_vars must be tf.resource list.")
40 inner_context = trace_type.InternalTracingContext()
41 trace_type_inputs = trace_type.from_value(
42 tuple(resource_vars), inner_context
43 ).components
45 def to_resource_spec(traced_input):
46 try:
47 handle_data = traced_input.dtype._handle_data # pylint: disable=protected-access
48 shape_and_type = handle_data.shape_and_type[0]
49 spec = tensor_spec.TensorSpec(
50 shape=shape_and_type.shape, dtype=shape_and_type.dtype
51 )
52 return spec
53 except Exception as e:
54 raise ValueError(
55 "Fail to convert tf.Variable list to TensorSpec list. The error"
56 " is: %s" % e
57 ) from e
59 return [to_resource_spec(trace_type) for trace_type in trace_type_inputs]
62def from_concrete_function(
63 concrete_fn,
64 specialized_flat_specs: Optional[List[tensor_spec.TensorSpec]] = None,
65):
66 """Generate the Compiler Ir from tf concrete function with TensorSpec.
68 Args:
69 concrete_fn: returned by using get_concrete_function.
70 specialized_flat_specs: specialized flat tf.TensorSpecs for function args.
72 Returns:
73 Function callable that generate the HLO text.
75 Raises:
76 ValueError: if concrete_fn is not "compilable" without concrete
77 inputs.
78 """
79 context.ensure_initialized()
80 fn_name = concrete_fn.name
81 filtered_flat_specs = specialized_flat_specs or list(
82 nest.flatten(concrete_fn.structured_input_signature)
83 )
85 if not all(s.shape.is_fully_defined() for s in filtered_flat_specs):
86 raise ValueError(
87 f"Only support static input shape but got inputs = {concrete_fn.inputs}"
88 )
90 def compiler_ir_generator(stage="hlo", device_name=None):
91 device_name = maybe_get_device_name(device_name)
92 res_bytes = context.context().get_compiler_ir(
93 device_name=device_name,
94 function_name=fn_name,
95 flat_args=filtered_flat_specs,
96 captured_inputs=concrete_fn.captured_inputs,
97 stage=stage,
98 )
99 if stage in (
100 "hlo_serialized",
101 "optimized_hlo_serialized",
102 "optimized_hlo_proto_serialized",
103 ):
104 return res_bytes
105 else:
106 return res_bytes.decode("utf-8")
108 return compiler_ir_generator