Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/model_signing/_serialization/file_shard.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

68 statements  

1# Copyright 2024 The Sigstore Authors 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Model serializers that operate at file shard level of granularity.""" 

16 

17from collections.abc import Callable, Iterable 

18import concurrent.futures 

19import itertools 

20import os 

21import pathlib 

22 

23from typing_extensions import override 

24 

25from model_signing import manifest 

26from model_signing._hashing import io 

27from model_signing._serialization import serialization 

28 

29 

30def _endpoints(step: int, end: int) -> Iterable[int]: 

31 """Yields numbers from `step` to `end` inclusive, spaced by `step`. 

32 

33 Last value is always equal to `end`, even when `end` is not a multiple of 

34 `step`. There is always a value returned. 

35 

36 Examples: 

37 ```python 

38 >>> list(_endpoints(2, 8)) 

39 [2, 4, 6, 8] 

40 >>> list(_endpoints(2, 9)) 

41 [2, 4, 6, 8, 9] 

42 >>> list(_endpoints(2, 2)) 

43 [2] 

44 

45 Yields: 

46 Values in the range, from `step` and up to `end`. 

47 """ 

48 yield from range(step, end, step) 

49 yield end 

50 

51 

52class Serializer(serialization.Serializer): 

53 """Model serializer that produces a manifest recording every file shard. 

54 

55 Traverses the model directory and creates digests for every file found, 

56 sharding the file in equal shards and computing the digests in parallel. 

57 """ 

58 

59 def __init__( 

60 self, 

61 sharded_hasher_factory: Callable[ 

62 [pathlib.Path, int, int], io.ShardedFileHasher 

63 ], 

64 *, 

65 max_workers: int | None = None, 

66 allow_symlinks: bool = False, 

67 ignore_paths: Iterable[pathlib.Path] = frozenset(), 

68 ): 

69 """Initializes an instance to serialize a model with this serializer. 

70 

71 Args: 

72 sharded_hasher_factory: A callable to build the hash engine used to 

73 hash every shard of the files in the model. Because each shard is 

74 processed in parallel, every thread needs to call the factory to 

75 start hashing. The arguments are the file, and the endpoints of 

76 the shard. 

77 max_workers: Maximum number of workers to use in parallel. Default 

78 is to defer to the `concurrent.futures` library. 

79 allow_symlinks: Controls whether symbolic links are included. If a 

80 symlink is present but the flag is `False` (default) the 

81 serialization would raise an error. 

82 ignore_paths: The paths of files to ignore. 

83 """ 

84 self._hasher_factory = sharded_hasher_factory 

85 self._max_workers = max_workers 

86 self._allow_symlinks = allow_symlinks 

87 self._ignore_paths = ignore_paths 

88 

89 # Precompute some private values only once by using a mock file hasher. 

90 # None of the arguments used to build the hasher are used. 

91 hasher = sharded_hasher_factory(pathlib.Path(), 0, 1) 

92 self._shard_size = hasher.shard_size 

93 self._serialization_description = manifest._ShardSerialization( 

94 # Here we need the internal hasher name, not the mangled name. 

95 # This name is used when guessing the hashing configuration. 

96 hasher._content_hasher.digest_name, 

97 self._shard_size, 

98 self._allow_symlinks, 

99 self._ignore_paths, 

100 ) 

101 

102 def set_allow_symlinks(self, allow_symlinks: bool) -> None: 

103 """Set whether following symlinks is allowed.""" 

104 self._allow_symlinks = allow_symlinks 

105 hasher = self._hasher_factory(pathlib.Path(), 0, 1) 

106 self._serialization_description = manifest._ShardSerialization( 

107 hasher._content_hasher.digest_name, 

108 self._shard_size, 

109 self._allow_symlinks, 

110 self._ignore_paths, 

111 ) 

112 

113 @override 

114 def serialize( 

115 self, 

116 model_path: pathlib.Path, 

117 *, 

118 ignore_paths: Iterable[pathlib.Path] = frozenset(), 

119 files_to_hash: Iterable[pathlib.Path] | None = None, 

120 ) -> manifest.Manifest: 

