Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/fastavro/_write_py.py: 26%
301 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:10 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:10 +0000
1# cython: auto_cpdef=True
3"""Python code for writing AVRO files"""
5# This code is a modified version of the code at
6# http://svn.apache.org/viewvc/avro/trunk/lang/py/src/avro/ which is under
7# Apache 2.0 license (http://www.apache.org/licenses/LICENSE-2.0)
9from abc import ABC, abstractmethod
10import json
11from io import BytesIO
12from os import urandom, SEEK_SET
13import bz2
14import lzma
15import zlib
16from typing import Union, IO, Iterable, Any, Optional, Dict
18from .const import NAMED_TYPES
19from .io.binary_encoder import BinaryEncoder
20from .io.json_encoder import AvroJSONEncoder
21from .validation import _validate
22from .read import HEADER_SCHEMA, SYNC_SIZE, MAGIC, reader
23from .logical_writers import LOGICAL_WRITERS
24from .schema import extract_record_type, extract_logical_type, parse_schema
25from ._write_common import _is_appendable
26from .types import Schema, NamedSchemas
29def write_null(encoder, datum, schema, named_schemas, fname, options):
30 """null is written as zero bytes"""
31 encoder.write_null()
34def write_boolean(encoder, datum, schema, named_schemas, fname, options):
35 """A boolean is written as a single byte whose value is either 0 (false) or
36 1 (true)."""
37 encoder.write_boolean(datum)
40def write_int(encoder, datum, schema, named_schemas, fname, options):
41 """int and long values are written using variable-length, zig-zag coding."""
42 encoder.write_int(datum)
45def write_long(encoder, datum, schema, named_schemas, fname, options):
46 """int and long values are written using variable-length, zig-zag coding."""
47 encoder.write_long(datum)
50def write_float(encoder, datum, schema, named_schemas, fname, options):
51 """A float is written as 4 bytes. The float is converted into a 32-bit
52 integer using a method equivalent to Java's floatToIntBits and then encoded
53 in little-endian format."""
54 encoder.write_float(datum)
57def write_double(encoder, datum, schema, named_schemas, fname, options):
58 """A double is written as 8 bytes. The double is converted into a 64-bit
59 integer using a method equivalent to Java's doubleToLongBits and then
60 encoded in little-endian format."""
61 encoder.write_double(datum)
64def write_bytes(encoder, datum, schema, named_schemas, fname, options):
65 """Bytes are encoded as a long followed by that many bytes of data."""
66 encoder.write_bytes(datum)
69def write_utf8(encoder, datum, schema, named_schemas, fname, options):
70 """A string is encoded as a long followed by that many bytes of UTF-8
71 encoded character data."""
72 encoder.write_utf8(datum)
75def write_crc32(encoder, datum):
76 """A 4-byte, big-endian CRC32 checksum"""
77 encoder.write_crc32(datum)
80def write_fixed(encoder, datum, schema, named_schemas, fname, options):
81 """Fixed instances are encoded using the number of bytes declared in the
82 schema."""
83 if len(datum) != schema["size"]:
84 raise ValueError(
85 f"data of length {len(datum)} does not match schema size: {schema}"
86 )
87 encoder.write_fixed(datum)
90def write_enum(encoder, datum, schema, named_schemas, fname, options):
91 """An enum is encoded by a int, representing the zero-based position of
92 the symbol in the schema."""
93 index = schema["symbols"].index(datum)
94 encoder.write_enum(index)
97def write_array(encoder, datum, schema, named_schemas, fname, options):
98 """Arrays are encoded as a series of blocks.
100 Each block consists of a long count value, followed by that many array
101 items. A block with count zero indicates the end of the array. Each item
102 is encoded per the array's item schema.
104 If a block's count is negative, then the count is followed immediately by a
105 long block size, indicating the number of bytes in the block. The actual
106 count in this case is the absolute value of the count written."""
107 encoder.write_array_start()
108 if len(datum) > 0:
109 encoder.write_item_count(len(datum))
110 dtype = schema["items"]
111 for item in datum:
112 write_data(encoder, item, dtype, named_schemas, fname, options)
113 encoder.end_item()
114 encoder.write_array_end()
117def write_map(encoder, datum, schema, named_schemas, fname, options):
118 """Maps are encoded as a series of blocks.
120 Each block consists of a long count value, followed by that many key/value
121 pairs. A block with count zero indicates the end of the map. Each item is
122 encoded per the map's value schema.
124 If a block's count is negative, then the count is followed immediately by a
125 long block size, indicating the number of bytes in the block. The actual
126 count in this case is the absolute value of the count written."""
127 encoder.write_map_start()
128 if len(datum) > 0:
129 encoder.write_item_count(len(datum))
130 vtype = schema["values"]
131 for key, val in datum.items():
132 encoder.write_utf8(key)
133 write_data(encoder, val, vtype, named_schemas, fname, options)
134 encoder.write_map_end()
137def write_union(encoder, datum, schema, named_schemas, fname, options):
138 """A union is encoded by first writing a long value indicating the
139 zero-based position within the union of the schema of its value. The value
140 is then encoded per the indicated schema within the union."""
142 best_match_index = -1
143 if isinstance(datum, tuple) and not options.get("disable_tuple_notation"):
144 (name, datum) = datum
145 for index, candidate in enumerate(schema):
146 extracted_type = extract_record_type(candidate)
147 if extracted_type in NAMED_TYPES:
148 schema_name = candidate["name"]
149 else:
150 schema_name = extracted_type
151 if name == schema_name:
152 best_match_index = index
153 break
155 if best_match_index == -1:
156 field = f"on field {fname}" if fname else ""
157 msg = (
158 f"provided union type name {name} not found in schema "
159 + f"{schema} {field}"
160 )
161 raise ValueError(msg)
162 index = best_match_index
163 else:
164 pytype = type(datum)
165 most_fields = -1
167 # All of Python's floating point values are doubles, so to
168 # avoid loss of precision, we should always prefer 'double'
169 # if we are forced to choose between float and double.
170 #
171 # If 'double' comes before 'float' in the union, then we'll immediately
172 # choose it, and don't need to worry. But if 'float' comes before
173 # 'double', we don't want to pick it.
174 #
175 # So, if we ever see 'float', we skim through the rest of the options,
176 # just to see if 'double' is a possibility, because we'd prefer it.
177 could_be_float = False
179 for index, candidate in enumerate(schema):
180 if could_be_float:
181 if extract_record_type(candidate) == "double":
182 best_match_index = index
183 break
184 else:
185 # Nothing except "double" is even worth considering.
186 continue
188 if _validate(
189 datum,
190 candidate,
191 named_schemas,
192 raise_errors=False,
193 field="",
194 options=options,
195 ):
196 record_type = extract_record_type(candidate)
197 if record_type == "record":
198 logical_type = extract_logical_type(candidate)
199 if logical_type:
200 prepare = LOGICAL_WRITERS.get(logical_type)
201 if prepare:
202 datum = prepare(datum, candidate)
204 candidate_fields = set(f["name"] for f in candidate["fields"])
205 datum_fields = set(datum)
206 fields = len(candidate_fields.intersection(datum_fields))
207 if fields > most_fields:
208 best_match_index = index
209 most_fields = fields
210 elif record_type == "float":
211 best_match_index = index
212 # Continue in the loop, because it's possible that there's
213 # another candidate which has record type 'double'
214 could_be_float = True
215 else:
216 best_match_index = index
217 break
218 if best_match_index == -1:
219 field = f"on field {fname}" if fname else ""
220 raise ValueError(
221 f"{repr(datum)} (type {pytype}) do not match {schema} {field}"
222 )
223 index = best_match_index
225 # write data
226 # TODO: There should be a way to give just the index
227 encoder.write_index(index, schema[index])
228 write_data(encoder, datum, schema[index], named_schemas, fname, options)
231def write_record(encoder, datum, schema, named_schemas, fname, options):
232 """A record is encoded by encoding the values of its fields in the order
233 that they are declared. In other words, a record is encoded as just the
234 concatenation of the encodings of its fields. Field values are encoded per
235 their schema."""
236 extras = set(datum) - set(field["name"] for field in schema["fields"])
237 if (options.get("strict") or options.get("strict_allow_default")) and extras:
238 raise ValueError(
239 f'record contains more fields than the schema specifies: {", ".join(extras)}'
240 )
241 for field in schema["fields"]:
242 name = field["name"]
243 field_type = field["type"]
244 if name not in datum:
245 if options.get("strict") or (
246 options.get("strict_allow_default") and "default" not in field
247 ):
248 raise ValueError(
249 f"Field {name} is specified in the schema but missing from the record"
250 )
251 elif "default" not in field and "null" not in field_type:
252 raise ValueError(f"no value and no default for {name}")
253 datum_value = datum.get(name, field.get("default"))
254 if field_type == "float" or field_type == "double":
255 # Handle float values like "NaN"
256 datum_value = float(datum_value)
257 write_data(
258 encoder,
259 datum_value,
260 field_type,
261 named_schemas,
262 name,
263 options,
264 )
267WRITERS = {
268 "null": write_null,
269 "boolean": write_boolean,
270 "string": write_utf8,
271 "int": write_int,
272 "long": write_long,
273 "float": write_float,
274 "double": write_double,
275 "bytes": write_bytes,
276 "fixed": write_fixed,
277 "enum": write_enum,
278 "array": write_array,
279 "map": write_map,
280 "union": write_union,
281 "error_union": write_union,
282 "record": write_record,
283 "error": write_record,
284}
287def write_data(encoder, datum, schema, named_schemas, fname, options):
288 """Write a datum of data to output stream.
290 Parameters
291 ----------
292 encoder: encoder
293 Type of encoder (e.g. binary or json)
294 datum: object
295 Data to write
296 schema: dict
297 Schema to use
298 named_schemas: dict
299 Mapping of fullname to schema definition
300 """
302 record_type = extract_record_type(schema)
303 logical_type = extract_logical_type(schema)
305 fn = WRITERS.get(record_type)
306 if fn:
307 if logical_type:
308 prepare = LOGICAL_WRITERS.get(logical_type)
309 if prepare:
310 datum = prepare(datum, schema)
311 try:
312 return fn(encoder, datum, schema, named_schemas, fname, options)
313 except TypeError as ex:
314 if fname:
315 raise TypeError(f"{ex} on field {fname}")
316 raise
317 else:
318 return write_data(
319 encoder, datum, named_schemas[record_type], named_schemas, "", options
320 )
323def write_header(encoder, metadata, sync_marker):
324 header = {
325 "magic": MAGIC,
326 "meta": {key: value.encode() for key, value in metadata.items()},
327 "sync": sync_marker,
328 }
329 write_data(encoder, header, HEADER_SCHEMA, {}, "", {})
332def null_write_block(encoder, block_bytes, compression_level):
333 """Write block in "null" codec."""
334 encoder.write_long(len(block_bytes))
335 encoder._fo.write(block_bytes)
338def deflate_write_block(encoder, block_bytes, compression_level):
339 """Write block in "deflate" codec."""
340 # The first two characters and last character are zlib
341 # wrappers around deflate data.
342 if compression_level is not None:
343 data = zlib.compress(block_bytes, compression_level)[2:-1]
344 else:
345 data = zlib.compress(block_bytes)[2:-1]
346 encoder.write_long(len(data))
347 encoder._fo.write(data)
350def bzip2_write_block(encoder, block_bytes, compression_level):
351 """Write block in "bzip2" codec."""
352 data = bz2.compress(block_bytes)
353 encoder.write_long(len(data))
354 encoder._fo.write(data)
357def xz_write_block(encoder, block_bytes, compression_level):
358 """Write block in "xz" codec."""
359 data = lzma.compress(block_bytes)
360 encoder.write_long(len(data))
361 encoder._fo.write(data)
364BLOCK_WRITERS = {
365 "null": null_write_block,
366 "deflate": deflate_write_block,
367 "bzip2": bzip2_write_block,
368 "xz": xz_write_block,
369}
372def _missing_codec_lib(codec, library):
373 def missing(encoder, block_bytes, compression_level):
374 raise ValueError(
375 f"{codec} codec is supported but you need to install {library}"
376 )
378 return missing
381def snappy_write_block(encoder, block_bytes, compression_level):
382 """Write block in "snappy" codec."""
383 data = snappy.compress(block_bytes)
384 encoder.write_long(len(data) + 4) # for CRC
385 encoder._fo.write(data)
386 encoder.write_crc32(block_bytes)
389try:
390 import snappy
391except ImportError:
392 BLOCK_WRITERS["snappy"] = _missing_codec_lib("snappy", "python-snappy")
393else:
394 BLOCK_WRITERS["snappy"] = snappy_write_block
397def zstandard_write_block(encoder, block_bytes, compression_level):
398 """Write block in "zstandard" codec."""
399 if compression_level is not None:
400 data = zstd.ZstdCompressor(level=compression_level).compress(block_bytes)
401 else:
402 data = zstd.ZstdCompressor().compress(block_bytes)
403 encoder.write_long(len(data))
404 encoder._fo.write(data)
407try:
408 import zstandard as zstd
409except ImportError:
410 BLOCK_WRITERS["zstandard"] = _missing_codec_lib("zstandard", "zstandard")
411else:
412 BLOCK_WRITERS["zstandard"] = zstandard_write_block
415def lz4_write_block(encoder, block_bytes, compression_level):
416 """Write block in "lz4" codec."""
417 data = lz4.block.compress(block_bytes)
418 encoder.write_long(len(data))
419 encoder._fo.write(data)
422try:
423 import lz4.block
424except ImportError:
425 BLOCK_WRITERS["lz4"] = _missing_codec_lib("lz4", "lz4")
426else:
427 BLOCK_WRITERS["lz4"] = lz4_write_block
430class GenericWriter(ABC):
431 def __init__(self, schema, metadata=None, validator=None, options={}):
432 self._named_schemas = {}
433 self.validate_fn = _validate if validator else None
434 self.metadata = metadata or {}
435 self.options = options
437 # A schema of None is allowed when appending and when doing so the
438 # self.schema will be updated later
439 if schema is not None:
440 self.schema = parse_schema(schema, self._named_schemas)
442 if isinstance(schema, dict):
443 schema = {
444 key: value
445 for key, value in schema.items()
446 if key not in ("__fastavro_parsed", "__named_schemas")
447 }
448 elif isinstance(schema, list):
449 schemas = []
450 for s in schema:
451 if isinstance(s, dict):
452 schemas.append(
453 {
454 key: value
455 for key, value in s.items()
456 if key
457 not in (
458 "__fastavro_parsed",
459 "__named_schemas",
460 )
461 }
462 )
463 else:
464 schemas.append(s)
465 schema = schemas
467 self.metadata["avro.schema"] = json.dumps(schema)
469 @abstractmethod
470 def write(self, record):
471 pass
473 @abstractmethod
474 def flush(self):
475 pass
478class Writer(GenericWriter):
479 def __init__(
480 self,
481 fo: Union[IO, BinaryEncoder],
482 schema: Schema,
483 codec: str = "null",
484 sync_interval: int = 1000 * SYNC_SIZE,
485 metadata: Optional[Dict[str, str]] = None,
486 validator: bool = False,
487 sync_marker: bytes = b"",
488 compression_level: Optional[int] = None,
489 options: Dict[str, bool] = {},
490 ):
491 super().__init__(schema, metadata, validator, options)
493 self.metadata["avro.codec"] = codec
494 if isinstance(fo, BinaryEncoder):
495 self.encoder = fo
496 else:
497 self.encoder = BinaryEncoder(fo)
498 self.io = BinaryEncoder(BytesIO())
499 self.block_count = 0
500 self.sync_interval = sync_interval
501 self.compression_level = compression_level
503 if _is_appendable(self.encoder._fo):
504 # Seed to the beginning to read the header
505 self.encoder._fo.seek(0)
506 avro_reader = reader(self.encoder._fo)
507 header = avro_reader._header
509 self._named_schemas = {}
510 self.schema = parse_schema(avro_reader.writer_schema, self._named_schemas)
512 codec = avro_reader.metadata.get("avro.codec", "null")
514 self.sync_marker = header["sync"]
516 # Seek to the end of the file
517 self.encoder._fo.seek(0, 2)
519 self.block_writer = BLOCK_WRITERS[codec]
520 else:
521 self.sync_marker = sync_marker or urandom(SYNC_SIZE)
523 try:
524 self.block_writer = BLOCK_WRITERS[codec]
525 except KeyError:
526 raise ValueError(f"unrecognized codec: {codec}")
528 write_header(self.encoder, self.metadata, self.sync_marker)
530 def dump(self):
531 self.encoder.write_long(self.block_count)
532 self.block_writer(self.encoder, self.io._fo.getvalue(), self.compression_level)
533 self.encoder._fo.write(self.sync_marker)
534 self.io._fo.truncate(0)
535 self.io._fo.seek(0, SEEK_SET)
536 self.block_count = 0
538 def write(self, record):
539 if self.validate_fn:
540 self.validate_fn(
541 record, self.schema, self._named_schemas, "", True, self.options
542 )
543 write_data(self.io, record, self.schema, self._named_schemas, "", self.options)
544 self.block_count += 1
545 if self.io._fo.tell() >= self.sync_interval:
546 self.dump()
548 def write_block(self, block):
549 # Clear existing block if there are any records pending
550 if self.io._fo.tell() or self.block_count > 0:
551 self.dump()
552 self.encoder.write_long(block.num_records)
553 self.block_writer(self.encoder, block.bytes_.getvalue(), self.compression_level)
554 self.encoder._fo.write(self.sync_marker)
556 def flush(self):
557 if self.io._fo.tell() or self.block_count > 0:
558 self.dump()
559 self.encoder._fo.flush()
562class JSONWriter(GenericWriter):
563 def __init__(
564 self,
565 fo: AvroJSONEncoder,
566 schema: Schema,
567 codec: str = "null",
568 sync_interval: int = 1000 * SYNC_SIZE,
569 metadata: Optional[Dict[str, str]] = None,
570 validator: bool = False,
571 sync_marker: bytes = b"",
572 codec_compression_level: Optional[int] = None,
573 options: Dict[str, bool] = {},
574 ):
575 super().__init__(schema, metadata, validator, options)
577 self.encoder = fo
578 self.encoder.configure(self.schema, self._named_schemas)
580 def write(self, record):
581 if self.validate_fn:
582 self.validate_fn(
583 record, self.schema, self._named_schemas, "", True, self.options
584 )
585 write_data(
586 self.encoder, record, self.schema, self._named_schemas, "", self.options
587 )
589 def flush(self):
590 self.encoder.flush()
593def writer(
594 fo: Union[IO, AvroJSONEncoder],
595 schema: Schema,
596 records: Iterable[Any],
597 codec: str = "null",
598 sync_interval: int = 1000 * SYNC_SIZE,
599 metadata: Optional[Dict[str, str]] = None,
600 validator: bool = False,
601 sync_marker: bytes = b"",
602 codec_compression_level: Optional[int] = None,
603 *,
604 strict: bool = False,
605 strict_allow_default: bool = False,
606 disable_tuple_notation: bool = False,
607):
608 """Write records to fo (stream) according to schema
610 Parameters
611 ----------
612 fo
613 Output stream
614 schema
615 Writer schema
616 records
617 Records to write. This is commonly a list of the dictionary
618 representation of the records, but it can be any iterable
619 codec
620 Compression codec, can be 'null', 'deflate' or 'snappy' (if installed)
621 sync_interval
622 Size of sync interval
623 metadata
624 Header metadata
625 validator
626 If true, validation will be done on the records
627 sync_marker
628 A byte string used as the avro sync marker. If not provided, a random
629 byte string will be used.
630 codec_compression_level
631 Compression level to use with the specified codec (if the codec
632 supports it)
633 strict
634 If set to True, an error will be raised if records do not contain
635 exactly the same fields that the schema states
636 strict_allow_default
637 If set to True, an error will be raised if records do not contain
638 exactly the same fields that the schema states unless it is a missing
639 field that has a default value in the schema
640 disable_tuple_notation
641 If set to True, tuples will not be treated as a special case. Therefore,
642 using a tuple to indicate the type of a record will not work
645 Example::
647 from fastavro import writer, parse_schema
649 schema = {
650 'doc': 'A weather reading.',
651 'name': 'Weather',
652 'namespace': 'test',
653 'type': 'record',
654 'fields': [
655 {'name': 'station', 'type': 'string'},
656 {'name': 'time', 'type': 'long'},
657 {'name': 'temp', 'type': 'int'},
658 ],
659 }
660 parsed_schema = parse_schema(schema)
662 records = [
663 {u'station': u'011990-99999', u'temp': 0, u'time': 1433269388},
664 {u'station': u'011990-99999', u'temp': 22, u'time': 1433270389},
665 {u'station': u'011990-99999', u'temp': -11, u'time': 1433273379},
666 {u'station': u'012650-99999', u'temp': 111, u'time': 1433275478},
667 ]
669 with open('weather.avro', 'wb') as out:
670 writer(out, parsed_schema, records)
672 The `fo` argument is a file-like object so another common example usage
673 would use an `io.BytesIO` object like so::
675 from io import BytesIO
676 from fastavro import writer
678 fo = BytesIO()
679 writer(fo, schema, records)
681 Given an existing avro file, it's possible to append to it by re-opening
682 the file in `a+b` mode. If the file is only opened in `ab` mode, we aren't
683 able to read some of the existing header information and an error will be
684 raised. For example::
686 # Write initial records
687 with open('weather.avro', 'wb') as out:
688 writer(out, parsed_schema, records)
690 # Write some more records
691 with open('weather.avro', 'a+b') as out:
692 writer(out, None, more_records)
694 Note: When appending, any schema provided will be ignored since the schema
695 in the avro file will be re-used. Therefore it is convenient to just use
696 None as the schema.
697 """
698 # Sanity check that records is not a single dictionary (as that is a common
699 # mistake and the exception that gets raised is not helpful)
700 if isinstance(records, dict):
701 raise ValueError('"records" argument should be an iterable, not dict')
703 output: Union[JSONWriter, Writer]
704 if isinstance(fo, AvroJSONEncoder):
705 output = JSONWriter(
706 fo,
707 schema,
708 codec,
709 sync_interval,
710 metadata,
711 validator,
712 sync_marker,
713 codec_compression_level,
714 options={
715 "strict": strict,
716 "strict_allow_default": strict_allow_default,
717 "disable_tuple_notation": disable_tuple_notation,
718 },
719 )
720 else:
721 output = Writer(
722 BinaryEncoder(fo),
723 schema,
724 codec,
725 sync_interval,
726 metadata,
727 validator,
728 sync_marker,
729 codec_compression_level,
730 options={
731 "strict": strict,
732 "strict_allow_default": strict_allow_default,
733 "disable_tuple_notation": disable_tuple_notation,
734 },
735 )
737 for record in records:
738 output.write(record)
739 output.flush()
742def schemaless_writer(
743 fo: IO,
744 schema: Schema,
745 record: Any,
746 *,
747 strict: bool = False,
748 strict_allow_default: bool = False,
749 disable_tuple_notation: bool = False,
750):
751 """Write a single record without the schema or header information
753 Parameters
754 ----------
755 fo
756 Output file
757 schema
758 Schema
759 record
760 Record to write
761 strict
762 If set to True, an error will be raised if records do not contain
763 exactly the same fields that the schema states
764 strict_allow_default
765 If set to True, an error will be raised if records do not contain
766 exactly the same fields that the schema states unless it is a missing
767 field that has a default value in the schema
768 disable_tuple_notation
769 If set to True, tuples will not be treated as a special case. Therefore,
770 using a tuple to indicate the type of a record will not work
773 Example::
775 parsed_schema = fastavro.parse_schema(schema)
776 with open('file', 'wb') as fp:
777 fastavro.schemaless_writer(fp, parsed_schema, record)
779 Note: The ``schemaless_writer`` can only write a single record.
780 """
781 named_schemas: NamedSchemas = {}
782 schema = parse_schema(schema, named_schemas)
784 encoder = BinaryEncoder(fo)
785 write_data(
786 encoder,
787 record,
788 schema,
789 named_schemas,
790 "",
791 {
792 "strict": strict,
793 "strict_allow_default": strict_allow_default,
794 "disable_tuple_notation": disable_tuple_notation,
795 },
796 )
797 encoder.flush()