1# SPDX-FileCopyrightText: 2022 James R. Barlow
2# SPDX-License-Identifier: MPL-2.0
3
4from __future__ import annotations
5
6import struct
7from collections.abc import Callable
8from typing import Any, NamedTuple
9
10from PIL import Image
11from PIL.TiffTags import TAGS_V2 as TIFF_TAGS
12
13
14class ImageDecompressionError(Exception):
15 """Image decompression error."""
16
17
18BytesLike = bytes | memoryview
19MutableBytesLike = bytearray | memoryview
20
21
22def _next_multiple(n: int, k: int) -> int:
23 """Return the multiple of k that is greater than or equal n.
24
25 >>> _next_multiple(101, 4)
26 104
27 >>> _next_multiple(100, 4)
28 100
29 """
30 div, mod = divmod(n, k)
31 if mod > 0:
32 div += 1
33 return div * k
34
35
36def unpack_subbyte_pixels(
37 packed: BytesLike, size: tuple[int, int], bits: int, scale: int = 0
38) -> tuple[BytesLike, int]:
39 """Unpack subbyte *bits* pixels into full bytes and rescale.
40
41 When scale is 0, the appropriate scale is calculated.
42 e.g. for 2-bit, the scale is adjusted so that
43 0b00 = 0.00 = 0x00
44 0b01 = 0.33 = 0x55
45 0b10 = 0.66 = 0xaa
46 0b11 = 1.00 = 0xff
47 When scale is 1, no scaling is applied, appropriate when
48 the bytes are palette indexes.
49 """
50 width, height = size
51 bits_per_byte = 8 // bits
52 stride = _next_multiple(width, bits_per_byte)
53 buffer = bytearray(bits_per_byte * stride * height)
54 max_read = len(buffer) // bits_per_byte
55 if scale == 0:
56 scale = 255 / ((2**bits) - 1)
57 if bits == 4:
58 _4bit_inner_loop(packed[:max_read], buffer, scale)
59 elif bits == 2:
60 _2bit_inner_loop(packed[:max_read], buffer, scale)
61 # elif bits == 1:
62 # _1bit_inner_loop(packed[:max_read], buffer, scale)
63 else:
64 raise NotImplementedError(bits)
65 return memoryview(buffer), stride
66
67
68# def _1bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
69# """Unpack 1-bit values to their 8-bit equivalents.
70
71# Thus *out* must be 8x at long as *in*.
72# """
73# for n, val in enumerate(in_):
74# out[8 * n + 0] = int((val >> 7) & 0b1) * scale
75# out[8 * n + 1] = int((val >> 6) & 0b1) * scale
76# out[8 * n + 2] = int((val >> 5) & 0b1) * scale
77# out[8 * n + 3] = int((val >> 4) & 0b1) * scale
78# out[8 * n + 4] = int((val >> 3) & 0b1) * scale
79# out[8 * n + 5] = int((val >> 2) & 0b1) * scale
80# out[8 * n + 6] = int((val >> 1) & 0b1) * scale
81# out[8 * n + 7] = int((val >> 0) & 0b1) * scale
82
83
84def _2bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
85 """Unpack 2-bit values to their 8-bit equivalents.
86
87 Thus *out* must be 4x at long as *in*.
88
89 Images of this type are quite rare in practice, so we don't
90 optimize this loop.
91 """
92 for n, val in enumerate(in_):
93 out[4 * n] = int((val >> 6) * scale)
94 out[4 * n + 1] = int(((val >> 4) & 0b11) * scale)
95 out[4 * n + 2] = int(((val >> 2) & 0b11) * scale)
96 out[4 * n + 3] = int((val & 0b11) * scale)
97
98
99def _4bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
100 """Unpack 4-bit values to their 8-bit equivalents.
101
102 Thus *out* must be 2x at long as *in*.
103
104 Images of this type are quite rare in practice, so we don't
105 optimize this loop.
106 """
107 for n, val in enumerate(in_):
108 out[2 * n] = int((val >> 4) * scale)
109 out[2 * n + 1] = int((val & 0b1111) * scale)
110
111
112def image_from_byte_buffer(buffer: BytesLike, size: tuple[int, int], stride: int):
113 """Use Pillow to create one-component image from a byte buffer.
114
115 *stride* is the number of bytes per row, and is essential for packed bits
116 with odd image widths.
117 """
118 ystep = 1 # image is top to bottom in memory
119 # Even if the image is type 'P' (palette), we create it as a 'L' grayscale
120 # at this step. The palette is attached later.
121 try:
122 return Image.frombuffer('L', size, buffer, "raw", 'L', stride, ystep)
123 except ValueError as e:
124 if 'buffer is not large enough' in str(e):
125 # If Pillow says the buffer is not large enough, then we're going
126 # to guess that it's padded to a multiple of 4 bytes. In practice
127 # the image may just be corrupted.
128 try:
129 return Image.frombuffer(
130 'L', size, buffer, "raw", 'L', (size[0] + 3) // 4, ystep
131 )
132 except ValueError as e2:
133 raise ImageDecompressionError(str(e2)) from e2
134 else:
135 raise ImageDecompressionError() from e
136
137
138def _make_rgb_palette(gray_palette: bytes) -> bytes:
139 palette = b''
140 for entry in gray_palette:
141 palette += bytes([entry]) * 3
142 return palette
143
144
145def _depalettize_cmyk(buffer: BytesLike, palette: BytesLike):
146 with memoryview(buffer) as mv:
147 output = bytearray(4 * len(mv))
148 for n, pal_idx in enumerate(mv):
149 output[4 * n : 4 * (n + 1)] = palette[4 * pal_idx : 4 * (pal_idx + 1)]
150 return output
151
152
153def image_from_buffer_and_palette(
154 buffer: BytesLike,
155 size: tuple[int, int],
156 stride: int,
157 base_mode: str,
158 palette: BytesLike,
159) -> Image.Image:
160 """Construct an image from a byte buffer and apply the palette.
161
162 1/2/4-bit images must be unpacked (no scaling!) to byte buffers first, such
163 that every 8-bit integer is an index into the palette.
164 """
165 if base_mode == 'RGB':
166 im = image_from_byte_buffer(buffer, size, stride)
167 im.putpalette(palette, rawmode=base_mode)
168 elif base_mode == 'L':
169 # Pillow does not fully support palettes with rawmode='L'.
170 # Convert to RGB palette.
171 gray_palette = _make_rgb_palette(palette)
172 im = image_from_byte_buffer(buffer, size, stride)
173 im.putpalette(gray_palette, rawmode='RGB')
174 elif base_mode == 'CMYK':
175 # Pillow does not support CMYK with palettes; convert manually
176 output = _depalettize_cmyk(buffer, palette)
177 im = Image.frombuffer('CMYK', size, data=output, decoder_name='raw')
178 else:
179 raise NotImplementedError(f'palette with {base_mode}')
180 return im
181
182
183def fix_1bit_palette_image(
184 im: Image.Image, base_mode: str, palette: BytesLike
185) -> Image.Image:
186 """Apply palettes to 1-bit images."""
187 im = im.convert('P')
188 if base_mode == 'RGB' and len(palette) == 6:
189 # rgbrgb -> rgb000000...rgb
190 expanded_palette = b''.join(
191 [palette[0:3], (b'\x00\x00\x00' * (256 - 2)), palette[3:6]]
192 )
193 im.putpalette(expanded_palette, rawmode='RGB')
194 elif base_mode == 'L':
195 try:
196 im.putpalette(palette, rawmode='L')
197 except ValueError as e:
198 if 'unrecognized raw mode' in str(e):
199 rgb_palette = _make_rgb_palette(palette)
200 im.putpalette(rgb_palette, rawmode='RGB')
201 return im
202
203
204def generate_ccitt_header(
205 size: tuple[int, int],
206 *,
207 data_length: int,
208 ccitt_group: int,
209 t4_options: int | None,
210 photometry: int,
211 icc: bytes,
212) -> bytes:
213 """Generate binary CCITT header for image with given parameters."""
214 tiff_header_struct = '<' + '2s' + 'H' + 'L' + 'H'
215
216 tag_keys = {tag.name: key for key, tag in TIFF_TAGS.items()} # type: ignore
217 ifd_struct = '<HHLL'
218
219 class IFD(NamedTuple):
220 key: int
221 typecode: Any
222 count_: int
223 data: int | Callable[[], int | None]
224
225 ifds: list[IFD] = []
226
227 def header_length(ifd_count) -> int:
228 return (
229 struct.calcsize(tiff_header_struct)
230 + struct.calcsize(ifd_struct) * ifd_count
231 + 4
232 )
233
234 def add_ifd(tag_name: str, data: int | Callable[[], int | None], count: int = 1):
235 key = tag_keys[tag_name]
236 typecode = TIFF_TAGS[key].type # type: ignore
237 ifds.append(IFD(key, typecode, count, data))
238
239 image_offset = None
240 width, height = size
241 add_ifd('ImageWidth', width)
242 add_ifd('ImageLength', height)
243 add_ifd('BitsPerSample', 1)
244 add_ifd('Compression', ccitt_group)
245 add_ifd('FillOrder', 1)
246 if t4_options is not None:
247 add_ifd('T4Options', t4_options)
248 add_ifd('PhotometricInterpretation', photometry)
249 add_ifd('StripOffsets', lambda: image_offset)
250 add_ifd('RowsPerStrip', height)
251 add_ifd('StripByteCounts', data_length)
252
253 icc_offset = 0
254 if icc:
255 add_ifd('ICCProfile', lambda: icc_offset, count=len(icc))
256
257 icc_offset = header_length(len(ifds))
258 image_offset = icc_offset + len(icc)
259
260 ifd_args = [(arg() if callable(arg) else arg) for ifd in ifds for arg in ifd]
261 tiff_header = struct.pack(
262 (tiff_header_struct + ifd_struct[1:] * len(ifds) + 'L'),
263 b'II', # Byte order indication: Little endian
264 42, # Version number (always 42)
265 8, # Offset to first IFD
266 len(ifds), # Number of tags in IFD
267 *ifd_args,
268 0, # Last IFD
269 )
270
271 if icc:
272 tiff_header += icc
273 return tiff_header