1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""High level API for the hashing interface of `model_signing` library.
16
17Hashing is used both for signing and verification and users should ensure that
18the same configuration is used in both cases.
19
20The module could also be used to just hash a single model, without signing it:
21
22```python
23model_signing.hashing.hash(model_path)
24```
25
26This module allows setting up the hashing configuration to a single variable and
27then sharing it between signing and verification.
28
29```python
30hashing_config = model_signing.hashing.Config().set_ignored_paths(
31 paths=["README.md"], ignore_git_paths=True
32)
33
34signing_config = (
35 model_signing.signing.Config()
36 .use_elliptic_key_signer(private_key="key")
37 .set_hashing_config(hashing_config)
38)
39
40verifying_config = (
41 model_signing.verifying.Config()
42 .use_elliptic_key_verifier(public_key="key.pub")
43 .set_hashing_config(hashing_config)
44)
45```
46
47The API defined here is stable and backwards compatible.
48"""
49
50from collections.abc import Callable, Iterable
51import os
52import pathlib
53import sys
54from typing import Literal, Optional, Union
55
56import blake3
57
58from model_signing import manifest
59from model_signing._hashing import hashing
60from model_signing._hashing import io
61from model_signing._hashing import memory
62from model_signing._serialization import file
63from model_signing._serialization import file_shard
64
65
66if sys.version_info >= (3, 11):
67 from typing import Self
68else:
69 from typing_extensions import Self
70
71
72# `TypeAlias` only exists from Python 3.10
73# `TypeAlias` is deprecated in Python 3.12 in favor of `type`
74if sys.version_info >= (3, 10):
75 from typing import TypeAlias
76else:
77 from typing_extensions import TypeAlias
78
79
80# Type alias to support `os.PathLike`, `str` and `bytes` objects in the API
81# When Python 3.12 is the minimum supported version we can use `type`
82# When Python 3.11 is the minimum supported version we can use `|`
83PathLike: TypeAlias = Union[str, bytes, os.PathLike]
84
85
86def hash(model_path: PathLike) -> manifest.Manifest:
87 """Hashes a model using the default configuration.
88
89 Hashing is the shared part between signing and verification and is also
90 expected to be the slowest component. When serializing a model, we need to
91 spend time proportional to the model size on disk.
92
93 This method returns a "manifest" of the model. A manifest is a collection of
94 every object in the model, paired with the corresponding hash. Currently, we
95 consider an object in the model to be either a file or a shard of the file.
96 Large models with large files will be hashed much faster when every shard is
97 hashed in parallel, at the cost of generating a larger payload for the
98 signature. In future releases we could support hashing individual tensors or
99 tensor slices for further speed optimizations for very large models.
100
101 Args:
102 model_path: The path to the model to hash.
103
104 Returns:
105 A manifest of the hashed model.
106 """
107 return Config().hash(model_path)
108
109
110class Config:
111 """Configuration to use when hashing models.
112
113 Hashing is the shared part between signing and verification and is also
114 expected to be the slowest component. When serializing a model, we need to
115 spend time proportional to the model size on disk.
116
117 Hashing builds a "manifest" of the model. A manifest is a collection of
118 every object in the model, paired with the corresponding hash. Currently, we
119 consider an object in the model to be either a file or a shard of the file.
120 Large models with large files will be hashed much faster when every shard is
121 hashed in parallel, at the cost of generating a larger payload for the
122 signature. In future releases we could support hashing individual tensors or
123 tensor slices for further speed optimizations for very large models.
124
125 This configuration class supports configuring the hashing granularity. By
126 default, we hash at file level granularity.
127
128 This configuration class also supports configuring the hash method used to
129 generate the hash for every object in the model. We currently support
130 SHA256, BLAKE2, and BLAKE3, with SHA256 being the default.
131
132 This configuration class also supports configuring which paths from the
133 model directory should be ignored. These are files that doesn't impact the
134 behavior of the model, or files that won't be distributed with the model. By
135 default, only files that are associated with a git repository (`.git`,
136 `.gitattributes`, `.gitignore`, etc.) are ignored.
137 """
138
139 def __init__(self):
140 """Initializes the default configuration for hashing."""
141 self._ignored_paths = frozenset()
142 self._ignore_git_paths = True
143 self.use_file_serialization()
144 self._allow_symlinks = False
145
146 def hash(
147 self,
148 model_path: PathLike,
149 *,
150 files_to_hash: Optional[Iterable[PathLike]] = None,
151 ) -> manifest.Manifest:
152 """Hashes a model using the current configuration."""
153 # All paths in ``_ignored_paths`` are expected to be relative to the
154 # model directory. Join them to ``model_path`` and ensure they do not
155 # escape it.
156 model_path = pathlib.Path(model_path)
157 ignored_paths = []
158 for p in self._ignored_paths:
159 full = model_path / p
160 try:
161 full.relative_to(model_path)
162 except ValueError:
163 continue
164 ignored_paths.append(full)
165
166 if self._ignore_git_paths:
167 ignored_paths.extend(
168 [
169 model_path / p
170 for p in [
171 ".git/",
172 ".gitattributes",
173 ".github/",
174 ".gitignore",
175 ]
176 ]
177 )
178
179 self._serializer.set_allow_symlinks(self._allow_symlinks)
180
181 return self._serializer.serialize(
182 pathlib.Path(model_path),
183 ignore_paths=ignored_paths,
184 files_to_hash=files_to_hash,
185 )
186
187 def _build_stream_hasher(
188 self,
189 hashing_algorithm: Literal["sha256", "blake2", "blake3"] = "sha256",
190 ) -> hashing.StreamingHashEngine:
191 """Builds a streaming hasher from a constant string.
192
193 Args:
194 hashing_algorithm: The hashing algorithm to use.
195
196 Returns:
197 An instance of the requested hasher.
198 """
199 # TODO: Once Python 3.9 support is deprecated revert to using `match`
200 if hashing_algorithm == "sha256":
201 return memory.SHA256()
202 if hashing_algorithm == "blake2":
203 return memory.BLAKE2()
204 if hashing_algorithm == "blake3":
205 return memory.BLAKE3()
206
207 raise ValueError(f"Unsupported hashing method {hashing_algorithm}")
208
209 def _build_file_hasher_factory(
210 self,
211 hashing_algorithm: Literal["sha256", "blake2", "blake3"] = "sha256",
212 chunk_size: int = 1048576,
213 max_workers: Optional[int] = None,
214 ) -> Callable[[pathlib.Path], io.FileHasher]:
215 """Builds the hasher factory for a serialization by file.
216
217 Args:
218 hashing_algorithm: The hashing algorithm to use to hash a file.
219 chunk_size: The amount of file to read at once. Default is 1MB. A
220 special value of 0 signals to attempt to read everything in a
221 single call. This is ignored for BLAKE3.
222 max_workers: Maximum number of workers to use in parallel. Defaults
223 to the number of logical cores. Only relevant for BLAKE3.
224
225 Returns:
226 The hasher factory that should be used by the active serialization
227 method.
228 """
229 if max_workers is None:
230 max_workers = blake3.blake3.AUTO
231
232 def _factory(path: pathlib.Path) -> io.FileHasher:
233 if hashing_algorithm == "blake3":
234 return io.Blake3FileHasher(path, max_threads=max_workers)
235 hasher = self._build_stream_hasher(hashing_algorithm)
236 return io.SimpleFileHasher(path, hasher, chunk_size=chunk_size)
237
238 return _factory
239
240 def _build_sharded_file_hasher_factory(
241 self,
242 hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
243 chunk_size: int = 1048576,
244 shard_size: int = 1_000_000_000,
245 ) -> Callable[[pathlib.Path, int, int], io.ShardedFileHasher]:
246 """Builds the hasher factory for a serialization by file shards.
247
248 This is not recommended for BLAKE3 because it is not necessary. BLAKE3
249 already operates in parallel.
250
251 Args:
252 hashing_algorithm: The hashing algorithm to use to hash a shard.
253 chunk_size: The amount of file to read at once. Default is 1MB. A
254 special value of 0 signals to attempt to read everything in a
255 single call.
256 shard_size: The size of a file shard. Default is 1 GB.
257
258 Returns:
259 The hasher factory that should be used by the active serialization
260 method.
261 """
262
263 def _factory(
264 path: pathlib.Path, start: int, end: int
265 ) -> io.ShardedFileHasher:
266 hasher = self._build_stream_hasher(hashing_algorithm)
267 return io.ShardedFileHasher(
268 path,
269 hasher,
270 start=start,
271 end=end,
272 chunk_size=chunk_size,
273 shard_size=shard_size,
274 )
275
276 return _factory
277
278 def use_file_serialization(
279 self,
280 *,
281 hashing_algorithm: Literal["sha256", "blake2", "blake3"] = "sha256",
282 chunk_size: int = 1048576,
283 max_workers: Optional[int] = None,
284 allow_symlinks: bool = False,
285 ignore_paths: Iterable[pathlib.Path] = frozenset(),
286 ) -> Self:
287 """Configures serialization to build a manifest of (file, hash) pairs.
288
289 The serialization method in this configuration is changed to one where
290 every file in the model is paired with its digest and a manifest
291 containing all these pairings is being built.
292
293 Args:
294 hashing_algorithm: The hashing algorithm to use to hash a file.
295 chunk_size: The amount of file to read at once. Default is 1MB. A
296 special value of 0 signals to attempt to read everything in a
297 single call. Ignored for BLAKE3.
298 max_workers: Maximum number of workers to use in parallel. Default
299 is to defer to the `concurrent.futures` library to select the best
300 value for the current machine, or the number of logical cores
301 when doing BLAKE3 hashing. When reading files off of slower
302 hardware like an HDD rather than an SSD, and using BLAKE3,
303 setting max_workers to 1 may improve performance.
304 allow_symlinks: Controls whether symbolic links are included. If a
305 symlink is present but the flag is `False` (default) the
306 serialization would raise an error.
307
308 Returns:
309 The new hashing configuration with the new serialization method.
310 """
311 self._serializer = file.Serializer(
312 self._build_file_hasher_factory(
313 hashing_algorithm, chunk_size, max_workers
314 ),
315 max_workers=max_workers,
316 allow_symlinks=allow_symlinks,
317 ignore_paths=ignore_paths,
318 )
319 return self
320
321 def use_shard_serialization(
322 self,
323 *,
324 hashing_algorithm: Literal["sha256", "blake2", "blake3"] = "sha256",
325 chunk_size: int = 1048576,
326 shard_size: int = 1_000_000_000,
327 max_workers: Optional[int] = None,
328 allow_symlinks: bool = False,
329 ignore_paths: Iterable[pathlib.Path] = frozenset(),
330 ) -> Self:
331 """Configures serialization to build a manifest of (shard, hash) pairs.
332
333 For BLAKE3 this is equivalent to not sharding. Sharding is bypassed
334 because BLAKE3 already operates in parallel. This means the chunk_size
335 and shard_size args are ignored.
336
337 The serialization method in this configuration is changed to one where
338 every file in the model is sharded in equal sized shards, every shard is
339 paired with its digest and a manifest containing all these pairings is
340 being built.
341
342 Args:
343 hashing_algorithm: The hashing algorithm to use to hash a shard.
344 chunk_size: The amount of file to read at once. Default is 1MB. A
345 special value of 0 signals to attempt to read everything in a
346 single call.
347 shard_size: The size of a file shard. Default is 1 GB.
348 max_workers: Maximum number of workers to use in parallel. Default
349 is to defer to the `concurrent.futures` library to select the best
350 value for the current machine.
351 allow_symlinks: Controls whether symbolic links are included. If a
352 symlink is present but the flag is `False` (default) the
353 serialization would raise an error.
354 ignore_paths: Paths of files to ignore.
355
356 Returns:
357 The new hashing configuration with the new serialization method.
358 """
359 if hashing_algorithm == "blake3":
360 return self.use_file_serialization(
361 hashing_algorithm=hashing_algorithm,
362 chunk_size=chunk_size,
363 max_workers=max_workers,
364 allow_symlinks=allow_symlinks,
365 ignore_paths=ignore_paths,
366 )
367
368 self._serializer = file_shard.Serializer(
369 self._build_sharded_file_hasher_factory(
370 hashing_algorithm, chunk_size, shard_size
371 ),
372 max_workers=max_workers,
373 allow_symlinks=allow_symlinks,
374 ignore_paths=ignore_paths,
375 )
376 return self
377
378 def set_ignored_paths(
379 self, *, paths: Iterable[PathLike], ignore_git_paths: bool = True
380 ) -> Self:
381 """Configures the paths to be ignored during serialization of a model.
382
383 If the model is a single file, there are no paths that are ignored. If
384 the model is a directory, all paths are considered as relative to the
385 model directory, since we never look at files outside of it.
386
387 If an ignored path is a directory, serialization will ignore both the
388 path and any of its children.
389
390 Args:
391 paths: The paths to ignore.
392 ignore_git_paths: Whether to ignore git related paths (default) or
393 include them in the signature.
394
395 Returns:
396 The new hashing configuration with a new set of ignored paths.
397 """
398 # Preserve the user-provided relative paths; they are resolved against
399 # the model directory later when hashing.
400 self._ignored_paths = frozenset(pathlib.Path(p) for p in paths)
401 self._ignore_git_paths = ignore_git_paths
402 return self
403
404 def add_ignored_paths(
405 self, *, model_path: PathLike, paths: Iterable[PathLike]
406 ) -> None:
407 """Add more paths to ignore to existing set of paths.
408
409 Args:
410 model_path: The path to the model
411 paths: Additional paths to ignore. All path must be relative to
412 the model directory.
413 """
414 newset = set(self._ignored_paths)
415 model_path = pathlib.Path(model_path)
416 for p in paths:
417 candidate = pathlib.Path(p)
418 full = model_path / candidate
419 try:
420 full.relative_to(model_path)
421 except ValueError:
422 continue
423 newset.add(candidate)
424 self._ignored_paths = newset
425
426 def set_allow_symlinks(self, allow_symlinks: bool) -> Self:
427 """Set whether following symlinks is allowed."""
428 self._allow_symlinks = allow_symlinks
429 return self