Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/model_signing/_serialization/file_shard.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

67 statements  

1# Copyright 2024 The Sigstore Authors 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Model serializers that operate at file shard level of granularity.""" 

16 

17from collections.abc import Callable, Iterable 

18import concurrent.futures 

19import itertools 

20import os 

21import pathlib 

22from typing import Optional 

23 

24from typing_extensions import override 

25 

26from model_signing import manifest 

27from model_signing._hashing import io 

28from model_signing._serialization import serialization 

29 

30 

31def _endpoints(step: int, end: int) -> Iterable[int]: 

32 """Yields numbers from `step` to `end` inclusive, spaced by `step`. 

33 

34 Last value is always equal to `end`, even when `end` is not a multiple of 

35 `step`. There is always a value returned. 

36 

37 Examples: 

38 ```python 

39 >>> list(_endpoints(2, 8)) 

40 [2, 4, 6, 8] 

41 >>> list(_endpoints(2, 9)) 

42 [2, 4, 6, 8, 9] 

43 >>> list(_endpoints(2, 2)) 

44 [2] 

45 

46 Yields: 

47 Values in the range, from `step` and up to `end`. 

48 """ 

49 yield from range(step, end, step) 

50 yield end 

51 

52 

53class Serializer(serialization.Serializer): 

54 """Model serializer that produces a manifest recording every file shard. 

55 

56 Traverses the model directory and creates digests for every file found, 

57 sharding the file in equal shards and computing the digests in parallel. 

58 """ 

59 

60 def __init__( 

61 self, 

62 sharded_hasher_factory: Callable[ 

63 [pathlib.Path, int, int], io.ShardedFileHasher 

64 ], 

65 *, 

66 max_workers: Optional[int] = None, 

67 allow_symlinks: bool = False, 

68 ignore_paths: Iterable[pathlib.Path] = frozenset(), 

69 ): 

70 """Initializes an instance to serialize a model with this serializer. 

71 

72 Args: 

73 sharded_hasher_factory: A callable to build the hash engine used to 

74 hash every shard of the files in the model. Because each shard is 

75 processed in parallel, every thread needs to call the factory to 

76 start hashing. The arguments are the file, and the endpoints of 

77 the shard. 

78 max_workers: Maximum number of workers to use in parallel. Default 

79 is to defer to the `concurrent.futures` library. 

80 allow_symlinks: Controls whether symbolic links are included. If a 

81 symlink is present but the flag is `False` (default) the 

82 serialization would raise an error. 

83 ignore_paths: The paths of files to ignore. 

84 """ 

85 self._hasher_factory = sharded_hasher_factory 

86 self._max_workers = max_workers 

87 self._allow_symlinks = allow_symlinks 

88 self._ignore_paths = ignore_paths 

89 

90 # Precompute some private values only once by using a mock file hasher. 

91 # None of the arguments used to build the hasher are used. 

92 hasher = sharded_hasher_factory(pathlib.Path(), 0, 1) 

93 self._shard_size = hasher.shard_size 

94 self._serialization_description = manifest._ShardSerialization( 

95 # Here we need the internal hasher name, not the mangled name. 

96 # This name is used when guessing the hashing configuration. 

97 hasher._content_hasher.digest_name, 

98 self._shard_size, 

99 self._allow_symlinks, 

100 self._ignore_paths, 

101 ) 

102 

103 def set_allow_symlinks(self, allow_symlinks: bool) -> None: 

104 """Set whether following symlinks is allowed.""" 

105 self._allow_symlinks = allow_symlinks 

106 hasher = self._hasher_factory(pathlib.Path(), 0, 1) 

107 self._serialization_description = manifest._ShardSerialization( 

108 hasher._content_hasher.digest_name, 

109 self._shard_size, 

110 self._allow_symlinks, 

111 self._ignore_paths, 

112 ) 

113 

114 @override 

115 def serialize( 

116 self, 

117 model_path: pathlib.Path, 

118 *, 

119 ignore_paths: Iterable[pathlib.Path] = frozenset(), 

120 files_to_hash: Optional[Iterable[pathlib.Path]] = None, 

121 ) -> manifest.Manifest: 

