1#
2# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
3#
4# This code is distributed under the terms and conditions
5# from the MIT License (MIT).
6#
7"""Implements ByteBuffer class for amortizing network transfer overhead."""
8
9from __future__ import annotations
10
11import io
12from typing import IO, TYPE_CHECKING, cast
13
14if TYPE_CHECKING:
15 from collections.abc import Iterable
16
17
18class ByteBuffer:
19 """Byte buffer that allows callers to read data with minimal copying, and has a fast ``__len__`` method.
20
21 The buffer is parametrized by its ``chunk_size``, which is the number of
22 bytes that it will read in from the supplied reader or iterable when the
23 buffer is being filled. As the primary use case for this buffer is to
24 amortize the overhead costs of transferring data over the network (rather
25 than capping memory consumption), it leads to more predictable performance
26 to always read the same amount of bytes each time the buffer is filled,
27 hence the ``chunk_size`` parameter instead of some fixed capacity.
28
29 The bytes are stored in a bytestring, and previously-read bytes are freed
30 when the buffer is next filled (by slicing the bytestring into a smaller
31 copy).
32
33 Args:
34 chunk_size: The number of bytes that will be read from the supplied reader
35 or iterable when filling the buffer.
36
37 Example:
38 >>> buf = ByteBuffer(chunk_size=8)
39 >>> message_bytes = iter([b"Hello, W", b"orld!"])
40 >>> buf.fill(message_bytes)
41 8
42 >>> len(buf) # only chunk_size bytes are filled
43 8
44 >>> buf.peek()
45 b'Hello, W'
46 >>> len(buf) # peek() does not change read position
47 8
48 >>> buf.read(6)
49 b'Hello,'
50 >>> len(buf) # read() does change read position
51 2
52 >>> buf.fill(message_bytes)
53 5
54 >>> buf.read()
55 b' World!'
56 >>> len(buf)
57 0
58 """
59
60 def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None:
61 self._chunk_size = chunk_size
62 self.empty()
63
64 def __len__(self) -> int:
65 """Return the number of unread bytes in the buffer as an int."""
66 return len(self._bytes) - self._pos
67
68 def read(self, size: int = -1) -> bytes:
69 """Read bytes from the buffer and advance the read position.
70
71 Args:
72 size: Maximum number of bytes to read. If negative or not supplied, read
73 all unread bytes in the buffer.
74
75 Returns:
76 The bytes read from the buffer.
77 """
78 part = self.peek(size)
79 self._pos += len(part)
80 return part
81
82 def peek(self, size: int = -1) -> bytes:
83 """Get bytes from the buffer without advancing the read position.
84
85 Args:
86 size: Maximum number of bytes to return. If negative or not supplied,
87 return all unread bytes in the buffer.
88
89 Returns:
90 The peeked bytes from the buffer.
91 """
92 if size < 0 or size > len(self):
93 size = len(self)
94
95 return bytes(self._bytes[self._pos : self._pos + size])
96
97 def empty(self) -> None:
98 """Remove all bytes from the buffer."""
99 self._bytes = bytearray()
100 self._pos = 0
101
102 def fill(self, source: IO[bytes] | Iterable[bytes], size: int = -1) -> int:
103 """Fill the buffer with bytes from source.
104
105 Reads from ``source`` until one of these conditions is met:
106
107 * ``size`` bytes have been read from source (if ``size >= 0``);
108 * ``chunk_size`` bytes have been read from source;
109 * no more bytes can be read from source.
110
111 Note:
112 All previously-read bytes in the buffer are removed.
113
114 Args:
115 source: The source of bytes to fill the buffer with, either a file-like
116 object or an iterable/list of bytes. If this argument has the ``read``
117 attribute, it's assumed to be a file-like object and ``read`` is called
118 to get the bytes; otherwise it's assumed to be an iterable or list that
119 contains bytes, and a for loop is used to get the bytes.
120 size: The number of bytes to try to read from source. If not supplied,
121 negative, or larger than the buffer's ``chunk_size``, then ``chunk_size``
122 bytes are read. Note that if source is an iterable or list, then
123 it's possible that more than size bytes will be read if iterating
124 over source produces more than one byte at a time.
125
126 Returns:
127 The number of new bytes added to the buffer.
128 """
129 size = size if size >= 0 else self._chunk_size
130 size = min(size, self._chunk_size)
131
132 if self._pos != 0:
133 self._bytes = self._bytes[self._pos :]
134 self._pos = 0
135
136 if hasattr(source, "read"):
137 new_bytes = cast("IO[bytes]", source).read(size)
138 else:
139 new_bytes = bytearray()
140 for more_bytes in source:
141 new_bytes += more_bytes
142 if len(new_bytes) >= size:
143 break
144
145 self._bytes += new_bytes
146 return len(new_bytes)
147
148 def readline(self, terminator: bytes) -> bytes:
149 """Read a line from this buffer efficiently.
150
151 A line is a contiguous sequence of bytes that ends with either:
152
153 1. The ``terminator`` character
154 2. The end of the buffer itself
155
156 Args:
157 terminator: The line terminator byte.
158
159 Returns:
160 The line bytes (including the terminator if present).
161 """
162 index = self._bytes.find(terminator, self._pos)
163 size = len(self) if index == -1 else index - self._pos + 1
164 return self.read(size)