1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Model serializers that operate at file level granularity."""
16
17from collections.abc import Callable, Iterable
18import concurrent.futures
19import itertools
20import os
21import pathlib
22from typing import Optional
23
24from typing_extensions import override
25
26from model_signing import manifest
27from model_signing._hashing import io
28from model_signing._serialization import serialization
29
30
31class Serializer(serialization.Serializer):
32 """Model serializer that produces a manifest recording every file.
33
34 Traverses the model directory and creates digests for every file found,
35 possibly in parallel.
36 """
37
38 def __init__(
39 self,
40 file_hasher_factory: Callable[[pathlib.Path], io.FileHasher],
41 *,
42 max_workers: Optional[int] = None,
43 allow_symlinks: bool = False,
44 ignore_paths: Iterable[pathlib.Path] = frozenset(),
45 ):
46 """Initializes an instance to serialize a model with this serializer.
47
48 Args:
49 file_hasher_factory: A callable to build the hash engine used to
50 hash individual files.
51 max_workers: Maximum number of workers to use in parallel. Default
52 is to defer to the `concurrent.futures` library.
53 allow_symlinks: Controls whether symbolic links are included. If a
54 symlink is present but the flag is `False` (default) the
55 serialization would raise an error.
56 ignore_paths: The paths of files to ignore.
57 """
58 self._hasher_factory = file_hasher_factory
59 self._max_workers = max_workers
60 self._allow_symlinks = allow_symlinks
61 self._ignore_paths = ignore_paths
62
63 # Precompute some private values only once by using a mock file hasher.
64 # None of the arguments used to build the hasher are used.
65 hasher = file_hasher_factory(pathlib.Path())
66 self._serialization_description = manifest._FileSerialization(
67 hasher.digest_name, self._allow_symlinks, self._ignore_paths
68 )
69
70 def set_allow_symlinks(self, allow_symlinks: bool) -> None:
71 """Set whether following symlinks is allowed."""
72 self._allow_symlinks = allow_symlinks
73 hasher = self._hasher_factory(pathlib.Path())
74 self._serialization_description = manifest._FileSerialization(
75 hasher.digest_name, self._allow_symlinks, self._ignore_paths
76 )
77
78 @override
79 def serialize(
80 self,
81 model_path: pathlib.Path,
82 *,
83 ignore_paths: Iterable[pathlib.Path] = frozenset(),
84 files_to_hash: Optional[Iterable[pathlib.Path]] = None,
85 ) -> manifest.Manifest:
86 """Serializes the model given by the `model_path` argument.
87
88 Args:
89 model_path: The path to the model.
90 ignore_paths: The paths to ignore during serialization. If a
91 provided path is a directory, all children of the directory are
92 ignored.
93 files_to_hash: Optional list of files that are to be hashed;
94 ignore all others
95
96 Returns:
97 The model's serialized manifest.
98
99 Raises:
100 ValueError: The model contains a symbolic link, but the serializer
101 was not initialized with `allow_symlinks=True`.
102 """
103 paths = []
104 # TODO: github.com/sigstore/model-transparency/issues/200 - When
105 # Python3.12 is the minimum supported version, the glob can be replaced
106 # with `pathlib.Path.walk` for a clearer interface, and some speed
107 # improvement.
108 if files_to_hash is None:
109 files_to_hash = itertools.chain(
110 (model_path,), model_path.glob("**/*")
111 )
112 for path in files_to_hash:
113 serialization.check_file_or_directory(
114 path, allow_symlinks=self._allow_symlinks
115 )
116 if path.is_file() and not serialization.should_ignore(
117 path, ignore_paths
118 ):
119 paths.append(path)
120
121 manifest_items = []
122 with concurrent.futures.ThreadPoolExecutor(
123 max_workers=self._max_workers
124 ) as tpe:
125 futures = [
126 tpe.submit(self._compute_hash, model_path, path)
127 for path in paths
128 ]
129 for future in concurrent.futures.as_completed(futures):
130 manifest_items.append(future.result())
131
132 # Recreate serialization_description for new ignore_paths
133 if ignore_paths:
134 rel_ignore_paths = []
135 for p in ignore_paths:
136 rp = os.path.relpath(p, model_path)
137 # rp may start with "../" if it is not relative to model_path
138 if not rp.startswith("../"):
139 rel_ignore_paths.append(pathlib.Path(rp))
140
141 hasher = self._hasher_factory(pathlib.Path())
142 self._serialization_description = manifest._FileSerialization(
143 hasher.digest_name,
144 self._allow_symlinks,
145 frozenset(list(self._ignore_paths) + rel_ignore_paths),
146 )
147
148 model_name = model_path.name
149 if not model_name or model_name == "..":
150 model_name = os.path.basename(model_path.resolve())
151
152 return manifest.Manifest(
153 model_name, manifest_items, self._serialization_description
154 )
155
156 def _compute_hash(
157 self, model_path: pathlib.Path, path: pathlib.Path
158 ) -> manifest.FileManifestItem:
159 """Produces the manifest item of the file given by `path`.
160
161 Args:
162 model_path: The path to the model.
163 path: Path to the file in the model, that is currently transformed
164 to a manifest item.
165
166 Returns:
167 The itemized manifest.
168 """
169 relative_path = path.relative_to(model_path)
170 digest = self._hasher_factory(path).compute()
171 return manifest.FileManifestItem(path=relative_path, digest=digest)