1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Model serializers that operate at file shard level of granularity."""
16
17from collections.abc import Callable, Iterable
18import concurrent.futures
19import itertools
20import os
21import pathlib
22
23from typing_extensions import override
24
25from model_signing import manifest
26from model_signing._hashing import io
27from model_signing._serialization import serialization
28
29
30def _endpoints(step: int, end: int) -> Iterable[int]:
31 """Yields numbers from `step` to `end` inclusive, spaced by `step`.
32
33 Last value is always equal to `end`, even when `end` is not a multiple of
34 `step`. There is always a value returned.
35
36 Examples:
37 ```python
38 >>> list(_endpoints(2, 8))
39 [2, 4, 6, 8]
40 >>> list(_endpoints(2, 9))
41 [2, 4, 6, 8, 9]
42 >>> list(_endpoints(2, 2))
43 [2]
44
45 Yields:
46 Values in the range, from `step` and up to `end`.
47 """
48 yield from range(step, end, step)
49 yield end
50
51
52class Serializer(serialization.Serializer):
53 """Model serializer that produces a manifest recording every file shard.
54
55 Traverses the model directory and creates digests for every file found,
56 sharding the file in equal shards and computing the digests in parallel.
57 """
58
59 def __init__(
60 self,
61 sharded_hasher_factory: Callable[
62 [pathlib.Path, int, int], io.ShardedFileHasher
63 ],
64 *,
65 max_workers: int | None = None,
66 allow_symlinks: bool = False,
67 ignore_paths: Iterable[pathlib.Path] = frozenset(),
68 ):
69 """Initializes an instance to serialize a model with this serializer.
70
71 Args:
72 sharded_hasher_factory: A callable to build the hash engine used to
73 hash every shard of the files in the model. Because each shard is
74 processed in parallel, every thread needs to call the factory to
75 start hashing. The arguments are the file, and the endpoints of
76 the shard.
77 max_workers: Maximum number of workers to use in parallel. Default
78 is to defer to the `concurrent.futures` library.
79 allow_symlinks: Controls whether symbolic links are included. If a
80 symlink is present but the flag is `False` (default) the
81 serialization would raise an error.
82 ignore_paths: The paths of files to ignore.
83 """
84 self._hasher_factory = sharded_hasher_factory
85 self._max_workers = max_workers
86 self._allow_symlinks = allow_symlinks
87 self._ignore_paths = ignore_paths
88
89 # Precompute some private values only once by using a mock file hasher.
90 # None of the arguments used to build the hasher are used.
91 hasher = sharded_hasher_factory(pathlib.Path(), 0, 1)
92 self._shard_size = hasher.shard_size
93 self._serialization_description = manifest._ShardSerialization(
94 # Here we need the internal hasher name, not the mangled name.
95 # This name is used when guessing the hashing configuration.
96 hasher._content_hasher.digest_name,
97 self._shard_size,
98 self._allow_symlinks,
99 self._ignore_paths,
100 )
101
102 def set_allow_symlinks(self, allow_symlinks: bool) -> None:
103 """Set whether following symlinks is allowed."""
104 self._allow_symlinks = allow_symlinks
105 hasher = self._hasher_factory(pathlib.Path(), 0, 1)
106 self._serialization_description = manifest._ShardSerialization(
107 hasher._content_hasher.digest_name,
108 self._shard_size,
109 self._allow_symlinks,
110 self._ignore_paths,
111 )
112
113 @override
114 def serialize(
115 self,
116 model_path: pathlib.Path,
117 *,
118 ignore_paths: Iterable[pathlib.Path] = frozenset(),
119 files_to_hash: Iterable[pathlib.Path] | None = None,
120 ) -> manifest.Manifest:
121 """Serializes the model given by the `model_path` argument.
122
123 Args:
124 model_path: The path to the model.
125 ignore_paths: The paths to ignore during serialization. If a
126 provided path is a directory, all children of the directory are
127 ignored.
128 files_to_hash: Optional list of files to hash; ignore all others
129
130 Returns:
131 The model's serialized manifest.
132
133 Raises:
134 ValueError: The model contains a symbolic link, but the serializer
135 was not initialized with `allow_symlinks=True`.
136 """
137 shards = []
138 # TODO: github.com/sigstore/model-transparency/issues/200 - When
139 # Python3.12 is the minimum supported version, the glob can be replaced
140 # with `pathlib.Path.walk` for a clearer interface, and some speed
141 # improvement.
142 if files_to_hash is None:
143 files_to_hash = itertools.chain(
144 (model_path,), model_path.glob("**/*")
145 )
146 for path in files_to_hash:
147 if serialization.should_ignore(path, ignore_paths):
148 continue
149 serialization.check_file_or_directory(
150 path, allow_symlinks=self._allow_symlinks
151 )
152 if path.is_file():
153 shards.extend(self._get_shards(path))
154
155 manifest_items = []
156 with concurrent.futures.ThreadPoolExecutor(
157 max_workers=self._max_workers
158 ) as tpe:
159 futures = [
160 tpe.submit(self._compute_hash, model_path, path, start, end)
161 for path, start, end in shards
162 ]
163 for future in concurrent.futures.as_completed(futures):
164 manifest_items.append(future.result())
165
166 # Recreate serialization_description for new ignore_paths
167 if ignore_paths:
168 rel_ignore_paths = []
169 for p in ignore_paths:
170 rp = os.path.relpath(p, model_path)
171 # rp may start with "../" if it is not relative to model_path
172 if not rp.startswith("../"):
173 rel_ignore_paths.append(pathlib.Path(rp))
174
175 hasher = self._hasher_factory(pathlib.Path(), 0, 1)
176 self._serialization_description = manifest._ShardSerialization(
177 hasher._content_hasher.digest_name,
178 self._shard_size,
179 self._allow_symlinks,
180 frozenset(list(self._ignore_paths) + rel_ignore_paths),
181 )
182
183 model_name = model_path.name
184 if not model_name or model_name == "..":
185 model_name = os.path.basename(model_path.resolve())
186
187 return manifest.Manifest(
188 model_name, manifest_items, self._serialization_description
189 )
190
191 def _get_shards(
192 self, path: pathlib.Path
193 ) -> list[tuple[pathlib.Path, int, int]]:
194 """Determines the shards of a given file path."""
195 shards = []
196 path_size = path.stat().st_size
197 if path_size > 0:
198 start = 0
199 for end in _endpoints(self._shard_size, path_size):
200 shards.append((path, start, end))
201 start = end
202 return shards
203
204 def _compute_hash(
205 self, model_path: pathlib.Path, path: pathlib.Path, start: int, end: int
206 ) -> manifest.ShardedFileManifestItem:
207 """Produces the manifest item of the file given by `path`.
208
209 Args:
210 model_path: The path to the model.
211 path: Path to the file in the model, that is currently transformed
212 to a manifest item.
213 start: The start offset of the shard (included).
214 end: The end offset of the shard (not included).
215
216 Returns:
217 The itemized manifest.
218 """
219 relative_path = path.relative_to(model_path)
220 digest = self._hasher_factory(path, start, end).compute()
221 return manifest.ShardedFileManifestItem(
222 path=relative_path, digest=digest, start=start, end=end
223 )