121 """Serializes the model given by the `model_path` argument. 

122 

123 Args: 

124 model_path: The path to the model. 

125 ignore_paths: The paths to ignore during serialization. If a 

126 provided path is a directory, all children of the directory are 

127 ignored. 

128 files_to_hash: Optional list of files to hash; ignore all others 

129 

130 Returns: 

131 The model's serialized manifest. 

132 

133 Raises: 

134 ValueError: The model contains a symbolic link, but the serializer 

135 was not initialized with `allow_symlinks=True`. 

136 """ 

137 shards = [] 

138 # TODO: github.com/sigstore/model-transparency/issues/200 - When 

139 # Python3.12 is the minimum supported version, the glob can be replaced 

140 # with `pathlib.Path.walk` for a clearer interface, and some speed 

141 # improvement. 

142 if files_to_hash is None: 

143 files_to_hash = itertools.chain( 

144 (model_path,), model_path.glob("**/*") 

145 ) 

146 for path in files_to_hash: 

147 if serialization.should_ignore(path, ignore_paths): 

148 continue 

149 serialization.check_file_or_directory( 

150 path, allow_symlinks=self._allow_symlinks 

151 ) 

152 if path.is_file(): 

153 shards.extend(self._get_shards(path)) 

154 

155 manifest_items = [] 

156 with concurrent.futures.ThreadPoolExecutor( 

157 max_workers=self._max_workers 

158 ) as tpe: 

159 futures = [ 

160 tpe.submit(self._compute_hash, model_path, path, start, end) 

161 for path, start, end in shards 

162 ] 

163 for future in concurrent.futures.as_completed(futures): 

164 manifest_items.append(future.result()) 

165 

166 # Recreate serialization_description for new ignore_paths 

167 if ignore_paths: 

168 rel_ignore_paths = [] 

169 for p in ignore_paths: 

170 rp = os.path.relpath(p, model_path) 

171 # rp may start with "../" if it is not relative to model_path 

172 if not rp.startswith("../"): 

173 rel_ignore_paths.append(pathlib.Path(rp)) 

174 

175 hasher = self._hasher_factory(pathlib.Path(), 0, 1) 

176 self._serialization_description = manifest._ShardSerialization( 

177 hasher._content_hasher.digest_name, 

178 self._shard_size, 

179 self._allow_symlinks, 

180 frozenset(list(self._ignore_paths) + rel_ignore_paths), 

181 ) 

182 

183 model_name = model_path.name 

184 if not model_name or model_name == "..": 

185 model_name = os.path.basename(model_path.resolve()) 

186 

187 return manifest.Manifest( 

188 model_name, manifest_items, self._serialization_description 

189 ) 

190 

191 def _get_shards( 

192 self, path: pathlib.Path 

193 ) -> list[tuple[pathlib.Path, int, int]]: 

194 """Determines the shards of a given file path.""" 

195 shards = [] 

196 path_size = path.stat().st_size 

197 if path_size > 0: 

198 start = 0 

199 for end in _endpoints(self._shard_size, path_size): 

200 shards.append((path, start, end)) 

201 start = end 

202 return shards 

203 

204 def _compute_hash( 

205 self, model_path: pathlib.Path, path: pathlib.Path, start: int, end: int 

206 ) -> manifest.ShardedFileManifestItem: 

207 """Produces the manifest item of the file given by `path`. 

208 

209 Args: 

210 model_path: The path to the model. 

211 path: Path to the file in the model, that is currently transformed 

212 to a manifest item. 

213 start: The start offset of the shard (included). 

214 end: The end offset of the shard (not included). 

215 

216 Returns: 

217 The itemized manifest. 

218 """ 

219 relative_path = path.relative_to(model_path) 

220 digest = self._hasher_factory(path, start, end).compute() 

221 return manifest.ShardedFileManifestItem( 

222 path=relative_path, digest=digest, start=start, end=end 

223 )