Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/lock_util.py: 45%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Locking related utils.""" 

16 

17import threading 

18 

19 

20class GroupLock(object): 

21 """A lock to allow many members of a group to access a resource exclusively. 

22 

23 This lock provides a way to allow access to a resource by multiple threads 

24 belonging to a logical group at the same time, while restricting access to 

25 threads from all other groups. You can think of this as an extension of a 

26 reader-writer lock, where you allow multiple writers at the same time. We 

27 made it generic to support multiple groups instead of just two - readers and 

28 writers. 

29 

30 Simple usage example with two groups accessing the same resource: 

31 

32 ```python 

33 lock = GroupLock(num_groups=2) 

34 

35 # In a member of group 0: 

36 with lock.group(0): 

37 # do stuff, access the resource 

38 # ... 

39 

40 # In a member of group 1: 

41 with lock.group(1): 

42 # do stuff, access the resource 

43 # ... 

44 ``` 

45 

46 Using as a context manager with `.group(group_id)` is the easiest way. You 

47 can also use the `acquire` and `release` method directly. 

48 """ 

49 

50 __slots__ = ["_ready", "_num_groups", "_group_member_counts"] 

51 

52 def __init__(self, num_groups=2): 

53 """Initialize a group lock. 

54 

55 Args: 

56 num_groups: The number of groups that will be accessing the resource under 

57 consideration. Should be a positive number. 

58 

59 Returns: 

60 A group lock that can then be used to synchronize code. 

61 

62 Raises: 

63 ValueError: If num_groups is less than 1. 

64 """ 

65 if num_groups < 1: 

66 raise ValueError( 

67 "Argument `num_groups` must be a positive integer. " 

68 f"Received: num_groups={num_groups}") 

69 self._ready = threading.Condition(threading.Lock()) 

70 self._num_groups = num_groups 

71 self._group_member_counts = [0] * self._num_groups 

72 

73 def group(self, group_id): 

74 """Enter a context where the lock is with group `group_id`. 

75 

76 Args: 

77 group_id: The group for which to acquire and release the lock. 

78 

79 Returns: 

80 A context manager which will acquire the lock for `group_id`. 

81 """ 

82 self._validate_group_id(group_id) 

83 return self._Context(self, group_id) 

84 

85 def acquire(self, group_id): 

86 """Acquire the group lock for a specific group `group_id`.""" 

87 self._validate_group_id(group_id) 

88 

89 self._ready.acquire() 

90 while self._another_group_active(group_id): 

91 self._ready.wait() 

92 self._group_member_counts[group_id] += 1 

93 self._ready.release() 

94 

95 def release(self, group_id): 

96 """Release the group lock for a specific group `group_id`.""" 

97 self._validate_group_id(group_id) 

98 

99 self._ready.acquire() 

100 self._group_member_counts[group_id] -= 1 

101 if self._group_member_counts[group_id] == 0: 

102 self._ready.notify_all() 

103 self._ready.release() 

104 

105 def _another_group_active(self, group_id): 

106 return any( 

107 c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id) 

108 

109 def _validate_group_id(self, group_id): 

110 if group_id < 0 or group_id >= self._num_groups: 

111 raise ValueError( 

112 "Argument `group_id` should verify `0 <= group_id < num_groups` " 

113 f"(with `num_groups={self._num_groups}`). " 

114 f"Received: group_id={group_id}") 

115 

116 class _Context(object): 

117 """Context manager helper for `GroupLock`.""" 

118 

119 __slots__ = ["_lock", "_group_id"] 

120 

121 def __init__(self, lock, group_id): 

122 self._lock = lock 

123 self._group_id = group_id 

124 

125 def __enter__(self): 

126 self._lock.acquire(self._group_id) 

127 

128 def __exit__(self, type_arg, value_arg, traceback_arg): 

129 del type_arg, value_arg, traceback_arg 

130 self._lock.release(self._group_id)