Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/compiler_ir.py: 28%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implmentation for defining get_compiler_ir.""" 

16from typing import List, Optional 

17 

18from tensorflow.core.function import trace_type 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import tensor_spec 

22from tensorflow.python.ops import random_ops 

23 

24from tensorflow.python.util import nest 

25 

26 

27def maybe_get_device_name(device_name): 

28 # TODO(cheshire): This is a hack to get the current "preferred" device, 

29 # there is no current API to get it otherwise. 

30 if device_name is None: 

31 device_name = random_ops.random_normal([]).device 

32 return device_name 

33 

34 

35# TODO(fmuham): Use trace_type._flatten here instead when available 

36def make_handledata_tensor_specs(resource_vars): 

37 """Convert tf.Variable list to its corresponding TensorSpec list.""" 

38 if not all(x.dtype is dtypes.resource for x in resource_vars): 

39 raise RuntimeError("Resource_vars must be tf.resource list.") 

40 inner_context = trace_type.InternalTracingContext() 

41 trace_type_inputs = trace_type.from_value( 

42 tuple(resource_vars), inner_context 

43 ).components 

44 

45 def to_resource_spec(traced_input): 

46 try: 

47 handle_data = traced_input.dtype._handle_data # pylint: disable=protected-access 

48 shape_and_type = handle_data.shape_and_type[0] 

49 spec = tensor_spec.TensorSpec( 

50 shape=shape_and_type.shape, dtype=shape_and_type.dtype 

51 ) 

52 return spec 

53 except Exception as e: 

54 raise ValueError( 

55 "Fail to convert tf.Variable list to TensorSpec list. The error" 

56 " is: %s" % e 

57 ) from e 

58 

59 return [to_resource_spec(trace_type) for trace_type in trace_type_inputs] 

60 

61 

62def from_concrete_function( 

63 concrete_fn, 

64 specialized_flat_specs: Optional[List[tensor_spec.TensorSpec]] = None, 

65): 

66 """Generate the Compiler Ir from tf concrete function with TensorSpec. 

67 

68 Args: 

69 concrete_fn: returned by using get_concrete_function. 

70 specialized_flat_specs: specialized flat tf.TensorSpecs for function args. 

71 

72 Returns: 

73 Function callable that generate the HLO text. 

74 

75 Raises: 

76 ValueError: if concrete_fn is not "compilable" without concrete 

77 inputs. 

78 """ 

79 context.ensure_initialized() 

80 fn_name = concrete_fn.name 

81 filtered_flat_specs = specialized_flat_specs or list( 

82 nest.flatten(concrete_fn.structured_input_signature) 

83 ) 

84 

85 if not all(s.shape.is_fully_defined() for s in filtered_flat_specs): 

86 raise ValueError( 

87 f"Only support static input shape but got inputs = {concrete_fn.inputs}" 

88 ) 

89 

90 def compiler_ir_generator(stage="hlo", device_name=None): 

91 device_name = maybe_get_device_name(device_name) 

92 res_bytes = context.context().get_compiler_ir( 

93 device_name=device_name, 

94 function_name=fn_name, 

95 flat_args=filtered_flat_specs, 

96 captured_inputs=concrete_fn.captured_inputs, 

97 stage=stage, 

98 ) 

99 if stage in ( 

100 "hlo_serialized", 

101 "optimized_hlo_serialized", 

102 "optimized_hlo_proto_serialized", 

103 ): 

104 return res_bytes 

105 else: 

106 return res_bytes.decode("utf-8") 

107 

108 return compiler_ir_generator