1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Model serializers that operate at file level granularity."""
16
17from collections.abc import Callable, Iterable
18import concurrent.futures
19import itertools
20import os
21import pathlib
22
23from typing_extensions import override
24
25from model_signing import manifest
26from model_signing._hashing import io
27from model_signing._serialization import serialization
28
29
30class Serializer(serialization.Serializer):
31 """Model serializer that produces a manifest recording every file.
32
33 Traverses the model directory and creates digests for every file found,
34 possibly in parallel.
35 """
36
37 def __init__(
38 self,
39 file_hasher_factory: Callable[[pathlib.Path], io.FileHasher],
40 *,
41 max_workers: int | None = None,
42 allow_symlinks: bool = False,
43 ignore_paths: Iterable[pathlib.Path] = frozenset(),
44 ):
45 """Initializes an instance to serialize a model with this serializer.
46
47 Args:
48 file_hasher_factory: A callable to build the hash engine used to
49 hash individual files.
50 max_workers: Maximum number of workers to use in parallel. Default
51 is to defer to the `concurrent.futures` library.
52 allow_symlinks: Controls whether symbolic links are included. If a
53 symlink is present but the flag is `False` (default) the
54 serialization would raise an error.
55 ignore_paths: The paths of files to ignore.
56 """
57 self._hasher_factory = file_hasher_factory
58 self._max_workers = max_workers
59 self._allow_symlinks = allow_symlinks
60 self._ignore_paths = ignore_paths
61
62 # Precompute some private values only once by using a mock file hasher.
63 # None of the arguments used to build the hasher are used.
64 hasher = file_hasher_factory(pathlib.Path())
65 self._serialization_description = manifest._FileSerialization(
66 hasher.digest_name, self._allow_symlinks, self._ignore_paths
67 )
68 self._is_blake3 = hasher.digest_name == "blake3"
69
70 def set_allow_symlinks(self, allow_symlinks: bool) -> None:
71 """Set whether following symlinks is allowed."""
72 self._allow_symlinks = allow_symlinks
73 hasher = self._hasher_factory(pathlib.Path())
74 self._serialization_description = manifest._FileSerialization(
75 hasher.digest_name, self._allow_symlinks, self._ignore_paths
76 )
77
78 @override
79 def serialize(
80 self,
81 model_path: pathlib.Path,
82 *,
83 ignore_paths: Iterable[pathlib.Path] = frozenset(),
84 files_to_hash: Iterable[pathlib.Path] | None = None,
85 ) -> manifest.Manifest:
86 """Serializes the model given by the `model_path` argument.
87
88 Args:
89 model_path: The path to the model.
90 ignore_paths: The paths to ignore during serialization. If a
91 provided path is a directory, all children of the directory are
92 ignored.
93 files_to_hash: Optional list of files that are to be hashed;
94 ignore all others
95
96 Returns:
97 The model's serialized manifest.
98
99 Raises:
100 ValueError: The model contains a symbolic link, but the serializer
101 was not initialized with `allow_symlinks=True`.
102 """
103 paths = []
104 # TODO: github.com/sigstore/model-transparency/issues/200 - When
105 # Python3.12 is the minimum supported version, the glob can be replaced
106 # with `pathlib.Path.walk` for a clearer interface, and some speed
107 # improvement.
108 if files_to_hash is None:
109 files_to_hash = itertools.chain(
110 (model_path,), model_path.glob("**/*")
111 )
112 for path in files_to_hash:
113 if serialization.should_ignore(path, ignore_paths):
114 continue
115 serialization.check_file_or_directory(
116 path, allow_symlinks=self._allow_symlinks
117 )
118 if path.is_file():
119 paths.append(path)
120
121 manifest_items = []
122 with concurrent.futures.ThreadPoolExecutor(
123 # blake3 parallelizes internally
124 max_workers=1 if self._is_blake3 else self._max_workers
125 ) as tpe:
126 futures = [
127 tpe.submit(self._compute_hash, model_path, path)
128 for path in paths
129 ]
130 for future in concurrent.futures.as_completed(futures):
131 manifest_items.append(future.result())
132
133 # Recreate serialization_description for new ignore_paths
134 if ignore_paths:
135 rel_ignore_paths = []
136 for p in ignore_paths:
137 rp = os.path.relpath(p, model_path)
138 # rp may start with "../" if it is not relative to model_path
139 if not rp.startswith("../"):
140 rel_ignore_paths.append(pathlib.Path(rp))
141
142 hasher = self._hasher_factory(pathlib.Path())
143 self._serialization_description = manifest._FileSerialization(
144 hasher.digest_name,
145 self._allow_symlinks,
146 frozenset(list(self._ignore_paths) + rel_ignore_paths),
147 )
148
149 model_name = model_path.name
150 if not model_name or model_name == "..":
151 model_name = os.path.basename(model_path.resolve())
152
153 return manifest.Manifest(
154 model_name, manifest_items, self._serialization_description
155 )
156
157 def _compute_hash(
158 self, model_path: pathlib.Path, path: pathlib.Path
159 ) -> manifest.FileManifestItem:
160 """Produces the manifest item of the file given by `path`.
161
162 Args:
163 model_path: The path to the model.
164 path: Path to the file in the model, that is currently transformed
165 to a manifest item.
166
167 Returns:
168 The itemized manifest.
169 """
170 relative_path = path.relative_to(model_path)
171 digest = self._hasher_factory(path).compute()
172 return manifest.FileManifestItem(path=relative_path, digest=digest)