1# Copyright 2024 The Sigstore Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""An in-memory serialized representation of an ML model.
16
17A manifest pairs objects from the model (e.g., files, shards of files), with
18their hashes. When signing the model we first generate a manifest from
19serializing the model using a configured serialization method (see
20`model_signing.signing`).
21
22When verifying the integrity of the model, after checking the authenticity of
23the signature, we extract a manifest from it. Then, we serialize the local model
24(the model being tested) and compare the two manifests.
25
26The serialization method used during signing must match the one used during
27verification. We can auto detect the method to use during verification from the
28signature, but it is recommended to be explicit when possible.
29
30Comparing the manifests can be done by checking that every item matches, both in
31name and in associated hash. In the future we will support partial object
32matching. This is useful, for example, for the cases where the original model
33contained files for multiple ML frameworks, but the user only uses the model
34with one framework. This way, the user can verify the integrity only for the
35files that are actually used.
36
37This API should not be used directly, we don't guarantee that it is fully stable
38at the moment.
39"""
40
41import abc
42from collections.abc import Iterable, Iterator
43import dataclasses
44import pathlib
45import sys
46from typing import Any, Final
47
48from typing_extensions import override
49
50from model_signing._hashing import hashing
51
52
53if sys.version_info >= (3, 11):
54 from typing import Self
55else:
56 from typing_extensions import Self
57
58
59@dataclasses.dataclass(frozen=True, order=True)
60class _ResourceDescriptor:
61 """A description of any content from any `Manifest`.
62
63 We aim this to be similar to in-toto's `ResourceDescriptor`. To support
64 cases where in-toto cannot be directly used, we make this a dataclass that
65 can be mapped to in-toto when needed, and used as its own otherwise.
66
67 Not all fields from in-toto are specified at this moment. All fields here
68 must be present, unlike in-toto, where all are optional.
69
70 See github.com/in-toto/attestation/blob/main/spec/v1/resource_descriptor.md
71 for the in-toto specification.
72
73 Attributes:
74 identifier: A string that uniquely identifies this object within the
75 manifest. Depending on serialized format, users might require the
76 identifier to be unique across all manifests stored in a system.
77 Producers and consumers can agree on additional requirements (e.g.,
78 several descriptors must have a common pattern for the identifier and
79 the integrity of the model implies integrity of all these items,
80 ignoring any other descriptor). Corresponds to `name`, `uri`, or
81 `content` in in-toto specification.
82 digest: One digest for the item. Note that unlike in-toto, we only have
83 one digest for the item and it is always required.
84 """
85
86 identifier: str
87 digest: hashing.Digest
88
89
90class _ManifestKey(metaclass=abc.ABCMeta):
91 """An object that can be a key for an item in a manifest.
92
93 We need to be able to convert the key to string when we serialize the
94 manifest and to rebuild the object from the serialized representation when
95 we deserialize the manifest from the signature.
96 """
97
98 @classmethod
99 @abc.abstractmethod
100 def from_str(cls, s: str) -> Self:
101 """Builds a manifest key from the string representation.
102
103 It is guaranteed that for any `key` of a derived class `C` and a (valid)
104 string `s`, the following two round-trip properties hold:
105
106 ```
107 str(C.from_str(s)) == s
108 C.from_str(str(key)) == key
109 ```
110
111 Raises:
112 ValueError: if the string argument cannot be decoded correctly.
113 """
114
115
116@dataclasses.dataclass(frozen=True, order=True)
117class _File(_ManifestKey):
118 """A dataclass to hold information about a file as a manifest key.
119
120 Attributes:
121 path: The path to the file, relative to the model root.
122 """
123
124 path: pathlib.PurePath
125
126 def __str__(self) -> str:
127 return str(self.path)
128
129 @classmethod
130 @override
131 def from_str(cls, s: str) -> Self:
132 # Note that we always decode the string to a pure POSIX path
133 return cls(pathlib.PurePosixPath(s))
134
135
136@dataclasses.dataclass(frozen=True, order=True)
137class _Shard(_ManifestKey):
138 """A dataclass to hold information about a file shard as a manifest key.
139
140 Attributes:
141 path: The path to the file, relative to the model root.
142 start: The start offset of the shard (included).
143 end: The end offset of the shard (not included).
144 """
145
146 path: pathlib.PurePath
147 start: int
148 end: int
149
150 def __str__(self) -> str:
151 return f"{str(self.path)}:{self.start}:{self.end}"
152
153 @classmethod
154 @override
155 def from_str(cls, s: str) -> Self:
156 parts = s.split(":")
157 if len(parts) != 3:
158 raise ValueError(f"Expected 3 components separated by `:`, got {s}")
159
160 path = pathlib.PurePosixPath(parts[0])
161 start = int(parts[1])
162 end = int(parts[2])
163
164 return cls(path, start, end)
165
166
167class ManifestItem(metaclass=abc.ABCMeta):
168 """An individual object of a model, stored as an item in a manifest.
169
170 For example, this could be a file, or a file shard. All file paths are
171 relative to the model root, to allow moving or renaming the model, without
172 invalidating the signature.
173
174 The integrity of each `ManifestItem` can be verified individually and in
175 parallel. If the item is backed by a file, we recompute the hash of the
176 portion of the file that represents this item.
177
178 Attributes:
179 digest: The digest of the item. Use the `key` property to obtain a
180 canonical unique representation for the item.
181 """
182
183 digest: hashing.Digest
184
185 @property
186 @abc.abstractmethod
187 def key(self) -> _ManifestKey:
188 """A unique representation for the manifest item.
189
190 Two items in the same manifest must not share the same `key`. The
191 information contained in `key` should be sufficient to determine how to
192 compute the item's digest.
193 """
194
195
196class FileManifestItem(ManifestItem):
197 """A manifest item that records a filename path together with its digest.
198
199 Note that the path component is a `pathlib.PurePath`, relative to the model.
200 To ensure that the manifest is consistent across operating systems, we
201 convert the path to a POSIX path.
202 """
203
204 def __init__(self, *, path: pathlib.PurePath, digest: hashing.Digest):
205 """Builds a manifest item pairing a file with its digest.
206
207 Args:
208 path: The path to the file, relative to the model root.
209 digest: The digest of the file.
210 """
211 # Note: we need to force a PurePosixPath to canonicalize the manifest.
212 self._path = pathlib.PurePosixPath(path)
213 self.digest = digest
214
215 @property
216 @override
217 def key(self) -> _ManifestKey:
218 return _File(self._path)
219
220
221class ShardedFileManifestItem(ManifestItem):
222 """A manifest item that records a file shard together with its digest.
223
224 Note that the path component is a `pathlib.PurePath`, relative to the model.
225 To ensure that the manifest is consistent across operating systems, we
226 convert the path to a POSIX path.
227 """
228
229 def __init__(
230 self,
231 *,
232 path: pathlib.PurePath,
233 start: int,
234 end: int,
235 digest: hashing.Digest,
236 ):
237 """Builds a manifest item pairing a file shard with its digest.
238
239 Args:
240 path: The path to the file, relative to the model root.
241 start: The start offset of the shard (included).
242 end: The end offset of the shard (not included).
243 digest: The digest of the file shard.
244 """
245 # Note: we need to force a PurePosixPath to canonicalize the manifest.
246 self._path = pathlib.PurePosixPath(path)
247 self._start = start
248 self._end = end
249 self.digest = digest
250
251 @property
252 @override
253 def key(self) -> _ManifestKey:
254 return _Shard(self._path, self._start, self._end)
255
256
257class SerializationType(metaclass=abc.ABCMeta):
258 """A description of the serialization process that generated the manifest.
259
260 These should record all the parameters needed to ensure a reproducible
261 serialization. These are used to build a manifest back from the signature in
262 a backwards compatible way. We use these to determine what serialization to
263 use when verifying a signature.
264 """
265
266 @property
267 @abc.abstractmethod
268 def serialization_parameters(self) -> dict[str, Any]:
269 """The arguments of the serialization method."""
270
271 @classmethod
272 def from_args(cls, args: dict[str, Any]) -> Self:
273 """Builds an instance of this class based on the dict representation.
274
275 This is the reverse of `serialization_parameters`.
276
277 Args:
278 args: The arguments as a dictionary (equivalent to `**kwargs`).
279 """
280 serialization_type = args["method"]
281 for subclass in [_FileSerialization, _ShardSerialization]:
282 if serialization_type == subclass.method:
283 return subclass._from_args(args)
284 raise ValueError(f"Unknown serialization type {serialization_type}")
285
286 @classmethod
287 @abc.abstractmethod
288 def _from_args(cls, args: dict[str, Any]) -> Self:
289 """Performs the actual build from `from_dict`."""
290
291 @abc.abstractmethod
292 def new_item(self, name: str, digest: hashing.Digest) -> ManifestItem:
293 """Builds a `ManifestItem` of the correct type.
294
295 Each serialization type results in different types for the items in the
296 manifest. This method parses the `name` of the item according to the
297 serialization type to create the proper manifest item.
298
299 Args:
300 name: The name of the item, as shown in the manifest.
301 digest: The digest of the item.
302 """
303
304
305class _FileSerialization(SerializationType):
306 method: Final[str] = "files"
307
308 def __init__(
309 self,
310 hash_type: str,
311 allow_symlinks: bool = False,
312 ignore_paths: Iterable[pathlib.Path] = frozenset(),
313 ):
314 """Records the manifest serialization type for serialization by files.
315
316 We only need to record the hashing engine used and whether symlinks are
317 hashed or ignored.
318
319 Args:
320 hash_type: A string representation of the hash algorithm.
321 allow_symlinks: Controls whether symbolic links are included.
322 """
323 self._hash_type = hash_type
324 self._allow_symlinks = allow_symlinks
325 self._ignore_paths = [str(p) for p in ignore_paths]
326
327 @property
328 @override
329 def serialization_parameters(self) -> dict[str, Any]:
330 res = {
331 "method": self.method,
332 "hash_type": self._hash_type,
333 "allow_symlinks": self._allow_symlinks,
334 }
335 if self._ignore_paths:
336 res["ignore_paths"] = self._ignore_paths
337 return res
338
339 @classmethod
340 @override
341 def _from_args(cls, args: dict[str, Any]) -> Self:
342 return cls(
343 args["hash_type"],
344 args["allow_symlinks"],
345 args.get("ignore_paths", frozenset()),
346 )
347
348 @override
349 def new_item(self, name: str, digest: hashing.Digest) -> ManifestItem:
350 path = pathlib.PurePosixPath(name)
351 return FileManifestItem(path=path, digest=digest)
352
353
354class _ShardSerialization(SerializationType):
355 method: Final[str] = "shards"
356
357 def __init__(
358 self,
359 hash_type: str,
360 shard_size: int,
361 allow_symlinks: bool = False,
362 ignore_paths: Iterable[pathlib.Path] = frozenset(),
363 ):
364 """Records the manifest serialization type for serialization by files.
365
366 We need to record the hashing engine used and whether symlinks are
367 hashed or ignored, just like for file serialization. We also need to
368 record the shard size used to split the files, since different shard
369 sizes results in different resources.
370
371 Args:
372 hash_type: A string representation of the hash algorithm.
373 allow_symlinks: Controls whether symbolic links are included.
374 ignore_paths: Paths of files to ignore.
375 """
376 self._hash_type = hash_type
377 self._allow_symlinks = allow_symlinks
378 self._shard_size = shard_size
379 self._ignore_paths = [str(p) for p in ignore_paths]
380
381 @property
382 @override
383 def serialization_parameters(self) -> dict[str, Any]:
384 res = {
385 "method": self.method,
386 "hash_type": self._hash_type,
387 "shard_size": self._shard_size,
388 "allow_symlinks": self._allow_symlinks,
389 }
390 if self._ignore_paths:
391 res["ignore_paths"] = self._ignore_paths
392 return res
393
394 @classmethod
395 @override
396 def _from_args(cls, args: dict[str, Any]) -> Self:
397 return cls(
398 args["hash_type"],
399 args["shard_size"],
400 args["allow_symlinks"],
401 args.get("ignore_paths", frozenset()),
402 )
403
404 @override
405 def new_item(self, name: str, digest: hashing.Digest) -> ManifestItem:
406 parts = name.split(":")
407 if len(parts) != 3:
408 raise ValueError(
409 "Invalid resource name: expected 3 components separated by "
410 f"`:`, got {name}"
411 )
412
413 path = pathlib.PurePosixPath(parts[0])
414 start = int(parts[1])
415 end = int(parts[2])
416 return ShardedFileManifestItem(
417 path=path, start=start, end=end, digest=digest
418 )
419
420
421class Manifest:
422 """Generic manifest file to represent a model."""
423
424 def __init__(
425 self,
426 model_name: str,
427 items: Iterable[ManifestItem],
428 serialization_type: SerializationType,
429 ):
430 """Builds a manifest from a collection of already hashed objects.
431
432 Args:
433 model_name: A name for the model that generated the manifest. This
434 is the final component of the model path, and is only informative.
435 See `model_name` property.
436 items: An iterable sequence of objects and their hashes.
437 """
438 self._name = model_name
439 self._item_to_digest = {item.key: item.digest for item in items}
440 self._serialization_type = serialization_type
441
442 def __eq__(self, other: Self):
443 return self._item_to_digest == other._item_to_digest
444
445 def resource_descriptors(self) -> Iterator[_ResourceDescriptor]:
446 """Yields each resource from the manifest, one by one."""
447 for item, digest in sorted(self._item_to_digest.items()):
448 yield _ResourceDescriptor(identifier=str(item), digest=digest)
449
450 @property
451 def model_name(self) -> str:
452 """The name of the model when serialized (final component of the path).
453
454 This is only informative. Changing the name of the model should still
455 result in the same digests after serialization, it must not invalidate
456 signatures. As a result, two manifests with different model names but
457 with the same resource descriptors will compare equal.
458 """
459 return self._name
460
461 @property
462 def serialization_type(self) -> dict[str, Any]:
463 """The serialization (and arguments) used to build the manifest.
464
465 This is needed to record the serialization method used to generate the
466 manifest so that signature verification can use the same method.
467 """
468 return self._serialization_type.serialization_parameters