122 """Serializes the model given by the `model_path` argument. 

123 

124 Args: 

125 model_path: The path to the model. 

126 ignore_paths: The paths to ignore during serialization. If a 

127 provided path is a directory, all children of the directory are 

128 ignored. 

129 files_to_hash: Optional list of files to hash; ignore all others 

130 

131 Returns: 

132 The model's serialized manifest. 

133 

134 Raises: 

135 ValueError: The model contains a symbolic link, but the serializer 

136 was not initialized with `allow_symlinks=True`. 

137 """ 

138 shards = [] 

139 # TODO: github.com/sigstore/model-transparency/issues/200 - When 

140 # Python3.12 is the minimum supported version, the glob can be replaced 

141 # with `pathlib.Path.walk` for a clearer interface, and some speed 

142 # improvement. 

143 if files_to_hash is None: 

144 files_to_hash = itertools.chain( 

145 (model_path,), model_path.glob("**/*") 

146 ) 

147 for path in files_to_hash: 

148 serialization.check_file_or_directory( 

149 path, allow_symlinks=self._allow_symlinks 

150 ) 

151 if path.is_file() and not serialization.should_ignore( 

152 path, ignore_paths 

153 ): 

154 shards.extend(self._get_shards(path)) 

155 

156 manifest_items = [] 

157 with concurrent.futures.ThreadPoolExecutor( 

158 max_workers=self._max_workers 

159 ) as tpe: 

160 futures = [ 

161 tpe.submit(self._compute_hash, model_path, path, start, end) 

162 for path, start, end in shards 

163 ] 

164 for future in concurrent.futures.as_completed(futures): 

165 manifest_items.append(future.result()) 

166 

167 # Recreate serialization_description for new ignore_paths 

168 if ignore_paths: 

169 rel_ignore_paths = [] 

170 for p in ignore_paths: 

171 rp = os.path.relpath(p, model_path) 

172 # rp may start with "../" if it is not relative to model_path 

173 if not rp.startswith("../"): 

174 rel_ignore_paths.append(pathlib.Path(rp)) 

175 

176 hasher = self._hasher_factory(pathlib.Path(), 0, 1) 

177 self._serialization_description = manifest._ShardSerialization( 

178 hasher._content_hasher.digest_name, 

179 self._shard_size, 

180 self._allow_symlinks, 

181 frozenset(list(self._ignore_paths) + rel_ignore_paths), 

182 ) 

183 

184 model_name = model_path.name 

185 if not model_name or model_name == "..": 

186 model_name = os.path.basename(model_path.resolve()) 

187 

188 return manifest.Manifest( 

189 model_name, manifest_items, self._serialization_description 

190 ) 

191 

192 def _get_shards( 

193 self, path: pathlib.Path 

194 ) -> list[tuple[pathlib.Path, int, int]]: 

195 """Determines the shards of a given file path.""" 

196 shards = [] 

197 path_size = path.stat().st_size 

198 if path_size > 0: 

199 start = 0 

200 for end in _endpoints(self._shard_size, path_size): 

201 shards.append((path, start, end)) 

202 start = end 

203 return shards 

204 

205 def _compute_hash( 

206 self, model_path: pathlib.Path, path: pathlib.Path, start: int, end: int 

207 ) -> manifest.ShardedFileManifestItem: 

208 """Produces the manifest item of the file given by `path`. 

209 

210 Args: 

211 model_path: The path to the model. 

212 path: Path to the file in the model, that is currently transformed 

213 to a manifest item. 

214 start: The start offset of the shard (included). 

215 end: The end offset of the shard (not included). 

216 

217 Returns: 

218 The itemized manifest. 

219 """ 

220 relative_path = path.relative_to(model_path) 

221 digest = self._hasher_factory(path, start, end).compute() 

222 return manifest.ShardedFileManifestItem( 

223 path=relative_path, digest=digest, start=start, end=end 

224 )