Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/reduce_util.py: 67%
15 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for reduce operations."""
17import enum
19from tensorflow.python.ops import variable_scope
20from tensorflow.python.util.tf_export import tf_export
23@tf_export("distribute.ReduceOp")
24class ReduceOp(enum.Enum):
25 """Indicates how a set of values should be reduced.
27 * `SUM`: Add all the values.
28 * `MEAN`: Take the arithmetic mean ("average") of the values.
29 """
30 # TODO(priyag): Add the following types:
31 # `MIN`: Return the minimum of all values.
32 # `MAX`: Return the maximum of all values.
33 SUM = "SUM"
34 MEAN = "MEAN"
36 @staticmethod
37 def from_variable_aggregation(aggregation):
38 mapping = {
39 variable_scope.VariableAggregation.SUM: ReduceOp.SUM,
40 variable_scope.VariableAggregation.MEAN: ReduceOp.MEAN,
41 }
43 reduce_op = mapping.get(aggregation)
44 if not reduce_op:
45 raise ValueError("Could not convert from `tf.VariableAggregation` %s to"
46 "`tf.distribute.ReduceOp` type" % aggregation)
47 return reduce_op