1# SPDX-FileCopyrightText: 2022 James R. Barlow
2# SPDX-License-Identifier: MPL-2.0
3
4from __future__ import annotations
5
6import struct
7from typing import Any, Callable, NamedTuple, Union
8
9from PIL import Image
10from PIL.TiffTags import TAGS_V2 as TIFF_TAGS
11
12
13class ImageDecompressionError(Exception):
14 """Image decompression error."""
15
16
17BytesLike = Union[bytes, memoryview]
18MutableBytesLike = Union[bytearray, memoryview]
19
20
21def _next_multiple(n: int, k: int) -> int:
22 """Return the multiple of k that is greater than or equal n.
23
24 >>> _next_multiple(101, 4)
25 104
26 >>> _next_multiple(100, 4)
27 100
28 """
29 div, mod = divmod(n, k)
30 if mod > 0:
31 div += 1
32 return div * k
33
34
35def unpack_subbyte_pixels(
36 packed: BytesLike, size: tuple[int, int], bits: int, scale: int = 0
37) -> tuple[BytesLike, int]:
38 """Unpack subbyte *bits* pixels into full bytes and rescale.
39
40 When scale is 0, the appropriate scale is calculated.
41 e.g. for 2-bit, the scale is adjusted so that
42 0b00 = 0.00 = 0x00
43 0b01 = 0.33 = 0x55
44 0b10 = 0.66 = 0xaa
45 0b11 = 1.00 = 0xff
46 When scale is 1, no scaling is applied, appropriate when
47 the bytes are palette indexes.
48 """
49 width, height = size
50 bits_per_byte = 8 // bits
51 stride = _next_multiple(width, bits_per_byte)
52 buffer = bytearray(bits_per_byte * stride * height)
53 max_read = len(buffer) // bits_per_byte
54 if scale == 0:
55 scale = 255 / ((2**bits) - 1)
56 if bits == 4:
57 _4bit_inner_loop(packed[:max_read], buffer, scale)
58 elif bits == 2:
59 _2bit_inner_loop(packed[:max_read], buffer, scale)
60 # elif bits == 1:
61 # _1bit_inner_loop(packed[:max_read], buffer, scale)
62 else:
63 raise NotImplementedError(bits)
64 return memoryview(buffer), stride
65
66
67# def _1bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
68# """Unpack 1-bit values to their 8-bit equivalents.
69
70# Thus *out* must be 8x at long as *in*.
71# """
72# for n, val in enumerate(in_):
73# out[8 * n + 0] = int((val >> 7) & 0b1) * scale
74# out[8 * n + 1] = int((val >> 6) & 0b1) * scale
75# out[8 * n + 2] = int((val >> 5) & 0b1) * scale
76# out[8 * n + 3] = int((val >> 4) & 0b1) * scale
77# out[8 * n + 4] = int((val >> 3) & 0b1) * scale
78# out[8 * n + 5] = int((val >> 2) & 0b1) * scale
79# out[8 * n + 6] = int((val >> 1) & 0b1) * scale
80# out[8 * n + 7] = int((val >> 0) & 0b1) * scale
81
82
83def _2bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
84 """Unpack 2-bit values to their 8-bit equivalents.
85
86 Thus *out* must be 4x at long as *in*.
87
88 Images of this type are quite rare in practice, so we don't
89 optimize this loop.
90 """
91 for n, val in enumerate(in_):
92 out[4 * n] = int((val >> 6) * scale)
93 out[4 * n + 1] = int(((val >> 4) & 0b11) * scale)
94 out[4 * n + 2] = int(((val >> 2) & 0b11) * scale)
95 out[4 * n + 3] = int((val & 0b11) * scale)
96
97
98def _4bit_inner_loop(in_: BytesLike, out: MutableBytesLike, scale: int) -> None:
99 """Unpack 4-bit values to their 8-bit equivalents.
100
101 Thus *out* must be 2x at long as *in*.
102
103 Images of this type are quite rare in practice, so we don't
104 optimize this loop.
105 """
106 for n, val in enumerate(in_):
107 out[2 * n] = int((val >> 4) * scale)
108 out[2 * n + 1] = int((val & 0b1111) * scale)
109
110
111def image_from_byte_buffer(buffer: BytesLike, size: tuple[int, int], stride: int):
112 """Use Pillow to create one-component image from a byte buffer.
113
114 *stride* is the number of bytes per row, and is essential for packed bits
115 with odd image widths.
116 """
117 ystep = 1 # image is top to bottom in memory
118 # Even if the image is type 'P' (palette), we create it as a 'L' grayscale
119 # at this step. The palette is attached later.
120 try:
121 return Image.frombuffer('L', size, buffer, "raw", 'L', stride, ystep)
122 except ValueError as e:
123 if 'buffer is not large enough' in str(e):
124 # If Pillow says the buffer is not large enough, then we're going
125 # to guess that it's padded to a multiple of 4 bytes. In practice
126 # the image may just be corrupted.
127 try:
128 return Image.frombuffer(
129 'L', size, buffer, "raw", 'L', (size[0] + 3) // 4, ystep
130 )
131 except ValueError as e2:
132 raise ImageDecompressionError(str(e2)) from e2
133 else:
134 raise ImageDecompressionError() from e
135
136
137def _make_rgb_palette(gray_palette: bytes) -> bytes:
138 palette = b''
139 for entry in gray_palette:
140 palette += bytes([entry]) * 3
141 return palette
142
143
144def _depalettize_cmyk(buffer: BytesLike, palette: BytesLike):
145 with memoryview(buffer) as mv:
146 output = bytearray(4 * len(mv))
147 for n, pal_idx in enumerate(mv):
148 output[4 * n : 4 * (n + 1)] = palette[4 * pal_idx : 4 * (pal_idx + 1)]
149 return output
150
151
152def image_from_buffer_and_palette(
153 buffer: BytesLike,
154 size: tuple[int, int],
155 stride: int,
156 base_mode: str,
157 palette: BytesLike,
158) -> Image.Image:
159 """Construct an image from a byte buffer and apply the palette.
160
161 1/2/4-bit images must be unpacked (no scaling!) to byte buffers first, such
162 that every 8-bit integer is an index into the palette.
163 """
164 if base_mode == 'RGB':
165 im = image_from_byte_buffer(buffer, size, stride)
166 im.putpalette(palette, rawmode=base_mode)
167 elif base_mode == 'L':
168 # Pillow does not fully support palettes with rawmode='L'.
169 # Convert to RGB palette.
170 gray_palette = _make_rgb_palette(palette)
171 im = image_from_byte_buffer(buffer, size, stride)
172 im.putpalette(gray_palette, rawmode='RGB')
173 elif base_mode == 'CMYK':
174 # Pillow does not support CMYK with palettes; convert manually
175 output = _depalettize_cmyk(buffer, palette)
176 im = Image.frombuffer('CMYK', size, data=output, decoder_name='raw')
177 else:
178 raise NotImplementedError(f'palette with {base_mode}')
179 return im
180
181
182def fix_1bit_palette_image(
183 im: Image.Image, base_mode: str, palette: BytesLike
184) -> Image.Image:
185 """Apply palettes to 1-bit images."""
186 im = im.convert('P')
187 if base_mode == 'RGB' and len(palette) == 6:
188 # rgbrgb -> rgb000000...rgb
189 expanded_palette = b''.join(
190 [palette[0:3], (b'\x00\x00\x00' * (256 - 2)), palette[3:6]]
191 )
192 im.putpalette(expanded_palette, rawmode='RGB')
193 elif base_mode == 'L':
194 try:
195 im.putpalette(palette, rawmode='L')
196 except ValueError as e:
197 if 'unrecognized raw mode' in str(e):
198 rgb_palette = _make_rgb_palette(palette)
199 im.putpalette(rgb_palette, rawmode='RGB')
200 return im
201
202
203def generate_ccitt_header(
204 size: tuple[int, int],
205 *,
206 data_length: int,
207 ccitt_group: int,
208 t4_options: int | None,
209 photometry: int,
210 icc: bytes,
211) -> bytes:
212 """Generate binary CCITT header for image with given parameters."""
213 tiff_header_struct = '<' + '2s' + 'H' + 'L' + 'H'
214
215 tag_keys = {tag.name: key for key, tag in TIFF_TAGS.items()} # type: ignore
216 ifd_struct = '<HHLL'
217
218 class IFD(NamedTuple):
219 key: int
220 typecode: Any
221 count_: int
222 data: int | Callable[[], int | None]
223
224 ifds: list[IFD] = []
225
226 def header_length(ifd_count) -> int:
227 return (
228 struct.calcsize(tiff_header_struct)
229 + struct.calcsize(ifd_struct) * ifd_count
230 + 4
231 )
232
233 def add_ifd(tag_name: str, data: int | Callable[[], int | None], count: int = 1):
234 key = tag_keys[tag_name]
235 typecode = TIFF_TAGS[key].type # type: ignore
236 ifds.append(IFD(key, typecode, count, data))
237
238 image_offset = None
239 width, height = size
240 add_ifd('ImageWidth', width)
241 add_ifd('ImageLength', height)
242 add_ifd('BitsPerSample', 1)
243 add_ifd('Compression', ccitt_group)
244 add_ifd('FillOrder', 1)
245 if t4_options is not None:
246 add_ifd('T4Options', t4_options)
247 add_ifd('PhotometricInterpretation', photometry)
248 add_ifd('StripOffsets', lambda: image_offset)
249 add_ifd('RowsPerStrip', height)
250 add_ifd('StripByteCounts', data_length)
251
252 icc_offset = 0
253 if icc:
254 add_ifd('ICCProfile', lambda: icc_offset, count=len(icc))
255
256 icc_offset = header_length(len(ifds))
257 image_offset = icc_offset + len(icc)
258
259 ifd_args = [(arg() if callable(arg) else arg) for ifd in ifds for arg in ifd]
260 tiff_header = struct.pack(
261 (tiff_header_struct + ifd_struct[1:] * len(ifds) + 'L'),
262 b'II', # Byte order indication: Little endian
263 42, # Version number (always 42)
264 8, # Offset to first IFD
265 len(ifds), # Number of tags in IFD
266 *ifd_args,
267 0, # Last IFD
268 )
269
270 if icc:
271 tiff_header += icc
272 return tiff_header