Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pandas/io/stata.py: 14%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Module contains tools for processing Stata files into DataFrames
4The StataReader below was originally written by Joe Presbrey as part of PyDTA.
5It has been extended and improved by Skipper Seabold from the Statsmodels
6project who also developed the StataWriter and was finally added to pandas in
7a once again improved version.
9You can find more information on http://presbrey.mit.edu/PyDTA and
10https://www.statsmodels.org/devel/
11"""
12from __future__ import annotations
14from collections import abc
15import datetime
16from io import BytesIO
17import os
18import struct
19import sys
20from types import TracebackType
21from typing import (
22 IO,
23 TYPE_CHECKING,
24 Any,
25 AnyStr,
26 Callable,
27 Final,
28 Hashable,
29 Sequence,
30 cast,
31)
32import warnings
34from dateutil.relativedelta import relativedelta
35import numpy as np
37from pandas._libs.lib import infer_dtype
38from pandas._libs.writers import max_len_string_array
39from pandas._typing import (
40 CompressionOptions,
41 FilePath,
42 ReadBuffer,
43 StorageOptions,
44 WriteBuffer,
45)
46from pandas.errors import (
47 CategoricalConversionWarning,
48 InvalidColumnName,
49 PossiblePrecisionLoss,
50 ValueLabelTypeMismatch,
51)
52from pandas.util._decorators import (
53 Appender,
54 doc,
55)
56from pandas.util._exceptions import find_stack_level
58from pandas.core.dtypes.common import (
59 ensure_object,
60 is_categorical_dtype,
61 is_datetime64_dtype,
62 is_numeric_dtype,
63)
65from pandas import (
66 Categorical,
67 DatetimeIndex,
68 NaT,
69 Timestamp,
70 isna,
71 to_datetime,
72 to_timedelta,
73)
74from pandas.core.arrays.boolean import BooleanDtype
75from pandas.core.arrays.integer import IntegerDtype
76from pandas.core.frame import DataFrame
77from pandas.core.indexes.base import Index
78from pandas.core.series import Series
79from pandas.core.shared_docs import _shared_docs
81from pandas.io.common import get_handle
83if TYPE_CHECKING:
84 from typing import Literal
86_version_error = (
87 "Version of given Stata file is {version}. pandas supports importing "
88 "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
89 "114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16),"
90 "and 119 (Stata 15/16, over 32,767 variables)."
91)
93_statafile_processing_params1 = """\
94convert_dates : bool, default True
95 Convert date variables to DataFrame time values.
96convert_categoricals : bool, default True
97 Read value labels and convert columns to Categorical/Factor variables."""
99_statafile_processing_params2 = """\
100index_col : str, optional
101 Column to set as index.
102convert_missing : bool, default False
103 Flag indicating whether to convert missing values to their Stata
104 representations. If False, missing values are replaced with nan.
105 If True, columns containing missing values are returned with
106 object data types and missing values are represented by
107 StataMissingValue objects.
108preserve_dtypes : bool, default True
109 Preserve Stata datatypes. If False, numeric data are upcast to pandas
110 default types for foreign data (float64 or int64).
111columns : list or None
112 Columns to retain. Columns will be returned in the given order. None
113 returns all columns.
114order_categoricals : bool, default True
115 Flag indicating whether converted categorical data are ordered."""
117_chunksize_params = """\
118chunksize : int, default None
119 Return StataReader object for iterations, returns chunks with
120 given number of lines."""
122_iterator_params = """\
123iterator : bool, default False
124 Return StataReader object."""
126_reader_notes = """\
127Notes
128-----
129Categorical variables read through an iterator may not have the same
130categories and dtype. This occurs when a variable stored in a DTA
131file is associated to an incomplete set of value labels that only
132label a strict subset of the values."""
134_read_stata_doc = f"""
135Read Stata file into DataFrame.
137Parameters
138----------
139filepath_or_buffer : str, path object or file-like object
140 Any valid string path is acceptable. The string could be a URL. Valid
141 URL schemes include http, ftp, s3, and file. For file URLs, a host is
142 expected. A local file could be: ``file://localhost/path/to/table.dta``.
144 If you want to pass in a path object, pandas accepts any ``os.PathLike``.
146 By file-like object, we refer to objects with a ``read()`` method,
147 such as a file handle (e.g. via builtin ``open`` function)
148 or ``StringIO``.
149{_statafile_processing_params1}
150{_statafile_processing_params2}
151{_chunksize_params}
152{_iterator_params}
153{_shared_docs["decompression_options"] % "filepath_or_buffer"}
154{_shared_docs["storage_options"]}
156Returns
157-------
158DataFrame or StataReader
160See Also
161--------
162io.stata.StataReader : Low-level reader for Stata data files.
163DataFrame.to_stata: Export Stata data files.
165{_reader_notes}
167Examples
168--------
170Creating a dummy stata for this example
172>>> df = pd.DataFrame({{'animal': ['falcon', 'parrot', 'falcon', 'parrot'],
173... 'speed': [350, 18, 361, 15]}}) # doctest: +SKIP
174>>> df.to_stata('animals.dta') # doctest: +SKIP
176Read a Stata dta file:
178>>> df = pd.read_stata('animals.dta') # doctest: +SKIP
180Read a Stata dta file in 10,000 line chunks:
182>>> values = np.random.randint(0, 10, size=(20_000, 1), dtype="uint8") # doctest: +SKIP
183>>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
184>>> df.to_stata('filename.dta') # doctest: +SKIP
186>>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP
187>>> for chunk in itr:
188... # Operate on a single chunk, e.g., chunk.mean()
189... pass # doctest: +SKIP
190"""
192_read_method_doc = f"""\
193Reads observations from Stata file, converting them into a dataframe
195Parameters
196----------
197nrows : int
198 Number of lines to read from data file, if None read whole file.
199{_statafile_processing_params1}
200{_statafile_processing_params2}
202Returns
203-------
204DataFrame
205"""
207_stata_reader_doc = f"""\
208Class for reading Stata dta files.
210Parameters
211----------
212path_or_buf : path (string), buffer or path object
213 string, path object (pathlib.Path or py._path.local.LocalPath) or object
214 implementing a binary read() functions.
215{_statafile_processing_params1}
216{_statafile_processing_params2}
217{_chunksize_params}
218{_shared_docs["decompression_options"]}
219{_shared_docs["storage_options"]}
221{_reader_notes}
222"""
225_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
228stata_epoch: Final = datetime.datetime(1960, 1, 1)
231# TODO: Add typing. As of January 2020 it is not possible to type this function since
232# mypy doesn't understand that a Series and an int can be combined using mathematical
233# operations. (+, -).
234def _stata_elapsed_date_to_datetime_vec(dates, fmt) -> Series:
235 """
236 Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime
238 Parameters
239 ----------
240 dates : Series
241 The Stata Internal Format date to convert to datetime according to fmt
242 fmt : str
243 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
244 Returns
246 Returns
247 -------
248 converted : Series
249 The converted dates
251 Examples
252 --------
253 >>> dates = pd.Series([52])
254 >>> _stata_elapsed_date_to_datetime_vec(dates , "%tw")
255 0 1961-01-01
256 dtype: datetime64[ns]
258 Notes
259 -----
260 datetime/c - tc
261 milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
262 datetime/C - tC - NOT IMPLEMENTED
263 milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
264 date - td
265 days since 01jan1960 (01jan1960 = 0)
266 weekly date - tw
267 weeks since 1960w1
268 This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
269 The datetime value is the start of the week in terms of days in the
270 year, not ISO calendar weeks.
271 monthly date - tm
272 months since 1960m1
273 quarterly date - tq
274 quarters since 1960q1
275 half-yearly date - th
276 half-years since 1960h1 yearly
277 date - ty
278 years since 0000
279 """
280 MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year
281 MAX_DAY_DELTA = (Timestamp.max - datetime.datetime(1960, 1, 1)).days
282 MIN_DAY_DELTA = (Timestamp.min - datetime.datetime(1960, 1, 1)).days
283 MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000
284 MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000
286 def convert_year_month_safe(year, month) -> Series:
287 """
288 Convert year and month to datetimes, using pandas vectorized versions
289 when the date range falls within the range supported by pandas.
290 Otherwise it falls back to a slower but more robust method
291 using datetime.
292 """
293 if year.max() < MAX_YEAR and year.min() > MIN_YEAR:
294 return to_datetime(100 * year + month, format="%Y%m")
295 else:
296 index = getattr(year, "index", None)
297 return Series(
298 [datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index
299 )
301 def convert_year_days_safe(year, days) -> Series:
302 """
303 Converts year (e.g. 1999) and days since the start of the year to a
304 datetime or datetime64 Series
305 """
306 if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR:
307 return to_datetime(year, format="%Y") + to_timedelta(days, unit="d")
308 else:
309 index = getattr(year, "index", None)
310 value = [
311 datetime.datetime(y, 1, 1) + relativedelta(days=int(d))
312 for y, d in zip(year, days)
313 ]
314 return Series(value, index=index)
316 def convert_delta_safe(base, deltas, unit) -> Series:
317 """
318 Convert base dates and deltas to datetimes, using pandas vectorized
319 versions if the deltas satisfy restrictions required to be expressed
320 as dates in pandas.
321 """
322 index = getattr(deltas, "index", None)
323 if unit == "d":
324 if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA:
325 values = [base + relativedelta(days=int(d)) for d in deltas]
326 return Series(values, index=index)
327 elif unit == "ms":
328 if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA:
329 values = [
330 base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas
331 ]
332 return Series(values, index=index)
333 else:
334 raise ValueError("format not understood")
335 base = to_datetime(base)
336 deltas = to_timedelta(deltas, unit=unit)
337 return base + deltas
339 # TODO(non-nano): If/when pandas supports more than datetime64[ns], this
340 # should be improved to use correct range, e.g. datetime[Y] for yearly
341 bad_locs = np.isnan(dates)
342 has_bad_values = False
343 if bad_locs.any():
344 has_bad_values = True
345 # reset cache to avoid SettingWithCopy checks (we own the DataFrame and the
346 # `dates` Series is used to overwrite itself in the DataFramae)
347 dates._reset_cacher()
348 dates[bad_locs] = 1.0 # Replace with NaT
349 dates = dates.astype(np.int64)
351 if fmt.startswith(("%tc", "tc")): # Delta ms relative to base
352 base = stata_epoch
353 ms = dates
354 conv_dates = convert_delta_safe(base, ms, "ms")
355 elif fmt.startswith(("%tC", "tC")):
356 warnings.warn(
357 "Encountered %tC format. Leaving in Stata Internal Format.",
358 stacklevel=find_stack_level(),
359 )
360 conv_dates = Series(dates, dtype=object)
361 if has_bad_values:
362 conv_dates[bad_locs] = NaT
363 return conv_dates
364 # Delta days relative to base
365 elif fmt.startswith(("%td", "td", "%d", "d")):
366 base = stata_epoch
367 days = dates
368 conv_dates = convert_delta_safe(base, days, "d")
369 # does not count leap days - 7 days is a week.
370 # 52nd week may have more than 7 days
371 elif fmt.startswith(("%tw", "tw")):
372 year = stata_epoch.year + dates // 52
373 days = (dates % 52) * 7
374 conv_dates = convert_year_days_safe(year, days)
375 elif fmt.startswith(("%tm", "tm")): # Delta months relative to base
376 year = stata_epoch.year + dates // 12
377 month = (dates % 12) + 1
378 conv_dates = convert_year_month_safe(year, month)
379 elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base
380 year = stata_epoch.year + dates // 4
381 quarter_month = (dates % 4) * 3 + 1
382 conv_dates = convert_year_month_safe(year, quarter_month)
383 elif fmt.startswith(("%th", "th")): # Delta half-years relative to base
384 year = stata_epoch.year + dates // 2
385 month = (dates % 2) * 6 + 1
386 conv_dates = convert_year_month_safe(year, month)
387 elif fmt.startswith(("%ty", "ty")): # Years -- not delta
388 year = dates
389 first_month = np.ones_like(dates)
390 conv_dates = convert_year_month_safe(year, first_month)
391 else:
392 raise ValueError(f"Date fmt {fmt} not understood")
394 if has_bad_values: # Restore NaT for bad values
395 conv_dates[bad_locs] = NaT
397 return conv_dates
400def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series:
401 """
402 Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime
404 Parameters
405 ----------
406 dates : Series
407 Series or array containing datetime.datetime or datetime64[ns] to
408 convert to the Stata Internal Format given by fmt
409 fmt : str
410 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
411 """
412 index = dates.index
413 NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000
414 US_PER_DAY = NS_PER_DAY / 1000
416 def parse_dates_safe(
417 dates, delta: bool = False, year: bool = False, days: bool = False
418 ):
419 d = {}
420 if is_datetime64_dtype(dates.dtype):
421 if delta:
422 time_delta = dates - Timestamp(stata_epoch).as_unit("ns")
423 d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds
424 if days or year:
425 date_index = DatetimeIndex(dates)
426 d["year"] = date_index._data.year
427 d["month"] = date_index._data.month
428 if days:
429 days_in_ns = dates.view(np.int64) - to_datetime(
430 d["year"], format="%Y"
431 ).view(np.int64)
432 d["days"] = days_in_ns // NS_PER_DAY
434 elif infer_dtype(dates, skipna=False) == "datetime":
435 if delta:
436 delta = dates._values - stata_epoch
438 def f(x: datetime.timedelta) -> float:
439 return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds
441 v = np.vectorize(f)
442 d["delta"] = v(delta)
443 if year:
444 year_month = dates.apply(lambda x: 100 * x.year + x.month)
445 d["year"] = year_month._values // 100
446 d["month"] = year_month._values - d["year"] * 100
447 if days:
449 def g(x: datetime.datetime) -> int:
450 return (x - datetime.datetime(x.year, 1, 1)).days
452 v = np.vectorize(g)
453 d["days"] = v(dates)
454 else:
455 raise ValueError(
456 "Columns containing dates must contain either "
457 "datetime64, datetime.datetime or null values."
458 )
460 return DataFrame(d, index=index)
462 bad_loc = isna(dates)
463 index = dates.index
464 if bad_loc.any():
465 dates = Series(dates)
466 if is_datetime64_dtype(dates):
467 dates[bad_loc] = to_datetime(stata_epoch)
468 else:
469 dates[bad_loc] = stata_epoch
471 if fmt in ["%tc", "tc"]:
472 d = parse_dates_safe(dates, delta=True)
473 conv_dates = d.delta / 1000
474 elif fmt in ["%tC", "tC"]:
475 warnings.warn(
476 "Stata Internal Format tC not supported.",
477 stacklevel=find_stack_level(),
478 )
479 conv_dates = dates
480 elif fmt in ["%td", "td"]:
481 d = parse_dates_safe(dates, delta=True)
482 conv_dates = d.delta // US_PER_DAY
483 elif fmt in ["%tw", "tw"]:
484 d = parse_dates_safe(dates, year=True, days=True)
485 conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7
486 elif fmt in ["%tm", "tm"]:
487 d = parse_dates_safe(dates, year=True)
488 conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1
489 elif fmt in ["%tq", "tq"]:
490 d = parse_dates_safe(dates, year=True)
491 conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3
492 elif fmt in ["%th", "th"]:
493 d = parse_dates_safe(dates, year=True)
494 conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int)
495 elif fmt in ["%ty", "ty"]:
496 d = parse_dates_safe(dates, year=True)
497 conv_dates = d.year
498 else:
499 raise ValueError(f"Format {fmt} is not a known Stata date format")
501 conv_dates = Series(conv_dates, dtype=np.float64)
502 missing_value = struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
503 conv_dates[bad_loc] = missing_value
505 return Series(conv_dates, index=index)
508excessive_string_length_error: Final = """
509Fixed width strings in Stata .dta files are limited to 244 (or fewer)
510characters. Column '{0}' does not satisfy this restriction. Use the
511'version=117' parameter to write the newer (Stata 13 and later) format.
512"""
515precision_loss_doc: Final = """
516Column converted from {0} to {1}, and some data are outside of the lossless
517conversion range. This may result in a loss of precision in the saved data.
518"""
521value_label_mismatch_doc: Final = """
522Stata value labels (pandas categories) must be strings. Column {0} contains
523non-string labels which will be converted to strings. Please check that the
524Stata data file created has not lost information due to duplicate labels.
525"""
528invalid_name_doc: Final = """
529Not all pandas column names were valid Stata variable names.
530The following replacements have been made:
532 {0}
534If this is not what you expect, please make sure you have Stata-compliant
535column names in your DataFrame (strings only, max 32 characters, only
536alphanumerics and underscores, no Stata reserved words)
537"""
540categorical_conversion_warning: Final = """
541One or more series with value labels are not fully labeled. Reading this
542dataset with an iterator results in categorical variable with different
543categories. This occurs since it is not possible to know all possible values
544until the entire dataset has been read. To avoid this warning, you can either
545read dataset without an iterator, or manually convert categorical data by
546``convert_categoricals`` to False and then accessing the variable labels
547through the value_labels method of the reader.
548"""
551def _cast_to_stata_types(data: DataFrame) -> DataFrame:
552 """
553 Checks the dtypes of the columns of a pandas DataFrame for
554 compatibility with the data types and ranges supported by Stata, and
555 converts if necessary.
557 Parameters
558 ----------
559 data : DataFrame
560 The DataFrame to check and convert
562 Notes
563 -----
564 Numeric columns in Stata must be one of int8, int16, int32, float32 or
565 float64, with some additional value restrictions. int8 and int16 columns
566 are checked for violations of the value restrictions and upcast if needed.
567 int64 data is not usable in Stata, and so it is downcast to int32 whenever
568 the value are in the int32 range, and sidecast to float64 when larger than
569 this range. If the int64 values are outside of the range of those
570 perfectly representable as float64 values, a warning is raised.
572 bool columns are cast to int8. uint columns are converted to int of the
573 same size if there is no loss in precision, otherwise are upcast to a
574 larger type. uint64 is currently not supported since it is concerted to
575 object in a DataFrame.
576 """
577 ws = ""
578 # original, if small, if large
579 conversion_data: tuple[
580 tuple[type, type, type],
581 tuple[type, type, type],
582 tuple[type, type, type],
583 tuple[type, type, type],
584 tuple[type, type, type],
585 ] = (
586 (np.bool_, np.int8, np.int8),
587 (np.uint8, np.int8, np.int16),
588 (np.uint16, np.int16, np.int32),
589 (np.uint32, np.int32, np.int64),
590 (np.uint64, np.int64, np.float64),
591 )
593 float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
594 float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
596 for col in data:
597 # Cast from unsupported types to supported types
598 is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype))
599 orig = data[col]
600 # We need to find orig_missing before altering data below
601 orig_missing = orig.isna()
602 if is_nullable_int:
603 missing_loc = data[col].isna()
604 if missing_loc.any():
605 # Replace with always safe value
606 fv = 0 if isinstance(data[col].dtype, IntegerDtype) else False
607 data.loc[missing_loc, col] = fv
608 # Replace with NumPy-compatible column
609 data[col] = data[col].astype(data[col].dtype.numpy_dtype)
610 dtype = data[col].dtype
611 for c_data in conversion_data:
612 if dtype == c_data[0]:
613 if data[col].max() <= np.iinfo(c_data[1]).max:
614 dtype = c_data[1]
615 else:
616 dtype = c_data[2]
617 if c_data[2] == np.int64: # Warn if necessary
618 if data[col].max() >= 2**53:
619 ws = precision_loss_doc.format("uint64", "float64")
621 data[col] = data[col].astype(dtype)
623 # Check values and upcast if necessary
624 if dtype == np.int8:
625 if data[col].max() > 100 or data[col].min() < -127:
626 data[col] = data[col].astype(np.int16)
627 elif dtype == np.int16:
628 if data[col].max() > 32740 or data[col].min() < -32767:
629 data[col] = data[col].astype(np.int32)
630 elif dtype == np.int64:
631 if data[col].max() <= 2147483620 and data[col].min() >= -2147483647:
632 data[col] = data[col].astype(np.int32)
633 else:
634 data[col] = data[col].astype(np.float64)
635 if data[col].max() >= 2**53 or data[col].min() <= -(2**53):
636 ws = precision_loss_doc.format("int64", "float64")
637 elif dtype in (np.float32, np.float64):
638 if np.isinf(data[col]).any():
639 raise ValueError(
640 f"Column {col} contains infinity or -infinity"
641 "which is outside the range supported by Stata."
642 )
643 value = data[col].max()
644 if dtype == np.float32 and value > float32_max:
645 data[col] = data[col].astype(np.float64)
646 elif dtype == np.float64:
647 if value > float64_max:
648 raise ValueError(
649 f"Column {col} has a maximum value ({value}) outside the range "
650 f"supported by Stata ({float64_max})"
651 )
652 if is_nullable_int:
653 if orig_missing.any():
654 # Replace missing by Stata sentinel value
655 sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
656 data.loc[orig_missing, col] = sentinel
657 if ws:
658 warnings.warn(
659 ws,
660 PossiblePrecisionLoss,
661 stacklevel=find_stack_level(),
662 )
664 return data
667class StataValueLabel:
668 """
669 Parse a categorical column and prepare formatted output
671 Parameters
672 ----------
673 catarray : Series
674 Categorical Series to encode
675 encoding : {"latin-1", "utf-8"}
676 Encoding to use for value labels.
677 """
679 def __init__(
680 self, catarray: Series, encoding: Literal["latin-1", "utf-8"] = "latin-1"
681 ) -> None:
682 if encoding not in ("latin-1", "utf-8"):
683 raise ValueError("Only latin-1 and utf-8 are supported.")
684 self.labname = catarray.name
685 self._encoding = encoding
686 categories = catarray.cat.categories
687 self.value_labels: list[tuple[float, str]] = list(
688 zip(np.arange(len(categories)), categories)
689 )
690 self.value_labels.sort(key=lambda x: x[0])
692 self._prepare_value_labels()
694 def _prepare_value_labels(self):
695 """Encode value labels."""
697 self.text_len = 0
698 self.txt: list[bytes] = []
699 self.n = 0
700 # Offsets (length of categories), converted to int32
701 self.off = np.array([], dtype=np.int32)
702 # Values, converted to int32
703 self.val = np.array([], dtype=np.int32)
704 self.len = 0
706 # Compute lengths and setup lists of offsets and labels
707 offsets: list[int] = []
708 values: list[float] = []
709 for vl in self.value_labels:
710 category: str | bytes = vl[1]
711 if not isinstance(category, str):
712 category = str(category)
713 warnings.warn(
714 value_label_mismatch_doc.format(self.labname),
715 ValueLabelTypeMismatch,
716 stacklevel=find_stack_level(),
717 )
718 category = category.encode(self._encoding)
719 offsets.append(self.text_len)
720 self.text_len += len(category) + 1 # +1 for the padding
721 values.append(vl[0])
722 self.txt.append(category)
723 self.n += 1
725 if self.text_len > 32000:
726 raise ValueError(
727 "Stata value labels for a single variable must "
728 "have a combined length less than 32,000 characters."
729 )
731 # Ensure int32
732 self.off = np.array(offsets, dtype=np.int32)
733 self.val = np.array(values, dtype=np.int32)
735 # Total length
736 self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
738 def generate_value_label(self, byteorder: str) -> bytes:
739 """
740 Generate the binary representation of the value labels.
742 Parameters
743 ----------
744 byteorder : str
745 Byte order of the output
747 Returns
748 -------
749 value_label : bytes
750 Bytes containing the formatted value label
751 """
752 encoding = self._encoding
753 bio = BytesIO()
754 null_byte = b"\x00"
756 # len
757 bio.write(struct.pack(byteorder + "i", self.len))
759 # labname
760 labname = str(self.labname)[:32].encode(encoding)
761 lab_len = 32 if encoding not in ("utf-8", "utf8") else 128
762 labname = _pad_bytes(labname, lab_len + 1)
763 bio.write(labname)
765 # padding - 3 bytes
766 for i in range(3):
767 bio.write(struct.pack("c", null_byte))
769 # value_label_table
770 # n - int32
771 bio.write(struct.pack(byteorder + "i", self.n))
773 # textlen - int32
774 bio.write(struct.pack(byteorder + "i", self.text_len))
776 # off - int32 array (n elements)
777 for offset in self.off:
778 bio.write(struct.pack(byteorder + "i", offset))
780 # val - int32 array (n elements)
781 for value in self.val:
782 bio.write(struct.pack(byteorder + "i", value))
784 # txt - Text labels, null terminated
785 for text in self.txt:
786 bio.write(text + null_byte)
788 return bio.getvalue()
791class StataNonCatValueLabel(StataValueLabel):
792 """
793 Prepare formatted version of value labels
795 Parameters
796 ----------
797 labname : str
798 Value label name
799 value_labels: Dictionary
800 Mapping of values to labels
801 encoding : {"latin-1", "utf-8"}
802 Encoding to use for value labels.
803 """
805 def __init__(
806 self,
807 labname: str,
808 value_labels: dict[float, str],
809 encoding: Literal["latin-1", "utf-8"] = "latin-1",
810 ) -> None:
811 if encoding not in ("latin-1", "utf-8"):
812 raise ValueError("Only latin-1 and utf-8 are supported.")
814 self.labname = labname
815 self._encoding = encoding
816 self.value_labels: list[tuple[float, str]] = sorted(
817 value_labels.items(), key=lambda x: x[0]
818 )
819 self._prepare_value_labels()
822class StataMissingValue:
823 """
824 An observation's missing value.
826 Parameters
827 ----------
828 value : {int, float}
829 The Stata missing value code
831 Notes
832 -----
833 More information: <https://www.stata.com/help.cgi?missing>
835 Integer missing values make the code '.', '.a', ..., '.z' to the ranges
836 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
837 2147483647 (for int32). Missing values for floating point data types are
838 more complex but the pattern is simple to discern from the following table.
840 np.float32 missing values (float in Stata)
841 0000007f .
842 0008007f .a
843 0010007f .b
844 ...
845 00c0007f .x
846 00c8007f .y
847 00d0007f .z
849 np.float64 missing values (double in Stata)
850 000000000000e07f .
851 000000000001e07f .a
852 000000000002e07f .b
853 ...
854 000000000018e07f .x
855 000000000019e07f .y
856 00000000001ae07f .z
857 """
859 # Construct a dictionary of missing values
860 MISSING_VALUES: dict[float, str] = {}
861 bases: Final = (101, 32741, 2147483621)
862 for b in bases:
863 # Conversion to long to avoid hash issues on 32 bit platforms #8968
864 MISSING_VALUES[b] = "."
865 for i in range(1, 27):
866 MISSING_VALUES[i + b] = "." + chr(96 + i)
868 float32_base: bytes = b"\x00\x00\x00\x7f"
869 increment: int = struct.unpack("<i", b"\x00\x08\x00\x00")[0]
870 for i in range(27):
871 key = struct.unpack("<f", float32_base)[0]
872 MISSING_VALUES[key] = "."
873 if i > 0:
874 MISSING_VALUES[key] += chr(96 + i)
875 int_value = struct.unpack("<i", struct.pack("<f", key))[0] + increment
876 float32_base = struct.pack("<i", int_value)
878 float64_base: bytes = b"\x00\x00\x00\x00\x00\x00\xe0\x7f"
879 increment = struct.unpack("q", b"\x00\x00\x00\x00\x00\x01\x00\x00")[0]
880 for i in range(27):
881 key = struct.unpack("<d", float64_base)[0]
882 MISSING_VALUES[key] = "."
883 if i > 0:
884 MISSING_VALUES[key] += chr(96 + i)
885 int_value = struct.unpack("q", struct.pack("<d", key))[0] + increment
886 float64_base = struct.pack("q", int_value)
888 BASE_MISSING_VALUES: Final = {
889 "int8": 101,
890 "int16": 32741,
891 "int32": 2147483621,
892 "float32": struct.unpack("<f", float32_base)[0],
893 "float64": struct.unpack("<d", float64_base)[0],
894 }
896 def __init__(self, value: float) -> None:
897 self._value = value
898 # Conversion to int to avoid hash issues on 32 bit platforms #8968
899 value = int(value) if value < 2147483648 else float(value)
900 self._str = self.MISSING_VALUES[value]
902 @property
903 def string(self) -> str:
904 """
905 The Stata representation of the missing value: '.', '.a'..'.z'
907 Returns
908 -------
909 str
910 The representation of the missing value.
911 """
912 return self._str
914 @property
915 def value(self) -> float:
916 """
917 The binary representation of the missing value.
919 Returns
920 -------
921 {int, float}
922 The binary representation of the missing value.
923 """
924 return self._value
926 def __str__(self) -> str:
927 return self.string
929 def __repr__(self) -> str:
930 return f"{type(self)}({self})"
932 def __eq__(self, other: Any) -> bool:
933 return (
934 isinstance(other, type(self))
935 and self.string == other.string
936 and self.value == other.value
937 )
939 @classmethod
940 def get_base_missing_value(cls, dtype: np.dtype) -> float:
941 if dtype.type is np.int8:
942 value = cls.BASE_MISSING_VALUES["int8"]
943 elif dtype.type is np.int16:
944 value = cls.BASE_MISSING_VALUES["int16"]
945 elif dtype.type is np.int32:
946 value = cls.BASE_MISSING_VALUES["int32"]
947 elif dtype.type is np.float32:
948 value = cls.BASE_MISSING_VALUES["float32"]
949 elif dtype.type is np.float64:
950 value = cls.BASE_MISSING_VALUES["float64"]
951 else:
952 raise ValueError("Unsupported dtype")
953 return value
956class StataParser:
957 def __init__(self) -> None:
958 # type code.
959 # --------------------
960 # str1 1 = 0x01
961 # str2 2 = 0x02
962 # ...
963 # str244 244 = 0xf4
964 # byte 251 = 0xfb (sic)
965 # int 252 = 0xfc
966 # long 253 = 0xfd
967 # float 254 = 0xfe
968 # double 255 = 0xff
969 # --------------------
970 # NOTE: the byte type seems to be reserved for categorical variables
971 # with a label, but the underlying variable is -127 to 100
972 # we're going to drop the label and cast to int
973 self.DTYPE_MAP = dict(
974 list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
975 + [
976 (251, np.dtype(np.int8)),
977 (252, np.dtype(np.int16)),
978 (253, np.dtype(np.int32)),
979 (254, np.dtype(np.float32)),
980 (255, np.dtype(np.float64)),
981 ]
982 )
983 self.DTYPE_MAP_XML: dict[int, np.dtype] = {
984 32768: np.dtype(np.uint8), # Keys to GSO
985 65526: np.dtype(np.float64),
986 65527: np.dtype(np.float32),
987 65528: np.dtype(np.int32),
988 65529: np.dtype(np.int16),
989 65530: np.dtype(np.int8),
990 }
991 self.TYPE_MAP = list(tuple(range(251)) + tuple("bhlfd"))
992 self.TYPE_MAP_XML = {
993 # Not really a Q, unclear how to handle byteswap
994 32768: "Q",
995 65526: "d",
996 65527: "f",
997 65528: "l",
998 65529: "h",
999 65530: "b",
1000 }
1001 # NOTE: technically, some of these are wrong. there are more numbers
1002 # that can be represented. it's the 27 ABOVE and BELOW the max listed
1003 # numeric data type in [U] 12.2.2 of the 11.2 manual
1004 float32_min = b"\xff\xff\xff\xfe"
1005 float32_max = b"\xff\xff\xff\x7e"
1006 float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff"
1007 float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f"
1008 self.VALID_RANGE = {
1009 "b": (-127, 100),
1010 "h": (-32767, 32740),
1011 "l": (-2147483647, 2147483620),
1012 "f": (
1013 np.float32(struct.unpack("<f", float32_min)[0]),
1014 np.float32(struct.unpack("<f", float32_max)[0]),
1015 ),
1016 "d": (
1017 np.float64(struct.unpack("<d", float64_min)[0]),
1018 np.float64(struct.unpack("<d", float64_max)[0]),
1019 ),
1020 }
1022 self.OLD_TYPE_MAPPING = {
1023 98: 251, # byte
1024 105: 252, # int
1025 108: 253, # long
1026 102: 254, # float
1027 100: 255, # double
1028 }
1030 # These missing values are the generic '.' in Stata, and are used
1031 # to replace nans
1032 self.MISSING_VALUES = {
1033 "b": 101,
1034 "h": 32741,
1035 "l": 2147483621,
1036 "f": np.float32(struct.unpack("<f", b"\x00\x00\x00\x7f")[0]),
1037 "d": np.float64(
1038 struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
1039 ),
1040 }
1041 self.NUMPY_TYPE_MAP = {
1042 "b": "i1",
1043 "h": "i2",
1044 "l": "i4",
1045 "f": "f4",
1046 "d": "f8",
1047 "Q": "u8",
1048 }
1050 # Reserved words cannot be used as variable names
1051 self.RESERVED_WORDS = (
1052 "aggregate",
1053 "array",
1054 "boolean",
1055 "break",
1056 "byte",
1057 "case",
1058 "catch",
1059 "class",
1060 "colvector",
1061 "complex",
1062 "const",
1063 "continue",
1064 "default",
1065 "delegate",
1066 "delete",
1067 "do",
1068 "double",
1069 "else",
1070 "eltypedef",
1071 "end",
1072 "enum",
1073 "explicit",
1074 "export",
1075 "external",
1076 "float",
1077 "for",
1078 "friend",
1079 "function",
1080 "global",
1081 "goto",
1082 "if",
1083 "inline",
1084 "int",
1085 "local",
1086 "long",
1087 "NULL",
1088 "pragma",
1089 "protected",
1090 "quad",
1091 "rowvector",
1092 "short",
1093 "typedef",
1094 "typename",
1095 "virtual",
1096 "_all",
1097 "_N",
1098 "_skip",
1099 "_b",
1100 "_pi",
1101 "str#",
1102 "in",
1103 "_pred",
1104 "strL",
1105 "_coef",
1106 "_rc",
1107 "using",
1108 "_cons",
1109 "_se",
1110 "with",
1111 "_n",
1112 )
1115class StataReader(StataParser, abc.Iterator):
1116 __doc__ = _stata_reader_doc
1118 _path_or_buf: IO[bytes] # Will be assigned by `_open_file`.
1120 def __init__(
1121 self,
1122 path_or_buf: FilePath | ReadBuffer[bytes],
1123 convert_dates: bool = True,
1124 convert_categoricals: bool = True,
1125 index_col: str | None = None,
1126 convert_missing: bool = False,
1127 preserve_dtypes: bool = True,
1128 columns: Sequence[str] | None = None,
1129 order_categoricals: bool = True,
1130 chunksize: int | None = None,
1131 compression: CompressionOptions = "infer",
1132 storage_options: StorageOptions = None,
1133 ) -> None:
1134 super().__init__()
1135 self._col_sizes: list[int] = []
1137 # Arguments to the reader (can be temporarily overridden in
1138 # calls to read).
1139 self._convert_dates = convert_dates
1140 self._convert_categoricals = convert_categoricals
1141 self._index_col = index_col
1142 self._convert_missing = convert_missing
1143 self._preserve_dtypes = preserve_dtypes
1144 self._columns = columns
1145 self._order_categoricals = order_categoricals
1146 self._original_path_or_buf = path_or_buf
1147 self._compression = compression
1148 self._storage_options = storage_options
1149 self._encoding = ""
1150 self._chunksize = chunksize
1151 self._using_iterator = False
1152 self._entered = False
1153 if self._chunksize is None:
1154 self._chunksize = 1
1155 elif not isinstance(chunksize, int) or chunksize <= 0:
1156 raise ValueError("chunksize must be a positive integer when set.")
1158 # State variables for the file
1159 self._close_file: Callable[[], None] | None = None
1160 self._has_string_data = False
1161 self._missing_values = False
1162 self._can_read_value_labels = False
1163 self._column_selector_set = False
1164 self._value_labels_read = False
1165 self._data_read = False
1166 self._dtype: np.dtype | None = None
1167 self._lines_read = 0
1169 self._native_byteorder = _set_endianness(sys.byteorder)
1171 def _ensure_open(self) -> None:
1172 """
1173 Ensure the file has been opened and its header data read.
1174 """
1175 if not hasattr(self, "_path_or_buf"):
1176 self._open_file()
1178 def _open_file(self) -> None:
1179 """
1180 Open the file (with compression options, etc.), and read header information.
1181 """
1182 if not self._entered:
1183 warnings.warn(
1184 "StataReader is being used without using a context manager. "
1185 "Using StataReader as a context manager is the only supported method.",
1186 ResourceWarning,
1187 stacklevel=find_stack_level(),
1188 )
1189 handles = get_handle(
1190 self._original_path_or_buf,
1191 "rb",
1192 storage_options=self._storage_options,
1193 is_text=False,
1194 compression=self._compression,
1195 )
1196 if hasattr(handles.handle, "seekable") and handles.handle.seekable():
1197 # If the handle is directly seekable, use it without an extra copy.
1198 self._path_or_buf = handles.handle
1199 self._close_file = handles.close
1200 else:
1201 # Copy to memory, and ensure no encoding.
1202 with handles:
1203 self._path_or_buf = BytesIO(handles.handle.read())
1204 self._close_file = self._path_or_buf.close
1206 self._read_header()
1207 self._setup_dtype()
1209 def __enter__(self) -> StataReader:
1210 """enter context manager"""
1211 self._entered = True
1212 return self
1214 def __exit__(
1215 self,
1216 exc_type: type[BaseException] | None,
1217 exc_value: BaseException | None,
1218 traceback: TracebackType | None,
1219 ) -> None:
1220 if self._close_file:
1221 self._close_file()
1223 def close(self) -> None:
1224 """Close the handle if its open.
1226 .. deprecated: 2.0.0
1228 The close method is not part of the public API.
1229 The only supported way to use StataReader is to use it as a context manager.
1230 """
1231 warnings.warn(
1232 "The StataReader.close() method is not part of the public API and "
1233 "will be removed in a future version without notice. "
1234 "Using StataReader as a context manager is the only supported method.",
1235 FutureWarning,
1236 stacklevel=find_stack_level(),
1237 )
1238 if self._close_file:
1239 self._close_file()
1241 def _set_encoding(self) -> None:
1242 """
1243 Set string encoding which depends on file version
1244 """
1245 if self._format_version < 118:
1246 self._encoding = "latin-1"
1247 else:
1248 self._encoding = "utf-8"
1250 def _read_int8(self) -> int:
1251 return struct.unpack("b", self._path_or_buf.read(1))[0]
1253 def _read_uint8(self) -> int:
1254 return struct.unpack("B", self._path_or_buf.read(1))[0]
1256 def _read_uint16(self) -> int:
1257 return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0]
1259 def _read_uint32(self) -> int:
1260 return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0]
1262 def _read_uint64(self) -> int:
1263 return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0]
1265 def _read_int16(self) -> int:
1266 return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0]
1268 def _read_int32(self) -> int:
1269 return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0]
1271 def _read_int64(self) -> int:
1272 return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0]
1274 def _read_char8(self) -> bytes:
1275 return struct.unpack("c", self._path_or_buf.read(1))[0]
1277 def _read_int16_count(self, count: int) -> tuple[int, ...]:
1278 return struct.unpack(
1279 f"{self._byteorder}{'h' * count}",
1280 self._path_or_buf.read(2 * count),
1281 )
1283 def _read_header(self) -> None:
1284 first_char = self._read_char8()
1285 if first_char == b"<":
1286 self._read_new_header()
1287 else:
1288 self._read_old_header(first_char)
1290 self._has_string_data = len([x for x in self._typlist if type(x) is int]) > 0
1292 # calculate size of a data record
1293 self._col_sizes = [self._calcsize(typ) for typ in self._typlist]
1295 def _read_new_header(self) -> None:
1296 # The first part of the header is common to 117 - 119.
1297 self._path_or_buf.read(27) # stata_dta><header><release>
1298 self._format_version = int(self._path_or_buf.read(3))
1299 if self._format_version not in [117, 118, 119]:
1300 raise ValueError(_version_error.format(version=self._format_version))
1301 self._set_encoding()
1302 self._path_or_buf.read(21) # </release><byteorder>
1303 self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<"
1304 self._path_or_buf.read(15) # </byteorder><K>
1305 self._nvar = (
1306 self._read_uint16() if self._format_version <= 118 else self._read_uint32()
1307 )
1308 self._path_or_buf.read(7) # </K><N>
1310 self._nobs = self._get_nobs()
1311 self._path_or_buf.read(11) # </N><label>
1312 self._data_label = self._get_data_label()
1313 self._path_or_buf.read(19) # </label><timestamp>
1314 self._time_stamp = self._get_time_stamp()
1315 self._path_or_buf.read(26) # </timestamp></header><map>
1316 self._path_or_buf.read(8) # 0x0000000000000000
1317 self._path_or_buf.read(8) # position of <map>
1319 self._seek_vartypes = self._read_int64() + 16
1320 self._seek_varnames = self._read_int64() + 10
1321 self._seek_sortlist = self._read_int64() + 10
1322 self._seek_formats = self._read_int64() + 9
1323 self._seek_value_label_names = self._read_int64() + 19
1325 # Requires version-specific treatment
1326 self._seek_variable_labels = self._get_seek_variable_labels()
1328 self._path_or_buf.read(8) # <characteristics>
1329 self._data_location = self._read_int64() + 6
1330 self._seek_strls = self._read_int64() + 7
1331 self._seek_value_labels = self._read_int64() + 14
1333 self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes)
1335 self._path_or_buf.seek(self._seek_varnames)
1336 self._varlist = self._get_varlist()
1338 self._path_or_buf.seek(self._seek_sortlist)
1339 self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
1341 self._path_or_buf.seek(self._seek_formats)
1342 self._fmtlist = self._get_fmtlist()
1344 self._path_or_buf.seek(self._seek_value_label_names)
1345 self._lbllist = self._get_lbllist()
1347 self._path_or_buf.seek(self._seek_variable_labels)
1348 self._variable_labels = self._get_variable_labels()
1350 # Get data type information, works for versions 117-119.
1351 def _get_dtypes(
1352 self, seek_vartypes: int
1353 ) -> tuple[list[int | str], list[str | np.dtype]]:
1354 self._path_or_buf.seek(seek_vartypes)
1355 raw_typlist = [self._read_uint16() for _ in range(self._nvar)]
1357 def f(typ: int) -> int | str:
1358 if typ <= 2045:
1359 return typ
1360 try:
1361 return self.TYPE_MAP_XML[typ]
1362 except KeyError as err:
1363 raise ValueError(f"cannot convert stata types [{typ}]") from err
1365 typlist = [f(x) for x in raw_typlist]
1367 def g(typ: int) -> str | np.dtype:
1368 if typ <= 2045:
1369 return str(typ)
1370 try:
1371 return self.DTYPE_MAP_XML[typ]
1372 except KeyError as err:
1373 raise ValueError(f"cannot convert stata dtype [{typ}]") from err
1375 dtyplist = [g(x) for x in raw_typlist]
1377 return typlist, dtyplist
1379 def _get_varlist(self) -> list[str]:
1380 # 33 in order formats, 129 in formats 118 and 119
1381 b = 33 if self._format_version < 118 else 129
1382 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1384 # Returns the format list
1385 def _get_fmtlist(self) -> list[str]:
1386 if self._format_version >= 118:
1387 b = 57
1388 elif self._format_version > 113:
1389 b = 49
1390 elif self._format_version > 104:
1391 b = 12
1392 else:
1393 b = 7
1395 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1397 # Returns the label list
1398 def _get_lbllist(self) -> list[str]:
1399 if self._format_version >= 118:
1400 b = 129
1401 elif self._format_version > 108:
1402 b = 33
1403 else:
1404 b = 9
1405 return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
1407 def _get_variable_labels(self) -> list[str]:
1408 if self._format_version >= 118:
1409 vlblist = [
1410 self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar)
1411 ]
1412 elif self._format_version > 105:
1413 vlblist = [
1414 self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar)
1415 ]
1416 else:
1417 vlblist = [
1418 self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar)
1419 ]
1420 return vlblist
1422 def _get_nobs(self) -> int:
1423 if self._format_version >= 118:
1424 return self._read_uint64()
1425 else:
1426 return self._read_uint32()
1428 def _get_data_label(self) -> str:
1429 if self._format_version >= 118:
1430 strlen = self._read_uint16()
1431 return self._decode(self._path_or_buf.read(strlen))
1432 elif self._format_version == 117:
1433 strlen = self._read_int8()
1434 return self._decode(self._path_or_buf.read(strlen))
1435 elif self._format_version > 105:
1436 return self._decode(self._path_or_buf.read(81))
1437 else:
1438 return self._decode(self._path_or_buf.read(32))
1440 def _get_time_stamp(self) -> str:
1441 if self._format_version >= 118:
1442 strlen = self._read_int8()
1443 return self._path_or_buf.read(strlen).decode("utf-8")
1444 elif self._format_version == 117:
1445 strlen = self._read_int8()
1446 return self._decode(self._path_or_buf.read(strlen))
1447 elif self._format_version > 104:
1448 return self._decode(self._path_or_buf.read(18))
1449 else:
1450 raise ValueError()
1452 def _get_seek_variable_labels(self) -> int:
1453 if self._format_version == 117:
1454 self._path_or_buf.read(8) # <variable_labels>, throw away
1455 # Stata 117 data files do not follow the described format. This is
1456 # a work around that uses the previous label, 33 bytes for each
1457 # variable, 20 for the closing tag and 17 for the opening tag
1458 return self._seek_value_label_names + (33 * self._nvar) + 20 + 17
1459 elif self._format_version >= 118:
1460 return self._read_int64() + 17
1461 else:
1462 raise ValueError()
1464 def _read_old_header(self, first_char: bytes) -> None:
1465 self._format_version = int(first_char[0])
1466 if self._format_version not in [104, 105, 108, 111, 113, 114, 115]:
1467 raise ValueError(_version_error.format(version=self._format_version))
1468 self._set_encoding()
1469 self._byteorder = ">" if self._read_int8() == 0x1 else "<"
1470 self._filetype = self._read_int8()
1471 self._path_or_buf.read(1) # unused
1473 self._nvar = self._read_uint16()
1474 self._nobs = self._get_nobs()
1476 self._data_label = self._get_data_label()
1478 self._time_stamp = self._get_time_stamp()
1480 # descriptors
1481 if self._format_version > 108:
1482 typlist = [int(c) for c in self._path_or_buf.read(self._nvar)]
1483 else:
1484 buf = self._path_or_buf.read(self._nvar)
1485 typlistb = np.frombuffer(buf, dtype=np.uint8)
1486 typlist = []
1487 for tp in typlistb:
1488 if tp in self.OLD_TYPE_MAPPING:
1489 typlist.append(self.OLD_TYPE_MAPPING[tp])
1490 else:
1491 typlist.append(tp - 127) # bytes
1493 try:
1494 self._typlist = [self.TYPE_MAP[typ] for typ in typlist]
1495 except ValueError as err:
1496 invalid_types = ",".join([str(x) for x in typlist])
1497 raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
1498 try:
1499 self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
1500 except ValueError as err:
1501 invalid_dtypes = ",".join([str(x) for x in typlist])
1502 raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
1504 if self._format_version > 108:
1505 self._varlist = [
1506 self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar)
1507 ]
1508 else:
1509 self._varlist = [
1510 self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar)
1511 ]
1512 self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
1514 self._fmtlist = self._get_fmtlist()
1516 self._lbllist = self._get_lbllist()
1518 self._variable_labels = self._get_variable_labels()
1520 # ignore expansion fields (Format 105 and later)
1521 # When reading, read five bytes; the last four bytes now tell you
1522 # the size of the next read, which you discard. You then continue
1523 # like this until you read 5 bytes of zeros.
1525 if self._format_version > 104:
1526 while True:
1527 data_type = self._read_int8()
1528 if self._format_version > 108:
1529 data_len = self._read_int32()
1530 else:
1531 data_len = self._read_int16()
1532 if data_type == 0:
1533 break
1534 self._path_or_buf.read(data_len)
1536 # necessary data to continue parsing
1537 self._data_location = self._path_or_buf.tell()
1539 def _setup_dtype(self) -> np.dtype:
1540 """Map between numpy and state dtypes"""
1541 if self._dtype is not None:
1542 return self._dtype
1544 dtypes = [] # Convert struct data types to numpy data type
1545 for i, typ in enumerate(self._typlist):
1546 if typ in self.NUMPY_TYPE_MAP:
1547 typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
1548 dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}"))
1549 else:
1550 dtypes.append((f"s{i}", f"S{typ}"))
1551 self._dtype = np.dtype(dtypes)
1553 return self._dtype
1555 def _calcsize(self, fmt: int | str) -> int:
1556 if isinstance(fmt, int):
1557 return fmt
1558 return struct.calcsize(self._byteorder + fmt)
1560 def _decode(self, s: bytes) -> str:
1561 # have bytes not strings, so must decode
1562 s = s.partition(b"\0")[0]
1563 try:
1564 return s.decode(self._encoding)
1565 except UnicodeDecodeError:
1566 # GH 25960, fallback to handle incorrect format produced when 117
1567 # files are converted to 118 files in Stata
1568 encoding = self._encoding
1569 msg = f"""
1570One or more strings in the dta file could not be decoded using {encoding}, and
1571so the fallback encoding of latin-1 is being used. This can happen when a file
1572has been incorrectly encoded by Stata or some other software. You should verify
1573the string values returned are correct."""
1574 warnings.warn(
1575 msg,
1576 UnicodeWarning,
1577 stacklevel=find_stack_level(),
1578 )
1579 return s.decode("latin-1")
1581 def _read_value_labels(self) -> None:
1582 self._ensure_open()
1583 if self._value_labels_read:
1584 # Don't read twice
1585 return
1586 if self._format_version <= 108:
1587 # Value labels are not supported in version 108 and earlier.
1588 self._value_labels_read = True
1589 self._value_label_dict: dict[str, dict[float, str]] = {}
1590 return
1592 if self._format_version >= 117:
1593 self._path_or_buf.seek(self._seek_value_labels)
1594 else:
1595 assert self._dtype is not None
1596 offset = self._nobs * self._dtype.itemsize
1597 self._path_or_buf.seek(self._data_location + offset)
1599 self._value_labels_read = True
1600 self._value_label_dict = {}
1602 while True:
1603 if self._format_version >= 117:
1604 if self._path_or_buf.read(5) == b"</val": # <lbl>
1605 break # end of value label table
1607 slength = self._path_or_buf.read(4)
1608 if not slength:
1609 break # end of value label table (format < 117)
1610 if self._format_version <= 117:
1611 labname = self._decode(self._path_or_buf.read(33))
1612 else:
1613 labname = self._decode(self._path_or_buf.read(129))
1614 self._path_or_buf.read(3) # padding
1616 n = self._read_uint32()
1617 txtlen = self._read_uint32()
1618 off = np.frombuffer(
1619 self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
1620 )
1621 val = np.frombuffer(
1622 self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
1623 )
1624 ii = np.argsort(off)
1625 off = off[ii]
1626 val = val[ii]
1627 txt = self._path_or_buf.read(txtlen)
1628 self._value_label_dict[labname] = {}
1629 for i in range(n):
1630 end = off[i + 1] if i < n - 1 else txtlen
1631 self._value_label_dict[labname][val[i]] = self._decode(
1632 txt[off[i] : end]
1633 )
1634 if self._format_version >= 117:
1635 self._path_or_buf.read(6) # </lbl>
1636 self._value_labels_read = True
1638 def _read_strls(self) -> None:
1639 self._path_or_buf.seek(self._seek_strls)
1640 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1641 self.GSO = {"0": ""}
1642 while True:
1643 if self._path_or_buf.read(3) != b"GSO":
1644 break
1646 if self._format_version == 117:
1647 v_o = self._read_uint64()
1648 else:
1649 buf = self._path_or_buf.read(12)
1650 # Only tested on little endian file on little endian machine.
1651 v_size = 2 if self._format_version == 118 else 3
1652 if self._byteorder == "<":
1653 buf = buf[0:v_size] + buf[4 : (12 - v_size)]
1654 else:
1655 # This path may not be correct, impossible to test
1656 buf = buf[0:v_size] + buf[(4 + v_size) :]
1657 v_o = struct.unpack("Q", buf)[0]
1658 typ = self._read_uint8()
1659 length = self._read_uint32()
1660 va = self._path_or_buf.read(length)
1661 if typ == 130:
1662 decoded_va = va[0:-1].decode(self._encoding)
1663 else:
1664 # Stata says typ 129 can be binary, so use str
1665 decoded_va = str(va)
1666 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1667 self.GSO[str(v_o)] = decoded_va
1669 def __next__(self) -> DataFrame:
1670 self._using_iterator = True
1671 return self.read(nrows=self._chunksize)
1673 def get_chunk(self, size: int | None = None) -> DataFrame:
1674 """
1675 Reads lines from Stata file and returns as dataframe
1677 Parameters
1678 ----------
1679 size : int, defaults to None
1680 Number of lines to read. If None, reads whole file.
1682 Returns
1683 -------
1684 DataFrame
1685 """
1686 if size is None:
1687 size = self._chunksize
1688 return self.read(nrows=size)
1690 @Appender(_read_method_doc)
1691 def read(
1692 self,
1693 nrows: int | None = None,
1694 convert_dates: bool | None = None,
1695 convert_categoricals: bool | None = None,
1696 index_col: str | None = None,
1697 convert_missing: bool | None = None,
1698 preserve_dtypes: bool | None = None,
1699 columns: Sequence[str] | None = None,
1700 order_categoricals: bool | None = None,
1701 ) -> DataFrame:
1702 self._ensure_open()
1703 # Handle empty file or chunk. If reading incrementally raise
1704 # StopIteration. If reading the whole thing return an empty
1705 # data frame.
1706 if (self._nobs == 0) and (nrows is None):
1707 self._can_read_value_labels = True
1708 self._data_read = True
1709 return DataFrame(columns=self._varlist)
1711 # Handle options
1712 if convert_dates is None:
1713 convert_dates = self._convert_dates
1714 if convert_categoricals is None:
1715 convert_categoricals = self._convert_categoricals
1716 if convert_missing is None:
1717 convert_missing = self._convert_missing
1718 if preserve_dtypes is None:
1719 preserve_dtypes = self._preserve_dtypes
1720 if columns is None:
1721 columns = self._columns
1722 if order_categoricals is None:
1723 order_categoricals = self._order_categoricals
1724 if index_col is None:
1725 index_col = self._index_col
1727 if nrows is None:
1728 nrows = self._nobs
1730 if (self._format_version >= 117) and (not self._value_labels_read):
1731 self._can_read_value_labels = True
1732 self._read_strls()
1734 # Read data
1735 assert self._dtype is not None
1736 dtype = self._dtype
1737 max_read_len = (self._nobs - self._lines_read) * dtype.itemsize
1738 read_len = nrows * dtype.itemsize
1739 read_len = min(read_len, max_read_len)
1740 if read_len <= 0:
1741 # Iterator has finished, should never be here unless
1742 # we are reading the file incrementally
1743 if convert_categoricals:
1744 self._read_value_labels()
1745 raise StopIteration
1746 offset = self._lines_read * dtype.itemsize
1747 self._path_or_buf.seek(self._data_location + offset)
1748 read_lines = min(nrows, self._nobs - self._lines_read)
1749 raw_data = np.frombuffer(
1750 self._path_or_buf.read(read_len), dtype=dtype, count=read_lines
1751 )
1753 self._lines_read += read_lines
1754 if self._lines_read == self._nobs:
1755 self._can_read_value_labels = True
1756 self._data_read = True
1757 # if necessary, swap the byte order to native here
1758 if self._byteorder != self._native_byteorder:
1759 raw_data = raw_data.byteswap().newbyteorder()
1761 if convert_categoricals:
1762 self._read_value_labels()
1764 if len(raw_data) == 0:
1765 data = DataFrame(columns=self._varlist)
1766 else:
1767 data = DataFrame.from_records(raw_data)
1768 data.columns = Index(self._varlist)
1770 # If index is not specified, use actual row number rather than
1771 # restarting at 0 for each chunk.
1772 if index_col is None:
1773 rng = range(self._lines_read - read_lines, self._lines_read)
1774 data.index = Index(rng) # set attr instead of set_index to avoid copy
1776 if columns is not None:
1777 data = self._do_select_columns(data, columns)
1779 # Decode strings
1780 for col, typ in zip(data, self._typlist):
1781 if type(typ) is int:
1782 data[col] = data[col].apply(self._decode, convert_dtype=True)
1784 data = self._insert_strls(data)
1786 cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0]
1787 # Convert columns (if needed) to match input type
1788 ix = data.index
1789 requires_type_conversion = False
1790 data_formatted = []
1791 for i in cols_:
1792 if self._dtyplist[i] is not None:
1793 col = data.columns[i]
1794 dtype = data[col].dtype
1795 if dtype != np.dtype(object) and dtype != self._dtyplist[i]:
1796 requires_type_conversion = True
1797 data_formatted.append(
1798 (col, Series(data[col], ix, self._dtyplist[i]))
1799 )
1800 else:
1801 data_formatted.append((col, data[col]))
1802 if requires_type_conversion:
1803 data = DataFrame.from_dict(dict(data_formatted))
1804 del data_formatted
1806 data = self._do_convert_missing(data, convert_missing)
1808 if convert_dates:
1810 def any_startswith(x: str) -> bool:
1811 return any(x.startswith(fmt) for fmt in _date_formats)
1813 cols = np.where([any_startswith(x) for x in self._fmtlist])[0]
1814 for i in cols:
1815 col = data.columns[i]
1816 data[col] = _stata_elapsed_date_to_datetime_vec(
1817 data[col], self._fmtlist[i]
1818 )
1820 if convert_categoricals and self._format_version > 108:
1821 data = self._do_convert_categoricals(
1822 data, self._value_label_dict, self._lbllist, order_categoricals
1823 )
1825 if not preserve_dtypes:
1826 retyped_data = []
1827 convert = False
1828 for col in data:
1829 dtype = data[col].dtype
1830 if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1831 dtype = np.dtype(np.float64)
1832 convert = True
1833 elif dtype in (
1834 np.dtype(np.int8),
1835 np.dtype(np.int16),
1836 np.dtype(np.int32),
1837 ):
1838 dtype = np.dtype(np.int64)
1839 convert = True
1840 retyped_data.append((col, data[col].astype(dtype)))
1841 if convert:
1842 data = DataFrame.from_dict(dict(retyped_data))
1844 if index_col is not None:
1845 data = data.set_index(data.pop(index_col))
1847 return data
1849 def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
1850 # Check for missing values, and replace if found
1851 replacements = {}
1852 for i, colname in enumerate(data):
1853 fmt = self._typlist[i]
1854 if fmt not in self.VALID_RANGE:
1855 continue
1857 fmt = cast(str, fmt) # only strs in VALID_RANGE
1858 nmin, nmax = self.VALID_RANGE[fmt]
1859 series = data[colname]
1861 # appreciably faster to do this with ndarray instead of Series
1862 svals = series._values
1863 missing = (svals < nmin) | (svals > nmax)
1865 if not missing.any():
1866 continue
1868 if convert_missing: # Replacement follows Stata notation
1869 missing_loc = np.nonzero(np.asarray(missing))[0]
1870 umissing, umissing_loc = np.unique(series[missing], return_inverse=True)
1871 replacement = Series(series, dtype=object)
1872 for j, um in enumerate(umissing):
1873 missing_value = StataMissingValue(um)
1875 loc = missing_loc[umissing_loc == j]
1876 replacement.iloc[loc] = missing_value
1877 else: # All replacements are identical
1878 dtype = series.dtype
1879 if dtype not in (np.float32, np.float64):
1880 dtype = np.float64
1881 replacement = Series(series, dtype=dtype)
1882 if not replacement._values.flags["WRITEABLE"]:
1883 # only relevant for ArrayManager; construction
1884 # path for BlockManager ensures writeability
1885 replacement = replacement.copy()
1886 # Note: operating on ._values is much faster than directly
1887 # TODO: can we fix that?
1888 replacement._values[missing] = np.nan
1889 replacements[colname] = replacement
1891 if replacements:
1892 for col, value in replacements.items():
1893 data[col] = value
1894 return data
1896 def _insert_strls(self, data: DataFrame) -> DataFrame:
1897 if not hasattr(self, "GSO") or len(self.GSO) == 0:
1898 return data
1899 for i, typ in enumerate(self._typlist):
1900 if typ != "Q":
1901 continue
1902 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1903 data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]]
1904 return data
1906 def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame:
1907 if not self._column_selector_set:
1908 column_set = set(columns)
1909 if len(column_set) != len(columns):
1910 raise ValueError("columns contains duplicate entries")
1911 unmatched = column_set.difference(data.columns)
1912 if unmatched:
1913 joined = ", ".join(list(unmatched))
1914 raise ValueError(
1915 "The following columns were not "
1916 f"found in the Stata data set: {joined}"
1917 )
1918 # Copy information for retained columns for later processing
1919 dtyplist = []
1920 typlist = []
1921 fmtlist = []
1922 lbllist = []
1923 for col in columns:
1924 i = data.columns.get_loc(col)
1925 dtyplist.append(self._dtyplist[i])
1926 typlist.append(self._typlist[i])
1927 fmtlist.append(self._fmtlist[i])
1928 lbllist.append(self._lbllist[i])
1930 self._dtyplist = dtyplist
1931 self._typlist = typlist
1932 self._fmtlist = fmtlist
1933 self._lbllist = lbllist
1934 self._column_selector_set = True
1936 return data[columns]
1938 def _do_convert_categoricals(
1939 self,
1940 data: DataFrame,
1941 value_label_dict: dict[str, dict[float, str]],
1942 lbllist: Sequence[str],
1943 order_categoricals: bool,
1944 ) -> DataFrame:
1945 """
1946 Converts categorical columns to Categorical type.
1947 """
1948 value_labels = list(value_label_dict.keys())
1949 cat_converted_data = []
1950 for col, label in zip(data, lbllist):
1951 if label in value_labels:
1952 # Explicit call with ordered=True
1953 vl = value_label_dict[label]
1954 keys = np.array(list(vl.keys()))
1955 column = data[col]
1956 key_matches = column.isin(keys)
1957 if self._using_iterator and key_matches.all():
1958 initial_categories: np.ndarray | None = keys
1959 # If all categories are in the keys and we are iterating,
1960 # use the same keys for all chunks. If some are missing
1961 # value labels, then we will fall back to the categories
1962 # varying across chunks.
1963 else:
1964 if self._using_iterator:
1965 # warn is using an iterator
1966 warnings.warn(
1967 categorical_conversion_warning,
1968 CategoricalConversionWarning,
1969 stacklevel=find_stack_level(),
1970 )
1971 initial_categories = None
1972 cat_data = Categorical(
1973 column, categories=initial_categories, ordered=order_categoricals
1974 )
1975 if initial_categories is None:
1976 # If None here, then we need to match the cats in the Categorical
1977 categories = []
1978 for category in cat_data.categories:
1979 if category in vl:
1980 categories.append(vl[category])
1981 else:
1982 categories.append(category)
1983 else:
1984 # If all cats are matched, we can use the values
1985 categories = list(vl.values())
1986 try:
1987 # Try to catch duplicate categories
1988 # TODO: if we get a non-copying rename_categories, use that
1989 cat_data = cat_data.rename_categories(categories)
1990 except ValueError as err:
1991 vc = Series(categories, copy=False).value_counts()
1992 repeated_cats = list(vc.index[vc > 1])
1993 repeats = "-" * 80 + "\n" + "\n".join(repeated_cats)
1994 # GH 25772
1995 msg = f"""
1996Value labels for column {col} are not unique. These cannot be converted to
1997pandas categoricals.
1999Either read the file with `convert_categoricals` set to False or use the
2000low level interface in `StataReader` to separately read the values and the
2001value_labels.
2003The repeated labels are:
2004{repeats}
2005"""
2006 raise ValueError(msg) from err
2007 # TODO: is the next line needed above in the data(...) method?
2008 cat_series = Series(cat_data, index=data.index, copy=False)
2009 cat_converted_data.append((col, cat_series))
2010 else:
2011 cat_converted_data.append((col, data[col]))
2012 data = DataFrame(dict(cat_converted_data), copy=False)
2013 return data
2015 @property
2016 def data_label(self) -> str:
2017 """
2018 Return data label of Stata file.
2019 """
2020 self._ensure_open()
2021 return self._data_label
2023 @property
2024 def time_stamp(self) -> str:
2025 """
2026 Return time stamp of Stata file.
2027 """
2028 self._ensure_open()
2029 return self._time_stamp
2031 def variable_labels(self) -> dict[str, str]:
2032 """
2033 Return a dict associating each variable name with corresponding label.
2035 Returns
2036 -------
2037 dict
2038 """
2039 self._ensure_open()
2040 return dict(zip(self._varlist, self._variable_labels))
2042 def value_labels(self) -> dict[str, dict[float, str]]:
2043 """
2044 Return a nested dict associating each variable name to its value and label.
2046 Returns
2047 -------
2048 dict
2049 """
2050 if not self._value_labels_read:
2051 self._read_value_labels()
2053 return self._value_label_dict
2056@Appender(_read_stata_doc)
2057def read_stata(
2058 filepath_or_buffer: FilePath | ReadBuffer[bytes],
2059 *,
2060 convert_dates: bool = True,
2061 convert_categoricals: bool = True,
2062 index_col: str | None = None,
2063 convert_missing: bool = False,
2064 preserve_dtypes: bool = True,
2065 columns: Sequence[str] | None = None,
2066 order_categoricals: bool = True,
2067 chunksize: int | None = None,
2068 iterator: bool = False,
2069 compression: CompressionOptions = "infer",
2070 storage_options: StorageOptions = None,
2071) -> DataFrame | StataReader:
2072 reader = StataReader(
2073 filepath_or_buffer,
2074 convert_dates=convert_dates,
2075 convert_categoricals=convert_categoricals,
2076 index_col=index_col,
2077 convert_missing=convert_missing,
2078 preserve_dtypes=preserve_dtypes,
2079 columns=columns,
2080 order_categoricals=order_categoricals,
2081 chunksize=chunksize,
2082 storage_options=storage_options,
2083 compression=compression,
2084 )
2086 if iterator or chunksize:
2087 return reader
2089 with reader:
2090 return reader.read()
2093def _set_endianness(endianness: str) -> str:
2094 if endianness.lower() in ["<", "little"]:
2095 return "<"
2096 elif endianness.lower() in [">", "big"]:
2097 return ">"
2098 else: # pragma : no cover
2099 raise ValueError(f"Endianness {endianness} not understood")
2102def _pad_bytes(name: AnyStr, length: int) -> AnyStr:
2103 """
2104 Take a char string and pads it with null bytes until it's length chars.
2105 """
2106 if isinstance(name, bytes):
2107 return name + b"\x00" * (length - len(name))
2108 return name + "\x00" * (length - len(name))
2111def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
2112 """
2113 Convert from one of the stata date formats to a type in TYPE_MAP.
2114 """
2115 if fmt in [
2116 "tc",
2117 "%tc",
2118 "td",
2119 "%td",
2120 "tw",
2121 "%tw",
2122 "tm",
2123 "%tm",
2124 "tq",
2125 "%tq",
2126 "th",
2127 "%th",
2128 "ty",
2129 "%ty",
2130 ]:
2131 return np.dtype(np.float64) # Stata expects doubles for SIFs
2132 else:
2133 raise NotImplementedError(f"Format {fmt} not implemented")
2136def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict:
2137 new_dict = {}
2138 for key in convert_dates:
2139 if not convert_dates[key].startswith("%"): # make sure proper fmts
2140 convert_dates[key] = "%" + convert_dates[key]
2141 if key in varlist:
2142 new_dict.update({varlist.index(key): convert_dates[key]})
2143 else:
2144 if not isinstance(key, int):
2145 raise ValueError("convert_dates key must be a column or an integer")
2146 new_dict.update({key: convert_dates[key]})
2147 return new_dict
2150def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int:
2151 """
2152 Convert dtype types to stata types. Returns the byte of the given ordinal.
2153 See TYPE_MAP and comments for an explanation. This is also explained in
2154 the dta spec.
2155 1 - 244 are strings of this length
2156 Pandas Stata
2157 251 - for int8 byte
2158 252 - for int16 int
2159 253 - for int32 long
2160 254 - for float32 float
2161 255 - for double double
2163 If there are dates to convert, then dtype will already have the correct
2164 type inserted.
2165 """
2166 # TODO: expand to handle datetime to integer conversion
2167 if dtype.type is np.object_: # try to coerce it to the biggest string
2168 # not memory efficient, what else could we
2169 # do?
2170 itemsize = max_len_string_array(ensure_object(column._values))
2171 return max(itemsize, 1)
2172 elif dtype.type is np.float64:
2173 return 255
2174 elif dtype.type is np.float32:
2175 return 254
2176 elif dtype.type is np.int32:
2177 return 253
2178 elif dtype.type is np.int16:
2179 return 252
2180 elif dtype.type is np.int8:
2181 return 251
2182 else: # pragma : no cover
2183 raise NotImplementedError(f"Data type {dtype} not supported.")
2186def _dtype_to_default_stata_fmt(
2187 dtype, column: Series, dta_version: int = 114, force_strl: bool = False
2188) -> str:
2189 """
2190 Map numpy dtype to stata's default format for this type. Not terribly
2191 important since users can change this in Stata. Semantics are
2193 object -> "%DDs" where DD is the length of the string. If not a string,
2194 raise ValueError
2195 float64 -> "%10.0g"
2196 float32 -> "%9.0g"
2197 int64 -> "%9.0g"
2198 int32 -> "%12.0g"
2199 int16 -> "%8.0g"
2200 int8 -> "%8.0g"
2201 strl -> "%9s"
2202 """
2203 # TODO: Refactor to combine type with format
2204 # TODO: expand this to handle a default datetime format?
2205 if dta_version < 117:
2206 max_str_len = 244
2207 else:
2208 max_str_len = 2045
2209 if force_strl:
2210 return "%9s"
2211 if dtype.type is np.object_:
2212 itemsize = max_len_string_array(ensure_object(column._values))
2213 if itemsize > max_str_len:
2214 if dta_version >= 117:
2215 return "%9s"
2216 else:
2217 raise ValueError(excessive_string_length_error.format(column.name))
2218 return "%" + str(max(itemsize, 1)) + "s"
2219 elif dtype == np.float64:
2220 return "%10.0g"
2221 elif dtype == np.float32:
2222 return "%9.0g"
2223 elif dtype == np.int32:
2224 return "%12.0g"
2225 elif dtype in (np.int8, np.int16):
2226 return "%8.0g"
2227 else: # pragma : no cover
2228 raise NotImplementedError(f"Data type {dtype} not supported.")
2231@doc(
2232 storage_options=_shared_docs["storage_options"],
2233 compression_options=_shared_docs["compression_options"] % "fname",
2234)
2235class StataWriter(StataParser):
2236 """
2237 A class for writing Stata binary dta files
2239 Parameters
2240 ----------
2241 fname : path (string), buffer or path object
2242 string, path object (pathlib.Path or py._path.local.LocalPath) or
2243 object implementing a binary write() functions. If using a buffer
2244 then the buffer will not be automatically closed after the file
2245 is written.
2246 data : DataFrame
2247 Input to save
2248 convert_dates : dict
2249 Dictionary mapping columns containing datetime types to stata internal
2250 format to use when writing the dates. Options are 'tc', 'td', 'tm',
2251 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
2252 Datetime columns that do not have a conversion type specified will be
2253 converted to 'tc'. Raises NotImplementedError if a datetime column has
2254 timezone information
2255 write_index : bool
2256 Write the index to Stata dataset.
2257 byteorder : str
2258 Can be ">", "<", "little", or "big". default is `sys.byteorder`
2259 time_stamp : datetime
2260 A datetime to use as file creation date. Default is the current time
2261 data_label : str
2262 A label for the data set. Must be 80 characters or smaller.
2263 variable_labels : dict
2264 Dictionary containing columns as keys and variable labels as values.
2265 Each label must be 80 characters or smaller.
2266 {compression_options}
2268 .. versionadded:: 1.1.0
2270 .. versionchanged:: 1.4.0 Zstandard support.
2272 {storage_options}
2274 .. versionadded:: 1.2.0
2276 value_labels : dict of dicts
2277 Dictionary containing columns as keys and dictionaries of column value
2278 to labels as values. The combined length of all labels for a single
2279 variable must be 32,000 characters or smaller.
2281 .. versionadded:: 1.4.0
2283 Returns
2284 -------
2285 writer : StataWriter instance
2286 The StataWriter instance has a write_file method, which will
2287 write the file to the given `fname`.
2289 Raises
2290 ------
2291 NotImplementedError
2292 * If datetimes contain timezone information
2293 ValueError
2294 * Columns listed in convert_dates are neither datetime64[ns]
2295 or datetime.datetime
2296 * Column dtype is not representable in Stata
2297 * Column listed in convert_dates is not in DataFrame
2298 * Categorical label contains more than 32,000 characters
2300 Examples
2301 --------
2302 >>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b'])
2303 >>> writer = StataWriter('./data_file.dta', data)
2304 >>> writer.write_file()
2306 Directly write a zip file
2307 >>> compression = {{"method": "zip", "archive_name": "data_file.dta"}}
2308 >>> writer = StataWriter('./data_file.zip', data, compression=compression)
2309 >>> writer.write_file()
2311 Save a DataFrame with dates
2312 >>> from datetime import datetime
2313 >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
2314 >>> writer = StataWriter('./date_data_file.dta', data, {{'date' : 'tw'}})
2315 >>> writer.write_file()
2316 """
2318 _max_string_length = 244
2319 _encoding: Literal["latin-1", "utf-8"] = "latin-1"
2321 def __init__(
2322 self,
2323 fname: FilePath | WriteBuffer[bytes],
2324 data: DataFrame,
2325 convert_dates: dict[Hashable, str] | None = None,
2326 write_index: bool = True,
2327 byteorder: str | None = None,
2328 time_stamp: datetime.datetime | None = None,
2329 data_label: str | None = None,
2330 variable_labels: dict[Hashable, str] | None = None,
2331 compression: CompressionOptions = "infer",
2332 storage_options: StorageOptions = None,
2333 *,
2334 value_labels: dict[Hashable, dict[float, str]] | None = None,
2335 ) -> None:
2336 super().__init__()
2337 self.data = data
2338 self._convert_dates = {} if convert_dates is None else convert_dates
2339 self._write_index = write_index
2340 self._time_stamp = time_stamp
2341 self._data_label = data_label
2342 self._variable_labels = variable_labels
2343 self._non_cat_value_labels = value_labels
2344 self._value_labels: list[StataValueLabel] = []
2345 self._has_value_labels = np.array([], dtype=bool)
2346 self._compression = compression
2347 self._output_file: IO[bytes] | None = None
2348 self._converted_names: dict[Hashable, str] = {}
2349 # attach nobs, nvars, data, varlist, typlist
2350 self._prepare_pandas(data)
2351 self.storage_options = storage_options
2353 if byteorder is None:
2354 byteorder = sys.byteorder
2355 self._byteorder = _set_endianness(byteorder)
2356 self._fname = fname
2357 self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
2359 def _write(self, to_write: str) -> None:
2360 """
2361 Helper to call encode before writing to file for Python 3 compat.
2362 """
2363 self.handles.handle.write(to_write.encode(self._encoding))
2365 def _write_bytes(self, value: bytes) -> None:
2366 """
2367 Helper to assert file is open before writing.
2368 """
2369 self.handles.handle.write(value)
2371 def _prepare_non_cat_value_labels(
2372 self, data: DataFrame
2373 ) -> list[StataNonCatValueLabel]:
2374 """
2375 Check for value labels provided for non-categorical columns. Value
2376 labels
2377 """
2378 non_cat_value_labels: list[StataNonCatValueLabel] = []
2379 if self._non_cat_value_labels is None:
2380 return non_cat_value_labels
2382 for labname, labels in self._non_cat_value_labels.items():
2383 if labname in self._converted_names:
2384 colname = self._converted_names[labname]
2385 elif labname in data.columns:
2386 colname = str(labname)
2387 else:
2388 raise KeyError(
2389 f"Can't create value labels for {labname}, it wasn't "
2390 "found in the dataset."
2391 )
2393 if not is_numeric_dtype(data[colname].dtype):
2394 # Labels should not be passed explicitly for categorical
2395 # columns that will be converted to int
2396 raise ValueError(
2397 f"Can't create value labels for {labname}, value labels "
2398 "can only be applied to numeric columns."
2399 )
2400 svl = StataNonCatValueLabel(colname, labels, self._encoding)
2401 non_cat_value_labels.append(svl)
2402 return non_cat_value_labels
2404 def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
2405 """
2406 Check for categorical columns, retain categorical information for
2407 Stata file and convert categorical data to int
2408 """
2409 is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
2410 if not any(is_cat):
2411 return data
2413 self._has_value_labels |= np.array(is_cat)
2415 get_base_missing_value = StataMissingValue.get_base_missing_value
2416 data_formatted = []
2417 for col, col_is_cat in zip(data, is_cat):
2418 if col_is_cat:
2419 svl = StataValueLabel(data[col], encoding=self._encoding)
2420 self._value_labels.append(svl)
2421 dtype = data[col].cat.codes.dtype
2422 if dtype == np.int64:
2423 raise ValueError(
2424 "It is not possible to export "
2425 "int64-based categorical data to Stata."
2426 )
2427 values = data[col].cat.codes._values.copy()
2429 # Upcast if needed so that correct missing values can be set
2430 if values.max() >= get_base_missing_value(dtype):
2431 if dtype == np.int8:
2432 dtype = np.dtype(np.int16)
2433 elif dtype == np.int16:
2434 dtype = np.dtype(np.int32)
2435 else:
2436 dtype = np.dtype(np.float64)
2437 values = np.array(values, dtype=dtype)
2439 # Replace missing values with Stata missing value for type
2440 values[values == -1] = get_base_missing_value(dtype)
2441 data_formatted.append((col, values))
2442 else:
2443 data_formatted.append((col, data[col]))
2444 return DataFrame.from_dict(dict(data_formatted))
2446 def _replace_nans(self, data: DataFrame) -> DataFrame:
2447 # return data
2448 """
2449 Checks floating point data columns for nans, and replaces these with
2450 the generic Stata for missing value (.)
2451 """
2452 for c in data:
2453 dtype = data[c].dtype
2454 if dtype in (np.float32, np.float64):
2455 if dtype == np.float32:
2456 replacement = self.MISSING_VALUES["f"]
2457 else:
2458 replacement = self.MISSING_VALUES["d"]
2459 data[c] = data[c].fillna(replacement)
2461 return data
2463 def _update_strl_names(self) -> None:
2464 """No-op, forward compatibility"""
2466 def _validate_variable_name(self, name: str) -> str:
2467 """
2468 Validate variable names for Stata export.
2470 Parameters
2471 ----------
2472 name : str
2473 Variable name
2475 Returns
2476 -------
2477 str
2478 The validated name with invalid characters replaced with
2479 underscores.
2481 Notes
2482 -----
2483 Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9
2484 and _.
2485 """
2486 for c in name:
2487 if (
2488 (c < "A" or c > "Z")
2489 and (c < "a" or c > "z")
2490 and (c < "0" or c > "9")
2491 and c != "_"
2492 ):
2493 name = name.replace(c, "_")
2494 return name
2496 def _check_column_names(self, data: DataFrame) -> DataFrame:
2497 """
2498 Checks column names to ensure that they are valid Stata column names.
2499 This includes checks for:
2500 * Non-string names
2501 * Stata keywords
2502 * Variables that start with numbers
2503 * Variables with names that are too long
2505 When an illegal variable name is detected, it is converted, and if
2506 dates are exported, the variable name is propagated to the date
2507 conversion dictionary
2508 """
2509 converted_names: dict[Hashable, str] = {}
2510 columns = list(data.columns)
2511 original_columns = columns[:]
2513 duplicate_var_id = 0
2514 for j, name in enumerate(columns):
2515 orig_name = name
2516 if not isinstance(name, str):
2517 name = str(name)
2519 name = self._validate_variable_name(name)
2521 # Variable name must not be a reserved word
2522 if name in self.RESERVED_WORDS:
2523 name = "_" + name
2525 # Variable name may not start with a number
2526 if "0" <= name[0] <= "9":
2527 name = "_" + name
2529 name = name[: min(len(name), 32)]
2531 if not name == orig_name:
2532 # check for duplicates
2533 while columns.count(name) > 0:
2534 # prepend ascending number to avoid duplicates
2535 name = "_" + str(duplicate_var_id) + name
2536 name = name[: min(len(name), 32)]
2537 duplicate_var_id += 1
2538 converted_names[orig_name] = name
2540 columns[j] = name
2542 data.columns = Index(columns)
2544 # Check date conversion, and fix key if needed
2545 if self._convert_dates:
2546 for c, o in zip(columns, original_columns):
2547 if c != o:
2548 self._convert_dates[c] = self._convert_dates[o]
2549 del self._convert_dates[o]
2551 if converted_names:
2552 conversion_warning = []
2553 for orig_name, name in converted_names.items():
2554 msg = f"{orig_name} -> {name}"
2555 conversion_warning.append(msg)
2557 ws = invalid_name_doc.format("\n ".join(conversion_warning))
2558 warnings.warn(
2559 ws,
2560 InvalidColumnName,
2561 stacklevel=find_stack_level(),
2562 )
2564 self._converted_names = converted_names
2565 self._update_strl_names()
2567 return data
2569 def _set_formats_and_types(self, dtypes: Series) -> None:
2570 self.fmtlist: list[str] = []
2571 self.typlist: list[int] = []
2572 for col, dtype in dtypes.items():
2573 self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col]))
2574 self.typlist.append(_dtype_to_stata_type(dtype, self.data[col]))
2576 def _prepare_pandas(self, data: DataFrame) -> None:
2577 # NOTE: we might need a different API / class for pandas objects so
2578 # we can set different semantics - handle this with a PR to pandas.io
2580 data = data.copy()
2582 if self._write_index:
2583 temp = data.reset_index()
2584 if isinstance(temp, DataFrame):
2585 data = temp
2587 # Ensure column names are strings
2588 data = self._check_column_names(data)
2590 # Check columns for compatibility with stata, upcast if necessary
2591 # Raise if outside the supported range
2592 data = _cast_to_stata_types(data)
2594 # Replace NaNs with Stata missing values
2595 data = self._replace_nans(data)
2597 # Set all columns to initially unlabelled
2598 self._has_value_labels = np.repeat(False, data.shape[1])
2600 # Create value labels for non-categorical data
2601 non_cat_value_labels = self._prepare_non_cat_value_labels(data)
2603 non_cat_columns = [svl.labname for svl in non_cat_value_labels]
2604 has_non_cat_val_labels = data.columns.isin(non_cat_columns)
2605 self._has_value_labels |= has_non_cat_val_labels
2606 self._value_labels.extend(non_cat_value_labels)
2608 # Convert categoricals to int data, and strip labels
2609 data = self._prepare_categoricals(data)
2611 self.nobs, self.nvar = data.shape
2612 self.data = data
2613 self.varlist = data.columns.tolist()
2615 dtypes = data.dtypes
2617 # Ensure all date columns are converted
2618 for col in data:
2619 if col in self._convert_dates:
2620 continue
2621 if is_datetime64_dtype(data[col]):
2622 self._convert_dates[col] = "tc"
2624 self._convert_dates = _maybe_convert_to_int_keys(
2625 self._convert_dates, self.varlist
2626 )
2627 for key in self._convert_dates:
2628 new_type = _convert_datetime_to_stata_type(self._convert_dates[key])
2629 dtypes[key] = np.dtype(new_type)
2631 # Verify object arrays are strings and encode to bytes
2632 self._encode_strings()
2634 self._set_formats_and_types(dtypes)
2636 # set the given format for the datetime cols
2637 if self._convert_dates is not None:
2638 for key in self._convert_dates:
2639 if isinstance(key, int):
2640 self.fmtlist[key] = self._convert_dates[key]
2642 def _encode_strings(self) -> None:
2643 """
2644 Encode strings in dta-specific encoding
2646 Do not encode columns marked for date conversion or for strL
2647 conversion. The strL converter independently handles conversion and
2648 also accepts empty string arrays.
2649 """
2650 convert_dates = self._convert_dates
2651 # _convert_strl is not available in dta 114
2652 convert_strl = getattr(self, "_convert_strl", [])
2653 for i, col in enumerate(self.data):
2654 # Skip columns marked for date conversion or strl conversion
2655 if i in convert_dates or col in convert_strl:
2656 continue
2657 column = self.data[col]
2658 dtype = column.dtype
2659 if dtype.type is np.object_:
2660 inferred_dtype = infer_dtype(column, skipna=True)
2661 if not ((inferred_dtype == "string") or len(column) == 0):
2662 col = column.name
2663 raise ValueError(
2664 f"""\
2665Column `{col}` cannot be exported.\n\nOnly string-like object arrays
2666containing all strings or a mix of strings and None can be exported.
2667Object arrays containing only null values are prohibited. Other object
2668types cannot be exported and must first be converted to one of the
2669supported types."""
2670 )
2671 encoded = self.data[col].str.encode(self._encoding)
2672 # If larger than _max_string_length do nothing
2673 if (
2674 max_len_string_array(ensure_object(encoded._values))
2675 <= self._max_string_length
2676 ):
2677 self.data[col] = encoded
2679 def write_file(self) -> None:
2680 """
2681 Export DataFrame object to Stata dta format.
2682 """
2683 with get_handle(
2684 self._fname,
2685 "wb",
2686 compression=self._compression,
2687 is_text=False,
2688 storage_options=self.storage_options,
2689 ) as self.handles:
2690 if self.handles.compression["method"] is not None:
2691 # ZipFile creates a file (with the same name) for each write call.
2692 # Write it first into a buffer and then write the buffer to the ZipFile.
2693 self._output_file, self.handles.handle = self.handles.handle, BytesIO()
2694 self.handles.created_handles.append(self.handles.handle)
2696 try:
2697 self._write_header(
2698 data_label=self._data_label, time_stamp=self._time_stamp
2699 )
2700 self._write_map()
2701 self._write_variable_types()
2702 self._write_varnames()
2703 self._write_sortlist()
2704 self._write_formats()
2705 self._write_value_label_names()
2706 self._write_variable_labels()
2707 self._write_expansion_fields()
2708 self._write_characteristics()
2709 records = self._prepare_data()
2710 self._write_data(records)
2711 self._write_strls()
2712 self._write_value_labels()
2713 self._write_file_close_tag()
2714 self._write_map()
2715 self._close()
2716 except Exception as exc:
2717 self.handles.close()
2718 if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile(
2719 self._fname
2720 ):
2721 try:
2722 os.unlink(self._fname)
2723 except OSError:
2724 warnings.warn(
2725 f"This save was not successful but {self._fname} could not "
2726 "be deleted. This file is not valid.",
2727 ResourceWarning,
2728 stacklevel=find_stack_level(),
2729 )
2730 raise exc
2732 def _close(self) -> None:
2733 """
2734 Close the file if it was created by the writer.
2736 If a buffer or file-like object was passed in, for example a GzipFile,
2737 then leave this file open for the caller to close.
2738 """
2739 # write compression
2740 if self._output_file is not None:
2741 assert isinstance(self.handles.handle, BytesIO)
2742 bio, self.handles.handle = self.handles.handle, self._output_file
2743 self.handles.handle.write(bio.getvalue())
2745 def _write_map(self) -> None:
2746 """No-op, future compatibility"""
2748 def _write_file_close_tag(self) -> None:
2749 """No-op, future compatibility"""
2751 def _write_characteristics(self) -> None:
2752 """No-op, future compatibility"""
2754 def _write_strls(self) -> None:
2755 """No-op, future compatibility"""
2757 def _write_expansion_fields(self) -> None:
2758 """Write 5 zeros for expansion fields"""
2759 self._write(_pad_bytes("", 5))
2761 def _write_value_labels(self) -> None:
2762 for vl in self._value_labels:
2763 self._write_bytes(vl.generate_value_label(self._byteorder))
2765 def _write_header(
2766 self,
2767 data_label: str | None = None,
2768 time_stamp: datetime.datetime | None = None,
2769 ) -> None:
2770 byteorder = self._byteorder
2771 # ds_format - just use 114
2772 self._write_bytes(struct.pack("b", 114))
2773 # byteorder
2774 self._write(byteorder == ">" and "\x01" or "\x02")
2775 # filetype
2776 self._write("\x01")
2777 # unused
2778 self._write("\x00")
2779 # number of vars, 2 bytes
2780 self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2])
2781 # number of obs, 4 bytes
2782 self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4])
2783 # data label 81 bytes, char, null terminated
2784 if data_label is None:
2785 self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80)))
2786 else:
2787 self._write_bytes(
2788 self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))
2789 )
2790 # time stamp, 18 bytes, char, null terminated
2791 # format dd Mon yyyy hh:mm
2792 if time_stamp is None:
2793 time_stamp = datetime.datetime.now()
2794 elif not isinstance(time_stamp, datetime.datetime):
2795 raise ValueError("time_stamp should be datetime type")
2796 # GH #13856
2797 # Avoid locale-specific month conversion
2798 months = [
2799 "Jan",
2800 "Feb",
2801 "Mar",
2802 "Apr",
2803 "May",
2804 "Jun",
2805 "Jul",
2806 "Aug",
2807 "Sep",
2808 "Oct",
2809 "Nov",
2810 "Dec",
2811 ]
2812 month_lookup = {i + 1: month for i, month in enumerate(months)}
2813 ts = (
2814 time_stamp.strftime("%d ")
2815 + month_lookup[time_stamp.month]
2816 + time_stamp.strftime(" %Y %H:%M")
2817 )
2818 self._write_bytes(self._null_terminate_bytes(ts))
2820 def _write_variable_types(self) -> None:
2821 for typ in self.typlist:
2822 self._write_bytes(struct.pack("B", typ))
2824 def _write_varnames(self) -> None:
2825 # varlist names are checked by _check_column_names
2826 # varlist, requires null terminated
2827 for name in self.varlist:
2828 name = self._null_terminate_str(name)
2829 name = _pad_bytes(name[:32], 33)
2830 self._write(name)
2832 def _write_sortlist(self) -> None:
2833 # srtlist, 2*(nvar+1), int array, encoded by byteorder
2834 srtlist = _pad_bytes("", 2 * (self.nvar + 1))
2835 self._write(srtlist)
2837 def _write_formats(self) -> None:
2838 # fmtlist, 49*nvar, char array
2839 for fmt in self.fmtlist:
2840 self._write(_pad_bytes(fmt, 49))
2842 def _write_value_label_names(self) -> None:
2843 # lbllist, 33*nvar, char array
2844 for i in range(self.nvar):
2845 # Use variable name when categorical
2846 if self._has_value_labels[i]:
2847 name = self.varlist[i]
2848 name = self._null_terminate_str(name)
2849 name = _pad_bytes(name[:32], 33)
2850 self._write(name)
2851 else: # Default is empty label
2852 self._write(_pad_bytes("", 33))
2854 def _write_variable_labels(self) -> None:
2855 # Missing labels are 80 blank characters plus null termination
2856 blank = _pad_bytes("", 81)
2858 if self._variable_labels is None:
2859 for i in range(self.nvar):
2860 self._write(blank)
2861 return
2863 for col in self.data:
2864 if col in self._variable_labels:
2865 label = self._variable_labels[col]
2866 if len(label) > 80:
2867 raise ValueError("Variable labels must be 80 characters or fewer")
2868 is_latin1 = all(ord(c) < 256 for c in label)
2869 if not is_latin1:
2870 raise ValueError(
2871 "Variable labels must contain only characters that "
2872 "can be encoded in Latin-1"
2873 )
2874 self._write(_pad_bytes(label, 81))
2875 else:
2876 self._write(blank)
2878 def _convert_strls(self, data: DataFrame) -> DataFrame:
2879 """No-op, future compatibility"""
2880 return data
2882 def _prepare_data(self) -> np.recarray:
2883 data = self.data
2884 typlist = self.typlist
2885 convert_dates = self._convert_dates
2886 # 1. Convert dates
2887 if self._convert_dates is not None:
2888 for i, col in enumerate(data):
2889 if i in convert_dates:
2890 data[col] = _datetime_to_stata_elapsed_vec(
2891 data[col], self.fmtlist[i]
2892 )
2893 # 2. Convert strls
2894 data = self._convert_strls(data)
2896 # 3. Convert bad string data to '' and pad to correct length
2897 dtypes = {}
2898 native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
2899 for i, col in enumerate(data):
2900 typ = typlist[i]
2901 if typ <= self._max_string_length:
2902 data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,))
2903 stype = f"S{typ}"
2904 dtypes[col] = stype
2905 data[col] = data[col].astype(stype)
2906 else:
2907 dtype = data[col].dtype
2908 if not native_byteorder:
2909 dtype = dtype.newbyteorder(self._byteorder)
2910 dtypes[col] = dtype
2912 return data.to_records(index=False, column_dtypes=dtypes)
2914 def _write_data(self, records: np.recarray) -> None:
2915 self._write_bytes(records.tobytes())
2917 @staticmethod
2918 def _null_terminate_str(s: str) -> str:
2919 s += "\x00"
2920 return s
2922 def _null_terminate_bytes(self, s: str) -> bytes:
2923 return self._null_terminate_str(s).encode(self._encoding)
2926def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int:
2927 """
2928 Converts dtype types to stata types. Returns the byte of the given ordinal.
2929 See TYPE_MAP and comments for an explanation. This is also explained in
2930 the dta spec.
2931 1 - 2045 are strings of this length
2932 Pandas Stata
2933 32768 - for object strL
2934 65526 - for int8 byte
2935 65527 - for int16 int
2936 65528 - for int32 long
2937 65529 - for float32 float
2938 65530 - for double double
2940 If there are dates to convert, then dtype will already have the correct
2941 type inserted.
2942 """
2943 # TODO: expand to handle datetime to integer conversion
2944 if force_strl:
2945 return 32768
2946 if dtype.type is np.object_: # try to coerce it to the biggest string
2947 # not memory efficient, what else could we
2948 # do?
2949 itemsize = max_len_string_array(ensure_object(column._values))
2950 itemsize = max(itemsize, 1)
2951 if itemsize <= 2045:
2952 return itemsize
2953 return 32768
2954 elif dtype.type is np.float64:
2955 return 65526
2956 elif dtype.type is np.float32:
2957 return 65527
2958 elif dtype.type is np.int32:
2959 return 65528
2960 elif dtype.type is np.int16:
2961 return 65529
2962 elif dtype.type is np.int8:
2963 return 65530
2964 else: # pragma : no cover
2965 raise NotImplementedError(f"Data type {dtype} not supported.")
2968def _pad_bytes_new(name: str | bytes, length: int) -> bytes:
2969 """
2970 Takes a bytes instance and pads it with null bytes until it's length chars.
2971 """
2972 if isinstance(name, str):
2973 name = bytes(name, "utf-8")
2974 return name + b"\x00" * (length - len(name))
2977class StataStrLWriter:
2978 """
2979 Converter for Stata StrLs
2981 Stata StrLs map 8 byte values to strings which are stored using a
2982 dictionary-like format where strings are keyed to two values.
2984 Parameters
2985 ----------
2986 df : DataFrame
2987 DataFrame to convert
2988 columns : Sequence[str]
2989 List of columns names to convert to StrL
2990 version : int, optional
2991 dta version. Currently supports 117, 118 and 119
2992 byteorder : str, optional
2993 Can be ">", "<", "little", or "big". default is `sys.byteorder`
2995 Notes
2996 -----
2997 Supports creation of the StrL block of a dta file for dta versions
2998 117, 118 and 119. These differ in how the GSO is stored. 118 and
2999 119 store the GSO lookup value as a uint32 and a uint64, while 117
3000 uses two uint32s. 118 and 119 also encode all strings as unicode
3001 which is required by the format. 117 uses 'latin-1' a fixed width
3002 encoding that extends the 7-bit ascii table with an additional 128
3003 characters.
3004 """
3006 def __init__(
3007 self,
3008 df: DataFrame,
3009 columns: Sequence[str],
3010 version: int = 117,
3011 byteorder: str | None = None,
3012 ) -> None:
3013 if version not in (117, 118, 119):
3014 raise ValueError("Only dta versions 117, 118 and 119 supported")
3015 self._dta_ver = version
3017 self.df = df
3018 self.columns = columns
3019 self._gso_table = {"": (0, 0)}
3020 if byteorder is None:
3021 byteorder = sys.byteorder
3022 self._byteorder = _set_endianness(byteorder)
3024 gso_v_type = "I" # uint32
3025 gso_o_type = "Q" # uint64
3026 self._encoding = "utf-8"
3027 if version == 117:
3028 o_size = 4
3029 gso_o_type = "I" # 117 used uint32
3030 self._encoding = "latin-1"
3031 elif version == 118:
3032 o_size = 6
3033 else: # version == 119
3034 o_size = 5
3035 self._o_offet = 2 ** (8 * (8 - o_size))
3036 self._gso_o_type = gso_o_type
3037 self._gso_v_type = gso_v_type
3039 def _convert_key(self, key: tuple[int, int]) -> int:
3040 v, o = key
3041 return v + self._o_offet * o
3043 def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]:
3044 """
3045 Generates the GSO lookup table for the DataFrame
3047 Returns
3048 -------
3049 gso_table : dict
3050 Ordered dictionary using the string found as keys
3051 and their lookup position (v,o) as values
3052 gso_df : DataFrame
3053 DataFrame where strl columns have been converted to
3054 (v,o) values
3056 Notes
3057 -----
3058 Modifies the DataFrame in-place.
3060 The DataFrame returned encodes the (v,o) values as uint64s. The
3061 encoding depends on the dta version, and can be expressed as
3063 enc = v + o * 2 ** (o_size * 8)
3065 so that v is stored in the lower bits and o is in the upper
3066 bits. o_size is
3068 * 117: 4
3069 * 118: 6
3070 * 119: 5
3071 """
3072 gso_table = self._gso_table
3073 gso_df = self.df
3074 columns = list(gso_df.columns)
3075 selected = gso_df[self.columns]
3076 col_index = [(col, columns.index(col)) for col in self.columns]
3077 keys = np.empty(selected.shape, dtype=np.uint64)
3078 for o, (idx, row) in enumerate(selected.iterrows()):
3079 for j, (col, v) in enumerate(col_index):
3080 val = row[col]
3081 # Allow columns with mixed str and None (GH 23633)
3082 val = "" if val is None else val
3083 key = gso_table.get(val, None)
3084 if key is None:
3085 # Stata prefers human numbers
3086 key = (v + 1, o + 1)
3087 gso_table[val] = key
3088 keys[o, j] = self._convert_key(key)
3089 for i, col in enumerate(self.columns):
3090 gso_df[col] = keys[:, i]
3092 return gso_table, gso_df
3094 def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes:
3095 """
3096 Generates the binary blob of GSOs that is written to the dta file.
3098 Parameters
3099 ----------
3100 gso_table : dict
3101 Ordered dictionary (str, vo)
3103 Returns
3104 -------
3105 gso : bytes
3106 Binary content of dta file to be placed between strl tags
3108 Notes
3109 -----
3110 Output format depends on dta version. 117 uses two uint32s to
3111 express v and o while 118+ uses a uint32 for v and a uint64 for o.
3112 """
3113 # Format information
3114 # Length includes null term
3115 # 117
3116 # GSOvvvvooootllllxxxxxxxxxxxxxxx...x
3117 # 3 u4 u4 u1 u4 string + null term
3118 #
3119 # 118, 119
3120 # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
3121 # 3 u4 u8 u1 u4 string + null term
3123 bio = BytesIO()
3124 gso = bytes("GSO", "ascii")
3125 gso_type = struct.pack(self._byteorder + "B", 130)
3126 null = struct.pack(self._byteorder + "B", 0)
3127 v_type = self._byteorder + self._gso_v_type
3128 o_type = self._byteorder + self._gso_o_type
3129 len_type = self._byteorder + "I"
3130 for strl, vo in gso_table.items():
3131 if vo == (0, 0):
3132 continue
3133 v, o = vo
3135 # GSO
3136 bio.write(gso)
3138 # vvvv
3139 bio.write(struct.pack(v_type, v))
3141 # oooo / oooooooo
3142 bio.write(struct.pack(o_type, o))
3144 # t
3145 bio.write(gso_type)
3147 # llll
3148 utf8_string = bytes(strl, "utf-8")
3149 bio.write(struct.pack(len_type, len(utf8_string) + 1))
3151 # xxx...xxx
3152 bio.write(utf8_string)
3153 bio.write(null)
3155 return bio.getvalue()
3158class StataWriter117(StataWriter):
3159 """
3160 A class for writing Stata binary dta files in Stata 13 format (117)
3162 Parameters
3163 ----------
3164 fname : path (string), buffer or path object
3165 string, path object (pathlib.Path or py._path.local.LocalPath) or
3166 object implementing a binary write() functions. If using a buffer
3167 then the buffer will not be automatically closed after the file
3168 is written.
3169 data : DataFrame
3170 Input to save
3171 convert_dates : dict
3172 Dictionary mapping columns containing datetime types to stata internal
3173 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3174 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3175 Datetime columns that do not have a conversion type specified will be
3176 converted to 'tc'. Raises NotImplementedError if a datetime column has
3177 timezone information
3178 write_index : bool
3179 Write the index to Stata dataset.
3180 byteorder : str
3181 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3182 time_stamp : datetime
3183 A datetime to use as file creation date. Default is the current time
3184 data_label : str
3185 A label for the data set. Must be 80 characters or smaller.
3186 variable_labels : dict
3187 Dictionary containing columns as keys and variable labels as values.
3188 Each label must be 80 characters or smaller.
3189 convert_strl : list
3190 List of columns names to convert to Stata StrL format. Columns with
3191 more than 2045 characters are automatically written as StrL.
3192 Smaller columns can be converted by including the column name. Using
3193 StrLs can reduce output file size when strings are longer than 8
3194 characters, and either frequently repeated or sparse.
3195 {compression_options}
3197 .. versionadded:: 1.1.0
3199 .. versionchanged:: 1.4.0 Zstandard support.
3201 value_labels : dict of dicts
3202 Dictionary containing columns as keys and dictionaries of column value
3203 to labels as values. The combined length of all labels for a single
3204 variable must be 32,000 characters or smaller.
3206 .. versionadded:: 1.4.0
3208 Returns
3209 -------
3210 writer : StataWriter117 instance
3211 The StataWriter117 instance has a write_file method, which will
3212 write the file to the given `fname`.
3214 Raises
3215 ------
3216 NotImplementedError
3217 * If datetimes contain timezone information
3218 ValueError
3219 * Columns listed in convert_dates are neither datetime64[ns]
3220 or datetime.datetime
3221 * Column dtype is not representable in Stata
3222 * Column listed in convert_dates is not in DataFrame
3223 * Categorical label contains more than 32,000 characters
3225 Examples
3226 --------
3227 >>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
3228 >>> writer = pd.io.stata.StataWriter117('./data_file.dta', data)
3229 >>> writer.write_file()
3231 Directly write a zip file
3232 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3233 >>> writer = pd.io.stata.StataWriter117(
3234 ... './data_file.zip', data, compression=compression
3235 ... )
3236 >>> writer.write_file()
3238 Or with long strings stored in strl format
3239 >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
3240 ... columns=['strls'])
3241 >>> writer = pd.io.stata.StataWriter117(
3242 ... './data_file_with_long_strings.dta', data, convert_strl=['strls'])
3243 >>> writer.write_file()
3244 """
3246 _max_string_length = 2045
3247 _dta_version = 117
3249 def __init__(
3250 self,
3251 fname: FilePath | WriteBuffer[bytes],
3252 data: DataFrame,
3253 convert_dates: dict[Hashable, str] | None = None,
3254 write_index: bool = True,
3255 byteorder: str | None = None,
3256 time_stamp: datetime.datetime | None = None,
3257 data_label: str | None = None,
3258 variable_labels: dict[Hashable, str] | None = None,
3259 convert_strl: Sequence[Hashable] | None = None,
3260 compression: CompressionOptions = "infer",
3261 storage_options: StorageOptions = None,
3262 *,
3263 value_labels: dict[Hashable, dict[float, str]] | None = None,
3264 ) -> None:
3265 # Copy to new list since convert_strl might be modified later
3266 self._convert_strl: list[Hashable] = []
3267 if convert_strl is not None:
3268 self._convert_strl.extend(convert_strl)
3270 super().__init__(
3271 fname,
3272 data,
3273 convert_dates,
3274 write_index,
3275 byteorder=byteorder,
3276 time_stamp=time_stamp,
3277 data_label=data_label,
3278 variable_labels=variable_labels,
3279 value_labels=value_labels,
3280 compression=compression,
3281 storage_options=storage_options,
3282 )
3283 self._map: dict[str, int] = {}
3284 self._strl_blob = b""
3286 @staticmethod
3287 def _tag(val: str | bytes, tag: str) -> bytes:
3288 """Surround val with <tag></tag>"""
3289 if isinstance(val, str):
3290 val = bytes(val, "utf-8")
3291 return bytes("<" + tag + ">", "utf-8") + val + bytes("</" + tag + ">", "utf-8")
3293 def _update_map(self, tag: str) -> None:
3294 """Update map location for tag with file position"""
3295 assert self.handles.handle is not None
3296 self._map[tag] = self.handles.handle.tell()
3298 def _write_header(
3299 self,
3300 data_label: str | None = None,
3301 time_stamp: datetime.datetime | None = None,
3302 ) -> None:
3303 """Write the file header"""
3304 byteorder = self._byteorder
3305 self._write_bytes(bytes("<stata_dta>", "utf-8"))
3306 bio = BytesIO()
3307 # ds_format - 117
3308 bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release"))
3309 # byteorder
3310 bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder"))
3311 # number of vars, 2 bytes in 117 and 118, 4 byte in 119
3312 nvar_type = "H" if self._dta_version <= 118 else "I"
3313 bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K"))
3314 # 117 uses 4 bytes, 118 uses 8
3315 nobs_size = "I" if self._dta_version == 117 else "Q"
3316 bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N"))
3317 # data label 81 bytes, char, null terminated
3318 label = data_label[:80] if data_label is not None else ""
3319 encoded_label = label.encode(self._encoding)
3320 label_size = "B" if self._dta_version == 117 else "H"
3321 label_len = struct.pack(byteorder + label_size, len(encoded_label))
3322 encoded_label = label_len + encoded_label
3323 bio.write(self._tag(encoded_label, "label"))
3324 # time stamp, 18 bytes, char, null terminated
3325 # format dd Mon yyyy hh:mm
3326 if time_stamp is None:
3327 time_stamp = datetime.datetime.now()
3328 elif not isinstance(time_stamp, datetime.datetime):
3329 raise ValueError("time_stamp should be datetime type")
3330 # Avoid locale-specific month conversion
3331 months = [
3332 "Jan",
3333 "Feb",
3334 "Mar",
3335 "Apr",
3336 "May",
3337 "Jun",
3338 "Jul",
3339 "Aug",
3340 "Sep",
3341 "Oct",
3342 "Nov",
3343 "Dec",
3344 ]
3345 month_lookup = {i + 1: month for i, month in enumerate(months)}
3346 ts = (
3347 time_stamp.strftime("%d ")
3348 + month_lookup[time_stamp.month]
3349 + time_stamp.strftime(" %Y %H:%M")
3350 )
3351 # '\x11' added due to inspection of Stata file
3352 stata_ts = b"\x11" + bytes(ts, "utf-8")
3353 bio.write(self._tag(stata_ts, "timestamp"))
3354 self._write_bytes(self._tag(bio.getvalue(), "header"))
3356 def _write_map(self) -> None:
3357 """
3358 Called twice during file write. The first populates the values in
3359 the map with 0s. The second call writes the final map locations when
3360 all blocks have been written.
3361 """
3362 if not self._map:
3363 self._map = {
3364 "stata_data": 0,
3365 "map": self.handles.handle.tell(),
3366 "variable_types": 0,
3367 "varnames": 0,
3368 "sortlist": 0,
3369 "formats": 0,
3370 "value_label_names": 0,
3371 "variable_labels": 0,
3372 "characteristics": 0,
3373 "data": 0,
3374 "strls": 0,
3375 "value_labels": 0,
3376 "stata_data_close": 0,
3377 "end-of-file": 0,
3378 }
3379 # Move to start of map
3380 self.handles.handle.seek(self._map["map"])
3381 bio = BytesIO()
3382 for val in self._map.values():
3383 bio.write(struct.pack(self._byteorder + "Q", val))
3384 self._write_bytes(self._tag(bio.getvalue(), "map"))
3386 def _write_variable_types(self) -> None:
3387 self._update_map("variable_types")
3388 bio = BytesIO()
3389 for typ in self.typlist:
3390 bio.write(struct.pack(self._byteorder + "H", typ))
3391 self._write_bytes(self._tag(bio.getvalue(), "variable_types"))
3393 def _write_varnames(self) -> None:
3394 self._update_map("varnames")
3395 bio = BytesIO()
3396 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3397 vn_len = 32 if self._dta_version == 117 else 128
3398 for name in self.varlist:
3399 name = self._null_terminate_str(name)
3400 name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1)
3401 bio.write(name)
3402 self._write_bytes(self._tag(bio.getvalue(), "varnames"))
3404 def _write_sortlist(self) -> None:
3405 self._update_map("sortlist")
3406 sort_size = 2 if self._dta_version < 119 else 4
3407 self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist"))
3409 def _write_formats(self) -> None:
3410 self._update_map("formats")
3411 bio = BytesIO()
3412 fmt_len = 49 if self._dta_version == 117 else 57
3413 for fmt in self.fmtlist:
3414 bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len))
3415 self._write_bytes(self._tag(bio.getvalue(), "formats"))
3417 def _write_value_label_names(self) -> None:
3418 self._update_map("value_label_names")
3419 bio = BytesIO()
3420 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3421 vl_len = 32 if self._dta_version == 117 else 128
3422 for i in range(self.nvar):
3423 # Use variable name when categorical
3424 name = "" # default name
3425 if self._has_value_labels[i]:
3426 name = self.varlist[i]
3427 name = self._null_terminate_str(name)
3428 encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
3429 bio.write(encoded_name)
3430 self._write_bytes(self._tag(bio.getvalue(), "value_label_names"))
3432 def _write_variable_labels(self) -> None:
3433 # Missing labels are 80 blank characters plus null termination
3434 self._update_map("variable_labels")
3435 bio = BytesIO()
3436 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3437 vl_len = 80 if self._dta_version == 117 else 320
3438 blank = _pad_bytes_new("", vl_len + 1)
3440 if self._variable_labels is None:
3441 for _ in range(self.nvar):
3442 bio.write(blank)
3443 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3444 return
3446 for col in self.data:
3447 if col in self._variable_labels:
3448 label = self._variable_labels[col]
3449 if len(label) > 80:
3450 raise ValueError("Variable labels must be 80 characters or fewer")
3451 try:
3452 encoded = label.encode(self._encoding)
3453 except UnicodeEncodeError as err:
3454 raise ValueError(
3455 "Variable labels must contain only characters that "
3456 f"can be encoded in {self._encoding}"
3457 ) from err
3459 bio.write(_pad_bytes_new(encoded, vl_len + 1))
3460 else:
3461 bio.write(blank)
3462 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3464 def _write_characteristics(self) -> None:
3465 self._update_map("characteristics")
3466 self._write_bytes(self._tag(b"", "characteristics"))
3468 def _write_data(self, records) -> None:
3469 self._update_map("data")
3470 self._write_bytes(b"<data>")
3471 self._write_bytes(records.tobytes())
3472 self._write_bytes(b"</data>")
3474 def _write_strls(self) -> None:
3475 self._update_map("strls")
3476 self._write_bytes(self._tag(self._strl_blob, "strls"))
3478 def _write_expansion_fields(self) -> None:
3479 """No-op in dta 117+"""
3481 def _write_value_labels(self) -> None:
3482 self._update_map("value_labels")
3483 bio = BytesIO()
3484 for vl in self._value_labels:
3485 lab = vl.generate_value_label(self._byteorder)
3486 lab = self._tag(lab, "lbl")
3487 bio.write(lab)
3488 self._write_bytes(self._tag(bio.getvalue(), "value_labels"))
3490 def _write_file_close_tag(self) -> None:
3491 self._update_map("stata_data_close")
3492 self._write_bytes(bytes("</stata_dta>", "utf-8"))
3493 self._update_map("end-of-file")
3495 def _update_strl_names(self) -> None:
3496 """
3497 Update column names for conversion to strl if they might have been
3498 changed to comply with Stata naming rules
3499 """
3500 # Update convert_strl if names changed
3501 for orig, new in self._converted_names.items():
3502 if orig in self._convert_strl:
3503 idx = self._convert_strl.index(orig)
3504 self._convert_strl[idx] = new
3506 def _convert_strls(self, data: DataFrame) -> DataFrame:
3507 """
3508 Convert columns to StrLs if either very large or in the
3509 convert_strl variable
3510 """
3511 convert_cols = [
3512 col
3513 for i, col in enumerate(data)
3514 if self.typlist[i] == 32768 or col in self._convert_strl
3515 ]
3517 if convert_cols:
3518 ssw = StataStrLWriter(data, convert_cols, version=self._dta_version)
3519 tab, new_data = ssw.generate_table()
3520 data = new_data
3521 self._strl_blob = ssw.generate_blob(tab)
3522 return data
3524 def _set_formats_and_types(self, dtypes: Series) -> None:
3525 self.typlist = []
3526 self.fmtlist = []
3527 for col, dtype in dtypes.items():
3528 force_strl = col in self._convert_strl
3529 fmt = _dtype_to_default_stata_fmt(
3530 dtype,
3531 self.data[col],
3532 dta_version=self._dta_version,
3533 force_strl=force_strl,
3534 )
3535 self.fmtlist.append(fmt)
3536 self.typlist.append(
3537 _dtype_to_stata_type_117(dtype, self.data[col], force_strl)
3538 )
3541class StataWriterUTF8(StataWriter117):
3542 """
3543 Stata binary dta file writing in Stata 15 (118) and 16 (119) formats
3545 DTA 118 and 119 format files support unicode string data (both fixed
3546 and strL) format. Unicode is also supported in value labels, variable
3547 labels and the dataset label. Format 119 is automatically used if the
3548 file contains more than 32,767 variables.
3550 Parameters
3551 ----------
3552 fname : path (string), buffer or path object
3553 string, path object (pathlib.Path or py._path.local.LocalPath) or
3554 object implementing a binary write() functions. If using a buffer
3555 then the buffer will not be automatically closed after the file
3556 is written.
3557 data : DataFrame
3558 Input to save
3559 convert_dates : dict, default None
3560 Dictionary mapping columns containing datetime types to stata internal
3561 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3562 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3563 Datetime columns that do not have a conversion type specified will be
3564 converted to 'tc'. Raises NotImplementedError if a datetime column has
3565 timezone information
3566 write_index : bool, default True
3567 Write the index to Stata dataset.
3568 byteorder : str, default None
3569 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3570 time_stamp : datetime, default None
3571 A datetime to use as file creation date. Default is the current time
3572 data_label : str, default None
3573 A label for the data set. Must be 80 characters or smaller.
3574 variable_labels : dict, default None
3575 Dictionary containing columns as keys and variable labels as values.
3576 Each label must be 80 characters or smaller.
3577 convert_strl : list, default None
3578 List of columns names to convert to Stata StrL format. Columns with
3579 more than 2045 characters are automatically written as StrL.
3580 Smaller columns can be converted by including the column name. Using
3581 StrLs can reduce output file size when strings are longer than 8
3582 characters, and either frequently repeated or sparse.
3583 version : int, default None
3584 The dta version to use. By default, uses the size of data to determine
3585 the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3586 for storing larger DataFrames.
3587 {compression_options}
3589 .. versionadded:: 1.1.0
3591 .. versionchanged:: 1.4.0 Zstandard support.
3593 value_labels : dict of dicts
3594 Dictionary containing columns as keys and dictionaries of column value
3595 to labels as values. The combined length of all labels for a single
3596 variable must be 32,000 characters or smaller.
3598 .. versionadded:: 1.4.0
3600 Returns
3601 -------
3602 StataWriterUTF8
3603 The instance has a write_file method, which will write the file to the
3604 given `fname`.
3606 Raises
3607 ------
3608 NotImplementedError
3609 * If datetimes contain timezone information
3610 ValueError
3611 * Columns listed in convert_dates are neither datetime64[ns]
3612 or datetime.datetime
3613 * Column dtype is not representable in Stata
3614 * Column listed in convert_dates is not in DataFrame
3615 * Categorical label contains more than 32,000 characters
3617 Examples
3618 --------
3619 Using Unicode data and column names
3621 >>> from pandas.io.stata import StataWriterUTF8
3622 >>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ'])
3623 >>> writer = StataWriterUTF8('./data_file.dta', data)
3624 >>> writer.write_file()
3626 Directly write a zip file
3627 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3628 >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3629 >>> writer.write_file()
3631 Or with long strings stored in strl format
3633 >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
3634 ... columns=['strls'])
3635 >>> writer = StataWriterUTF8('./data_file_with_long_strings.dta', data,
3636 ... convert_strl=['strls'])
3637 >>> writer.write_file()
3638 """
3640 _encoding: Literal["utf-8"] = "utf-8"
3642 def __init__(
3643 self,
3644 fname: FilePath | WriteBuffer[bytes],
3645 data: DataFrame,
3646 convert_dates: dict[Hashable, str] | None = None,
3647 write_index: bool = True,
3648 byteorder: str | None = None,
3649 time_stamp: datetime.datetime | None = None,
3650 data_label: str | None = None,
3651 variable_labels: dict[Hashable, str] | None = None,
3652 convert_strl: Sequence[Hashable] | None = None,
3653 version: int | None = None,
3654 compression: CompressionOptions = "infer",
3655 storage_options: StorageOptions = None,
3656 *,
3657 value_labels: dict[Hashable, dict[float, str]] | None = None,
3658 ) -> None:
3659 if version is None:
3660 version = 118 if data.shape[1] <= 32767 else 119
3661 elif version not in (118, 119):
3662 raise ValueError("version must be either 118 or 119.")
3663 elif version == 118 and data.shape[1] > 32767:
3664 raise ValueError(
3665 "You must use version 119 for data sets containing more than"
3666 "32,767 variables"
3667 )
3669 super().__init__(
3670 fname,
3671 data,
3672 convert_dates=convert_dates,
3673 write_index=write_index,
3674 byteorder=byteorder,
3675 time_stamp=time_stamp,
3676 data_label=data_label,
3677 variable_labels=variable_labels,
3678 value_labels=value_labels,
3679 convert_strl=convert_strl,
3680 compression=compression,
3681 storage_options=storage_options,
3682 )
3683 # Override version set in StataWriter117 init
3684 self._dta_version = version
3686 def _validate_variable_name(self, name: str) -> str:
3687 """
3688 Validate variable names for Stata export.
3690 Parameters
3691 ----------
3692 name : str
3693 Variable name
3695 Returns
3696 -------
3697 str
3698 The validated name with invalid characters replaced with
3699 underscores.
3701 Notes
3702 -----
3703 Stata 118+ support most unicode characters. The only limitation is in
3704 the ascii range where the characters supported are a-z, A-Z, 0-9 and _.
3705 """
3706 # High code points appear to be acceptable
3707 for c in name:
3708 if (
3709 (
3710 ord(c) < 128
3711 and (c < "A" or c > "Z")
3712 and (c < "a" or c > "z")
3713 and (c < "0" or c > "9")
3714 and c != "_"
3715 )
3716 or 128 <= ord(c) < 192
3717 or c in {"×", "÷"}
3718 ):
3719 name = name.replace(c, "_")
3721 return name