1from __future__ import annotations
2
3from typing import (
4 TYPE_CHECKING,
5 Callable,
6 Hashable,
7 Sequence,
8 cast,
9)
10
11import numpy as np
12
13from pandas._libs import lib
14from pandas._typing import (
15 AggFuncType,
16 AggFuncTypeBase,
17 AggFuncTypeDict,
18 IndexLabel,
19)
20from pandas.util._decorators import (
21 Appender,
22 Substitution,
23)
24
25from pandas.core.dtypes.cast import maybe_downcast_to_dtype
26from pandas.core.dtypes.common import (
27 is_extension_array_dtype,
28 is_integer_dtype,
29 is_list_like,
30 is_nested_list_like,
31 is_scalar,
32)
33from pandas.core.dtypes.generic import (
34 ABCDataFrame,
35 ABCSeries,
36)
37
38import pandas.core.common as com
39from pandas.core.frame import _shared_docs
40from pandas.core.groupby import Grouper
41from pandas.core.indexes.api import (
42 Index,
43 MultiIndex,
44 get_objs_combined_axis,
45)
46from pandas.core.reshape.concat import concat
47from pandas.core.reshape.util import cartesian_product
48from pandas.core.series import Series
49
50if TYPE_CHECKING:
51 from pandas import DataFrame
52
53
54# Note: We need to make sure `frame` is imported before `pivot`, otherwise
55# _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency
56@Substitution("\ndata : DataFrame")
57@Appender(_shared_docs["pivot_table"], indents=1)
58def pivot_table(
59 data: DataFrame,
60 values=None,
61 index=None,
62 columns=None,
63 aggfunc: AggFuncType = "mean",
64 fill_value=None,
65 margins: bool = False,
66 dropna: bool = True,
67 margins_name: Hashable = "All",
68 observed: bool = False,
69 sort: bool = True,
70) -> DataFrame:
71 index = _convert_by(index)
72 columns = _convert_by(columns)
73
74 if isinstance(aggfunc, list):
75 pieces: list[DataFrame] = []
76 keys = []
77 for func in aggfunc:
78 _table = __internal_pivot_table(
79 data,
80 values=values,
81 index=index,
82 columns=columns,
83 fill_value=fill_value,
84 aggfunc=func,
85 margins=margins,
86 dropna=dropna,
87 margins_name=margins_name,
88 observed=observed,
89 sort=sort,
90 )
91 pieces.append(_table)
92 keys.append(getattr(func, "__name__", func))
93
94 table = concat(pieces, keys=keys, axis=1)
95 return table.__finalize__(data, method="pivot_table")
96
97 table = __internal_pivot_table(
98 data,
99 values,
100 index,
101 columns,
102 aggfunc,
103 fill_value,
104 margins,
105 dropna,
106 margins_name,
107 observed,
108 sort,
109 )
110 return table.__finalize__(data, method="pivot_table")
111
112
113def __internal_pivot_table(
114 data: DataFrame,
115 values,
116 index,
117 columns,
118 aggfunc: AggFuncTypeBase | AggFuncTypeDict,
119 fill_value,
120 margins: bool,
121 dropna: bool,
122 margins_name: Hashable,
123 observed: bool,
124 sort: bool,
125) -> DataFrame:
126 """
127 Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
128 """
129 keys = index + columns
130
131 values_passed = values is not None
132 if values_passed:
133 if is_list_like(values):
134 values_multi = True
135 values = list(values)
136 else:
137 values_multi = False
138 values = [values]
139
140 # GH14938 Make sure value labels are in data
141 for i in values:
142 if i not in data:
143 raise KeyError(i)
144
145 to_filter = []
146 for x in keys + values:
147 if isinstance(x, Grouper):
148 x = x.key
149 try:
150 if x in data:
151 to_filter.append(x)
152 except TypeError:
153 pass
154 if len(to_filter) < len(data.columns):
155 data = data[to_filter]
156
157 else:
158 values = data.columns
159 for key in keys:
160 try:
161 values = values.drop(key)
162 except (TypeError, ValueError, KeyError):
163 pass
164 values = list(values)
165
166 grouped = data.groupby(keys, observed=observed, sort=sort)
167 agged = grouped.agg(aggfunc)
168
169 if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
170 agged = agged.dropna(how="all")
171
172 # gh-21133
173 # we want to down cast if
174 # the original values are ints
175 # as we grouped with a NaN value
176 # and then dropped, coercing to floats
177 for v in values:
178 if (
179 v in data
180 and is_integer_dtype(data[v])
181 and v in agged
182 and not is_integer_dtype(agged[v])
183 ):
184 if not isinstance(agged[v], ABCDataFrame) and isinstance(
185 data[v].dtype, np.dtype
186 ):
187 # exclude DataFrame case bc maybe_downcast_to_dtype expects
188 # ArrayLike
189 # e.g. test_pivot_table_multiindex_columns_doctest_case
190 # agged.columns is a MultiIndex and 'v' is indexing only
191 # on its first level.
192 agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
193
194 table = agged
195
196 # GH17038, this check should only happen if index is defined (not None)
197 if table.index.nlevels > 1 and index:
198 # Related GH #17123
199 # If index_names are integers, determine whether the integers refer
200 # to the level position or name.
201 index_names = agged.index.names[: len(index)]
202 to_unstack = []
203 for i in range(len(index), len(keys)):
204 name = agged.index.names[i]
205 if name is None or name in index_names:
206 to_unstack.append(i)
207 else:
208 to_unstack.append(name)
209 table = agged.unstack(to_unstack)
210
211 if not dropna:
212 if isinstance(table.index, MultiIndex):
213 m = MultiIndex.from_arrays(
214 cartesian_product(table.index.levels), names=table.index.names
215 )
216 table = table.reindex(m, axis=0)
217
218 if isinstance(table.columns, MultiIndex):
219 m = MultiIndex.from_arrays(
220 cartesian_product(table.columns.levels), names=table.columns.names
221 )
222 table = table.reindex(m, axis=1)
223
224 if sort is True and isinstance(table, ABCDataFrame):
225 table = table.sort_index(axis=1)
226
227 if fill_value is not None:
228 table = table.fillna(fill_value, downcast="infer")
229
230 if margins:
231 if dropna:
232 data = data[data.notna().all(axis=1)]
233 table = _add_margins(
234 table,
235 data,
236 values,
237 rows=index,
238 cols=columns,
239 aggfunc=aggfunc,
240 observed=dropna,
241 margins_name=margins_name,
242 fill_value=fill_value,
243 )
244
245 # discard the top level
246 if values_passed and not values_multi and table.columns.nlevels > 1:
247 table = table.droplevel(0, axis=1)
248 if len(index) == 0 and len(columns) > 0:
249 table = table.T
250
251 # GH 15193 Make sure empty columns are removed if dropna=True
252 if isinstance(table, ABCDataFrame) and dropna:
253 table = table.dropna(how="all", axis=1)
254
255 return table
256
257
258def _add_margins(
259 table: DataFrame | Series,
260 data: DataFrame,
261 values,
262 rows,
263 cols,
264 aggfunc,
265 observed=None,
266 margins_name: Hashable = "All",
267 fill_value=None,
268):
269 if not isinstance(margins_name, str):
270 raise ValueError("margins_name argument must be a string")
271
272 msg = f'Conflicting name "{margins_name}" in margins'
273 for level in table.index.names:
274 if margins_name in table.index.get_level_values(level):
275 raise ValueError(msg)
276
277 grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
278
279 if table.ndim == 2:
280 # i.e. DataFrame
281 for level in table.columns.names[1:]:
282 if margins_name in table.columns.get_level_values(level):
283 raise ValueError(msg)
284
285 key: str | tuple[str, ...]
286 if len(rows) > 1:
287 key = (margins_name,) + ("",) * (len(rows) - 1)
288 else:
289 key = margins_name
290
291 if not values and isinstance(table, ABCSeries):
292 # If there are no values and the table is a series, then there is only
293 # one column in the data. Compute grand margin and return it.
294 return table._append(Series({key: grand_margin[margins_name]}))
295
296 elif values:
297 marginal_result_set = _generate_marginal_results(
298 table, data, values, rows, cols, aggfunc, observed, margins_name
299 )
300 if not isinstance(marginal_result_set, tuple):
301 return marginal_result_set
302 result, margin_keys, row_margin = marginal_result_set
303 else:
304 # no values, and table is a DataFrame
305 assert isinstance(table, ABCDataFrame)
306 marginal_result_set = _generate_marginal_results_without_values(
307 table, data, rows, cols, aggfunc, observed, margins_name
308 )
309 if not isinstance(marginal_result_set, tuple):
310 return marginal_result_set
311 result, margin_keys, row_margin = marginal_result_set
312
313 row_margin = row_margin.reindex(result.columns, fill_value=fill_value)
314 # populate grand margin
315 for k in margin_keys:
316 if isinstance(k, str):
317 row_margin[k] = grand_margin[k]
318 else:
319 row_margin[k] = grand_margin[k[0]]
320
321 from pandas import DataFrame
322
323 margin_dummy = DataFrame(row_margin, columns=Index([key])).T
324
325 row_names = result.index.names
326 # check the result column and leave floats
327 for dtype in set(result.dtypes):
328 if is_extension_array_dtype(dtype):
329 # Can hold NA already
330 continue
331
332 cols = result.select_dtypes([dtype]).columns
333 margin_dummy[cols] = margin_dummy[cols].apply(
334 maybe_downcast_to_dtype, args=(dtype,)
335 )
336 result = result._append(margin_dummy)
337 result.index.names = row_names
338
339 return result
340
341
342def _compute_grand_margin(
343 data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
344):
345 if values:
346 grand_margin = {}
347 for k, v in data[values].items():
348 try:
349 if isinstance(aggfunc, str):
350 grand_margin[k] = getattr(v, aggfunc)()
351 elif isinstance(aggfunc, dict):
352 if isinstance(aggfunc[k], str):
353 grand_margin[k] = getattr(v, aggfunc[k])()
354 else:
355 grand_margin[k] = aggfunc[k](v)
356 else:
357 grand_margin[k] = aggfunc(v)
358 except TypeError:
359 pass
360 return grand_margin
361 else:
362 return {margins_name: aggfunc(data.index)}
363
364
365def _generate_marginal_results(
366 table, data, values, rows, cols, aggfunc, observed, margins_name: Hashable = "All"
367):
368 if len(cols) > 0:
369 # need to "interleave" the margins
370 table_pieces = []
371 margin_keys = []
372
373 def _all_key(key):
374 return (key, margins_name) + ("",) * (len(cols) - 1)
375
376 if len(rows) > 0:
377 margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
378 cat_axis = 1
379
380 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
381 all_key = _all_key(key)
382
383 # we are going to mutate this, so need to copy!
384 piece = piece.copy()
385 piece[all_key] = margin[key]
386
387 table_pieces.append(piece)
388 margin_keys.append(all_key)
389 else:
390 from pandas import DataFrame
391
392 cat_axis = 0
393 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
394 if len(cols) > 1:
395 all_key = _all_key(key)
396 else:
397 all_key = margins_name
398 table_pieces.append(piece)
399 # GH31016 this is to calculate margin for each group, and assign
400 # corresponded key as index
401 transformed_piece = DataFrame(piece.apply(aggfunc)).T
402 if isinstance(piece.index, MultiIndex):
403 # We are adding an empty level
404 transformed_piece.index = MultiIndex.from_tuples(
405 [all_key], names=piece.index.names + [None]
406 )
407 else:
408 transformed_piece.index = Index([all_key], name=piece.index.name)
409
410 # append piece for margin into table_piece
411 table_pieces.append(transformed_piece)
412 margin_keys.append(all_key)
413
414 if not table_pieces:
415 # GH 49240
416 return table
417 else:
418 result = concat(table_pieces, axis=cat_axis)
419
420 if len(rows) == 0:
421 return result
422 else:
423 result = table
424 margin_keys = table.columns
425
426 if len(cols) > 0:
427 row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
428 row_margin = row_margin.stack()
429
430 # slight hack
431 new_order = [len(cols)] + list(range(len(cols)))
432 row_margin.index = row_margin.index.reorder_levels(new_order)
433 else:
434 row_margin = Series(np.nan, index=result.columns)
435
436 return result, margin_keys, row_margin
437
438
439def _generate_marginal_results_without_values(
440 table: DataFrame,
441 data,
442 rows,
443 cols,
444 aggfunc,
445 observed,
446 margins_name: Hashable = "All",
447):
448 if len(cols) > 0:
449 # need to "interleave" the margins
450 margin_keys: list | Index = []
451
452 def _all_key():
453 if len(cols) == 1:
454 return margins_name
455 return (margins_name,) + ("",) * (len(cols) - 1)
456
457 if len(rows) > 0:
458 margin = data[rows].groupby(rows, observed=observed).apply(aggfunc)
459 all_key = _all_key()
460 table[all_key] = margin
461 result = table
462 margin_keys.append(all_key)
463
464 else:
465 margin = data.groupby(level=0, axis=0, observed=observed).apply(aggfunc)
466 all_key = _all_key()
467 table[all_key] = margin
468 result = table
469 margin_keys.append(all_key)
470 return result
471 else:
472 result = table
473 margin_keys = table.columns
474
475 if len(cols):
476 row_margin = data[cols].groupby(cols, observed=observed).apply(aggfunc)
477 else:
478 row_margin = Series(np.nan, index=result.columns)
479
480 return result, margin_keys, row_margin
481
482
483def _convert_by(by):
484 if by is None:
485 by = []
486 elif (
487 is_scalar(by)
488 or isinstance(by, (np.ndarray, Index, ABCSeries, Grouper))
489 or callable(by)
490 ):
491 by = [by]
492 else:
493 by = list(by)
494 return by
495
496
497@Substitution("\ndata : DataFrame")
498@Appender(_shared_docs["pivot"], indents=1)
499def pivot(
500 data: DataFrame,
501 *,
502 columns: IndexLabel,
503 index: IndexLabel | lib.NoDefault = lib.NoDefault,
504 values: IndexLabel | lib.NoDefault = lib.NoDefault,
505) -> DataFrame:
506 columns_listlike = com.convert_to_list_like(columns)
507
508 # If columns is None we will create a MultiIndex level with None as name
509 # which might cause duplicated names because None is the default for
510 # level names
511 data = data.copy(deep=False)
512 data.index = data.index.copy()
513 data.index.names = [
514 name if name is not None else lib.NoDefault for name in data.index.names
515 ]
516
517 indexed: DataFrame | Series
518 if values is lib.NoDefault:
519 if index is not lib.NoDefault:
520 cols = com.convert_to_list_like(index)
521 else:
522 cols = []
523
524 append = index is lib.NoDefault
525 # error: Unsupported operand types for + ("List[Any]" and "ExtensionArray")
526 # error: Unsupported left operand type for + ("ExtensionArray")
527 indexed = data.set_index(
528 cols + columns_listlike, append=append # type: ignore[operator]
529 )
530 else:
531 if index is lib.NoDefault:
532 if isinstance(data.index, MultiIndex):
533 # GH 23955
534 index_list = [
535 data.index.get_level_values(i) for i in range(data.index.nlevels)
536 ]
537 else:
538 index_list = [Series(data.index, name=data.index.name)]
539 else:
540 index_list = [data[idx] for idx in com.convert_to_list_like(index)]
541
542 data_columns = [data[col] for col in columns_listlike]
543 index_list.extend(data_columns)
544 multiindex = MultiIndex.from_arrays(index_list)
545
546 if is_list_like(values) and not isinstance(values, tuple):
547 # Exclude tuple because it is seen as a single column name
548 values = cast(Sequence[Hashable], values)
549 indexed = data._constructor(
550 data[values]._values, index=multiindex, columns=values
551 )
552 else:
553 indexed = data._constructor_sliced(data[values]._values, index=multiindex)
554 # error: Argument 1 to "unstack" of "DataFrame" has incompatible type "Union
555 # [List[Any], ExtensionArray, ndarray[Any, Any], Index, Series]"; expected
556 # "Hashable"
557 result = indexed.unstack(columns_listlike) # type: ignore[arg-type]
558 result.index.names = [
559 name if name is not lib.NoDefault else None for name in result.index.names
560 ]
561
562 return result
563
564
565def crosstab(
566 index,
567 columns,
568 values=None,
569 rownames=None,
570 colnames=None,
571 aggfunc=None,
572 margins: bool = False,
573 margins_name: Hashable = "All",
574 dropna: bool = True,
575 normalize: bool = False,
576) -> DataFrame:
577 """
578 Compute a simple cross tabulation of two (or more) factors.
579
580 By default, computes a frequency table of the factors unless an
581 array of values and an aggregation function are passed.
582
583 Parameters
584 ----------
585 index : array-like, Series, or list of arrays/Series
586 Values to group by in the rows.
587 columns : array-like, Series, or list of arrays/Series
588 Values to group by in the columns.
589 values : array-like, optional
590 Array of values to aggregate according to the factors.
591 Requires `aggfunc` be specified.
592 rownames : sequence, default None
593 If passed, must match number of row arrays passed.
594 colnames : sequence, default None
595 If passed, must match number of column arrays passed.
596 aggfunc : function, optional
597 If specified, requires `values` be specified as well.
598 margins : bool, default False
599 Add row/column margins (subtotals).
600 margins_name : str, default 'All'
601 Name of the row/column that will contain the totals
602 when margins is True.
603 dropna : bool, default True
604 Do not include columns whose entries are all NaN.
605 normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
606 Normalize by dividing all values by the sum of values.
607
608 - If passed 'all' or `True`, will normalize over all values.
609 - If passed 'index' will normalize over each row.
610 - If passed 'columns' will normalize over each column.
611 - If margins is `True`, will also normalize margin values.
612
613 Returns
614 -------
615 DataFrame
616 Cross tabulation of the data.
617
618 See Also
619 --------
620 DataFrame.pivot : Reshape data based on column values.
621 pivot_table : Create a pivot table as a DataFrame.
622
623 Notes
624 -----
625 Any Series passed will have their name attributes used unless row or column
626 names for the cross-tabulation are specified.
627
628 Any input passed containing Categorical data will have **all** of its
629 categories included in the cross-tabulation, even if the actual data does
630 not contain any instances of a particular category.
631
632 In the event that there aren't overlapping indexes an empty DataFrame will
633 be returned.
634
635 Reference :ref:`the user guide <reshaping.crosstabulations>` for more examples.
636
637 Examples
638 --------
639 >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
640 ... "bar", "bar", "foo", "foo", "foo"], dtype=object)
641 >>> b = np.array(["one", "one", "one", "two", "one", "one",
642 ... "one", "two", "two", "two", "one"], dtype=object)
643 >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
644 ... "shiny", "dull", "shiny", "shiny", "shiny"],
645 ... dtype=object)
646 >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
647 b one two
648 c dull shiny dull shiny
649 a
650 bar 1 2 1 0
651 foo 2 2 1 2
652
653 Here 'c' and 'f' are not represented in the data and will not be
654 shown in the output because dropna is True by default. Set
655 dropna=False to preserve categories with no data.
656
657 >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])
658 >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f'])
659 >>> pd.crosstab(foo, bar)
660 col_0 d e
661 row_0
662 a 1 0
663 b 0 1
664 >>> pd.crosstab(foo, bar, dropna=False)
665 col_0 d e f
666 row_0
667 a 1 0 0
668 b 0 1 0
669 c 0 0 0
670 """
671 if values is None and aggfunc is not None:
672 raise ValueError("aggfunc cannot be used without values.")
673
674 if values is not None and aggfunc is None:
675 raise ValueError("values cannot be used without an aggfunc.")
676
677 if not is_nested_list_like(index):
678 index = [index]
679 if not is_nested_list_like(columns):
680 columns = [columns]
681
682 common_idx = None
683 pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))]
684 if pass_objs:
685 common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False)
686
687 rownames = _get_names(index, rownames, prefix="row")
688 colnames = _get_names(columns, colnames, prefix="col")
689
690 # duplicate names mapped to unique names for pivot op
691 (
692 rownames_mapper,
693 unique_rownames,
694 colnames_mapper,
695 unique_colnames,
696 ) = _build_names_mapper(rownames, colnames)
697
698 from pandas import DataFrame
699
700 data = {
701 **dict(zip(unique_rownames, index)),
702 **dict(zip(unique_colnames, columns)),
703 }
704 df = DataFrame(data, index=common_idx)
705
706 if values is None:
707 df["__dummy__"] = 0
708 kwargs = {"aggfunc": len, "fill_value": 0}
709 else:
710 df["__dummy__"] = values
711 kwargs = {"aggfunc": aggfunc}
712
713 # error: Argument 7 to "pivot_table" of "DataFrame" has incompatible type
714 # "**Dict[str, object]"; expected "Union[...]"
715 table = df.pivot_table(
716 "__dummy__",
717 index=unique_rownames,
718 columns=unique_colnames,
719 margins=margins,
720 margins_name=margins_name,
721 dropna=dropna,
722 **kwargs, # type: ignore[arg-type]
723 )
724
725 # Post-process
726 if normalize is not False:
727 table = _normalize(
728 table, normalize=normalize, margins=margins, margins_name=margins_name
729 )
730
731 table = table.rename_axis(index=rownames_mapper, axis=0)
732 table = table.rename_axis(columns=colnames_mapper, axis=1)
733
734 return table
735
736
737def _normalize(
738 table: DataFrame, normalize, margins: bool, margins_name: Hashable = "All"
739) -> DataFrame:
740 if not isinstance(normalize, (bool, str)):
741 axis_subs = {0: "index", 1: "columns"}
742 try:
743 normalize = axis_subs[normalize]
744 except KeyError as err:
745 raise ValueError("Not a valid normalize argument") from err
746
747 if margins is False:
748 # Actual Normalizations
749 normalizers: dict[bool | str, Callable] = {
750 "all": lambda x: x / x.sum(axis=1).sum(axis=0),
751 "columns": lambda x: x / x.sum(),
752 "index": lambda x: x.div(x.sum(axis=1), axis=0),
753 }
754
755 normalizers[True] = normalizers["all"]
756
757 try:
758 f = normalizers[normalize]
759 except KeyError as err:
760 raise ValueError("Not a valid normalize argument") from err
761
762 table = f(table)
763 table = table.fillna(0)
764
765 elif margins is True:
766 # keep index and column of pivoted table
767 table_index = table.index
768 table_columns = table.columns
769 last_ind_or_col = table.iloc[-1, :].name
770
771 # check if margin name is not in (for MI cases) and not equal to last
772 # index/column and save the column and index margin
773 if (margins_name not in last_ind_or_col) & (margins_name != last_ind_or_col):
774 raise ValueError(f"{margins_name} not in pivoted DataFrame")
775 column_margin = table.iloc[:-1, -1]
776 index_margin = table.iloc[-1, :-1]
777
778 # keep the core table
779 table = table.iloc[:-1, :-1]
780
781 # Normalize core
782 table = _normalize(table, normalize=normalize, margins=False)
783
784 # Fix Margins
785 if normalize == "columns":
786 column_margin = column_margin / column_margin.sum()
787 table = concat([table, column_margin], axis=1)
788 table = table.fillna(0)
789 table.columns = table_columns
790
791 elif normalize == "index":
792 index_margin = index_margin / index_margin.sum()
793 table = table._append(index_margin)
794 table = table.fillna(0)
795 table.index = table_index
796
797 elif normalize == "all" or normalize is True:
798 column_margin = column_margin / column_margin.sum()
799 index_margin = index_margin / index_margin.sum()
800 index_margin.loc[margins_name] = 1
801 table = concat([table, column_margin], axis=1)
802 table = table._append(index_margin)
803
804 table = table.fillna(0)
805 table.index = table_index
806 table.columns = table_columns
807
808 else:
809 raise ValueError("Not a valid normalize argument")
810
811 else:
812 raise ValueError("Not a valid margins argument")
813
814 return table
815
816
817def _get_names(arrs, names, prefix: str = "row"):
818 if names is None:
819 names = []
820 for i, arr in enumerate(arrs):
821 if isinstance(arr, ABCSeries) and arr.name is not None:
822 names.append(arr.name)
823 else:
824 names.append(f"{prefix}_{i}")
825 else:
826 if len(names) != len(arrs):
827 raise AssertionError("arrays and names must have the same length")
828 if not isinstance(names, list):
829 names = list(names)
830
831 return names
832
833
834def _build_names_mapper(
835 rownames: list[str], colnames: list[str]
836) -> tuple[dict[str, str], list[str], dict[str, str], list[str]]:
837 """
838 Given the names of a DataFrame's rows and columns, returns a set of unique row
839 and column names and mappers that convert to original names.
840
841 A row or column name is replaced if it is duplicate among the rows of the inputs,
842 among the columns of the inputs or between the rows and the columns.
843
844 Parameters
845 ----------
846 rownames: list[str]
847 colnames: list[str]
848
849 Returns
850 -------
851 Tuple(Dict[str, str], List[str], Dict[str, str], List[str])
852
853 rownames_mapper: dict[str, str]
854 a dictionary with new row names as keys and original rownames as values
855 unique_rownames: list[str]
856 a list of rownames with duplicate names replaced by dummy names
857 colnames_mapper: dict[str, str]
858 a dictionary with new column names as keys and original column names as values
859 unique_colnames: list[str]
860 a list of column names with duplicate names replaced by dummy names
861
862 """
863
864 def get_duplicates(names):
865 seen: set = set()
866 return {name for name in names if name not in seen}
867
868 shared_names = set(rownames).intersection(set(colnames))
869 dup_names = get_duplicates(rownames) | get_duplicates(colnames) | shared_names
870
871 rownames_mapper = {
872 f"row_{i}": name for i, name in enumerate(rownames) if name in dup_names
873 }
874 unique_rownames = [
875 f"row_{i}" if name in dup_names else name for i, name in enumerate(rownames)
876 ]
877
878 colnames_mapper = {
879 f"col_{i}": name for i, name in enumerate(colnames) if name in dup_names
880 }
881 unique_colnames = [
882 f"col_{i}" if name in dup_names else name for i, name in enumerate(colnames)
883 ]
884
885 return rownames_mapper, unique_rownames, colnames_mapper, unique_colnames