1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""High level API for the hashing interface of `model_signing` library.
16
17Hashing is used both for signing and verification and users should ensure that
18the same configuration is used in both cases.
19
20The module could also be used to just hash a single model, without signing it:
21
22```python
23model_signing.hashing.hash(model_path)
24```
25
26This module allows setting up the hashing configuration to a single variable and
27then sharing it between signing and verification.
28
29```python
30hashing_config = model_signing.hashing.Config().set_ignored_paths(
31 paths=["README.md"], ignore_git_paths=True
32)
33
34signing_config = (
35 model_signing.signing.Config()
36 .use_elliptic_key_signer(private_key="key")
37 .set_hashing_config(hashing_config)
38)
39
40verifying_config = (
41 model_signing.verifying.Config()
42 .use_elliptic_key_verifier(public_key="key.pub")
43 .set_hashing_config(hashing_config)
44)
45```
46
47The API defined here is stable and backwards compatible.
48"""
49
50from collections.abc import Callable, Iterable
51import os
52import pathlib
53import sys
54from typing import Literal, Optional, Union
55
56from model_signing import manifest
57from model_signing._hashing import hashing
58from model_signing._hashing import io
59from model_signing._hashing import memory
60from model_signing._serialization import file
61from model_signing._serialization import file_shard
62
63
64if sys.version_info >= (3, 11):
65 from typing import Self
66else:
67 from typing_extensions import Self
68
69
70# `TypeAlias` only exists from Python 3.10
71# `TypeAlias` is deprecated in Python 3.12 in favor of `type`
72if sys.version_info >= (3, 10):
73 from typing import TypeAlias
74else:
75 from typing_extensions import TypeAlias
76
77
78# Type alias to support `os.PathLike`, `str` and `bytes` objects in the API
79# When Python 3.12 is the minimum supported version we can use `type`
80# When Python 3.11 is the minimum supported version we can use `|`
81PathLike: TypeAlias = Union[str, bytes, os.PathLike]
82
83
84def hash(model_path: PathLike) -> manifest.Manifest:
85 """Hashes a model using the default configuration.
86
87 Hashing is the shared part between signing and verification and is also
88 expected to be the slowest component. When serializing a model, we need to
89 spend time proportional to the model size on disk.
90
91 This method returns a "manifest" of the model. A manifest is a collection of
92 every object in the model, paired with the corresponding hash. Currently, we
93 consider an object in the model to be either a file or a shard of the file.
94 Large models with large files will be hashed much faster when every shard is
95 hashed in parallel, at the cost of generating a larger payload for the
96 signature. In future releases we could support hashing individual tensors or
97 tensor slices for further speed optimizations for very large models.
98
99 Args:
100 model_path: The path to the model to hash.
101
102 Returns:
103 A manifest of the hashed model.
104 """
105 return Config().hash(model_path)
106
107
108class Config:
109 """Configuration to use when hashing models.
110
111 Hashing is the shared part between signing and verification and is also
112 expected to be the slowest component. When serializing a model, we need to
113 spend time proportional to the model size on disk.
114
115 Hashing builds a "manifest" of the model. A manifest is a collection of
116 every object in the model, paired with the corresponding hash. Currently, we
117 consider an object in the model to be either a file or a shard of the file.
118 Large models with large files will be hashed much faster when every shard is
119 hashed in parallel, at the cost of generating a larger payload for the
120 signature. In future releases we could support hashing individual tensors or
121 tensor slices for further speed optimizations for very large models.
122
123 This configuration class supports configuring the hashing granularity. By
124 default, we hash at file level granularity.
125
126 This configuration class also supports configuring the hash method used to
127 generate the hash for every object in the model. We currently support SHA256
128 and BLAKE2, with SHA256 being the default.
129
130 This configuration class also supports configuring which paths from the
131 model directory should be ignored. These are files that doesn't impact the
132 behavior of the model, or files that won't be distributed with the model. By
133 default, only files that are associated with a git repository (`.git`,
134 `.gitattributes`, `.gitignore`, etc.) are ignored.
135 """
136
137 def __init__(self):
138 """Initializes the default configuration for hashing."""
139 self._ignored_paths = frozenset()
140 self._ignore_git_paths = True
141 self.use_file_serialization()
142 self._allow_symlinks = False
143
144 def hash(
145 self,
146 model_path: PathLike,
147 *,
148 files_to_hash: Optional[Iterable[PathLike]] = None,
149 ) -> manifest.Manifest:
150 """Hashes a model using the current configuration."""
151 # All paths in ``_ignored_paths`` are expected to be relative to the
152 # model directory. Join them to ``model_path`` and ensure they do not
153 # escape it.
154 model_path = pathlib.Path(model_path)
155 ignored_paths = []
156 for p in self._ignored_paths:
157 full = model_path / p
158 try:
159 full.relative_to(model_path)
160 except ValueError:
161 continue
162 ignored_paths.append(full)
163
164 if self._ignore_git_paths:
165 ignored_paths.extend(
166 [
167 model_path / p
168 for p in [
169 ".git/",
170 ".gitattributes",
171 ".github/",
172 ".gitignore",
173 ]
174 ]
175 )
176
177 self._serializer.set_allow_symlinks(self._allow_symlinks)
178
179 return self._serializer.serialize(
180 pathlib.Path(model_path),
181 ignore_paths=ignored_paths,
182 files_to_hash=files_to_hash,
183 )
184
185 def _build_stream_hasher(
186 self, hashing_algorithm: Literal["sha256", "blake2"] = "sha256"
187 ) -> hashing.StreamingHashEngine:
188 """Builds a streaming hasher from a constant string.
189
190 Args:
191 hashing_algorithm: The hashing algorithm to use.
192
193 Returns:
194 An instance of the requested hasher.
195 """
196 # TODO: Once Python 3.9 support is deprecated revert to using `match`
197 if hashing_algorithm == "sha256":
198 return memory.SHA256()
199 if hashing_algorithm == "blake2":
200 return memory.BLAKE2()
201
202 raise ValueError(f"Unsupported hashing method {hashing_algorithm}")
203
204 def _build_file_hasher_factory(
205 self,
206 hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
207 chunk_size: int = 1048576,
208 ) -> Callable[[pathlib.Path], io.SimpleFileHasher]:
209 """Builds the hasher factory for a serialization by file.
210
211 Args:
212 hashing_algorithm: The hashing algorithm to use to hash a file.
213 chunk_size: The amount of file to read at once. Default is 1MB. A
214 special value of 0 signals to attempt to read everything in a
215 single call.
216
217 Returns:
218 The hasher factory that should be used by the active serialization
219 method.
220 """
221
222 def _factory(path: pathlib.Path) -> io.SimpleFileHasher:
223 hasher = self._build_stream_hasher(hashing_algorithm)
224 return io.SimpleFileHasher(path, hasher, chunk_size=chunk_size)
225
226 return _factory
227
228 def _build_sharded_file_hasher_factory(
229 self,
230 hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
231 chunk_size: int = 1048576,
232 shard_size: int = 1_000_000_000,
233 ) -> Callable[[pathlib.Path, int, int], io.ShardedFileHasher]:
234 """Builds the hasher factory for a serialization by file shards.
235
236 Args:
237 hashing_algorithm: The hashing algorithm to use to hash a shard.
238 chunk_size: The amount of file to read at once. Default is 1MB. A
239 special value of 0 signals to attempt to read everything in a
240 single call.
241 shard_size: The size of a file shard. Default is 1 GB.
242
243 Returns:
244 The hasher factory that should be used by the active serialization
245 method.
246 """
247
248 def _factory(
249 path: pathlib.Path, start: int, end: int
250 ) -> io.ShardedFileHasher:
251 hasher = self._build_stream_hasher(hashing_algorithm)
252 return io.ShardedFileHasher(
253 path,
254 hasher,
255 start=start,
256 end=end,
257 chunk_size=chunk_size,
258 shard_size=shard_size,
259 )
260
261 return _factory
262
263 def use_file_serialization(
264 self,
265 *,
266 hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
267 chunk_size: int = 1048576,
268 max_workers: Optional[int] = None,
269 allow_symlinks: bool = False,
270 ignore_paths: Iterable[pathlib.Path] = frozenset(),
271 ) -> Self:
272 """Configures serialization to build a manifest of (file, hash) pairs.
273
274 The serialization method in this configuration is changed to one where
275 every file in the model is paired with its digest and a manifest
276 containing all these pairings is being built.
277
278 Args:
279 hashing_algorithm: The hashing algorithm to use to hash a file.
280 chunk_size: The amount of file to read at once. Default is 1MB. A
281 special value of 0 signals to attempt to read everything in a
282 single call.
283 max_workers: Maximum number of workers to use in parallel. Default
284 is to defer to the `concurrent.futures` library to select the best
285 value for the current machine.
286 allow_symlinks: Controls whether symbolic links are included. If a
287 symlink is present but the flag is `False` (default) the
288 serialization would raise an error.
289
290 Returns:
291 The new hashing configuration with the new serialization method.
292 """
293 self._serializer = file.Serializer(
294 self._build_file_hasher_factory(hashing_algorithm, chunk_size),
295 max_workers=max_workers,
296 allow_symlinks=allow_symlinks,
297 ignore_paths=ignore_paths,
298 )
299 return self
300
301 def use_shard_serialization(
302 self,
303 *,
304 hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
305 chunk_size: int = 1048576,
306 shard_size: int = 1_000_000_000,
307 max_workers: Optional[int] = None,
308 allow_symlinks: bool = False,
309 ignore_paths: Iterable[pathlib.Path] = frozenset(),
310 ) -> Self:
311 """Configures serialization to build a manifest of (shard, hash) pairs.
312
313 The serialization method in this configuration is changed to one where
314 every file in the model is sharded in equal sized shards, every shard is
315 paired with its digest and a manifest containing all these pairings is
316 being built.
317
318 Args:
319 hashing_algorithm: The hashing algorithm to use to hash a shard.
320 chunk_size: The amount of file to read at once. Default is 1MB. A
321 special value of 0 signals to attempt to read everything in a
322 single call.
323 shard_size: The size of a file shard. Default is 1 GB.
324 max_workers: Maximum number of workers to use in parallel. Default
325 is to defer to the `concurrent.futures` library to select the best
326 value for the current machine.
327 allow_symlinks: Controls whether symbolic links are included. If a
328 symlink is present but the flag is `False` (default) the
329 serialization would raise an error.
330 ignore_paths: Paths of files to ignore.
331
332 Returns:
333 The new hashing configuration with the new serialization method.
334 """
335 self._serializer = file_shard.Serializer(
336 self._build_sharded_file_hasher_factory(
337 hashing_algorithm, chunk_size, shard_size
338 ),
339 max_workers=max_workers,
340 allow_symlinks=allow_symlinks,
341 ignore_paths=ignore_paths,
342 )
343 return self
344
345 def set_ignored_paths(
346 self, *, paths: Iterable[PathLike], ignore_git_paths: bool = True
347 ) -> Self:
348 """Configures the paths to be ignored during serialization of a model.
349
350 If the model is a single file, there are no paths that are ignored. If
351 the model is a directory, all paths are considered as relative to the
352 model directory, since we never look at files outside of it.
353
354 If an ignored path is a directory, serialization will ignore both the
355 path and any of its children.
356
357 Args:
358 paths: The paths to ignore.
359 ignore_git_paths: Whether to ignore git related paths (default) or
360 include them in the signature.
361
362 Returns:
363 The new hashing configuration with a new set of ignored paths.
364 """
365 # Preserve the user-provided relative paths; they are resolved against
366 # the model directory later when hashing.
367 self._ignored_paths = frozenset(pathlib.Path(p) for p in paths)
368 self._ignore_git_paths = ignore_git_paths
369 return self
370
371 def add_ignored_paths(
372 self, *, model_path: PathLike, paths: Iterable[PathLike]
373 ) -> None:
374 """Add more paths to ignore to existing set of paths.
375
376 Args:
377 model_path: The path to the model
378 paths: Additional paths to ignore. All path must be relative to
379 the model directory.
380 """
381 newset = set(self._ignored_paths)
382 model_path = pathlib.Path(model_path)
383 for p in paths:
384 candidate = pathlib.Path(p)
385 full = model_path / candidate
386 try:
387 full.relative_to(model_path)
388 except ValueError:
389 continue
390 newset.add(candidate)
391 self._ignored_paths = newset
392
393 def set_allow_symlinks(self, allow_symlinks: bool) -> Self:
394 """Set whether following symlinks is allowed."""
395 self._allow_symlinks = allow_symlinks
396 return self