Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/python_state.py: 82%
17 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Utilities for including Python state in TensorFlow checkpoints."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import abc
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.trackable import base
22from tensorflow.python.util.tf_export import tf_export
25PYTHON_STATE = "py_state"
28@tf_export("train.experimental.PythonState")
29class PythonState(base.Trackable, metaclass=abc.ABCMeta):
30 """A mixin for putting Python state in an object-based checkpoint.
32 This is an abstract class which allows extensions to TensorFlow's object-based
33 checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy
34 arrays:
36 ```python
37 import io
38 import numpy
40 class NumpyWrapper(tf.train.experimental.PythonState):
42 def __init__(self, array):
43 self.array = array
45 def serialize(self):
46 string_file = io.BytesIO()
47 try:
48 numpy.save(string_file, self.array, allow_pickle=False)
49 serialized = string_file.getvalue()
50 finally:
51 string_file.close()
52 return serialized
54 def deserialize(self, string_value):
55 string_file = io.BytesIO(string_value)
56 try:
57 self.array = numpy.load(string_file, allow_pickle=False)
58 finally:
59 string_file.close()
60 ```
62 Instances of `NumpyWrapper` are checkpointable objects, and will be saved and
63 restored from checkpoints along with TensorFlow state like variables.
65 ```python
66 root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
67 save_path = root.save(prefix)
68 root.numpy.array *= 2.
69 assert [2.] == root.numpy.array
70 root.restore(save_path)
71 assert [1.] == root.numpy.array
72 ```
73 """
75 @abc.abstractmethod
76 def serialize(self):
77 """Callback to serialize the object. Returns a string."""
79 @abc.abstractmethod
80 def deserialize(self, string_value):
81 """Callback to deserialize the object."""
83 def _serialize_to_tensors(self):
84 """Implements Trackable._serialize_to_tensors."""
85 with ops.init_scope():
86 value = constant_op.constant(self.serialize(), dtype=dtypes.string)
87 return {PYTHON_STATE: value}