Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/python_state.py: 82%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Utilities for including Python state in TensorFlow checkpoints.""" 

2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import abc 

17 

18from tensorflow.python.framework import constant_op 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.trackable import base 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25PYTHON_STATE = "py_state" 

26 

27 

28@tf_export("train.experimental.PythonState") 

29class PythonState(base.Trackable, metaclass=abc.ABCMeta): 

30 """A mixin for putting Python state in an object-based checkpoint. 

31 

32 This is an abstract class which allows extensions to TensorFlow's object-based 

33 checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy 

34 arrays: 

35 

36 ```python 

37 import io 

38 import numpy 

39 

40 class NumpyWrapper(tf.train.experimental.PythonState): 

41 

42 def __init__(self, array): 

43 self.array = array 

44 

45 def serialize(self): 

46 string_file = io.BytesIO() 

47 try: 

48 numpy.save(string_file, self.array, allow_pickle=False) 

49 serialized = string_file.getvalue() 

50 finally: 

51 string_file.close() 

52 return serialized 

53 

54 def deserialize(self, string_value): 

55 string_file = io.BytesIO(string_value) 

56 try: 

57 self.array = numpy.load(string_file, allow_pickle=False) 

58 finally: 

59 string_file.close() 

60 ``` 

61 

62 Instances of `NumpyWrapper` are checkpointable objects, and will be saved and 

63 restored from checkpoints along with TensorFlow state like variables. 

64 

65 ```python 

66 root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.]))) 

67 save_path = root.save(prefix) 

68 root.numpy.array *= 2. 

69 assert [2.] == root.numpy.array 

70 root.restore(save_path) 

71 assert [1.] == root.numpy.array 

72 ``` 

73 """ 

74 

75 @abc.abstractmethod 

76 def serialize(self): 

77 """Callback to serialize the object. Returns a string.""" 

78 

79 @abc.abstractmethod 

80 def deserialize(self, string_value): 

81 """Callback to deserialize the object.""" 

82 

83 def _serialize_to_tensors(self): 

84 """Implements Trackable._serialize_to_tensors.""" 

85 with ops.init_scope(): 

86 value = constant_op.constant(self.serialize(), dtype=dtypes.string) 

87 return {PYTHON_STATE: value}