1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Model serializers that operate at file shard level of granularity."""
16
17from collections.abc import Callable, Iterable
18import concurrent.futures
19import itertools
20import os
21import pathlib
22from typing import Optional
23
24from typing_extensions import override
25
26from model_signing import manifest
27from model_signing._hashing import io
28from model_signing._serialization import serialization
29
30
31def _endpoints(step: int, end: int) -> Iterable[int]:
32 """Yields numbers from `step` to `end` inclusive, spaced by `step`.
33
34 Last value is always equal to `end`, even when `end` is not a multiple of
35 `step`. There is always a value returned.
36
37 Examples:
38 ```python
39 >>> list(_endpoints(2, 8))
40 [2, 4, 6, 8]
41 >>> list(_endpoints(2, 9))
42 [2, 4, 6, 8, 9]
43 >>> list(_endpoints(2, 2))
44 [2]
45
46 Yields:
47 Values in the range, from `step` and up to `end`.
48 """
49 yield from range(step, end, step)
50 yield end
51
52
53class Serializer(serialization.Serializer):
54 """Model serializer that produces a manifest recording every file shard.
55
56 Traverses the model directory and creates digests for every file found,
57 sharding the file in equal shards and computing the digests in parallel.
58 """
59
60 def __init__(
61 self,
62 sharded_hasher_factory: Callable[
63 [pathlib.Path, int, int], io.ShardedFileHasher
64 ],
65 *,
66 max_workers: Optional[int] = None,
67 allow_symlinks: bool = False,
68 ignore_paths: Iterable[pathlib.Path] = frozenset(),
69 ):
70 """Initializes an instance to serialize a model with this serializer.
71
72 Args:
73 sharded_hasher_factory: A callable to build the hash engine used to
74 hash every shard of the files in the model. Because each shard is
75 processed in parallel, every thread needs to call the factory to
76 start hashing. The arguments are the file, and the endpoints of
77 the shard.
78 max_workers: Maximum number of workers to use in parallel. Default
79 is to defer to the `concurrent.futures` library.
80 allow_symlinks: Controls whether symbolic links are included. If a
81 symlink is present but the flag is `False` (default) the
82 serialization would raise an error.
83 ignore_paths: The paths of files to ignore.
84 """
85 self._hasher_factory = sharded_hasher_factory
86 self._max_workers = max_workers
87 self._allow_symlinks = allow_symlinks
88 self._ignore_paths = ignore_paths
89
90 # Precompute some private values only once by using a mock file hasher.
91 # None of the arguments used to build the hasher are used.
92 hasher = sharded_hasher_factory(pathlib.Path(), 0, 1)
93 self._shard_size = hasher.shard_size
94 self._serialization_description = manifest._ShardSerialization(
95 # Here we need the internal hasher name, not the mangled name.
96 # This name is used when guessing the hashing configuration.
97 hasher._content_hasher.digest_name,
98 self._shard_size,
99 self._allow_symlinks,
100 self._ignore_paths,
101 )
102
103 def set_allow_symlinks(self, allow_symlinks: bool) -> None:
104 """Set whether following symlinks is allowed."""
105 self._allow_symlinks = allow_symlinks
106 hasher = self._hasher_factory(pathlib.Path(), 0, 1)
107 self._serialization_description = manifest._ShardSerialization(
108 hasher._content_hasher.digest_name,
109 self._shard_size,
110 self._allow_symlinks,
111 self._ignore_paths,
112 )
113
114 @override
115 def serialize(
116 self,
117 model_path: pathlib.Path,
118 *,
119 ignore_paths: Iterable[pathlib.Path] = frozenset(),
120 files_to_hash: Optional[Iterable[pathlib.Path]] = None,
121 ) -> manifest.Manifest:
122 """Serializes the model given by the `model_path` argument.
123
124 Args:
125 model_path: The path to the model.
126 ignore_paths: The paths to ignore during serialization. If a
127 provided path is a directory, all children of the directory are
128 ignored.
129 files_to_hash: Optional list of files to hash; ignore all others
130
131 Returns:
132 The model's serialized manifest.
133
134 Raises:
135 ValueError: The model contains a symbolic link, but the serializer
136 was not initialized with `allow_symlinks=True`.
137 """
138 shards = []
139 # TODO: github.com/sigstore/model-transparency/issues/200 - When
140 # Python3.12 is the minimum supported version, the glob can be replaced
141 # with `pathlib.Path.walk` for a clearer interface, and some speed
142 # improvement.
143 if files_to_hash is None:
144 files_to_hash = itertools.chain(
145 (model_path,), model_path.glob("**/*")
146 )
147 for path in files_to_hash:
148 serialization.check_file_or_directory(
149 path, allow_symlinks=self._allow_symlinks
150 )
151 if path.is_file() and not serialization.should_ignore(
152 path, ignore_paths
153 ):
154 shards.extend(self._get_shards(path))
155
156 manifest_items = []
157 with concurrent.futures.ThreadPoolExecutor(
158 max_workers=self._max_workers
159 ) as tpe:
160 futures = [
161 tpe.submit(self._compute_hash, model_path, path, start, end)
162 for path, start, end in shards
163 ]
164 for future in concurrent.futures.as_completed(futures):
165 manifest_items.append(future.result())
166
167 # Recreate serialization_description for new ignore_paths
168 if ignore_paths:
169 rel_ignore_paths = []
170 for p in ignore_paths:
171 rp = os.path.relpath(p, model_path)
172 # rp may start with "../" if it is not relative to model_path
173 if not rp.startswith("../"):
174 rel_ignore_paths.append(pathlib.Path(rp))
175
176 hasher = self._hasher_factory(pathlib.Path(), 0, 1)
177 self._serialization_description = manifest._ShardSerialization(
178 hasher._content_hasher.digest_name,
179 self._shard_size,
180 self._allow_symlinks,
181 frozenset(list(self._ignore_paths) + rel_ignore_paths),
182 )
183
184 model_name = model_path.name
185 if not model_name or model_name == "..":
186 model_name = os.path.basename(model_path.resolve())
187
188 return manifest.Manifest(
189 model_name, manifest_items, self._serialization_description
190 )
191
192 def _get_shards(
193 self, path: pathlib.Path
194 ) -> list[tuple[pathlib.Path, int, int]]:
195 """Determines the shards of a given file path."""
196 shards = []
197 path_size = path.stat().st_size
198 if path_size > 0:
199 start = 0
200 for end in _endpoints(self._shard_size, path_size):
201 shards.append((path, start, end))
202 start = end
203 return shards
204
205 def _compute_hash(
206 self, model_path: pathlib.Path, path: pathlib.Path, start: int, end: int
207 ) -> manifest.ShardedFileManifestItem:
208 """Produces the manifest item of the file given by `path`.
209
210 Args:
211 model_path: The path to the model.
212 path: Path to the file in the model, that is currently transformed
213 to a manifest item.
214 start: The start offset of the shard (included).
215 end: The end offset of the shard (not included).
216
217 Returns:
218 The itemized manifest.
219 """
220 relative_path = path.relative_to(model_path)
221 digest = self._hasher_factory(path, start, end).compute()
222 return manifest.ShardedFileManifestItem(
223 path=relative_path, digest=digest, start=start, end=end
224 )