Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/pandas/io/stata.py: 14%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Module contains tools for processing Stata files into DataFrames
4The StataReader below was originally written by Joe Presbrey as part of PyDTA.
5It has been extended and improved by Skipper Seabold from the Statsmodels
6project who also developed the StataWriter and was finally added to pandas in
7a once again improved version.
9You can find more information on http://presbrey.mit.edu/PyDTA and
10https://www.statsmodels.org/devel/
11"""
12from __future__ import annotations
14from collections import abc
15from datetime import (
16 datetime,
17 timedelta,
18)
19from io import BytesIO
20import os
21import struct
22import sys
23from typing import (
24 IO,
25 TYPE_CHECKING,
26 AnyStr,
27 Callable,
28 Final,
29 cast,
30)
31import warnings
33import numpy as np
35from pandas._libs import lib
36from pandas._libs.lib import infer_dtype
37from pandas._libs.writers import max_len_string_array
38from pandas.errors import (
39 CategoricalConversionWarning,
40 InvalidColumnName,
41 PossiblePrecisionLoss,
42 ValueLabelTypeMismatch,
43)
44from pandas.util._decorators import (
45 Appender,
46 doc,
47)
48from pandas.util._exceptions import find_stack_level
50from pandas.core.dtypes.base import ExtensionDtype
51from pandas.core.dtypes.common import (
52 ensure_object,
53 is_numeric_dtype,
54 is_string_dtype,
55)
56from pandas.core.dtypes.dtypes import CategoricalDtype
58from pandas import (
59 Categorical,
60 DatetimeIndex,
61 NaT,
62 Timestamp,
63 isna,
64 to_datetime,
65 to_timedelta,
66)
67from pandas.core.frame import DataFrame
68from pandas.core.indexes.base import Index
69from pandas.core.indexes.range import RangeIndex
70from pandas.core.series import Series
71from pandas.core.shared_docs import _shared_docs
73from pandas.io.common import get_handle
75if TYPE_CHECKING:
76 from collections.abc import (
77 Hashable,
78 Sequence,
79 )
80 from types import TracebackType
81 from typing import Literal
83 from pandas._typing import (
84 CompressionOptions,
85 FilePath,
86 ReadBuffer,
87 Self,
88 StorageOptions,
89 WriteBuffer,
90 )
92_version_error = (
93 "Version of given Stata file is {version}. pandas supports importing "
94 "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
95 "114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16),"
96 "and 119 (Stata 15/16, over 32,767 variables)."
97)
99_statafile_processing_params1 = """\
100convert_dates : bool, default True
101 Convert date variables to DataFrame time values.
102convert_categoricals : bool, default True
103 Read value labels and convert columns to Categorical/Factor variables."""
105_statafile_processing_params2 = """\
106index_col : str, optional
107 Column to set as index.
108convert_missing : bool, default False
109 Flag indicating whether to convert missing values to their Stata
110 representations. If False, missing values are replaced with nan.
111 If True, columns containing missing values are returned with
112 object data types and missing values are represented by
113 StataMissingValue objects.
114preserve_dtypes : bool, default True
115 Preserve Stata datatypes. If False, numeric data are upcast to pandas
116 default types for foreign data (float64 or int64).
117columns : list or None
118 Columns to retain. Columns will be returned in the given order. None
119 returns all columns.
120order_categoricals : bool, default True
121 Flag indicating whether converted categorical data are ordered."""
123_chunksize_params = """\
124chunksize : int, default None
125 Return StataReader object for iterations, returns chunks with
126 given number of lines."""
128_iterator_params = """\
129iterator : bool, default False
130 Return StataReader object."""
132_reader_notes = """\
133Notes
134-----
135Categorical variables read through an iterator may not have the same
136categories and dtype. This occurs when a variable stored in a DTA
137file is associated to an incomplete set of value labels that only
138label a strict subset of the values."""
140_read_stata_doc = f"""
141Read Stata file into DataFrame.
143Parameters
144----------
145filepath_or_buffer : str, path object or file-like object
146 Any valid string path is acceptable. The string could be a URL. Valid
147 URL schemes include http, ftp, s3, and file. For file URLs, a host is
148 expected. A local file could be: ``file://localhost/path/to/table.dta``.
150 If you want to pass in a path object, pandas accepts any ``os.PathLike``.
152 By file-like object, we refer to objects with a ``read()`` method,
153 such as a file handle (e.g. via builtin ``open`` function)
154 or ``StringIO``.
155{_statafile_processing_params1}
156{_statafile_processing_params2}
157{_chunksize_params}
158{_iterator_params}
159{_shared_docs["decompression_options"] % "filepath_or_buffer"}
160{_shared_docs["storage_options"]}
162Returns
163-------
164DataFrame or pandas.api.typing.StataReader
166See Also
167--------
168io.stata.StataReader : Low-level reader for Stata data files.
169DataFrame.to_stata: Export Stata data files.
171{_reader_notes}
173Examples
174--------
176Creating a dummy stata for this example
178>>> df = pd.DataFrame({{'animal': ['falcon', 'parrot', 'falcon', 'parrot'],
179... 'speed': [350, 18, 361, 15]}}) # doctest: +SKIP
180>>> df.to_stata('animals.dta') # doctest: +SKIP
182Read a Stata dta file:
184>>> df = pd.read_stata('animals.dta') # doctest: +SKIP
186Read a Stata dta file in 10,000 line chunks:
188>>> values = np.random.randint(0, 10, size=(20_000, 1), dtype="uint8") # doctest: +SKIP
189>>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
190>>> df.to_stata('filename.dta') # doctest: +SKIP
192>>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP
193>>> for chunk in itr:
194... # Operate on a single chunk, e.g., chunk.mean()
195... pass # doctest: +SKIP
196"""
198_read_method_doc = f"""\
199Reads observations from Stata file, converting them into a dataframe
201Parameters
202----------
203nrows : int
204 Number of lines to read from data file, if None read whole file.
205{_statafile_processing_params1}
206{_statafile_processing_params2}
208Returns
209-------
210DataFrame
211"""
213_stata_reader_doc = f"""\
214Class for reading Stata dta files.
216Parameters
217----------
218path_or_buf : path (string), buffer or path object
219 string, path object (pathlib.Path or py._path.local.LocalPath) or object
220 implementing a binary read() functions.
221{_statafile_processing_params1}
222{_statafile_processing_params2}
223{_chunksize_params}
224{_shared_docs["decompression_options"]}
225{_shared_docs["storage_options"]}
227{_reader_notes}
228"""
231_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
234stata_epoch: Final = datetime(1960, 1, 1)
237def _stata_elapsed_date_to_datetime_vec(dates: Series, fmt: str) -> Series:
238 """
239 Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime
241 Parameters
242 ----------
243 dates : Series
244 The Stata Internal Format date to convert to datetime according to fmt
245 fmt : str
246 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
247 Returns
249 Returns
250 -------
251 converted : Series
252 The converted dates
254 Examples
255 --------
256 >>> dates = pd.Series([52])
257 >>> _stata_elapsed_date_to_datetime_vec(dates , "%tw")
258 0 1961-01-01
259 dtype: datetime64[ns]
261 Notes
262 -----
263 datetime/c - tc
264 milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
265 datetime/C - tC - NOT IMPLEMENTED
266 milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
267 date - td
268 days since 01jan1960 (01jan1960 = 0)
269 weekly date - tw
270 weeks since 1960w1
271 This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
272 The datetime value is the start of the week in terms of days in the
273 year, not ISO calendar weeks.
274 monthly date - tm
275 months since 1960m1
276 quarterly date - tq
277 quarters since 1960q1
278 half-yearly date - th
279 half-years since 1960h1 yearly
280 date - ty
281 years since 0000
282 """
283 MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year
284 MAX_DAY_DELTA = (Timestamp.max - datetime(1960, 1, 1)).days
285 MIN_DAY_DELTA = (Timestamp.min - datetime(1960, 1, 1)).days
286 MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000
287 MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000
289 def convert_year_month_safe(year, month) -> Series:
290 """
291 Convert year and month to datetimes, using pandas vectorized versions
292 when the date range falls within the range supported by pandas.
293 Otherwise it falls back to a slower but more robust method
294 using datetime.
295 """
296 if year.max() < MAX_YEAR and year.min() > MIN_YEAR:
297 return to_datetime(100 * year + month, format="%Y%m")
298 else:
299 index = getattr(year, "index", None)
300 return Series([datetime(y, m, 1) for y, m in zip(year, month)], index=index)
302 def convert_year_days_safe(year, days) -> Series:
303 """
304 Converts year (e.g. 1999) and days since the start of the year to a
305 datetime or datetime64 Series
306 """
307 if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR:
308 return to_datetime(year, format="%Y") + to_timedelta(days, unit="d")
309 else:
310 index = getattr(year, "index", None)
311 value = [
312 datetime(y, 1, 1) + timedelta(days=int(d)) for y, d in zip(year, days)
313 ]
314 return Series(value, index=index)
316 def convert_delta_safe(base, deltas, unit) -> Series:
317 """
318 Convert base dates and deltas to datetimes, using pandas vectorized
319 versions if the deltas satisfy restrictions required to be expressed
320 as dates in pandas.
321 """
322 index = getattr(deltas, "index", None)
323 if unit == "d":
324 if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA:
325 values = [base + timedelta(days=int(d)) for d in deltas]
326 return Series(values, index=index)
327 elif unit == "ms":
328 if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA:
329 values = [
330 base + timedelta(microseconds=(int(d) * 1000)) for d in deltas
331 ]
332 return Series(values, index=index)
333 else:
334 raise ValueError("format not understood")
335 base = to_datetime(base)
336 deltas = to_timedelta(deltas, unit=unit)
337 return base + deltas
339 # TODO(non-nano): If/when pandas supports more than datetime64[ns], this
340 # should be improved to use correct range, e.g. datetime[Y] for yearly
341 bad_locs = np.isnan(dates)
342 has_bad_values = False
343 if bad_locs.any():
344 has_bad_values = True
345 dates._values[bad_locs] = 1.0 # Replace with NaT
346 dates = dates.astype(np.int64)
348 if fmt.startswith(("%tc", "tc")): # Delta ms relative to base
349 base = stata_epoch
350 ms = dates
351 conv_dates = convert_delta_safe(base, ms, "ms")
352 elif fmt.startswith(("%tC", "tC")):
353 warnings.warn(
354 "Encountered %tC format. Leaving in Stata Internal Format.",
355 stacklevel=find_stack_level(),
356 )
357 conv_dates = Series(dates, dtype=object)
358 if has_bad_values:
359 conv_dates[bad_locs] = NaT
360 return conv_dates
361 # Delta days relative to base
362 elif fmt.startswith(("%td", "td", "%d", "d")):
363 base = stata_epoch
364 days = dates
365 conv_dates = convert_delta_safe(base, days, "d")
366 # does not count leap days - 7 days is a week.
367 # 52nd week may have more than 7 days
368 elif fmt.startswith(("%tw", "tw")):
369 year = stata_epoch.year + dates // 52
370 days = (dates % 52) * 7
371 conv_dates = convert_year_days_safe(year, days)
372 elif fmt.startswith(("%tm", "tm")): # Delta months relative to base
373 year = stata_epoch.year + dates // 12
374 month = (dates % 12) + 1
375 conv_dates = convert_year_month_safe(year, month)
376 elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base
377 year = stata_epoch.year + dates // 4
378 quarter_month = (dates % 4) * 3 + 1
379 conv_dates = convert_year_month_safe(year, quarter_month)
380 elif fmt.startswith(("%th", "th")): # Delta half-years relative to base
381 year = stata_epoch.year + dates // 2
382 month = (dates % 2) * 6 + 1
383 conv_dates = convert_year_month_safe(year, month)
384 elif fmt.startswith(("%ty", "ty")): # Years -- not delta
385 year = dates
386 first_month = np.ones_like(dates)
387 conv_dates = convert_year_month_safe(year, first_month)
388 else:
389 raise ValueError(f"Date fmt {fmt} not understood")
391 if has_bad_values: # Restore NaT for bad values
392 conv_dates[bad_locs] = NaT
394 return conv_dates
397def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series:
398 """
399 Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime
401 Parameters
402 ----------
403 dates : Series
404 Series or array containing datetime or datetime64[ns] to
405 convert to the Stata Internal Format given by fmt
406 fmt : str
407 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
408 """
409 index = dates.index
410 NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000
411 US_PER_DAY = NS_PER_DAY / 1000
413 def parse_dates_safe(
414 dates: Series, delta: bool = False, year: bool = False, days: bool = False
415 ):
416 d = {}
417 if lib.is_np_dtype(dates.dtype, "M"):
418 if delta:
419 time_delta = dates - Timestamp(stata_epoch).as_unit("ns")
420 d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds
421 if days or year:
422 date_index = DatetimeIndex(dates)
423 d["year"] = date_index._data.year
424 d["month"] = date_index._data.month
425 if days:
426 days_in_ns = dates._values.view(np.int64) - to_datetime(
427 d["year"], format="%Y"
428 )._values.view(np.int64)
429 d["days"] = days_in_ns // NS_PER_DAY
431 elif infer_dtype(dates, skipna=False) == "datetime":
432 if delta:
433 delta = dates._values - stata_epoch
435 def f(x: timedelta) -> float:
436 return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds
438 v = np.vectorize(f)
439 d["delta"] = v(delta)
440 if year:
441 year_month = dates.apply(lambda x: 100 * x.year + x.month)
442 d["year"] = year_month._values // 100
443 d["month"] = year_month._values - d["year"] * 100
444 if days:
446 def g(x: datetime) -> int:
447 return (x - datetime(x.year, 1, 1)).days
449 v = np.vectorize(g)
450 d["days"] = v(dates)
451 else:
452 raise ValueError(
453 "Columns containing dates must contain either "
454 "datetime64, datetime or null values."
455 )
457 return DataFrame(d, index=index)
459 bad_loc = isna(dates)
460 index = dates.index
461 if bad_loc.any():
462 if lib.is_np_dtype(dates.dtype, "M"):
463 dates._values[bad_loc] = to_datetime(stata_epoch)
464 else:
465 dates._values[bad_loc] = stata_epoch
467 if fmt in ["%tc", "tc"]:
468 d = parse_dates_safe(dates, delta=True)
469 conv_dates = d.delta / 1000
470 elif fmt in ["%tC", "tC"]:
471 warnings.warn(
472 "Stata Internal Format tC not supported.",
473 stacklevel=find_stack_level(),
474 )
475 conv_dates = dates
476 elif fmt in ["%td", "td"]:
477 d = parse_dates_safe(dates, delta=True)
478 conv_dates = d.delta // US_PER_DAY
479 elif fmt in ["%tw", "tw"]:
480 d = parse_dates_safe(dates, year=True, days=True)
481 conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7
482 elif fmt in ["%tm", "tm"]:
483 d = parse_dates_safe(dates, year=True)
484 conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1
485 elif fmt in ["%tq", "tq"]:
486 d = parse_dates_safe(dates, year=True)
487 conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3
488 elif fmt in ["%th", "th"]:
489 d = parse_dates_safe(dates, year=True)
490 conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int)
491 elif fmt in ["%ty", "ty"]:
492 d = parse_dates_safe(dates, year=True)
493 conv_dates = d.year
494 else:
495 raise ValueError(f"Format {fmt} is not a known Stata date format")
497 conv_dates = Series(conv_dates, dtype=np.float64, copy=False)
498 missing_value = struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
499 conv_dates[bad_loc] = missing_value
501 return Series(conv_dates, index=index, copy=False)
504excessive_string_length_error: Final = """
505Fixed width strings in Stata .dta files are limited to 244 (or fewer)
506characters. Column '{0}' does not satisfy this restriction. Use the
507'version=117' parameter to write the newer (Stata 13 and later) format.
508"""
511precision_loss_doc: Final = """
512Column converted from {0} to {1}, and some data are outside of the lossless
513conversion range. This may result in a loss of precision in the saved data.
514"""
517value_label_mismatch_doc: Final = """
518Stata value labels (pandas categories) must be strings. Column {0} contains
519non-string labels which will be converted to strings. Please check that the
520Stata data file created has not lost information due to duplicate labels.
521"""
524invalid_name_doc: Final = """
525Not all pandas column names were valid Stata variable names.
526The following replacements have been made:
528 {0}
530If this is not what you expect, please make sure you have Stata-compliant
531column names in your DataFrame (strings only, max 32 characters, only
532alphanumerics and underscores, no Stata reserved words)
533"""
536categorical_conversion_warning: Final = """
537One or more series with value labels are not fully labeled. Reading this
538dataset with an iterator results in categorical variable with different
539categories. This occurs since it is not possible to know all possible values
540until the entire dataset has been read. To avoid this warning, you can either
541read dataset without an iterator, or manually convert categorical data by
542``convert_categoricals`` to False and then accessing the variable labels
543through the value_labels method of the reader.
544"""
547def _cast_to_stata_types(data: DataFrame) -> DataFrame:
548 """
549 Checks the dtypes of the columns of a pandas DataFrame for
550 compatibility with the data types and ranges supported by Stata, and
551 converts if necessary.
553 Parameters
554 ----------
555 data : DataFrame
556 The DataFrame to check and convert
558 Notes
559 -----
560 Numeric columns in Stata must be one of int8, int16, int32, float32 or
561 float64, with some additional value restrictions. int8 and int16 columns
562 are checked for violations of the value restrictions and upcast if needed.
563 int64 data is not usable in Stata, and so it is downcast to int32 whenever
564 the value are in the int32 range, and sidecast to float64 when larger than
565 this range. If the int64 values are outside of the range of those
566 perfectly representable as float64 values, a warning is raised.
568 bool columns are cast to int8. uint columns are converted to int of the
569 same size if there is no loss in precision, otherwise are upcast to a
570 larger type. uint64 is currently not supported since it is concerted to
571 object in a DataFrame.
572 """
573 ws = ""
574 # original, if small, if large
575 conversion_data: tuple[
576 tuple[type, type, type],
577 tuple[type, type, type],
578 tuple[type, type, type],
579 tuple[type, type, type],
580 tuple[type, type, type],
581 ] = (
582 (np.bool_, np.int8, np.int8),
583 (np.uint8, np.int8, np.int16),
584 (np.uint16, np.int16, np.int32),
585 (np.uint32, np.int32, np.int64),
586 (np.uint64, np.int64, np.float64),
587 )
589 float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
590 float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
592 for col in data:
593 # Cast from unsupported types to supported types
594 is_nullable_int = (
595 isinstance(data[col].dtype, ExtensionDtype)
596 and data[col].dtype.kind in "iub"
597 )
598 # We need to find orig_missing before altering data below
599 orig_missing = data[col].isna()
600 if is_nullable_int:
601 fv = 0 if data[col].dtype.kind in "iu" else False
602 # Replace with NumPy-compatible column
603 data[col] = data[col].fillna(fv).astype(data[col].dtype.numpy_dtype)
604 elif isinstance(data[col].dtype, ExtensionDtype):
605 if getattr(data[col].dtype, "numpy_dtype", None) is not None:
606 data[col] = data[col].astype(data[col].dtype.numpy_dtype)
607 elif is_string_dtype(data[col].dtype):
608 data[col] = data[col].astype("object")
610 dtype = data[col].dtype
611 empty_df = data.shape[0] == 0
612 for c_data in conversion_data:
613 if dtype == c_data[0]:
614 if empty_df or data[col].max() <= np.iinfo(c_data[1]).max:
615 dtype = c_data[1]
616 else:
617 dtype = c_data[2]
618 if c_data[2] == np.int64: # Warn if necessary
619 if data[col].max() >= 2**53:
620 ws = precision_loss_doc.format("uint64", "float64")
622 data[col] = data[col].astype(dtype)
624 # Check values and upcast if necessary
626 if dtype == np.int8 and not empty_df:
627 if data[col].max() > 100 or data[col].min() < -127:
628 data[col] = data[col].astype(np.int16)
629 elif dtype == np.int16 and not empty_df:
630 if data[col].max() > 32740 or data[col].min() < -32767:
631 data[col] = data[col].astype(np.int32)
632 elif dtype == np.int64:
633 if empty_df or (
634 data[col].max() <= 2147483620 and data[col].min() >= -2147483647
635 ):
636 data[col] = data[col].astype(np.int32)
637 else:
638 data[col] = data[col].astype(np.float64)
639 if data[col].max() >= 2**53 or data[col].min() <= -(2**53):
640 ws = precision_loss_doc.format("int64", "float64")
641 elif dtype in (np.float32, np.float64):
642 if np.isinf(data[col]).any():
643 raise ValueError(
644 f"Column {col} contains infinity or -infinity"
645 "which is outside the range supported by Stata."
646 )
647 value = data[col].max()
648 if dtype == np.float32 and value > float32_max:
649 data[col] = data[col].astype(np.float64)
650 elif dtype == np.float64:
651 if value > float64_max:
652 raise ValueError(
653 f"Column {col} has a maximum value ({value}) outside the range "
654 f"supported by Stata ({float64_max})"
655 )
656 if is_nullable_int:
657 if orig_missing.any():
658 # Replace missing by Stata sentinel value
659 sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
660 data.loc[orig_missing, col] = sentinel
661 if ws:
662 warnings.warn(
663 ws,
664 PossiblePrecisionLoss,
665 stacklevel=find_stack_level(),
666 )
668 return data
671class StataValueLabel:
672 """
673 Parse a categorical column and prepare formatted output
675 Parameters
676 ----------
677 catarray : Series
678 Categorical Series to encode
679 encoding : {"latin-1", "utf-8"}
680 Encoding to use for value labels.
681 """
683 def __init__(
684 self, catarray: Series, encoding: Literal["latin-1", "utf-8"] = "latin-1"
685 ) -> None:
686 if encoding not in ("latin-1", "utf-8"):
687 raise ValueError("Only latin-1 and utf-8 are supported.")
688 self.labname = catarray.name
689 self._encoding = encoding
690 categories = catarray.cat.categories
691 self.value_labels = enumerate(categories)
693 self._prepare_value_labels()
695 def _prepare_value_labels(self) -> None:
696 """Encode value labels."""
698 self.text_len = 0
699 self.txt: list[bytes] = []
700 self.n = 0
701 # Offsets (length of categories), converted to int32
702 self.off = np.array([], dtype=np.int32)
703 # Values, converted to int32
704 self.val = np.array([], dtype=np.int32)
705 self.len = 0
707 # Compute lengths and setup lists of offsets and labels
708 offsets: list[int] = []
709 values: list[float] = []
710 for vl in self.value_labels:
711 category: str | bytes = vl[1]
712 if not isinstance(category, str):
713 category = str(category)
714 warnings.warn(
715 value_label_mismatch_doc.format(self.labname),
716 ValueLabelTypeMismatch,
717 stacklevel=find_stack_level(),
718 )
719 category = category.encode(self._encoding)
720 offsets.append(self.text_len)
721 self.text_len += len(category) + 1 # +1 for the padding
722 values.append(vl[0])
723 self.txt.append(category)
724 self.n += 1
726 if self.text_len > 32000:
727 raise ValueError(
728 "Stata value labels for a single variable must "
729 "have a combined length less than 32,000 characters."
730 )
732 # Ensure int32
733 self.off = np.array(offsets, dtype=np.int32)
734 self.val = np.array(values, dtype=np.int32)
736 # Total length
737 self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
739 def generate_value_label(self, byteorder: str) -> bytes:
740 """
741 Generate the binary representation of the value labels.
743 Parameters
744 ----------
745 byteorder : str
746 Byte order of the output
748 Returns
749 -------
750 value_label : bytes
751 Bytes containing the formatted value label
752 """
753 encoding = self._encoding
754 bio = BytesIO()
755 null_byte = b"\x00"
757 # len
758 bio.write(struct.pack(byteorder + "i", self.len))
760 # labname
761 labname = str(self.labname)[:32].encode(encoding)
762 lab_len = 32 if encoding not in ("utf-8", "utf8") else 128
763 labname = _pad_bytes(labname, lab_len + 1)
764 bio.write(labname)
766 # padding - 3 bytes
767 for i in range(3):
768 bio.write(struct.pack("c", null_byte))
770 # value_label_table
771 # n - int32
772 bio.write(struct.pack(byteorder + "i", self.n))
774 # textlen - int32
775 bio.write(struct.pack(byteorder + "i", self.text_len))
777 # off - int32 array (n elements)
778 for offset in self.off:
779 bio.write(struct.pack(byteorder + "i", offset))
781 # val - int32 array (n elements)
782 for value in self.val:
783 bio.write(struct.pack(byteorder + "i", value))
785 # txt - Text labels, null terminated
786 for text in self.txt:
787 bio.write(text + null_byte)
789 return bio.getvalue()
792class StataNonCatValueLabel(StataValueLabel):
793 """
794 Prepare formatted version of value labels
796 Parameters
797 ----------
798 labname : str
799 Value label name
800 value_labels: Dictionary
801 Mapping of values to labels
802 encoding : {"latin-1", "utf-8"}
803 Encoding to use for value labels.
804 """
806 def __init__(
807 self,
808 labname: str,
809 value_labels: dict[float, str],
810 encoding: Literal["latin-1", "utf-8"] = "latin-1",
811 ) -> None:
812 if encoding not in ("latin-1", "utf-8"):
813 raise ValueError("Only latin-1 and utf-8 are supported.")
815 self.labname = labname
816 self._encoding = encoding
817 self.value_labels = sorted( # type: ignore[assignment]
818 value_labels.items(), key=lambda x: x[0]
819 )
820 self._prepare_value_labels()
823class StataMissingValue:
824 """
825 An observation's missing value.
827 Parameters
828 ----------
829 value : {int, float}
830 The Stata missing value code
832 Notes
833 -----
834 More information: <https://www.stata.com/help.cgi?missing>
836 Integer missing values make the code '.', '.a', ..., '.z' to the ranges
837 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
838 2147483647 (for int32). Missing values for floating point data types are
839 more complex but the pattern is simple to discern from the following table.
841 np.float32 missing values (float in Stata)
842 0000007f .
843 0008007f .a
844 0010007f .b
845 ...
846 00c0007f .x
847 00c8007f .y
848 00d0007f .z
850 np.float64 missing values (double in Stata)
851 000000000000e07f .
852 000000000001e07f .a
853 000000000002e07f .b
854 ...
855 000000000018e07f .x
856 000000000019e07f .y
857 00000000001ae07f .z
858 """
860 # Construct a dictionary of missing values
861 MISSING_VALUES: dict[float, str] = {}
862 bases: Final = (101, 32741, 2147483621)
863 for b in bases:
864 # Conversion to long to avoid hash issues on 32 bit platforms #8968
865 MISSING_VALUES[b] = "."
866 for i in range(1, 27):
867 MISSING_VALUES[i + b] = "." + chr(96 + i)
869 float32_base: bytes = b"\x00\x00\x00\x7f"
870 increment_32: int = struct.unpack("<i", b"\x00\x08\x00\x00")[0]
871 for i in range(27):
872 key = struct.unpack("<f", float32_base)[0]
873 MISSING_VALUES[key] = "."
874 if i > 0:
875 MISSING_VALUES[key] += chr(96 + i)
876 int_value = struct.unpack("<i", struct.pack("<f", key))[0] + increment_32
877 float32_base = struct.pack("<i", int_value)
879 float64_base: bytes = b"\x00\x00\x00\x00\x00\x00\xe0\x7f"
880 increment_64 = struct.unpack("q", b"\x00\x00\x00\x00\x00\x01\x00\x00")[0]
881 for i in range(27):
882 key = struct.unpack("<d", float64_base)[0]
883 MISSING_VALUES[key] = "."
884 if i > 0:
885 MISSING_VALUES[key] += chr(96 + i)
886 int_value = struct.unpack("q", struct.pack("<d", key))[0] + increment_64
887 float64_base = struct.pack("q", int_value)
889 BASE_MISSING_VALUES: Final = {
890 "int8": 101,
891 "int16": 32741,
892 "int32": 2147483621,
893 "float32": struct.unpack("<f", float32_base)[0],
894 "float64": struct.unpack("<d", float64_base)[0],
895 }
897 def __init__(self, value: float) -> None:
898 self._value = value
899 # Conversion to int to avoid hash issues on 32 bit platforms #8968
900 value = int(value) if value < 2147483648 else float(value)
901 self._str = self.MISSING_VALUES[value]
903 @property
904 def string(self) -> str:
905 """
906 The Stata representation of the missing value: '.', '.a'..'.z'
908 Returns
909 -------
910 str
911 The representation of the missing value.
912 """
913 return self._str
915 @property
916 def value(self) -> float:
917 """
918 The binary representation of the missing value.
920 Returns
921 -------
922 {int, float}
923 The binary representation of the missing value.
924 """
925 return self._value
927 def __str__(self) -> str:
928 return self.string
930 def __repr__(self) -> str:
931 return f"{type(self)}({self})"
933 def __eq__(self, other: object) -> bool:
934 return (
935 isinstance(other, type(self))
936 and self.string == other.string
937 and self.value == other.value
938 )
940 @classmethod
941 def get_base_missing_value(cls, dtype: np.dtype) -> float:
942 if dtype.type is np.int8:
943 value = cls.BASE_MISSING_VALUES["int8"]
944 elif dtype.type is np.int16:
945 value = cls.BASE_MISSING_VALUES["int16"]
946 elif dtype.type is np.int32:
947 value = cls.BASE_MISSING_VALUES["int32"]
948 elif dtype.type is np.float32:
949 value = cls.BASE_MISSING_VALUES["float32"]
950 elif dtype.type is np.float64:
951 value = cls.BASE_MISSING_VALUES["float64"]
952 else:
953 raise ValueError("Unsupported dtype")
954 return value
957class StataParser:
958 def __init__(self) -> None:
959 # type code.
960 # --------------------
961 # str1 1 = 0x01
962 # str2 2 = 0x02
963 # ...
964 # str244 244 = 0xf4
965 # byte 251 = 0xfb (sic)
966 # int 252 = 0xfc
967 # long 253 = 0xfd
968 # float 254 = 0xfe
969 # double 255 = 0xff
970 # --------------------
971 # NOTE: the byte type seems to be reserved for categorical variables
972 # with a label, but the underlying variable is -127 to 100
973 # we're going to drop the label and cast to int
974 self.DTYPE_MAP = dict(
975 [(i, np.dtype(f"S{i}")) for i in range(1, 245)]
976 + [
977 (251, np.dtype(np.int8)),
978 (252, np.dtype(np.int16)),
979 (253, np.dtype(np.int32)),
980 (254, np.dtype(np.float32)),
981 (255, np.dtype(np.float64)),
982 ]
983 )
984 self.DTYPE_MAP_XML: dict[int, np.dtype] = {
985 32768: np.dtype(np.uint8), # Keys to GSO
986 65526: np.dtype(np.float64),
987 65527: np.dtype(np.float32),
988 65528: np.dtype(np.int32),
989 65529: np.dtype(np.int16),
990 65530: np.dtype(np.int8),
991 }
992 self.TYPE_MAP = list(tuple(range(251)) + tuple("bhlfd"))
993 self.TYPE_MAP_XML = {
994 # Not really a Q, unclear how to handle byteswap
995 32768: "Q",
996 65526: "d",
997 65527: "f",
998 65528: "l",
999 65529: "h",
1000 65530: "b",
1001 }
1002 # NOTE: technically, some of these are wrong. there are more numbers
1003 # that can be represented. it's the 27 ABOVE and BELOW the max listed
1004 # numeric data type in [U] 12.2.2 of the 11.2 manual
1005 float32_min = b"\xff\xff\xff\xfe"
1006 float32_max = b"\xff\xff\xff\x7e"
1007 float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff"
1008 float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f"
1009 self.VALID_RANGE = {
1010 "b": (-127, 100),
1011 "h": (-32767, 32740),
1012 "l": (-2147483647, 2147483620),
1013 "f": (
1014 np.float32(struct.unpack("<f", float32_min)[0]),
1015 np.float32(struct.unpack("<f", float32_max)[0]),
1016 ),
1017 "d": (
1018 np.float64(struct.unpack("<d", float64_min)[0]),
1019 np.float64(struct.unpack("<d", float64_max)[0]),
1020 ),
1021 }
1023 self.OLD_TYPE_MAPPING = {
1024 98: 251, # byte
1025 105: 252, # int
1026 108: 253, # long
1027 102: 254, # float
1028 100: 255, # double
1029 }
1031 # These missing values are the generic '.' in Stata, and are used
1032 # to replace nans
1033 self.MISSING_VALUES = {
1034 "b": 101,
1035 "h": 32741,
1036 "l": 2147483621,
1037 "f": np.float32(struct.unpack("<f", b"\x00\x00\x00\x7f")[0]),
1038 "d": np.float64(
1039 struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
1040 ),
1041 }
1042 self.NUMPY_TYPE_MAP = {
1043 "b": "i1",
1044 "h": "i2",
1045 "l": "i4",
1046 "f": "f4",
1047 "d": "f8",
1048 "Q": "u8",
1049 }
1051 # Reserved words cannot be used as variable names
1052 self.RESERVED_WORDS = {
1053 "aggregate",
1054 "array",
1055 "boolean",
1056 "break",
1057 "byte",
1058 "case",
1059 "catch",
1060 "class",
1061 "colvector",
1062 "complex",
1063 "const",
1064 "continue",
1065 "default",
1066 "delegate",
1067 "delete",
1068 "do",
1069 "double",
1070 "else",
1071 "eltypedef",
1072 "end",
1073 "enum",
1074 "explicit",
1075 "export",
1076 "external",
1077 "float",
1078 "for",
1079 "friend",
1080 "function",
1081 "global",
1082 "goto",
1083 "if",
1084 "inline",
1085 "int",
1086 "local",
1087 "long",
1088 "NULL",
1089 "pragma",
1090 "protected",
1091 "quad",
1092 "rowvector",
1093 "short",
1094 "typedef",
1095 "typename",
1096 "virtual",
1097 "_all",
1098 "_N",
1099 "_skip",
1100 "_b",
1101 "_pi",
1102 "str#",
1103 "in",
1104 "_pred",
1105 "strL",
1106 "_coef",
1107 "_rc",
1108 "using",
1109 "_cons",
1110 "_se",
1111 "with",
1112 "_n",
1113 }
1116class StataReader(StataParser, abc.Iterator):
1117 __doc__ = _stata_reader_doc
1119 _path_or_buf: IO[bytes] # Will be assigned by `_open_file`.
1121 def __init__(
1122 self,
1123 path_or_buf: FilePath | ReadBuffer[bytes],
1124 convert_dates: bool = True,
1125 convert_categoricals: bool = True,
1126 index_col: str | None = None,
1127 convert_missing: bool = False,
1128 preserve_dtypes: bool = True,
1129 columns: Sequence[str] | None = None,
1130 order_categoricals: bool = True,
1131 chunksize: int | None = None,
1132 compression: CompressionOptions = "infer",
1133 storage_options: StorageOptions | None = None,
1134 ) -> None:
1135 super().__init__()
1137 # Arguments to the reader (can be temporarily overridden in
1138 # calls to read).
1139 self._convert_dates = convert_dates
1140 self._convert_categoricals = convert_categoricals
1141 self._index_col = index_col
1142 self._convert_missing = convert_missing
1143 self._preserve_dtypes = preserve_dtypes
1144 self._columns = columns
1145 self._order_categoricals = order_categoricals
1146 self._original_path_or_buf = path_or_buf
1147 self._compression = compression
1148 self._storage_options = storage_options
1149 self._encoding = ""
1150 self._chunksize = chunksize
1151 self._using_iterator = False
1152 self._entered = False
1153 if self._chunksize is None:
1154 self._chunksize = 1
1155 elif not isinstance(chunksize, int) or chunksize <= 0:
1156 raise ValueError("chunksize must be a positive integer when set.")
1158 # State variables for the file
1159 self._close_file: Callable[[], None] | None = None
1160 self._missing_values = False
1161 self._can_read_value_labels = False
1162 self._column_selector_set = False
1163 self._value_labels_read = False
1164 self._data_read = False
1165 self._dtype: np.dtype | None = None
1166 self._lines_read = 0
1168 self._native_byteorder = _set_endianness(sys.byteorder)
1170 def _ensure_open(self) -> None:
1171 """
1172 Ensure the file has been opened and its header data read.
1173 """
1174 if not hasattr(self, "_path_or_buf"):
1175 self._open_file()
1177 def _open_file(self) -> None:
1178 """
1179 Open the file (with compression options, etc.), and read header information.
1180 """
1181 if not self._entered:
1182 warnings.warn(
1183 "StataReader is being used without using a context manager. "
1184 "Using StataReader as a context manager is the only supported method.",
1185 ResourceWarning,
1186 stacklevel=find_stack_level(),
1187 )
1188 handles = get_handle(
1189 self._original_path_or_buf,
1190 "rb",
1191 storage_options=self._storage_options,
1192 is_text=False,
1193 compression=self._compression,
1194 )
1195 if hasattr(handles.handle, "seekable") and handles.handle.seekable():
1196 # If the handle is directly seekable, use it without an extra copy.
1197 self._path_or_buf = handles.handle
1198 self._close_file = handles.close
1199 else:
1200 # Copy to memory, and ensure no encoding.
1201 with handles:
1202 self._path_or_buf = BytesIO(handles.handle.read())
1203 self._close_file = self._path_or_buf.close
1205 self._read_header()
1206 self._setup_dtype()
1208 def __enter__(self) -> Self:
1209 """enter context manager"""
1210 self._entered = True
1211 return self
1213 def __exit__(
1214 self,
1215 exc_type: type[BaseException] | None,
1216 exc_value: BaseException | None,
1217 traceback: TracebackType | None,
1218 ) -> None:
1219 if self._close_file:
1220 self._close_file()
1222 def close(self) -> None:
1223 """Close the handle if its open.
1225 .. deprecated: 2.0.0
1227 The close method is not part of the public API.
1228 The only supported way to use StataReader is to use it as a context manager.
1229 """
1230 warnings.warn(
1231 "The StataReader.close() method is not part of the public API and "
1232 "will be removed in a future version without notice. "
1233 "Using StataReader as a context manager is the only supported method.",
1234 FutureWarning,
1235 stacklevel=find_stack_level(),
1236 )
1237 if self._close_file:
1238 self._close_file()
1240 def _set_encoding(self) -> None:
1241 """
1242 Set string encoding which depends on file version
1243 """
1244 if self._format_version < 118:
1245 self._encoding = "latin-1"
1246 else:
1247 self._encoding = "utf-8"
1249 def _read_int8(self) -> int:
1250 return struct.unpack("b", self._path_or_buf.read(1))[0]
1252 def _read_uint8(self) -> int:
1253 return struct.unpack("B", self._path_or_buf.read(1))[0]
1255 def _read_uint16(self) -> int:
1256 return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0]
1258 def _read_uint32(self) -> int:
1259 return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0]
1261 def _read_uint64(self) -> int:
1262 return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0]
1264 def _read_int16(self) -> int:
1265 return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0]
1267 def _read_int32(self) -> int:
1268 return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0]
1270 def _read_int64(self) -> int:
1271 return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0]
1273 def _read_char8(self) -> bytes:
1274 return struct.unpack("c", self._path_or_buf.read(1))[0]
1276 def _read_int16_count(self, count: int) -> tuple[int, ...]:
1277 return struct.unpack(
1278 f"{self._byteorder}{'h' * count}",
1279 self._path_or_buf.read(2 * count),
1280 )
1282 def _read_header(self) -> None:
1283 first_char = self._read_char8()
1284 if first_char == b"<":
1285 self._read_new_header()
1286 else:
1287 self._read_old_header(first_char)
1289 def _read_new_header(self) -> None:
1290 # The first part of the header is common to 117 - 119.
1291 self._path_or_buf.read(27) # stata_dta><header><release>
1292 self._format_version = int(self._path_or_buf.read(3))
1293 if self._format_version not in [117, 118, 119]:
1294 raise ValueError(_version_error.format(version=self._format_version))
1295 self._set_encoding()
1296 self._path_or_buf.read(21) # </release><byteorder>
1297 self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<"
1298 self._path_or_buf.read(15) # </byteorder><K>
1299 self._nvar = (
1300 self._read_uint16() if self._format_version <= 118 else self._read_uint32()
1301 )
1302 self._path_or_buf.read(7) # </K><N>
1304 self._nobs = self._get_nobs()
1305 self._path_or_buf.read(11) # </N><label>
1306 self._data_label = self._get_data_label()
1307 self._path_or_buf.read(19) # </label><timestamp>
1308 self._time_stamp = self._get_time_stamp()
1309 self._path_or_buf.read(26) # </timestamp></header><map>
1310 self._path_or_buf.read(8) # 0x0000000000000000
1311 self._path_or_buf.read(8) # position of <map>
1313 self._seek_vartypes = self._read_int64() + 16
1314 self._seek_varnames = self._read_int64() + 10
1315 self._seek_sortlist = self._read_int64() + 10
1316 self._seek_formats = self._read_int64() + 9
1317 self._seek_value_label_names = self._read_int64() + 19
1319 # Requires version-specific treatment
1320 self._seek_variable_labels = self._get_seek_variable_labels()
1322 self._path_or_buf.read(8) # <characteristics>
1323 self._data_location = self._read_int64() + 6
1324 self._seek_strls = self._read_int64() + 7
1325 self._seek_value_labels = self._read_int64() + 14
1327 self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes)
1329 self._path_or_buf.seek(self._seek_varnames)
1330 self._varlist = self._get_varlist()
1332 self._path_or_buf.seek(self._seek_sortlist)
1333 self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
1335 self._path_or_buf.seek(self._seek_formats)
1336 self._fmtlist = self._get_fmtlist()
1338 self._path_or_buf.seek(self._seek_value_label_names)
1339 self._lbllist = self._get_lbllist()
1341 self._path_or_buf.seek(self._seek_variable_labels)
1342 self._variable_labels = self._get_variable_labels()
1344 # Get data type information, works for versions 117-119.
1345 def _get_dtypes(
1346 self, seek_vartypes: int
1347 ) -> tuple[list[int | str], list[str | np.dtype]]:
1348 self._path_or_buf.seek(seek_vartypes)
1349 typlist = []
1350 dtyplist = []
1351 for _ in range(self._nvar):
1352 typ = self._read_uint16()
1353 if typ <= 2045:
1354 typlist.append(typ)
1355 dtyplist.append(str(typ))
1356 else:
1357 try:
1358 typlist.append(self.TYPE_MAP_XML[typ]) # type: ignore[arg-type]
1359 dtyplist.append(self.DTYPE_MAP_XML[typ]) # type: ignore[arg-type]
1360 except KeyError as err:
1361 raise ValueError(f"cannot convert stata types [{typ}]") from err
1363 return typlist, dtyplist # type: ignore[return-value]
1365 def _get_varlist(self) -> list[str]:
1366 # 33 in order formats, 129 in formats 118 and 119
1367 b = 33 if self._format_version < 118 else 129
1368 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1370 # Returns the format list
1371 def _get_fmtlist(self) -> list[str]:
1372 if self._format_version >= 118:
1373 b = 57
1374 elif self._format_version > 113:
1375 b = 49
1376 elif self._format_version > 104:
1377 b = 12
1378 else:
1379 b = 7
1381 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1383 # Returns the label list
1384 def _get_lbllist(self) -> list[str]:
1385 if self._format_version >= 118:
1386 b = 129
1387 elif self._format_version > 108:
1388 b = 33
1389 else:
1390 b = 9
1391 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1393 def _get_variable_labels(self) -> list[str]:
1394 if self._format_version >= 118:
1395 vlblist = [
1396 self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar)
1397 ]
1398 elif self._format_version > 105:
1399 vlblist = [
1400 self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar)
1401 ]
1402 else:
1403 vlblist = [
1404 self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar)
1405 ]
1406 return vlblist
1408 def _get_nobs(self) -> int:
1409 if self._format_version >= 118:
1410 return self._read_uint64()
1411 else:
1412 return self._read_uint32()
1414 def _get_data_label(self) -> str:
1415 if self._format_version >= 118:
1416 strlen = self._read_uint16()
1417 return self._decode(self._path_or_buf.read(strlen))
1418 elif self._format_version == 117:
1419 strlen = self._read_int8()
1420 return self._decode(self._path_or_buf.read(strlen))
1421 elif self._format_version > 105:
1422 return self._decode(self._path_or_buf.read(81))
1423 else:
1424 return self._decode(self._path_or_buf.read(32))
1426 def _get_time_stamp(self) -> str:
1427 if self._format_version >= 118:
1428 strlen = self._read_int8()
1429 return self._path_or_buf.read(strlen).decode("utf-8")
1430 elif self._format_version == 117:
1431 strlen = self._read_int8()
1432 return self._decode(self._path_or_buf.read(strlen))
1433 elif self._format_version > 104:
1434 return self._decode(self._path_or_buf.read(18))
1435 else:
1436 raise ValueError()
1438 def _get_seek_variable_labels(self) -> int:
1439 if self._format_version == 117:
1440 self._path_or_buf.read(8) # <variable_labels>, throw away
1441 # Stata 117 data files do not follow the described format. This is
1442 # a work around that uses the previous label, 33 bytes for each
1443 # variable, 20 for the closing tag and 17 for the opening tag
1444 return self._seek_value_label_names + (33 * self._nvar) + 20 + 17
1445 elif self._format_version >= 118:
1446 return self._read_int64() + 17
1447 else:
1448 raise ValueError()
1450 def _read_old_header(self, first_char: bytes) -> None:
1451 self._format_version = int(first_char[0])
1452 if self._format_version not in [104, 105, 108, 111, 113, 114, 115]:
1453 raise ValueError(_version_error.format(version=self._format_version))
1454 self._set_encoding()
1455 self._byteorder = ">" if self._read_int8() == 0x1 else "<"
1456 self._filetype = self._read_int8()
1457 self._path_or_buf.read(1) # unused
1459 self._nvar = self._read_uint16()
1460 self._nobs = self._get_nobs()
1462 self._data_label = self._get_data_label()
1464 self._time_stamp = self._get_time_stamp()
1466 # descriptors
1467 if self._format_version > 108:
1468 typlist = [int(c) for c in self._path_or_buf.read(self._nvar)]
1469 else:
1470 buf = self._path_or_buf.read(self._nvar)
1471 typlistb = np.frombuffer(buf, dtype=np.uint8)
1472 typlist = []
1473 for tp in typlistb:
1474 if tp in self.OLD_TYPE_MAPPING:
1475 typlist.append(self.OLD_TYPE_MAPPING[tp])
1476 else:
1477 typlist.append(tp - 127) # bytes
1479 try:
1480 self._typlist = [self.TYPE_MAP[typ] for typ in typlist]
1481 except ValueError as err:
1482 invalid_types = ",".join([str(x) for x in typlist])
1483 raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
1484 try:
1485 self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
1486 except ValueError as err:
1487 invalid_dtypes = ",".join([str(x) for x in typlist])
1488 raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
1490 if self._format_version > 108:
1491 self._varlist = [
1492 self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar)
1493 ]
1494 else:
1495 self._varlist = [
1496 self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar)
1497 ]
1498 self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
1500 self._fmtlist = self._get_fmtlist()
1502 self._lbllist = self._get_lbllist()
1504 self._variable_labels = self._get_variable_labels()
1506 # ignore expansion fields (Format 105 and later)
1507 # When reading, read five bytes; the last four bytes now tell you
1508 # the size of the next read, which you discard. You then continue
1509 # like this until you read 5 bytes of zeros.
1511 if self._format_version > 104:
1512 while True:
1513 data_type = self._read_int8()
1514 if self._format_version > 108:
1515 data_len = self._read_int32()
1516 else:
1517 data_len = self._read_int16()
1518 if data_type == 0:
1519 break
1520 self._path_or_buf.read(data_len)
1522 # necessary data to continue parsing
1523 self._data_location = self._path_or_buf.tell()
1525 def _setup_dtype(self) -> np.dtype:
1526 """Map between numpy and state dtypes"""
1527 if self._dtype is not None:
1528 return self._dtype
1530 dtypes = [] # Convert struct data types to numpy data type
1531 for i, typ in enumerate(self._typlist):
1532 if typ in self.NUMPY_TYPE_MAP:
1533 typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
1534 dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}"))
1535 else:
1536 dtypes.append((f"s{i}", f"S{typ}"))
1537 self._dtype = np.dtype(dtypes)
1539 return self._dtype
1541 def _decode(self, s: bytes) -> str:
1542 # have bytes not strings, so must decode
1543 s = s.partition(b"\0")[0]
1544 try:
1545 return s.decode(self._encoding)
1546 except UnicodeDecodeError:
1547 # GH 25960, fallback to handle incorrect format produced when 117
1548 # files are converted to 118 files in Stata
1549 encoding = self._encoding
1550 msg = f"""
1551One or more strings in the dta file could not be decoded using {encoding}, and
1552so the fallback encoding of latin-1 is being used. This can happen when a file
1553has been incorrectly encoded by Stata or some other software. You should verify
1554the string values returned are correct."""
1555 warnings.warn(
1556 msg,
1557 UnicodeWarning,
1558 stacklevel=find_stack_level(),
1559 )
1560 return s.decode("latin-1")
1562 def _read_value_labels(self) -> None:
1563 self._ensure_open()
1564 if self._value_labels_read:
1565 # Don't read twice
1566 return
1567 if self._format_version <= 108:
1568 # Value labels are not supported in version 108 and earlier.
1569 self._value_labels_read = True
1570 self._value_label_dict: dict[str, dict[float, str]] = {}
1571 return
1573 if self._format_version >= 117:
1574 self._path_or_buf.seek(self._seek_value_labels)
1575 else:
1576 assert self._dtype is not None
1577 offset = self._nobs * self._dtype.itemsize
1578 self._path_or_buf.seek(self._data_location + offset)
1580 self._value_labels_read = True
1581 self._value_label_dict = {}
1583 while True:
1584 if self._format_version >= 117:
1585 if self._path_or_buf.read(5) == b"</val": # <lbl>
1586 break # end of value label table
1588 slength = self._path_or_buf.read(4)
1589 if not slength:
1590 break # end of value label table (format < 117)
1591 if self._format_version <= 117:
1592 labname = self._decode(self._path_or_buf.read(33))
1593 else:
1594 labname = self._decode(self._path_or_buf.read(129))
1595 self._path_or_buf.read(3) # padding
1597 n = self._read_uint32()
1598 txtlen = self._read_uint32()
1599 off = np.frombuffer(
1600 self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
1601 )
1602 val = np.frombuffer(
1603 self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
1604 )
1605 ii = np.argsort(off)
1606 off = off[ii]
1607 val = val[ii]
1608 txt = self._path_or_buf.read(txtlen)
1609 self._value_label_dict[labname] = {}
1610 for i in range(n):
1611 end = off[i + 1] if i < n - 1 else txtlen
1612 self._value_label_dict[labname][val[i]] = self._decode(
1613 txt[off[i] : end]
1614 )
1615 if self._format_version >= 117:
1616 self._path_or_buf.read(6) # </lbl>
1617 self._value_labels_read = True
1619 def _read_strls(self) -> None:
1620 self._path_or_buf.seek(self._seek_strls)
1621 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1622 self.GSO = {"0": ""}
1623 while True:
1624 if self._path_or_buf.read(3) != b"GSO":
1625 break
1627 if self._format_version == 117:
1628 v_o = self._read_uint64()
1629 else:
1630 buf = self._path_or_buf.read(12)
1631 # Only tested on little endian file on little endian machine.
1632 v_size = 2 if self._format_version == 118 else 3
1633 if self._byteorder == "<":
1634 buf = buf[0:v_size] + buf[4 : (12 - v_size)]
1635 else:
1636 # This path may not be correct, impossible to test
1637 buf = buf[0:v_size] + buf[(4 + v_size) :]
1638 v_o = struct.unpack("Q", buf)[0]
1639 typ = self._read_uint8()
1640 length = self._read_uint32()
1641 va = self._path_or_buf.read(length)
1642 if typ == 130:
1643 decoded_va = va[0:-1].decode(self._encoding)
1644 else:
1645 # Stata says typ 129 can be binary, so use str
1646 decoded_va = str(va)
1647 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1648 self.GSO[str(v_o)] = decoded_va
1650 def __next__(self) -> DataFrame:
1651 self._using_iterator = True
1652 return self.read(nrows=self._chunksize)
1654 def get_chunk(self, size: int | None = None) -> DataFrame:
1655 """
1656 Reads lines from Stata file and returns as dataframe
1658 Parameters
1659 ----------
1660 size : int, defaults to None
1661 Number of lines to read. If None, reads whole file.
1663 Returns
1664 -------
1665 DataFrame
1666 """
1667 if size is None:
1668 size = self._chunksize
1669 return self.read(nrows=size)
1671 @Appender(_read_method_doc)
1672 def read(
1673 self,
1674 nrows: int | None = None,
1675 convert_dates: bool | None = None,
1676 convert_categoricals: bool | None = None,
1677 index_col: str | None = None,
1678 convert_missing: bool | None = None,
1679 preserve_dtypes: bool | None = None,
1680 columns: Sequence[str] | None = None,
1681 order_categoricals: bool | None = None,
1682 ) -> DataFrame:
1683 self._ensure_open()
1685 # Handle options
1686 if convert_dates is None:
1687 convert_dates = self._convert_dates
1688 if convert_categoricals is None:
1689 convert_categoricals = self._convert_categoricals
1690 if convert_missing is None:
1691 convert_missing = self._convert_missing
1692 if preserve_dtypes is None:
1693 preserve_dtypes = self._preserve_dtypes
1694 if columns is None:
1695 columns = self._columns
1696 if order_categoricals is None:
1697 order_categoricals = self._order_categoricals
1698 if index_col is None:
1699 index_col = self._index_col
1700 if nrows is None:
1701 nrows = self._nobs
1703 # Handle empty file or chunk. If reading incrementally raise
1704 # StopIteration. If reading the whole thing return an empty
1705 # data frame.
1706 if (self._nobs == 0) and nrows == 0:
1707 self._can_read_value_labels = True
1708 self._data_read = True
1709 data = DataFrame(columns=self._varlist)
1710 # Apply dtypes correctly
1711 for i, col in enumerate(data.columns):
1712 dt = self._dtyplist[i]
1713 if isinstance(dt, np.dtype):
1714 if dt.char != "S":
1715 data[col] = data[col].astype(dt)
1716 if columns is not None:
1717 data = self._do_select_columns(data, columns)
1718 return data
1720 if (self._format_version >= 117) and (not self._value_labels_read):
1721 self._can_read_value_labels = True
1722 self._read_strls()
1724 # Read data
1725 assert self._dtype is not None
1726 dtype = self._dtype
1727 max_read_len = (self._nobs - self._lines_read) * dtype.itemsize
1728 read_len = nrows * dtype.itemsize
1729 read_len = min(read_len, max_read_len)
1730 if read_len <= 0:
1731 # Iterator has finished, should never be here unless
1732 # we are reading the file incrementally
1733 if convert_categoricals:
1734 self._read_value_labels()
1735 raise StopIteration
1736 offset = self._lines_read * dtype.itemsize
1737 self._path_or_buf.seek(self._data_location + offset)
1738 read_lines = min(nrows, self._nobs - self._lines_read)
1739 raw_data = np.frombuffer(
1740 self._path_or_buf.read(read_len), dtype=dtype, count=read_lines
1741 )
1743 self._lines_read += read_lines
1744 if self._lines_read == self._nobs:
1745 self._can_read_value_labels = True
1746 self._data_read = True
1747 # if necessary, swap the byte order to native here
1748 if self._byteorder != self._native_byteorder:
1749 raw_data = raw_data.byteswap().view(raw_data.dtype.newbyteorder())
1751 if convert_categoricals:
1752 self._read_value_labels()
1754 if len(raw_data) == 0:
1755 data = DataFrame(columns=self._varlist)
1756 else:
1757 data = DataFrame.from_records(raw_data)
1758 data.columns = Index(self._varlist)
1760 # If index is not specified, use actual row number rather than
1761 # restarting at 0 for each chunk.
1762 if index_col is None:
1763 data.index = RangeIndex(
1764 self._lines_read - read_lines, self._lines_read
1765 ) # set attr instead of set_index to avoid copy
1767 if columns is not None:
1768 data = self._do_select_columns(data, columns)
1770 # Decode strings
1771 for col, typ in zip(data, self._typlist):
1772 if isinstance(typ, int):
1773 data[col] = data[col].apply(self._decode)
1775 data = self._insert_strls(data)
1777 # Convert columns (if needed) to match input type
1778 valid_dtypes = [i for i, dtyp in enumerate(self._dtyplist) if dtyp is not None]
1779 object_type = np.dtype(object)
1780 for idx in valid_dtypes:
1781 dtype = data.iloc[:, idx].dtype
1782 if dtype not in (object_type, self._dtyplist[idx]):
1783 data.isetitem(idx, data.iloc[:, idx].astype(dtype))
1785 data = self._do_convert_missing(data, convert_missing)
1787 if convert_dates:
1788 for i, fmt in enumerate(self._fmtlist):
1789 if any(fmt.startswith(date_fmt) for date_fmt in _date_formats):
1790 data.isetitem(
1791 i, _stata_elapsed_date_to_datetime_vec(data.iloc[:, i], fmt)
1792 )
1794 if convert_categoricals and self._format_version > 108:
1795 data = self._do_convert_categoricals(
1796 data, self._value_label_dict, self._lbllist, order_categoricals
1797 )
1799 if not preserve_dtypes:
1800 retyped_data = []
1801 convert = False
1802 for col in data:
1803 dtype = data[col].dtype
1804 if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1805 dtype = np.dtype(np.float64)
1806 convert = True
1807 elif dtype in (
1808 np.dtype(np.int8),
1809 np.dtype(np.int16),
1810 np.dtype(np.int32),
1811 ):
1812 dtype = np.dtype(np.int64)
1813 convert = True
1814 retyped_data.append((col, data[col].astype(dtype)))
1815 if convert:
1816 data = DataFrame.from_dict(dict(retyped_data))
1818 if index_col is not None:
1819 data = data.set_index(data.pop(index_col))
1821 return data
1823 def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
1824 # Check for missing values, and replace if found
1825 replacements = {}
1826 for i in range(len(data.columns)):
1827 fmt = self._typlist[i]
1828 if fmt not in self.VALID_RANGE:
1829 continue
1831 fmt = cast(str, fmt) # only strs in VALID_RANGE
1832 nmin, nmax = self.VALID_RANGE[fmt]
1833 series = data.iloc[:, i]
1835 # appreciably faster to do this with ndarray instead of Series
1836 svals = series._values
1837 missing = (svals < nmin) | (svals > nmax)
1839 if not missing.any():
1840 continue
1842 if convert_missing: # Replacement follows Stata notation
1843 missing_loc = np.nonzero(np.asarray(missing))[0]
1844 umissing, umissing_loc = np.unique(series[missing], return_inverse=True)
1845 replacement = Series(series, dtype=object)
1846 for j, um in enumerate(umissing):
1847 missing_value = StataMissingValue(um)
1849 loc = missing_loc[umissing_loc == j]
1850 replacement.iloc[loc] = missing_value
1851 else: # All replacements are identical
1852 dtype = series.dtype
1853 if dtype not in (np.float32, np.float64):
1854 dtype = np.float64
1855 replacement = Series(series, dtype=dtype)
1856 if not replacement._values.flags["WRITEABLE"]:
1857 # only relevant for ArrayManager; construction
1858 # path for BlockManager ensures writeability
1859 replacement = replacement.copy()
1860 # Note: operating on ._values is much faster than directly
1861 # TODO: can we fix that?
1862 replacement._values[missing] = np.nan
1863 replacements[i] = replacement
1864 if replacements:
1865 for idx, value in replacements.items():
1866 data.isetitem(idx, value)
1867 return data
1869 def _insert_strls(self, data: DataFrame) -> DataFrame:
1870 if not hasattr(self, "GSO") or len(self.GSO) == 0:
1871 return data
1872 for i, typ in enumerate(self._typlist):
1873 if typ != "Q":
1874 continue
1875 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1876 data.isetitem(i, [self.GSO[str(k)] for k in data.iloc[:, i]])
1877 return data
1879 def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame:
1880 if not self._column_selector_set:
1881 column_set = set(columns)
1882 if len(column_set) != len(columns):
1883 raise ValueError("columns contains duplicate entries")
1884 unmatched = column_set.difference(data.columns)
1885 if unmatched:
1886 joined = ", ".join(list(unmatched))
1887 raise ValueError(
1888 "The following columns were not "
1889 f"found in the Stata data set: {joined}"
1890 )
1891 # Copy information for retained columns for later processing
1892 dtyplist = []
1893 typlist = []
1894 fmtlist = []
1895 lbllist = []
1896 for col in columns:
1897 i = data.columns.get_loc(col)
1898 dtyplist.append(self._dtyplist[i])
1899 typlist.append(self._typlist[i])
1900 fmtlist.append(self._fmtlist[i])
1901 lbllist.append(self._lbllist[i])
1903 self._dtyplist = dtyplist
1904 self._typlist = typlist
1905 self._fmtlist = fmtlist
1906 self._lbllist = lbllist
1907 self._column_selector_set = True
1909 return data[columns]
1911 def _do_convert_categoricals(
1912 self,
1913 data: DataFrame,
1914 value_label_dict: dict[str, dict[float, str]],
1915 lbllist: Sequence[str],
1916 order_categoricals: bool,
1917 ) -> DataFrame:
1918 """
1919 Converts categorical columns to Categorical type.
1920 """
1921 if not value_label_dict:
1922 return data
1923 cat_converted_data = []
1924 for col, label in zip(data, lbllist):
1925 if label in value_label_dict:
1926 # Explicit call with ordered=True
1927 vl = value_label_dict[label]
1928 keys = np.array(list(vl.keys()))
1929 column = data[col]
1930 key_matches = column.isin(keys)
1931 if self._using_iterator and key_matches.all():
1932 initial_categories: np.ndarray | None = keys
1933 # If all categories are in the keys and we are iterating,
1934 # use the same keys for all chunks. If some are missing
1935 # value labels, then we will fall back to the categories
1936 # varying across chunks.
1937 else:
1938 if self._using_iterator:
1939 # warn is using an iterator
1940 warnings.warn(
1941 categorical_conversion_warning,
1942 CategoricalConversionWarning,
1943 stacklevel=find_stack_level(),
1944 )
1945 initial_categories = None
1946 cat_data = Categorical(
1947 column, categories=initial_categories, ordered=order_categoricals
1948 )
1949 if initial_categories is None:
1950 # If None here, then we need to match the cats in the Categorical
1951 categories = []
1952 for category in cat_data.categories:
1953 if category in vl:
1954 categories.append(vl[category])
1955 else:
1956 categories.append(category)
1957 else:
1958 # If all cats are matched, we can use the values
1959 categories = list(vl.values())
1960 try:
1961 # Try to catch duplicate categories
1962 # TODO: if we get a non-copying rename_categories, use that
1963 cat_data = cat_data.rename_categories(categories)
1964 except ValueError as err:
1965 vc = Series(categories, copy=False).value_counts()
1966 repeated_cats = list(vc.index[vc > 1])
1967 repeats = "-" * 80 + "\n" + "\n".join(repeated_cats)
1968 # GH 25772
1969 msg = f"""
1970Value labels for column {col} are not unique. These cannot be converted to
1971pandas categoricals.
1973Either read the file with `convert_categoricals` set to False or use the
1974low level interface in `StataReader` to separately read the values and the
1975value_labels.
1977The repeated labels are:
1978{repeats}
1979"""
1980 raise ValueError(msg) from err
1981 # TODO: is the next line needed above in the data(...) method?
1982 cat_series = Series(cat_data, index=data.index, copy=False)
1983 cat_converted_data.append((col, cat_series))
1984 else:
1985 cat_converted_data.append((col, data[col]))
1986 data = DataFrame(dict(cat_converted_data), copy=False)
1987 return data
1989 @property
1990 def data_label(self) -> str:
1991 """
1992 Return data label of Stata file.
1994 Examples
1995 --------
1996 >>> df = pd.DataFrame([(1,)], columns=["variable"])
1997 >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21)
1998 >>> data_label = "This is a data file."
1999 >>> path = "/My_path/filename.dta"
2000 >>> df.to_stata(path, time_stamp=time_stamp, # doctest: +SKIP
2001 ... data_label=data_label, # doctest: +SKIP
2002 ... version=None) # doctest: +SKIP
2003 >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP
2004 ... print(reader.data_label) # doctest: +SKIP
2005 This is a data file.
2006 """
2007 self._ensure_open()
2008 return self._data_label
2010 @property
2011 def time_stamp(self) -> str:
2012 """
2013 Return time stamp of Stata file.
2014 """
2015 self._ensure_open()
2016 return self._time_stamp
2018 def variable_labels(self) -> dict[str, str]:
2019 """
2020 Return a dict associating each variable name with corresponding label.
2022 Returns
2023 -------
2024 dict
2026 Examples
2027 --------
2028 >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["col_1", "col_2"])
2029 >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21)
2030 >>> path = "/My_path/filename.dta"
2031 >>> variable_labels = {"col_1": "This is an example"}
2032 >>> df.to_stata(path, time_stamp=time_stamp, # doctest: +SKIP
2033 ... variable_labels=variable_labels, version=None) # doctest: +SKIP
2034 >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP
2035 ... print(reader.variable_labels()) # doctest: +SKIP
2036 {'index': '', 'col_1': 'This is an example', 'col_2': ''}
2037 >>> pd.read_stata(path) # doctest: +SKIP
2038 index col_1 col_2
2039 0 0 1 2
2040 1 1 3 4
2041 """
2042 self._ensure_open()
2043 return dict(zip(self._varlist, self._variable_labels))
2045 def value_labels(self) -> dict[str, dict[float, str]]:
2046 """
2047 Return a nested dict associating each variable name to its value and label.
2049 Returns
2050 -------
2051 dict
2053 Examples
2054 --------
2055 >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["col_1", "col_2"])
2056 >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21)
2057 >>> path = "/My_path/filename.dta"
2058 >>> value_labels = {"col_1": {3: "x"}}
2059 >>> df.to_stata(path, time_stamp=time_stamp, # doctest: +SKIP
2060 ... value_labels=value_labels, version=None) # doctest: +SKIP
2061 >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP
2062 ... print(reader.value_labels()) # doctest: +SKIP
2063 {'col_1': {3: 'x'}}
2064 >>> pd.read_stata(path) # doctest: +SKIP
2065 index col_1 col_2
2066 0 0 1 2
2067 1 1 x 4
2068 """
2069 if not self._value_labels_read:
2070 self._read_value_labels()
2072 return self._value_label_dict
2075@Appender(_read_stata_doc)
2076def read_stata(
2077 filepath_or_buffer: FilePath | ReadBuffer[bytes],
2078 *,
2079 convert_dates: bool = True,
2080 convert_categoricals: bool = True,
2081 index_col: str | None = None,
2082 convert_missing: bool = False,
2083 preserve_dtypes: bool = True,
2084 columns: Sequence[str] | None = None,
2085 order_categoricals: bool = True,
2086 chunksize: int | None = None,
2087 iterator: bool = False,
2088 compression: CompressionOptions = "infer",
2089 storage_options: StorageOptions | None = None,
2090) -> DataFrame | StataReader:
2091 reader = StataReader(
2092 filepath_or_buffer,
2093 convert_dates=convert_dates,
2094 convert_categoricals=convert_categoricals,
2095 index_col=index_col,
2096 convert_missing=convert_missing,
2097 preserve_dtypes=preserve_dtypes,
2098 columns=columns,
2099 order_categoricals=order_categoricals,
2100 chunksize=chunksize,
2101 storage_options=storage_options,
2102 compression=compression,
2103 )
2105 if iterator or chunksize:
2106 return reader
2108 with reader:
2109 return reader.read()
2112def _set_endianness(endianness: str) -> str:
2113 if endianness.lower() in ["<", "little"]:
2114 return "<"
2115 elif endianness.lower() in [">", "big"]:
2116 return ">"
2117 else: # pragma : no cover
2118 raise ValueError(f"Endianness {endianness} not understood")
2121def _pad_bytes(name: AnyStr, length: int) -> AnyStr:
2122 """
2123 Take a char string and pads it with null bytes until it's length chars.
2124 """
2125 if isinstance(name, bytes):
2126 return name + b"\x00" * (length - len(name))
2127 return name + "\x00" * (length - len(name))
2130def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
2131 """
2132 Convert from one of the stata date formats to a type in TYPE_MAP.
2133 """
2134 if fmt in [
2135 "tc",
2136 "%tc",
2137 "td",
2138 "%td",
2139 "tw",
2140 "%tw",
2141 "tm",
2142 "%tm",
2143 "tq",
2144 "%tq",
2145 "th",
2146 "%th",
2147 "ty",
2148 "%ty",
2149 ]:
2150 return np.dtype(np.float64) # Stata expects doubles for SIFs
2151 else:
2152 raise NotImplementedError(f"Format {fmt} not implemented")
2155def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict:
2156 new_dict = {}
2157 for key in convert_dates:
2158 if not convert_dates[key].startswith("%"): # make sure proper fmts
2159 convert_dates[key] = "%" + convert_dates[key]
2160 if key in varlist:
2161 new_dict.update({varlist.index(key): convert_dates[key]})
2162 else:
2163 if not isinstance(key, int):
2164 raise ValueError("convert_dates key must be a column or an integer")
2165 new_dict.update({key: convert_dates[key]})
2166 return new_dict
2169def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int:
2170 """
2171 Convert dtype types to stata types. Returns the byte of the given ordinal.
2172 See TYPE_MAP and comments for an explanation. This is also explained in
2173 the dta spec.
2174 1 - 244 are strings of this length
2175 Pandas Stata
2176 251 - for int8 byte
2177 252 - for int16 int
2178 253 - for int32 long
2179 254 - for float32 float
2180 255 - for double double
2182 If there are dates to convert, then dtype will already have the correct
2183 type inserted.
2184 """
2185 # TODO: expand to handle datetime to integer conversion
2186 if dtype.type is np.object_: # try to coerce it to the biggest string
2187 # not memory efficient, what else could we
2188 # do?
2189 itemsize = max_len_string_array(ensure_object(column._values))
2190 return max(itemsize, 1)
2191 elif dtype.type is np.float64:
2192 return 255
2193 elif dtype.type is np.float32:
2194 return 254
2195 elif dtype.type is np.int32:
2196 return 253
2197 elif dtype.type is np.int16:
2198 return 252
2199 elif dtype.type is np.int8:
2200 return 251
2201 else: # pragma : no cover
2202 raise NotImplementedError(f"Data type {dtype} not supported.")
2205def _dtype_to_default_stata_fmt(
2206 dtype, column: Series, dta_version: int = 114, force_strl: bool = False
2207) -> str:
2208 """
2209 Map numpy dtype to stata's default format for this type. Not terribly
2210 important since users can change this in Stata. Semantics are
2212 object -> "%DDs" where DD is the length of the string. If not a string,
2213 raise ValueError
2214 float64 -> "%10.0g"
2215 float32 -> "%9.0g"
2216 int64 -> "%9.0g"
2217 int32 -> "%12.0g"
2218 int16 -> "%8.0g"
2219 int8 -> "%8.0g"
2220 strl -> "%9s"
2221 """
2222 # TODO: Refactor to combine type with format
2223 # TODO: expand this to handle a default datetime format?
2224 if dta_version < 117:
2225 max_str_len = 244
2226 else:
2227 max_str_len = 2045
2228 if force_strl:
2229 return "%9s"
2230 if dtype.type is np.object_:
2231 itemsize = max_len_string_array(ensure_object(column._values))
2232 if itemsize > max_str_len:
2233 if dta_version >= 117:
2234 return "%9s"
2235 else:
2236 raise ValueError(excessive_string_length_error.format(column.name))
2237 return "%" + str(max(itemsize, 1)) + "s"
2238 elif dtype == np.float64:
2239 return "%10.0g"
2240 elif dtype == np.float32:
2241 return "%9.0g"
2242 elif dtype == np.int32:
2243 return "%12.0g"
2244 elif dtype in (np.int8, np.int16):
2245 return "%8.0g"
2246 else: # pragma : no cover
2247 raise NotImplementedError(f"Data type {dtype} not supported.")
2250@doc(
2251 storage_options=_shared_docs["storage_options"],
2252 compression_options=_shared_docs["compression_options"] % "fname",
2253)
2254class StataWriter(StataParser):
2255 """
2256 A class for writing Stata binary dta files
2258 Parameters
2259 ----------
2260 fname : path (string), buffer or path object
2261 string, path object (pathlib.Path or py._path.local.LocalPath) or
2262 object implementing a binary write() functions. If using a buffer
2263 then the buffer will not be automatically closed after the file
2264 is written.
2265 data : DataFrame
2266 Input to save
2267 convert_dates : dict
2268 Dictionary mapping columns containing datetime types to stata internal
2269 format to use when writing the dates. Options are 'tc', 'td', 'tm',
2270 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
2271 Datetime columns that do not have a conversion type specified will be
2272 converted to 'tc'. Raises NotImplementedError if a datetime column has
2273 timezone information
2274 write_index : bool
2275 Write the index to Stata dataset.
2276 byteorder : str
2277 Can be ">", "<", "little", or "big". default is `sys.byteorder`
2278 time_stamp : datetime
2279 A datetime to use as file creation date. Default is the current time
2280 data_label : str
2281 A label for the data set. Must be 80 characters or smaller.
2282 variable_labels : dict
2283 Dictionary containing columns as keys and variable labels as values.
2284 Each label must be 80 characters or smaller.
2285 {compression_options}
2287 .. versionchanged:: 1.4.0 Zstandard support.
2289 {storage_options}
2291 value_labels : dict of dicts
2292 Dictionary containing columns as keys and dictionaries of column value
2293 to labels as values. The combined length of all labels for a single
2294 variable must be 32,000 characters or smaller.
2296 .. versionadded:: 1.4.0
2298 Returns
2299 -------
2300 writer : StataWriter instance
2301 The StataWriter instance has a write_file method, which will
2302 write the file to the given `fname`.
2304 Raises
2305 ------
2306 NotImplementedError
2307 * If datetimes contain timezone information
2308 ValueError
2309 * Columns listed in convert_dates are neither datetime64[ns]
2310 or datetime
2311 * Column dtype is not representable in Stata
2312 * Column listed in convert_dates is not in DataFrame
2313 * Categorical label contains more than 32,000 characters
2315 Examples
2316 --------
2317 >>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b'])
2318 >>> writer = StataWriter('./data_file.dta', data)
2319 >>> writer.write_file()
2321 Directly write a zip file
2322 >>> compression = {{"method": "zip", "archive_name": "data_file.dta"}}
2323 >>> writer = StataWriter('./data_file.zip', data, compression=compression)
2324 >>> writer.write_file()
2326 Save a DataFrame with dates
2327 >>> from datetime import datetime
2328 >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
2329 >>> writer = StataWriter('./date_data_file.dta', data, {{'date' : 'tw'}})
2330 >>> writer.write_file()
2331 """
2333 _max_string_length = 244
2334 _encoding: Literal["latin-1", "utf-8"] = "latin-1"
2336 def __init__(
2337 self,
2338 fname: FilePath | WriteBuffer[bytes],
2339 data: DataFrame,
2340 convert_dates: dict[Hashable, str] | None = None,
2341 write_index: bool = True,
2342 byteorder: str | None = None,
2343 time_stamp: datetime | None = None,
2344 data_label: str | None = None,
2345 variable_labels: dict[Hashable, str] | None = None,
2346 compression: CompressionOptions = "infer",
2347 storage_options: StorageOptions | None = None,
2348 *,
2349 value_labels: dict[Hashable, dict[float, str]] | None = None,
2350 ) -> None:
2351 super().__init__()
2352 self.data = data
2353 self._convert_dates = {} if convert_dates is None else convert_dates
2354 self._write_index = write_index
2355 self._time_stamp = time_stamp
2356 self._data_label = data_label
2357 self._variable_labels = variable_labels
2358 self._non_cat_value_labels = value_labels
2359 self._value_labels: list[StataValueLabel] = []
2360 self._has_value_labels = np.array([], dtype=bool)
2361 self._compression = compression
2362 self._output_file: IO[bytes] | None = None
2363 self._converted_names: dict[Hashable, str] = {}
2364 # attach nobs, nvars, data, varlist, typlist
2365 self._prepare_pandas(data)
2366 self.storage_options = storage_options
2368 if byteorder is None:
2369 byteorder = sys.byteorder
2370 self._byteorder = _set_endianness(byteorder)
2371 self._fname = fname
2372 self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
2374 def _write(self, to_write: str) -> None:
2375 """
2376 Helper to call encode before writing to file for Python 3 compat.
2377 """
2378 self.handles.handle.write(to_write.encode(self._encoding))
2380 def _write_bytes(self, value: bytes) -> None:
2381 """
2382 Helper to assert file is open before writing.
2383 """
2384 self.handles.handle.write(value)
2386 def _prepare_non_cat_value_labels(
2387 self, data: DataFrame
2388 ) -> list[StataNonCatValueLabel]:
2389 """
2390 Check for value labels provided for non-categorical columns. Value
2391 labels
2392 """
2393 non_cat_value_labels: list[StataNonCatValueLabel] = []
2394 if self._non_cat_value_labels is None:
2395 return non_cat_value_labels
2397 for labname, labels in self._non_cat_value_labels.items():
2398 if labname in self._converted_names:
2399 colname = self._converted_names[labname]
2400 elif labname in data.columns:
2401 colname = str(labname)
2402 else:
2403 raise KeyError(
2404 f"Can't create value labels for {labname}, it wasn't "
2405 "found in the dataset."
2406 )
2408 if not is_numeric_dtype(data[colname].dtype):
2409 # Labels should not be passed explicitly for categorical
2410 # columns that will be converted to int
2411 raise ValueError(
2412 f"Can't create value labels for {labname}, value labels "
2413 "can only be applied to numeric columns."
2414 )
2415 svl = StataNonCatValueLabel(colname, labels, self._encoding)
2416 non_cat_value_labels.append(svl)
2417 return non_cat_value_labels
2419 def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
2420 """
2421 Check for categorical columns, retain categorical information for
2422 Stata file and convert categorical data to int
2423 """
2424 is_cat = [isinstance(dtype, CategoricalDtype) for dtype in data.dtypes]
2425 if not any(is_cat):
2426 return data
2428 self._has_value_labels |= np.array(is_cat)
2430 get_base_missing_value = StataMissingValue.get_base_missing_value
2431 data_formatted = []
2432 for col, col_is_cat in zip(data, is_cat):
2433 if col_is_cat:
2434 svl = StataValueLabel(data[col], encoding=self._encoding)
2435 self._value_labels.append(svl)
2436 dtype = data[col].cat.codes.dtype
2437 if dtype == np.int64:
2438 raise ValueError(
2439 "It is not possible to export "
2440 "int64-based categorical data to Stata."
2441 )
2442 values = data[col].cat.codes._values.copy()
2444 # Upcast if needed so that correct missing values can be set
2445 if values.max() >= get_base_missing_value(dtype):
2446 if dtype == np.int8:
2447 dtype = np.dtype(np.int16)
2448 elif dtype == np.int16:
2449 dtype = np.dtype(np.int32)
2450 else:
2451 dtype = np.dtype(np.float64)
2452 values = np.array(values, dtype=dtype)
2454 # Replace missing values with Stata missing value for type
2455 values[values == -1] = get_base_missing_value(dtype)
2456 data_formatted.append((col, values))
2457 else:
2458 data_formatted.append((col, data[col]))
2459 return DataFrame.from_dict(dict(data_formatted))
2461 def _replace_nans(self, data: DataFrame) -> DataFrame:
2462 # return data
2463 """
2464 Checks floating point data columns for nans, and replaces these with
2465 the generic Stata for missing value (.)
2466 """
2467 for c in data:
2468 dtype = data[c].dtype
2469 if dtype in (np.float32, np.float64):
2470 if dtype == np.float32:
2471 replacement = self.MISSING_VALUES["f"]
2472 else:
2473 replacement = self.MISSING_VALUES["d"]
2474 data[c] = data[c].fillna(replacement)
2476 return data
2478 def _update_strl_names(self) -> None:
2479 """No-op, forward compatibility"""
2481 def _validate_variable_name(self, name: str) -> str:
2482 """
2483 Validate variable names for Stata export.
2485 Parameters
2486 ----------
2487 name : str
2488 Variable name
2490 Returns
2491 -------
2492 str
2493 The validated name with invalid characters replaced with
2494 underscores.
2496 Notes
2497 -----
2498 Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9
2499 and _.
2500 """
2501 for c in name:
2502 if (
2503 (c < "A" or c > "Z")
2504 and (c < "a" or c > "z")
2505 and (c < "0" or c > "9")
2506 and c != "_"
2507 ):
2508 name = name.replace(c, "_")
2509 return name
2511 def _check_column_names(self, data: DataFrame) -> DataFrame:
2512 """
2513 Checks column names to ensure that they are valid Stata column names.
2514 This includes checks for:
2515 * Non-string names
2516 * Stata keywords
2517 * Variables that start with numbers
2518 * Variables with names that are too long
2520 When an illegal variable name is detected, it is converted, and if
2521 dates are exported, the variable name is propagated to the date
2522 conversion dictionary
2523 """
2524 converted_names: dict[Hashable, str] = {}
2525 columns = list(data.columns)
2526 original_columns = columns[:]
2528 duplicate_var_id = 0
2529 for j, name in enumerate(columns):
2530 orig_name = name
2531 if not isinstance(name, str):
2532 name = str(name)
2534 name = self._validate_variable_name(name)
2536 # Variable name must not be a reserved word
2537 if name in self.RESERVED_WORDS:
2538 name = "_" + name
2540 # Variable name may not start with a number
2541 if "0" <= name[0] <= "9":
2542 name = "_" + name
2544 name = name[: min(len(name), 32)]
2546 if not name == orig_name:
2547 # check for duplicates
2548 while columns.count(name) > 0:
2549 # prepend ascending number to avoid duplicates
2550 name = "_" + str(duplicate_var_id) + name
2551 name = name[: min(len(name), 32)]
2552 duplicate_var_id += 1
2553 converted_names[orig_name] = name
2555 columns[j] = name
2557 data.columns = Index(columns)
2559 # Check date conversion, and fix key if needed
2560 if self._convert_dates:
2561 for c, o in zip(columns, original_columns):
2562 if c != o:
2563 self._convert_dates[c] = self._convert_dates[o]
2564 del self._convert_dates[o]
2566 if converted_names:
2567 conversion_warning = []
2568 for orig_name, name in converted_names.items():
2569 msg = f"{orig_name} -> {name}"
2570 conversion_warning.append(msg)
2572 ws = invalid_name_doc.format("\n ".join(conversion_warning))
2573 warnings.warn(
2574 ws,
2575 InvalidColumnName,
2576 stacklevel=find_stack_level(),
2577 )
2579 self._converted_names = converted_names
2580 self._update_strl_names()
2582 return data
2584 def _set_formats_and_types(self, dtypes: Series) -> None:
2585 self.fmtlist: list[str] = []
2586 self.typlist: list[int] = []
2587 for col, dtype in dtypes.items():
2588 self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col]))
2589 self.typlist.append(_dtype_to_stata_type(dtype, self.data[col]))
2591 def _prepare_pandas(self, data: DataFrame) -> None:
2592 # NOTE: we might need a different API / class for pandas objects so
2593 # we can set different semantics - handle this with a PR to pandas.io
2595 data = data.copy()
2597 if self._write_index:
2598 temp = data.reset_index()
2599 if isinstance(temp, DataFrame):
2600 data = temp
2602 # Ensure column names are strings
2603 data = self._check_column_names(data)
2605 # Check columns for compatibility with stata, upcast if necessary
2606 # Raise if outside the supported range
2607 data = _cast_to_stata_types(data)
2609 # Replace NaNs with Stata missing values
2610 data = self._replace_nans(data)
2612 # Set all columns to initially unlabelled
2613 self._has_value_labels = np.repeat(False, data.shape[1])
2615 # Create value labels for non-categorical data
2616 non_cat_value_labels = self._prepare_non_cat_value_labels(data)
2618 non_cat_columns = [svl.labname for svl in non_cat_value_labels]
2619 has_non_cat_val_labels = data.columns.isin(non_cat_columns)
2620 self._has_value_labels |= has_non_cat_val_labels
2621 self._value_labels.extend(non_cat_value_labels)
2623 # Convert categoricals to int data, and strip labels
2624 data = self._prepare_categoricals(data)
2626 self.nobs, self.nvar = data.shape
2627 self.data = data
2628 self.varlist = data.columns.tolist()
2630 dtypes = data.dtypes
2632 # Ensure all date columns are converted
2633 for col in data:
2634 if col in self._convert_dates:
2635 continue
2636 if lib.is_np_dtype(data[col].dtype, "M"):
2637 self._convert_dates[col] = "tc"
2639 self._convert_dates = _maybe_convert_to_int_keys(
2640 self._convert_dates, self.varlist
2641 )
2642 for key in self._convert_dates:
2643 new_type = _convert_datetime_to_stata_type(self._convert_dates[key])
2644 dtypes.iloc[key] = np.dtype(new_type)
2646 # Verify object arrays are strings and encode to bytes
2647 self._encode_strings()
2649 self._set_formats_and_types(dtypes)
2651 # set the given format for the datetime cols
2652 if self._convert_dates is not None:
2653 for key in self._convert_dates:
2654 if isinstance(key, int):
2655 self.fmtlist[key] = self._convert_dates[key]
2657 def _encode_strings(self) -> None:
2658 """
2659 Encode strings in dta-specific encoding
2661 Do not encode columns marked for date conversion or for strL
2662 conversion. The strL converter independently handles conversion and
2663 also accepts empty string arrays.
2664 """
2665 convert_dates = self._convert_dates
2666 # _convert_strl is not available in dta 114
2667 convert_strl = getattr(self, "_convert_strl", [])
2668 for i, col in enumerate(self.data):
2669 # Skip columns marked for date conversion or strl conversion
2670 if i in convert_dates or col in convert_strl:
2671 continue
2672 column = self.data[col]
2673 dtype = column.dtype
2674 if dtype.type is np.object_:
2675 inferred_dtype = infer_dtype(column, skipna=True)
2676 if not ((inferred_dtype == "string") or len(column) == 0):
2677 col = column.name
2678 raise ValueError(
2679 f"""\
2680Column `{col}` cannot be exported.\n\nOnly string-like object arrays
2681containing all strings or a mix of strings and None can be exported.
2682Object arrays containing only null values are prohibited. Other object
2683types cannot be exported and must first be converted to one of the
2684supported types."""
2685 )
2686 encoded = self.data[col].str.encode(self._encoding)
2687 # If larger than _max_string_length do nothing
2688 if (
2689 max_len_string_array(ensure_object(encoded._values))
2690 <= self._max_string_length
2691 ):
2692 self.data[col] = encoded
2694 def write_file(self) -> None:
2695 """
2696 Export DataFrame object to Stata dta format.
2698 Examples
2699 --------
2700 >>> df = pd.DataFrame({"fully_labelled": [1, 2, 3, 3, 1],
2701 ... "partially_labelled": [1.0, 2.0, np.nan, 9.0, np.nan],
2702 ... "Y": [7, 7, 9, 8, 10],
2703 ... "Z": pd.Categorical(["j", "k", "l", "k", "j"]),
2704 ... })
2705 >>> path = "/My_path/filename.dta"
2706 >>> labels = {"fully_labelled": {1: "one", 2: "two", 3: "three"},
2707 ... "partially_labelled": {1.0: "one", 2.0: "two"},
2708 ... }
2709 >>> writer = pd.io.stata.StataWriter(path,
2710 ... df,
2711 ... value_labels=labels) # doctest: +SKIP
2712 >>> writer.write_file() # doctest: +SKIP
2713 >>> df = pd.read_stata(path) # doctest: +SKIP
2714 >>> df # doctest: +SKIP
2715 index fully_labelled partially_labeled Y Z
2716 0 0 one one 7 j
2717 1 1 two two 7 k
2718 2 2 three NaN 9 l
2719 3 3 three 9.0 8 k
2720 4 4 one NaN 10 j
2721 """
2722 with get_handle(
2723 self._fname,
2724 "wb",
2725 compression=self._compression,
2726 is_text=False,
2727 storage_options=self.storage_options,
2728 ) as self.handles:
2729 if self.handles.compression["method"] is not None:
2730 # ZipFile creates a file (with the same name) for each write call.
2731 # Write it first into a buffer and then write the buffer to the ZipFile.
2732 self._output_file, self.handles.handle = self.handles.handle, BytesIO()
2733 self.handles.created_handles.append(self.handles.handle)
2735 try:
2736 self._write_header(
2737 data_label=self._data_label, time_stamp=self._time_stamp
2738 )
2739 self._write_map()
2740 self._write_variable_types()
2741 self._write_varnames()
2742 self._write_sortlist()
2743 self._write_formats()
2744 self._write_value_label_names()
2745 self._write_variable_labels()
2746 self._write_expansion_fields()
2747 self._write_characteristics()
2748 records = self._prepare_data()
2749 self._write_data(records)
2750 self._write_strls()
2751 self._write_value_labels()
2752 self._write_file_close_tag()
2753 self._write_map()
2754 self._close()
2755 except Exception as exc:
2756 self.handles.close()
2757 if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile(
2758 self._fname
2759 ):
2760 try:
2761 os.unlink(self._fname)
2762 except OSError:
2763 warnings.warn(
2764 f"This save was not successful but {self._fname} could not "
2765 "be deleted. This file is not valid.",
2766 ResourceWarning,
2767 stacklevel=find_stack_level(),
2768 )
2769 raise exc
2771 def _close(self) -> None:
2772 """
2773 Close the file if it was created by the writer.
2775 If a buffer or file-like object was passed in, for example a GzipFile,
2776 then leave this file open for the caller to close.
2777 """
2778 # write compression
2779 if self._output_file is not None:
2780 assert isinstance(self.handles.handle, BytesIO)
2781 bio, self.handles.handle = self.handles.handle, self._output_file
2782 self.handles.handle.write(bio.getvalue())
2784 def _write_map(self) -> None:
2785 """No-op, future compatibility"""
2787 def _write_file_close_tag(self) -> None:
2788 """No-op, future compatibility"""
2790 def _write_characteristics(self) -> None:
2791 """No-op, future compatibility"""
2793 def _write_strls(self) -> None:
2794 """No-op, future compatibility"""
2796 def _write_expansion_fields(self) -> None:
2797 """Write 5 zeros for expansion fields"""
2798 self._write(_pad_bytes("", 5))
2800 def _write_value_labels(self) -> None:
2801 for vl in self._value_labels:
2802 self._write_bytes(vl.generate_value_label(self._byteorder))
2804 def _write_header(
2805 self,
2806 data_label: str | None = None,
2807 time_stamp: datetime | None = None,
2808 ) -> None:
2809 byteorder = self._byteorder
2810 # ds_format - just use 114
2811 self._write_bytes(struct.pack("b", 114))
2812 # byteorder
2813 self._write(byteorder == ">" and "\x01" or "\x02")
2814 # filetype
2815 self._write("\x01")
2816 # unused
2817 self._write("\x00")
2818 # number of vars, 2 bytes
2819 self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2])
2820 # number of obs, 4 bytes
2821 self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4])
2822 # data label 81 bytes, char, null terminated
2823 if data_label is None:
2824 self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80)))
2825 else:
2826 self._write_bytes(
2827 self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))
2828 )
2829 # time stamp, 18 bytes, char, null terminated
2830 # format dd Mon yyyy hh:mm
2831 if time_stamp is None:
2832 time_stamp = datetime.now()
2833 elif not isinstance(time_stamp, datetime):
2834 raise ValueError("time_stamp should be datetime type")
2835 # GH #13856
2836 # Avoid locale-specific month conversion
2837 months = [
2838 "Jan",
2839 "Feb",
2840 "Mar",
2841 "Apr",
2842 "May",
2843 "Jun",
2844 "Jul",
2845 "Aug",
2846 "Sep",
2847 "Oct",
2848 "Nov",
2849 "Dec",
2850 ]
2851 month_lookup = {i + 1: month for i, month in enumerate(months)}
2852 ts = (
2853 time_stamp.strftime("%d ")
2854 + month_lookup[time_stamp.month]
2855 + time_stamp.strftime(" %Y %H:%M")
2856 )
2857 self._write_bytes(self._null_terminate_bytes(ts))
2859 def _write_variable_types(self) -> None:
2860 for typ in self.typlist:
2861 self._write_bytes(struct.pack("B", typ))
2863 def _write_varnames(self) -> None:
2864 # varlist names are checked by _check_column_names
2865 # varlist, requires null terminated
2866 for name in self.varlist:
2867 name = self._null_terminate_str(name)
2868 name = _pad_bytes(name[:32], 33)
2869 self._write(name)
2871 def _write_sortlist(self) -> None:
2872 # srtlist, 2*(nvar+1), int array, encoded by byteorder
2873 srtlist = _pad_bytes("", 2 * (self.nvar + 1))
2874 self._write(srtlist)
2876 def _write_formats(self) -> None:
2877 # fmtlist, 49*nvar, char array
2878 for fmt in self.fmtlist:
2879 self._write(_pad_bytes(fmt, 49))
2881 def _write_value_label_names(self) -> None:
2882 # lbllist, 33*nvar, char array
2883 for i in range(self.nvar):
2884 # Use variable name when categorical
2885 if self._has_value_labels[i]:
2886 name = self.varlist[i]
2887 name = self._null_terminate_str(name)
2888 name = _pad_bytes(name[:32], 33)
2889 self._write(name)
2890 else: # Default is empty label
2891 self._write(_pad_bytes("", 33))
2893 def _write_variable_labels(self) -> None:
2894 # Missing labels are 80 blank characters plus null termination
2895 blank = _pad_bytes("", 81)
2897 if self._variable_labels is None:
2898 for i in range(self.nvar):
2899 self._write(blank)
2900 return
2902 for col in self.data:
2903 if col in self._variable_labels:
2904 label = self._variable_labels[col]
2905 if len(label) > 80:
2906 raise ValueError("Variable labels must be 80 characters or fewer")
2907 is_latin1 = all(ord(c) < 256 for c in label)
2908 if not is_latin1:
2909 raise ValueError(
2910 "Variable labels must contain only characters that "
2911 "can be encoded in Latin-1"
2912 )
2913 self._write(_pad_bytes(label, 81))
2914 else:
2915 self._write(blank)
2917 def _convert_strls(self, data: DataFrame) -> DataFrame:
2918 """No-op, future compatibility"""
2919 return data
2921 def _prepare_data(self) -> np.rec.recarray:
2922 data = self.data
2923 typlist = self.typlist
2924 convert_dates = self._convert_dates
2925 # 1. Convert dates
2926 if self._convert_dates is not None:
2927 for i, col in enumerate(data):
2928 if i in convert_dates:
2929 data[col] = _datetime_to_stata_elapsed_vec(
2930 data[col], self.fmtlist[i]
2931 )
2932 # 2. Convert strls
2933 data = self._convert_strls(data)
2935 # 3. Convert bad string data to '' and pad to correct length
2936 dtypes = {}
2937 native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
2938 for i, col in enumerate(data):
2939 typ = typlist[i]
2940 if typ <= self._max_string_length:
2941 with warnings.catch_warnings():
2942 warnings.filterwarnings(
2943 "ignore",
2944 "Downcasting object dtype arrays",
2945 category=FutureWarning,
2946 )
2947 dc = data[col].fillna("")
2948 data[col] = dc.apply(_pad_bytes, args=(typ,))
2949 stype = f"S{typ}"
2950 dtypes[col] = stype
2951 data[col] = data[col].astype(stype)
2952 else:
2953 dtype = data[col].dtype
2954 if not native_byteorder:
2955 dtype = dtype.newbyteorder(self._byteorder)
2956 dtypes[col] = dtype
2958 return data.to_records(index=False, column_dtypes=dtypes)
2960 def _write_data(self, records: np.rec.recarray) -> None:
2961 self._write_bytes(records.tobytes())
2963 @staticmethod
2964 def _null_terminate_str(s: str) -> str:
2965 s += "\x00"
2966 return s
2968 def _null_terminate_bytes(self, s: str) -> bytes:
2969 return self._null_terminate_str(s).encode(self._encoding)
2972def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int:
2973 """
2974 Converts dtype types to stata types. Returns the byte of the given ordinal.
2975 See TYPE_MAP and comments for an explanation. This is also explained in
2976 the dta spec.
2977 1 - 2045 are strings of this length
2978 Pandas Stata
2979 32768 - for object strL
2980 65526 - for int8 byte
2981 65527 - for int16 int
2982 65528 - for int32 long
2983 65529 - for float32 float
2984 65530 - for double double
2986 If there are dates to convert, then dtype will already have the correct
2987 type inserted.
2988 """
2989 # TODO: expand to handle datetime to integer conversion
2990 if force_strl:
2991 return 32768
2992 if dtype.type is np.object_: # try to coerce it to the biggest string
2993 # not memory efficient, what else could we
2994 # do?
2995 itemsize = max_len_string_array(ensure_object(column._values))
2996 itemsize = max(itemsize, 1)
2997 if itemsize <= 2045:
2998 return itemsize
2999 return 32768
3000 elif dtype.type is np.float64:
3001 return 65526
3002 elif dtype.type is np.float32:
3003 return 65527
3004 elif dtype.type is np.int32:
3005 return 65528
3006 elif dtype.type is np.int16:
3007 return 65529
3008 elif dtype.type is np.int8:
3009 return 65530
3010 else: # pragma : no cover
3011 raise NotImplementedError(f"Data type {dtype} not supported.")
3014def _pad_bytes_new(name: str | bytes, length: int) -> bytes:
3015 """
3016 Takes a bytes instance and pads it with null bytes until it's length chars.
3017 """
3018 if isinstance(name, str):
3019 name = bytes(name, "utf-8")
3020 return name + b"\x00" * (length - len(name))
3023class StataStrLWriter:
3024 """
3025 Converter for Stata StrLs
3027 Stata StrLs map 8 byte values to strings which are stored using a
3028 dictionary-like format where strings are keyed to two values.
3030 Parameters
3031 ----------
3032 df : DataFrame
3033 DataFrame to convert
3034 columns : Sequence[str]
3035 List of columns names to convert to StrL
3036 version : int, optional
3037 dta version. Currently supports 117, 118 and 119
3038 byteorder : str, optional
3039 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3041 Notes
3042 -----
3043 Supports creation of the StrL block of a dta file for dta versions
3044 117, 118 and 119. These differ in how the GSO is stored. 118 and
3045 119 store the GSO lookup value as a uint32 and a uint64, while 117
3046 uses two uint32s. 118 and 119 also encode all strings as unicode
3047 which is required by the format. 117 uses 'latin-1' a fixed width
3048 encoding that extends the 7-bit ascii table with an additional 128
3049 characters.
3050 """
3052 def __init__(
3053 self,
3054 df: DataFrame,
3055 columns: Sequence[str],
3056 version: int = 117,
3057 byteorder: str | None = None,
3058 ) -> None:
3059 if version not in (117, 118, 119):
3060 raise ValueError("Only dta versions 117, 118 and 119 supported")
3061 self._dta_ver = version
3063 self.df = df
3064 self.columns = columns
3065 self._gso_table = {"": (0, 0)}
3066 if byteorder is None:
3067 byteorder = sys.byteorder
3068 self._byteorder = _set_endianness(byteorder)
3070 gso_v_type = "I" # uint32
3071 gso_o_type = "Q" # uint64
3072 self._encoding = "utf-8"
3073 if version == 117:
3074 o_size = 4
3075 gso_o_type = "I" # 117 used uint32
3076 self._encoding = "latin-1"
3077 elif version == 118:
3078 o_size = 6
3079 else: # version == 119
3080 o_size = 5
3081 self._o_offet = 2 ** (8 * (8 - o_size))
3082 self._gso_o_type = gso_o_type
3083 self._gso_v_type = gso_v_type
3085 def _convert_key(self, key: tuple[int, int]) -> int:
3086 v, o = key
3087 return v + self._o_offet * o
3089 def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]:
3090 """
3091 Generates the GSO lookup table for the DataFrame
3093 Returns
3094 -------
3095 gso_table : dict
3096 Ordered dictionary using the string found as keys
3097 and their lookup position (v,o) as values
3098 gso_df : DataFrame
3099 DataFrame where strl columns have been converted to
3100 (v,o) values
3102 Notes
3103 -----
3104 Modifies the DataFrame in-place.
3106 The DataFrame returned encodes the (v,o) values as uint64s. The
3107 encoding depends on the dta version, and can be expressed as
3109 enc = v + o * 2 ** (o_size * 8)
3111 so that v is stored in the lower bits and o is in the upper
3112 bits. o_size is
3114 * 117: 4
3115 * 118: 6
3116 * 119: 5
3117 """
3118 gso_table = self._gso_table
3119 gso_df = self.df
3120 columns = list(gso_df.columns)
3121 selected = gso_df[self.columns]
3122 col_index = [(col, columns.index(col)) for col in self.columns]
3123 keys = np.empty(selected.shape, dtype=np.uint64)
3124 for o, (idx, row) in enumerate(selected.iterrows()):
3125 for j, (col, v) in enumerate(col_index):
3126 val = row[col]
3127 # Allow columns with mixed str and None (GH 23633)
3128 val = "" if val is None else val
3129 key = gso_table.get(val, None)
3130 if key is None:
3131 # Stata prefers human numbers
3132 key = (v + 1, o + 1)
3133 gso_table[val] = key
3134 keys[o, j] = self._convert_key(key)
3135 for i, col in enumerate(self.columns):
3136 gso_df[col] = keys[:, i]
3138 return gso_table, gso_df
3140 def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes:
3141 """
3142 Generates the binary blob of GSOs that is written to the dta file.
3144 Parameters
3145 ----------
3146 gso_table : dict
3147 Ordered dictionary (str, vo)
3149 Returns
3150 -------
3151 gso : bytes
3152 Binary content of dta file to be placed between strl tags
3154 Notes
3155 -----
3156 Output format depends on dta version. 117 uses two uint32s to
3157 express v and o while 118+ uses a uint32 for v and a uint64 for o.
3158 """
3159 # Format information
3160 # Length includes null term
3161 # 117
3162 # GSOvvvvooootllllxxxxxxxxxxxxxxx...x
3163 # 3 u4 u4 u1 u4 string + null term
3164 #
3165 # 118, 119
3166 # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
3167 # 3 u4 u8 u1 u4 string + null term
3169 bio = BytesIO()
3170 gso = bytes("GSO", "ascii")
3171 gso_type = struct.pack(self._byteorder + "B", 130)
3172 null = struct.pack(self._byteorder + "B", 0)
3173 v_type = self._byteorder + self._gso_v_type
3174 o_type = self._byteorder + self._gso_o_type
3175 len_type = self._byteorder + "I"
3176 for strl, vo in gso_table.items():
3177 if vo == (0, 0):
3178 continue
3179 v, o = vo
3181 # GSO
3182 bio.write(gso)
3184 # vvvv
3185 bio.write(struct.pack(v_type, v))
3187 # oooo / oooooooo
3188 bio.write(struct.pack(o_type, o))
3190 # t
3191 bio.write(gso_type)
3193 # llll
3194 utf8_string = bytes(strl, "utf-8")
3195 bio.write(struct.pack(len_type, len(utf8_string) + 1))
3197 # xxx...xxx
3198 bio.write(utf8_string)
3199 bio.write(null)
3201 return bio.getvalue()
3204class StataWriter117(StataWriter):
3205 """
3206 A class for writing Stata binary dta files in Stata 13 format (117)
3208 Parameters
3209 ----------
3210 fname : path (string), buffer or path object
3211 string, path object (pathlib.Path or py._path.local.LocalPath) or
3212 object implementing a binary write() functions. If using a buffer
3213 then the buffer will not be automatically closed after the file
3214 is written.
3215 data : DataFrame
3216 Input to save
3217 convert_dates : dict
3218 Dictionary mapping columns containing datetime types to stata internal
3219 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3220 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3221 Datetime columns that do not have a conversion type specified will be
3222 converted to 'tc'. Raises NotImplementedError if a datetime column has
3223 timezone information
3224 write_index : bool
3225 Write the index to Stata dataset.
3226 byteorder : str
3227 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3228 time_stamp : datetime
3229 A datetime to use as file creation date. Default is the current time
3230 data_label : str
3231 A label for the data set. Must be 80 characters or smaller.
3232 variable_labels : dict
3233 Dictionary containing columns as keys and variable labels as values.
3234 Each label must be 80 characters or smaller.
3235 convert_strl : list
3236 List of columns names to convert to Stata StrL format. Columns with
3237 more than 2045 characters are automatically written as StrL.
3238 Smaller columns can be converted by including the column name. Using
3239 StrLs can reduce output file size when strings are longer than 8
3240 characters, and either frequently repeated or sparse.
3241 {compression_options}
3243 .. versionchanged:: 1.4.0 Zstandard support.
3245 value_labels : dict of dicts
3246 Dictionary containing columns as keys and dictionaries of column value
3247 to labels as values. The combined length of all labels for a single
3248 variable must be 32,000 characters or smaller.
3250 .. versionadded:: 1.4.0
3252 Returns
3253 -------
3254 writer : StataWriter117 instance
3255 The StataWriter117 instance has a write_file method, which will
3256 write the file to the given `fname`.
3258 Raises
3259 ------
3260 NotImplementedError
3261 * If datetimes contain timezone information
3262 ValueError
3263 * Columns listed in convert_dates are neither datetime64[ns]
3264 or datetime
3265 * Column dtype is not representable in Stata
3266 * Column listed in convert_dates is not in DataFrame
3267 * Categorical label contains more than 32,000 characters
3269 Examples
3270 --------
3271 >>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
3272 >>> writer = pd.io.stata.StataWriter117('./data_file.dta', data)
3273 >>> writer.write_file()
3275 Directly write a zip file
3276 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3277 >>> writer = pd.io.stata.StataWriter117(
3278 ... './data_file.zip', data, compression=compression
3279 ... )
3280 >>> writer.write_file()
3282 Or with long strings stored in strl format
3283 >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
3284 ... columns=['strls'])
3285 >>> writer = pd.io.stata.StataWriter117(
3286 ... './data_file_with_long_strings.dta', data, convert_strl=['strls'])
3287 >>> writer.write_file()
3288 """
3290 _max_string_length = 2045
3291 _dta_version = 117
3293 def __init__(
3294 self,
3295 fname: FilePath | WriteBuffer[bytes],
3296 data: DataFrame,
3297 convert_dates: dict[Hashable, str] | None = None,
3298 write_index: bool = True,
3299 byteorder: str | None = None,
3300 time_stamp: datetime | None = None,
3301 data_label: str | None = None,
3302 variable_labels: dict[Hashable, str] | None = None,
3303 convert_strl: Sequence[Hashable] | None = None,
3304 compression: CompressionOptions = "infer",
3305 storage_options: StorageOptions | None = None,
3306 *,
3307 value_labels: dict[Hashable, dict[float, str]] | None = None,
3308 ) -> None:
3309 # Copy to new list since convert_strl might be modified later
3310 self._convert_strl: list[Hashable] = []
3311 if convert_strl is not None:
3312 self._convert_strl.extend(convert_strl)
3314 super().__init__(
3315 fname,
3316 data,
3317 convert_dates,
3318 write_index,
3319 byteorder=byteorder,
3320 time_stamp=time_stamp,
3321 data_label=data_label,
3322 variable_labels=variable_labels,
3323 value_labels=value_labels,
3324 compression=compression,
3325 storage_options=storage_options,
3326 )
3327 self._map: dict[str, int] = {}
3328 self._strl_blob = b""
3330 @staticmethod
3331 def _tag(val: str | bytes, tag: str) -> bytes:
3332 """Surround val with <tag></tag>"""
3333 if isinstance(val, str):
3334 val = bytes(val, "utf-8")
3335 return bytes("<" + tag + ">", "utf-8") + val + bytes("</" + tag + ">", "utf-8")
3337 def _update_map(self, tag: str) -> None:
3338 """Update map location for tag with file position"""
3339 assert self.handles.handle is not None
3340 self._map[tag] = self.handles.handle.tell()
3342 def _write_header(
3343 self,
3344 data_label: str | None = None,
3345 time_stamp: datetime | None = None,
3346 ) -> None:
3347 """Write the file header"""
3348 byteorder = self._byteorder
3349 self._write_bytes(bytes("<stata_dta>", "utf-8"))
3350 bio = BytesIO()
3351 # ds_format - 117
3352 bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release"))
3353 # byteorder
3354 bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder"))
3355 # number of vars, 2 bytes in 117 and 118, 4 byte in 119
3356 nvar_type = "H" if self._dta_version <= 118 else "I"
3357 bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K"))
3358 # 117 uses 4 bytes, 118 uses 8
3359 nobs_size = "I" if self._dta_version == 117 else "Q"
3360 bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N"))
3361 # data label 81 bytes, char, null terminated
3362 label = data_label[:80] if data_label is not None else ""
3363 encoded_label = label.encode(self._encoding)
3364 label_size = "B" if self._dta_version == 117 else "H"
3365 label_len = struct.pack(byteorder + label_size, len(encoded_label))
3366 encoded_label = label_len + encoded_label
3367 bio.write(self._tag(encoded_label, "label"))
3368 # time stamp, 18 bytes, char, null terminated
3369 # format dd Mon yyyy hh:mm
3370 if time_stamp is None:
3371 time_stamp = datetime.now()
3372 elif not isinstance(time_stamp, datetime):
3373 raise ValueError("time_stamp should be datetime type")
3374 # Avoid locale-specific month conversion
3375 months = [
3376 "Jan",
3377 "Feb",
3378 "Mar",
3379 "Apr",
3380 "May",
3381 "Jun",
3382 "Jul",
3383 "Aug",
3384 "Sep",
3385 "Oct",
3386 "Nov",
3387 "Dec",
3388 ]
3389 month_lookup = {i + 1: month for i, month in enumerate(months)}
3390 ts = (
3391 time_stamp.strftime("%d ")
3392 + month_lookup[time_stamp.month]
3393 + time_stamp.strftime(" %Y %H:%M")
3394 )
3395 # '\x11' added due to inspection of Stata file
3396 stata_ts = b"\x11" + bytes(ts, "utf-8")
3397 bio.write(self._tag(stata_ts, "timestamp"))
3398 self._write_bytes(self._tag(bio.getvalue(), "header"))
3400 def _write_map(self) -> None:
3401 """
3402 Called twice during file write. The first populates the values in
3403 the map with 0s. The second call writes the final map locations when
3404 all blocks have been written.
3405 """
3406 if not self._map:
3407 self._map = {
3408 "stata_data": 0,
3409 "map": self.handles.handle.tell(),
3410 "variable_types": 0,
3411 "varnames": 0,
3412 "sortlist": 0,
3413 "formats": 0,
3414 "value_label_names": 0,
3415 "variable_labels": 0,
3416 "characteristics": 0,
3417 "data": 0,
3418 "strls": 0,
3419 "value_labels": 0,
3420 "stata_data_close": 0,
3421 "end-of-file": 0,
3422 }
3423 # Move to start of map
3424 self.handles.handle.seek(self._map["map"])
3425 bio = BytesIO()
3426 for val in self._map.values():
3427 bio.write(struct.pack(self._byteorder + "Q", val))
3428 self._write_bytes(self._tag(bio.getvalue(), "map"))
3430 def _write_variable_types(self) -> None:
3431 self._update_map("variable_types")
3432 bio = BytesIO()
3433 for typ in self.typlist:
3434 bio.write(struct.pack(self._byteorder + "H", typ))
3435 self._write_bytes(self._tag(bio.getvalue(), "variable_types"))
3437 def _write_varnames(self) -> None:
3438 self._update_map("varnames")
3439 bio = BytesIO()
3440 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3441 vn_len = 32 if self._dta_version == 117 else 128
3442 for name in self.varlist:
3443 name = self._null_terminate_str(name)
3444 name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1)
3445 bio.write(name)
3446 self._write_bytes(self._tag(bio.getvalue(), "varnames"))
3448 def _write_sortlist(self) -> None:
3449 self._update_map("sortlist")
3450 sort_size = 2 if self._dta_version < 119 else 4
3451 self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist"))
3453 def _write_formats(self) -> None:
3454 self._update_map("formats")
3455 bio = BytesIO()
3456 fmt_len = 49 if self._dta_version == 117 else 57
3457 for fmt in self.fmtlist:
3458 bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len))
3459 self._write_bytes(self._tag(bio.getvalue(), "formats"))
3461 def _write_value_label_names(self) -> None:
3462 self._update_map("value_label_names")
3463 bio = BytesIO()
3464 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3465 vl_len = 32 if self._dta_version == 117 else 128
3466 for i in range(self.nvar):
3467 # Use variable name when categorical
3468 name = "" # default name
3469 if self._has_value_labels[i]:
3470 name = self.varlist[i]
3471 name = self._null_terminate_str(name)
3472 encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
3473 bio.write(encoded_name)
3474 self._write_bytes(self._tag(bio.getvalue(), "value_label_names"))
3476 def _write_variable_labels(self) -> None:
3477 # Missing labels are 80 blank characters plus null termination
3478 self._update_map("variable_labels")
3479 bio = BytesIO()
3480 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3481 vl_len = 80 if self._dta_version == 117 else 320
3482 blank = _pad_bytes_new("", vl_len + 1)
3484 if self._variable_labels is None:
3485 for _ in range(self.nvar):
3486 bio.write(blank)
3487 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3488 return
3490 for col in self.data:
3491 if col in self._variable_labels:
3492 label = self._variable_labels[col]
3493 if len(label) > 80:
3494 raise ValueError("Variable labels must be 80 characters or fewer")
3495 try:
3496 encoded = label.encode(self._encoding)
3497 except UnicodeEncodeError as err:
3498 raise ValueError(
3499 "Variable labels must contain only characters that "
3500 f"can be encoded in {self._encoding}"
3501 ) from err
3503 bio.write(_pad_bytes_new(encoded, vl_len + 1))
3504 else:
3505 bio.write(blank)
3506 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3508 def _write_characteristics(self) -> None:
3509 self._update_map("characteristics")
3510 self._write_bytes(self._tag(b"", "characteristics"))
3512 def _write_data(self, records) -> None:
3513 self._update_map("data")
3514 self._write_bytes(b"<data>")
3515 self._write_bytes(records.tobytes())
3516 self._write_bytes(b"</data>")
3518 def _write_strls(self) -> None:
3519 self._update_map("strls")
3520 self._write_bytes(self._tag(self._strl_blob, "strls"))
3522 def _write_expansion_fields(self) -> None:
3523 """No-op in dta 117+"""
3525 def _write_value_labels(self) -> None:
3526 self._update_map("value_labels")
3527 bio = BytesIO()
3528 for vl in self._value_labels:
3529 lab = vl.generate_value_label(self._byteorder)
3530 lab = self._tag(lab, "lbl")
3531 bio.write(lab)
3532 self._write_bytes(self._tag(bio.getvalue(), "value_labels"))
3534 def _write_file_close_tag(self) -> None:
3535 self._update_map("stata_data_close")
3536 self._write_bytes(bytes("</stata_dta>", "utf-8"))
3537 self._update_map("end-of-file")
3539 def _update_strl_names(self) -> None:
3540 """
3541 Update column names for conversion to strl if they might have been
3542 changed to comply with Stata naming rules
3543 """
3544 # Update convert_strl if names changed
3545 for orig, new in self._converted_names.items():
3546 if orig in self._convert_strl:
3547 idx = self._convert_strl.index(orig)
3548 self._convert_strl[idx] = new
3550 def _convert_strls(self, data: DataFrame) -> DataFrame:
3551 """
3552 Convert columns to StrLs if either very large or in the
3553 convert_strl variable
3554 """
3555 convert_cols = [
3556 col
3557 for i, col in enumerate(data)
3558 if self.typlist[i] == 32768 or col in self._convert_strl
3559 ]
3561 if convert_cols:
3562 ssw = StataStrLWriter(data, convert_cols, version=self._dta_version)
3563 tab, new_data = ssw.generate_table()
3564 data = new_data
3565 self._strl_blob = ssw.generate_blob(tab)
3566 return data
3568 def _set_formats_and_types(self, dtypes: Series) -> None:
3569 self.typlist = []
3570 self.fmtlist = []
3571 for col, dtype in dtypes.items():
3572 force_strl = col in self._convert_strl
3573 fmt = _dtype_to_default_stata_fmt(
3574 dtype,
3575 self.data[col],
3576 dta_version=self._dta_version,
3577 force_strl=force_strl,
3578 )
3579 self.fmtlist.append(fmt)
3580 self.typlist.append(
3581 _dtype_to_stata_type_117(dtype, self.data[col], force_strl)
3582 )
3585class StataWriterUTF8(StataWriter117):
3586 """
3587 Stata binary dta file writing in Stata 15 (118) and 16 (119) formats
3589 DTA 118 and 119 format files support unicode string data (both fixed
3590 and strL) format. Unicode is also supported in value labels, variable
3591 labels and the dataset label. Format 119 is automatically used if the
3592 file contains more than 32,767 variables.
3594 Parameters
3595 ----------
3596 fname : path (string), buffer or path object
3597 string, path object (pathlib.Path or py._path.local.LocalPath) or
3598 object implementing a binary write() functions. If using a buffer
3599 then the buffer will not be automatically closed after the file
3600 is written.
3601 data : DataFrame
3602 Input to save
3603 convert_dates : dict, default None
3604 Dictionary mapping columns containing datetime types to stata internal
3605 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3606 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3607 Datetime columns that do not have a conversion type specified will be
3608 converted to 'tc'. Raises NotImplementedError if a datetime column has
3609 timezone information
3610 write_index : bool, default True
3611 Write the index to Stata dataset.
3612 byteorder : str, default None
3613 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3614 time_stamp : datetime, default None
3615 A datetime to use as file creation date. Default is the current time
3616 data_label : str, default None
3617 A label for the data set. Must be 80 characters or smaller.
3618 variable_labels : dict, default None
3619 Dictionary containing columns as keys and variable labels as values.
3620 Each label must be 80 characters or smaller.
3621 convert_strl : list, default None
3622 List of columns names to convert to Stata StrL format. Columns with
3623 more than 2045 characters are automatically written as StrL.
3624 Smaller columns can be converted by including the column name. Using
3625 StrLs can reduce output file size when strings are longer than 8
3626 characters, and either frequently repeated or sparse.
3627 version : int, default None
3628 The dta version to use. By default, uses the size of data to determine
3629 the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3630 for storing larger DataFrames.
3631 {compression_options}
3633 .. versionchanged:: 1.4.0 Zstandard support.
3635 value_labels : dict of dicts
3636 Dictionary containing columns as keys and dictionaries of column value
3637 to labels as values. The combined length of all labels for a single
3638 variable must be 32,000 characters or smaller.
3640 .. versionadded:: 1.4.0
3642 Returns
3643 -------
3644 StataWriterUTF8
3645 The instance has a write_file method, which will write the file to the
3646 given `fname`.
3648 Raises
3649 ------
3650 NotImplementedError
3651 * If datetimes contain timezone information
3652 ValueError
3653 * Columns listed in convert_dates are neither datetime64[ns]
3654 or datetime
3655 * Column dtype is not representable in Stata
3656 * Column listed in convert_dates is not in DataFrame
3657 * Categorical label contains more than 32,000 characters
3659 Examples
3660 --------
3661 Using Unicode data and column names
3663 >>> from pandas.io.stata import StataWriterUTF8
3664 >>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ'])
3665 >>> writer = StataWriterUTF8('./data_file.dta', data)
3666 >>> writer.write_file()
3668 Directly write a zip file
3669 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3670 >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3671 >>> writer.write_file()
3673 Or with long strings stored in strl format
3675 >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
3676 ... columns=['strls'])
3677 >>> writer = StataWriterUTF8('./data_file_with_long_strings.dta', data,
3678 ... convert_strl=['strls'])
3679 >>> writer.write_file()
3680 """
3682 _encoding: Literal["utf-8"] = "utf-8"
3684 def __init__(
3685 self,
3686 fname: FilePath | WriteBuffer[bytes],
3687 data: DataFrame,
3688 convert_dates: dict[Hashable, str] | None = None,
3689 write_index: bool = True,
3690 byteorder: str | None = None,
3691 time_stamp: datetime | None = None,
3692 data_label: str | None = None,
3693 variable_labels: dict[Hashable, str] | None = None,
3694 convert_strl: Sequence[Hashable] | None = None,
3695 version: int | None = None,
3696 compression: CompressionOptions = "infer",
3697 storage_options: StorageOptions | None = None,
3698 *,
3699 value_labels: dict[Hashable, dict[float, str]] | None = None,
3700 ) -> None:
3701 if version is None:
3702 version = 118 if data.shape[1] <= 32767 else 119
3703 elif version not in (118, 119):
3704 raise ValueError("version must be either 118 or 119.")
3705 elif version == 118 and data.shape[1] > 32767:
3706 raise ValueError(
3707 "You must use version 119 for data sets containing more than"
3708 "32,767 variables"
3709 )
3711 super().__init__(
3712 fname,
3713 data,
3714 convert_dates=convert_dates,
3715 write_index=write_index,
3716 byteorder=byteorder,
3717 time_stamp=time_stamp,
3718 data_label=data_label,
3719 variable_labels=variable_labels,
3720 value_labels=value_labels,
3721 convert_strl=convert_strl,
3722 compression=compression,
3723 storage_options=storage_options,
3724 )
3725 # Override version set in StataWriter117 init
3726 self._dta_version = version
3728 def _validate_variable_name(self, name: str) -> str:
3729 """
3730 Validate variable names for Stata export.
3732 Parameters
3733 ----------
3734 name : str
3735 Variable name
3737 Returns
3738 -------
3739 str
3740 The validated name with invalid characters replaced with
3741 underscores.
3743 Notes
3744 -----
3745 Stata 118+ support most unicode characters. The only limitation is in
3746 the ascii range where the characters supported are a-z, A-Z, 0-9 and _.
3747 """
3748 # High code points appear to be acceptable
3749 for c in name:
3750 if (
3751 (
3752 ord(c) < 128
3753 and (c < "A" or c > "Z")
3754 and (c < "a" or c > "z")
3755 and (c < "0" or c > "9")
3756 and c != "_"
3757 )
3758 or 128 <= ord(c) < 192
3759 or c in {"×", "÷"} # noqa: RUF001
3760 ):
3761 name = name.replace(c, "_")
3763 return name