1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Model serializers that operate at file level granularity."""
16
17from collections.abc import Callable, Iterable
18import concurrent.futures
19import itertools
20import os
21import pathlib
22from typing import Optional
23
24from typing_extensions import override
25
26from model_signing import manifest
27from model_signing._hashing import io
28from model_signing._serialization import serialization
29
30
31class Serializer(serialization.Serializer):
32 """Model serializer that produces a manifest recording every file.
33
34 Traverses the model directory and creates digests for every file found,
35 possibly in parallel.
36 """
37
38 def __init__(
39 self,
40 file_hasher_factory: Callable[[pathlib.Path], io.FileHasher],
41 *,
42 max_workers: Optional[int] = None,
43 allow_symlinks: bool = False,
44 ignore_paths: Iterable[pathlib.Path] = frozenset(),
45 ):
46 """Initializes an instance to serialize a model with this serializer.
47
48 Args:
49 file_hasher_factory: A callable to build the hash engine used to
50 hash individual files.
51 max_workers: Maximum number of workers to use in parallel. Default
52 is to defer to the `concurrent.futures` library.
53 allow_symlinks: Controls whether symbolic links are included. If a
54 symlink is present but the flag is `False` (default) the
55 serialization would raise an error.
56 ignore_paths: The paths of files to ignore.
57 """
58 self._hasher_factory = file_hasher_factory
59 self._max_workers = max_workers
60 self._allow_symlinks = allow_symlinks
61 self._ignore_paths = ignore_paths
62
63 # Precompute some private values only once by using a mock file hasher.
64 # None of the arguments used to build the hasher are used.
65 hasher = file_hasher_factory(pathlib.Path())
66 self._serialization_description = manifest._FileSerialization(
67 hasher.digest_name, self._allow_symlinks, self._ignore_paths
68 )
69 self._is_blake3 = hasher.digest_name == "blake3"
70
71 def set_allow_symlinks(self, allow_symlinks: bool) -> None:
72 """Set whether following symlinks is allowed."""
73 self._allow_symlinks = allow_symlinks
74 hasher = self._hasher_factory(pathlib.Path())
75 self._serialization_description = manifest._FileSerialization(
76 hasher.digest_name, self._allow_symlinks, self._ignore_paths
77 )
78
79 @override
80 def serialize(
81 self,
82 model_path: pathlib.Path,
83 *,
84 ignore_paths: Iterable[pathlib.Path] = frozenset(),
85 files_to_hash: Optional[Iterable[pathlib.Path]] = None,
86 ) -> manifest.Manifest:
87 """Serializes the model given by the `model_path` argument.
88
89 Args:
90 model_path: The path to the model.
91 ignore_paths: The paths to ignore during serialization. If a
92 provided path is a directory, all children of the directory are
93 ignored.
94 files_to_hash: Optional list of files that are to be hashed;
95 ignore all others
96
97 Returns:
98 The model's serialized manifest.
99
100 Raises:
101 ValueError: The model contains a symbolic link, but the serializer
102 was not initialized with `allow_symlinks=True`.
103 """
104 paths = []
105 # TODO: github.com/sigstore/model-transparency/issues/200 - When
106 # Python3.12 is the minimum supported version, the glob can be replaced
107 # with `pathlib.Path.walk` for a clearer interface, and some speed
108 # improvement.
109 if files_to_hash is None:
110 files_to_hash = itertools.chain(
111 (model_path,), model_path.glob("**/*")
112 )
113 for path in files_to_hash:
114 serialization.check_file_or_directory(
115 path, allow_symlinks=self._allow_symlinks
116 )
117 if path.is_file() and not serialization.should_ignore(
118 path, ignore_paths
119 ):
120 paths.append(path)
121
122 manifest_items = []
123 with concurrent.futures.ThreadPoolExecutor(
124 # blake3 parallelizes internally
125 max_workers=1 if self._is_blake3 else self._max_workers
126 ) as tpe:
127 futures = [
128 tpe.submit(self._compute_hash, model_path, path)
129 for path in paths
130 ]
131 for future in concurrent.futures.as_completed(futures):
132 manifest_items.append(future.result())
133
134 # Recreate serialization_description for new ignore_paths
135 if ignore_paths:
136 rel_ignore_paths = []
137 for p in ignore_paths:
138 rp = os.path.relpath(p, model_path)
139 # rp may start with "../" if it is not relative to model_path
140 if not rp.startswith("../"):
141 rel_ignore_paths.append(pathlib.Path(rp))
142
143 hasher = self._hasher_factory(pathlib.Path())
144 self._serialization_description = manifest._FileSerialization(
145 hasher.digest_name,
146 self._allow_symlinks,
147 frozenset(list(self._ignore_paths) + rel_ignore_paths),
148 )
149
150 model_name = model_path.name
151 if not model_name or model_name == "..":
152 model_name = os.path.basename(model_path.resolve())
153
154 return manifest.Manifest(
155 model_name, manifest_items, self._serialization_description
156 )
157
158 def _compute_hash(
159 self, model_path: pathlib.Path, path: pathlib.Path
160 ) -> manifest.FileManifestItem:
161 """Produces the manifest item of the file given by `path`.
162
163 Args:
164 model_path: The path to the model.
165 path: Path to the file in the model, that is currently transformed
166 to a manifest item.
167
168 Returns:
169 The itemized manifest.
170 """
171 relative_path = path.relative_to(model_path)
172 digest = self._hasher_factory(path).compute()
173 return manifest.FileManifestItem(path=relative_path, digest=digest)