1from __future__ import annotations
2
3from collections.abc import (
4 Hashable,
5 Sequence,
6)
7from typing import (
8 TYPE_CHECKING,
9 Callable,
10 Literal,
11 cast,
12)
13import warnings
14
15import numpy as np
16
17from pandas._libs import lib
18from pandas.util._decorators import (
19 Appender,
20 Substitution,
21)
22from pandas.util._exceptions import find_stack_level
23
24from pandas.core.dtypes.cast import maybe_downcast_to_dtype
25from pandas.core.dtypes.common import (
26 is_list_like,
27 is_nested_list_like,
28 is_scalar,
29)
30from pandas.core.dtypes.dtypes import ExtensionDtype
31from pandas.core.dtypes.generic import (
32 ABCDataFrame,
33 ABCSeries,
34)
35
36import pandas.core.common as com
37from pandas.core.frame import _shared_docs
38from pandas.core.groupby import Grouper
39from pandas.core.indexes.api import (
40 Index,
41 MultiIndex,
42 get_objs_combined_axis,
43)
44from pandas.core.reshape.concat import concat
45from pandas.core.reshape.util import cartesian_product
46from pandas.core.series import Series
47
48if TYPE_CHECKING:
49 from pandas._typing import (
50 AggFuncType,
51 AggFuncTypeBase,
52 AggFuncTypeDict,
53 IndexLabel,
54 )
55
56 from pandas import DataFrame
57
58
59# Note: We need to make sure `frame` is imported before `pivot`, otherwise
60# _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency
61@Substitution("\ndata : DataFrame")
62@Appender(_shared_docs["pivot_table"], indents=1)
63def pivot_table(
64 data: DataFrame,
65 values=None,
66 index=None,
67 columns=None,
68 aggfunc: AggFuncType = "mean",
69 fill_value=None,
70 margins: bool = False,
71 dropna: bool = True,
72 margins_name: Hashable = "All",
73 observed: bool | lib.NoDefault = lib.no_default,
74 sort: bool = True,
75) -> DataFrame:
76 index = _convert_by(index)
77 columns = _convert_by(columns)
78
79 if isinstance(aggfunc, list):
80 pieces: list[DataFrame] = []
81 keys = []
82 for func in aggfunc:
83 _table = __internal_pivot_table(
84 data,
85 values=values,
86 index=index,
87 columns=columns,
88 fill_value=fill_value,
89 aggfunc=func,
90 margins=margins,
91 dropna=dropna,
92 margins_name=margins_name,
93 observed=observed,
94 sort=sort,
95 )
96 pieces.append(_table)
97 keys.append(getattr(func, "__name__", func))
98
99 table = concat(pieces, keys=keys, axis=1)
100 return table.__finalize__(data, method="pivot_table")
101
102 table = __internal_pivot_table(
103 data,
104 values,
105 index,
106 columns,
107 aggfunc,
108 fill_value,
109 margins,
110 dropna,
111 margins_name,
112 observed,
113 sort,
114 )
115 return table.__finalize__(data, method="pivot_table")
116
117
118def __internal_pivot_table(
119 data: DataFrame,
120 values,
121 index,
122 columns,
123 aggfunc: AggFuncTypeBase | AggFuncTypeDict,
124 fill_value,
125 margins: bool,
126 dropna: bool,
127 margins_name: Hashable,
128 observed: bool | lib.NoDefault,
129 sort: bool,
130) -> DataFrame:
131 """
132 Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
133 """
134 keys = index + columns
135
136 values_passed = values is not None
137 if values_passed:
138 if is_list_like(values):
139 values_multi = True
140 values = list(values)
141 else:
142 values_multi = False
143 values = [values]
144
145 # GH14938 Make sure value labels are in data
146 for i in values:
147 if i not in data:
148 raise KeyError(i)
149
150 to_filter = []
151 for x in keys + values:
152 if isinstance(x, Grouper):
153 x = x.key
154 try:
155 if x in data:
156 to_filter.append(x)
157 except TypeError:
158 pass
159 if len(to_filter) < len(data.columns):
160 data = data[to_filter]
161
162 else:
163 values = data.columns
164 for key in keys:
165 try:
166 values = values.drop(key)
167 except (TypeError, ValueError, KeyError):
168 pass
169 values = list(values)
170
171 observed_bool = False if observed is lib.no_default else observed
172 grouped = data.groupby(keys, observed=observed_bool, sort=sort, dropna=dropna)
173 if observed is lib.no_default and any(
174 ping._passed_categorical for ping in grouped._grouper.groupings
175 ):
176 warnings.warn(
177 "The default value of observed=False is deprecated and will change "
178 "to observed=True in a future version of pandas. Specify "
179 "observed=False to silence this warning and retain the current behavior",
180 category=FutureWarning,
181 stacklevel=find_stack_level(),
182 )
183 agged = grouped.agg(aggfunc)
184
185 if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
186 agged = agged.dropna(how="all")
187
188 table = agged
189
190 # GH17038, this check should only happen if index is defined (not None)
191 if table.index.nlevels > 1 and index:
192 # Related GH #17123
193 # If index_names are integers, determine whether the integers refer
194 # to the level position or name.
195 index_names = agged.index.names[: len(index)]
196 to_unstack = []
197 for i in range(len(index), len(keys)):
198 name = agged.index.names[i]
199 if name is None or name in index_names:
200 to_unstack.append(i)
201 else:
202 to_unstack.append(name)
203 table = agged.unstack(to_unstack, fill_value=fill_value)
204
205 if not dropna:
206 if isinstance(table.index, MultiIndex):
207 m = MultiIndex.from_arrays(
208 cartesian_product(table.index.levels), names=table.index.names
209 )
210 table = table.reindex(m, axis=0, fill_value=fill_value)
211
212 if isinstance(table.columns, MultiIndex):
213 m = MultiIndex.from_arrays(
214 cartesian_product(table.columns.levels), names=table.columns.names
215 )
216 table = table.reindex(m, axis=1, fill_value=fill_value)
217
218 if sort is True and isinstance(table, ABCDataFrame):
219 table = table.sort_index(axis=1)
220
221 if fill_value is not None:
222 table = table.fillna(fill_value)
223 if aggfunc is len and not observed and lib.is_integer(fill_value):
224 # TODO: can we avoid this? this used to be handled by
225 # downcast="infer" in fillna
226 table = table.astype(np.int64)
227
228 if margins:
229 if dropna:
230 data = data[data.notna().all(axis=1)]
231 table = _add_margins(
232 table,
233 data,
234 values,
235 rows=index,
236 cols=columns,
237 aggfunc=aggfunc,
238 observed=dropna,
239 margins_name=margins_name,
240 fill_value=fill_value,
241 )
242
243 # discard the top level
244 if values_passed and not values_multi and table.columns.nlevels > 1:
245 table.columns = table.columns.droplevel(0)
246 if len(index) == 0 and len(columns) > 0:
247 table = table.T
248
249 # GH 15193 Make sure empty columns are removed if dropna=True
250 if isinstance(table, ABCDataFrame) and dropna:
251 table = table.dropna(how="all", axis=1)
252
253 return table
254
255
256def _add_margins(
257 table: DataFrame | Series,
258 data: DataFrame,
259 values,
260 rows,
261 cols,
262 aggfunc,
263 observed: bool,
264 margins_name: Hashable = "All",
265 fill_value=None,
266):
267 if not isinstance(margins_name, str):
268 raise ValueError("margins_name argument must be a string")
269
270 msg = f'Conflicting name "{margins_name}" in margins'
271 for level in table.index.names:
272 if margins_name in table.index.get_level_values(level):
273 raise ValueError(msg)
274
275 grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
276
277 if table.ndim == 2:
278 # i.e. DataFrame
279 for level in table.columns.names[1:]:
280 if margins_name in table.columns.get_level_values(level):
281 raise ValueError(msg)
282
283 key: str | tuple[str, ...]
284 if len(rows) > 1:
285 key = (margins_name,) + ("",) * (len(rows) - 1)
286 else:
287 key = margins_name
288
289 if not values and isinstance(table, ABCSeries):
290 # If there are no values and the table is a series, then there is only
291 # one column in the data. Compute grand margin and return it.
292 return table._append(table._constructor({key: grand_margin[margins_name]}))
293
294 elif values:
295 marginal_result_set = _generate_marginal_results(
296 table, data, values, rows, cols, aggfunc, observed, margins_name
297 )
298 if not isinstance(marginal_result_set, tuple):
299 return marginal_result_set
300 result, margin_keys, row_margin = marginal_result_set
301 else:
302 # no values, and table is a DataFrame
303 assert isinstance(table, ABCDataFrame)
304 marginal_result_set = _generate_marginal_results_without_values(
305 table, data, rows, cols, aggfunc, observed, margins_name
306 )
307 if not isinstance(marginal_result_set, tuple):
308 return marginal_result_set
309 result, margin_keys, row_margin = marginal_result_set
310
311 row_margin = row_margin.reindex(result.columns, fill_value=fill_value)
312 # populate grand margin
313 for k in margin_keys:
314 if isinstance(k, str):
315 row_margin[k] = grand_margin[k]
316 else:
317 row_margin[k] = grand_margin[k[0]]
318
319 from pandas import DataFrame
320
321 margin_dummy = DataFrame(row_margin, columns=Index([key])).T
322
323 row_names = result.index.names
324 # check the result column and leave floats
325
326 for dtype in set(result.dtypes):
327 if isinstance(dtype, ExtensionDtype):
328 # Can hold NA already
329 continue
330
331 cols = result.select_dtypes([dtype]).columns
332 margin_dummy[cols] = margin_dummy[cols].apply(
333 maybe_downcast_to_dtype, args=(dtype,)
334 )
335 result = result._append(margin_dummy)
336 result.index.names = row_names
337
338 return result
339
340
341def _compute_grand_margin(
342 data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
343):
344 if values:
345 grand_margin = {}
346 for k, v in data[values].items():
347 try:
348 if isinstance(aggfunc, str):
349 grand_margin[k] = getattr(v, aggfunc)()
350 elif isinstance(aggfunc, dict):
351 if isinstance(aggfunc[k], str):
352 grand_margin[k] = getattr(v, aggfunc[k])()
353 else:
354 grand_margin[k] = aggfunc[k](v)
355 else:
356 grand_margin[k] = aggfunc(v)
357 except TypeError:
358 pass
359 return grand_margin
360 else:
361 return {margins_name: aggfunc(data.index)}
362
363
364def _generate_marginal_results(
365 table,
366 data: DataFrame,
367 values,
368 rows,
369 cols,
370 aggfunc,
371 observed: bool,
372 margins_name: Hashable = "All",
373):
374 margin_keys: list | Index
375 if len(cols) > 0:
376 # need to "interleave" the margins
377 table_pieces = []
378 margin_keys = []
379
380 def _all_key(key):
381 return (key, margins_name) + ("",) * (len(cols) - 1)
382
383 if len(rows) > 0:
384 margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
385 cat_axis = 1
386
387 for key, piece in table.T.groupby(level=0, observed=observed):
388 piece = piece.T
389 all_key = _all_key(key)
390
391 # we are going to mutate this, so need to copy!
392 piece = piece.copy()
393 piece[all_key] = margin[key]
394
395 table_pieces.append(piece)
396 margin_keys.append(all_key)
397 else:
398 from pandas import DataFrame
399
400 cat_axis = 0
401 for key, piece in table.groupby(level=0, observed=observed):
402 if len(cols) > 1:
403 all_key = _all_key(key)
404 else:
405 all_key = margins_name
406 table_pieces.append(piece)
407 # GH31016 this is to calculate margin for each group, and assign
408 # corresponded key as index
409 transformed_piece = DataFrame(piece.apply(aggfunc)).T
410 if isinstance(piece.index, MultiIndex):
411 # We are adding an empty level
412 transformed_piece.index = MultiIndex.from_tuples(
413 [all_key], names=piece.index.names + [None]
414 )
415 else:
416 transformed_piece.index = Index([all_key], name=piece.index.name)
417
418 # append piece for margin into table_piece
419 table_pieces.append(transformed_piece)
420 margin_keys.append(all_key)
421
422 if not table_pieces:
423 # GH 49240
424 return table
425 else:
426 result = concat(table_pieces, axis=cat_axis)
427
428 if len(rows) == 0:
429 return result
430 else:
431 result = table
432 margin_keys = table.columns
433
434 if len(cols) > 0:
435 row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
436 row_margin = row_margin.stack(future_stack=True)
437
438 # GH#26568. Use names instead of indices in case of numeric names
439 new_order_indices = [len(cols)] + list(range(len(cols)))
440 new_order_names = [row_margin.index.names[i] for i in new_order_indices]
441 row_margin.index = row_margin.index.reorder_levels(new_order_names)
442 else:
443 row_margin = data._constructor_sliced(np.nan, index=result.columns)
444
445 return result, margin_keys, row_margin
446
447
448def _generate_marginal_results_without_values(
449 table: DataFrame,
450 data: DataFrame,
451 rows,
452 cols,
453 aggfunc,
454 observed: bool,
455 margins_name: Hashable = "All",
456):
457 margin_keys: list | Index
458 if len(cols) > 0:
459 # need to "interleave" the margins
460 margin_keys = []
461
462 def _all_key():
463 if len(cols) == 1:
464 return margins_name
465 return (margins_name,) + ("",) * (len(cols) - 1)
466
467 if len(rows) > 0:
468 margin = data.groupby(rows, observed=observed)[rows].apply(aggfunc)
469 all_key = _all_key()
470 table[all_key] = margin
471 result = table
472 margin_keys.append(all_key)
473
474 else:
475 margin = data.groupby(level=0, axis=0, observed=observed).apply(aggfunc)
476 all_key = _all_key()
477 table[all_key] = margin
478 result = table
479 margin_keys.append(all_key)
480 return result
481 else:
482 result = table
483 margin_keys = table.columns
484
485 if len(cols):
486 row_margin = data.groupby(cols, observed=observed)[cols].apply(aggfunc)
487 else:
488 row_margin = Series(np.nan, index=result.columns)
489
490 return result, margin_keys, row_margin
491
492
493def _convert_by(by):
494 if by is None:
495 by = []
496 elif (
497 is_scalar(by)
498 or isinstance(by, (np.ndarray, Index, ABCSeries, Grouper))
499 or callable(by)
500 ):
501 by = [by]
502 else:
503 by = list(by)
504 return by
505
506
507@Substitution("\ndata : DataFrame")
508@Appender(_shared_docs["pivot"], indents=1)
509def pivot(
510 data: DataFrame,
511 *,
512 columns: IndexLabel,
513 index: IndexLabel | lib.NoDefault = lib.no_default,
514 values: IndexLabel | lib.NoDefault = lib.no_default,
515) -> DataFrame:
516 columns_listlike = com.convert_to_list_like(columns)
517
518 # If columns is None we will create a MultiIndex level with None as name
519 # which might cause duplicated names because None is the default for
520 # level names
521 data = data.copy(deep=False)
522 data.index = data.index.copy()
523 data.index.names = [
524 name if name is not None else lib.no_default for name in data.index.names
525 ]
526
527 indexed: DataFrame | Series
528 if values is lib.no_default:
529 if index is not lib.no_default:
530 cols = com.convert_to_list_like(index)
531 else:
532 cols = []
533
534 append = index is lib.no_default
535 # error: Unsupported operand types for + ("List[Any]" and "ExtensionArray")
536 # error: Unsupported left operand type for + ("ExtensionArray")
537 indexed = data.set_index(
538 cols + columns_listlike, append=append # type: ignore[operator]
539 )
540 else:
541 index_list: list[Index] | list[Series]
542 if index is lib.no_default:
543 if isinstance(data.index, MultiIndex):
544 # GH 23955
545 index_list = [
546 data.index.get_level_values(i) for i in range(data.index.nlevels)
547 ]
548 else:
549 index_list = [
550 data._constructor_sliced(data.index, name=data.index.name)
551 ]
552 else:
553 index_list = [data[idx] for idx in com.convert_to_list_like(index)]
554
555 data_columns = [data[col] for col in columns_listlike]
556 index_list.extend(data_columns)
557 multiindex = MultiIndex.from_arrays(index_list)
558
559 if is_list_like(values) and not isinstance(values, tuple):
560 # Exclude tuple because it is seen as a single column name
561 values = cast(Sequence[Hashable], values)
562 indexed = data._constructor(
563 data[values]._values, index=multiindex, columns=values
564 )
565 else:
566 indexed = data._constructor_sliced(data[values]._values, index=multiindex)
567 # error: Argument 1 to "unstack" of "DataFrame" has incompatible type "Union
568 # [List[Any], ExtensionArray, ndarray[Any, Any], Index, Series]"; expected
569 # "Hashable"
570 result = indexed.unstack(columns_listlike) # type: ignore[arg-type]
571 result.index.names = [
572 name if name is not lib.no_default else None for name in result.index.names
573 ]
574
575 return result
576
577
578def crosstab(
579 index,
580 columns,
581 values=None,
582 rownames=None,
583 colnames=None,
584 aggfunc=None,
585 margins: bool = False,
586 margins_name: Hashable = "All",
587 dropna: bool = True,
588 normalize: bool | Literal[0, 1, "all", "index", "columns"] = False,
589) -> DataFrame:
590 """
591 Compute a simple cross tabulation of two (or more) factors.
592
593 By default, computes a frequency table of the factors unless an
594 array of values and an aggregation function are passed.
595
596 Parameters
597 ----------
598 index : array-like, Series, or list of arrays/Series
599 Values to group by in the rows.
600 columns : array-like, Series, or list of arrays/Series
601 Values to group by in the columns.
602 values : array-like, optional
603 Array of values to aggregate according to the factors.
604 Requires `aggfunc` be specified.
605 rownames : sequence, default None
606 If passed, must match number of row arrays passed.
607 colnames : sequence, default None
608 If passed, must match number of column arrays passed.
609 aggfunc : function, optional
610 If specified, requires `values` be specified as well.
611 margins : bool, default False
612 Add row/column margins (subtotals).
613 margins_name : str, default 'All'
614 Name of the row/column that will contain the totals
615 when margins is True.
616 dropna : bool, default True
617 Do not include columns whose entries are all NaN.
618 normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
619 Normalize by dividing all values by the sum of values.
620
621 - If passed 'all' or `True`, will normalize over all values.
622 - If passed 'index' will normalize over each row.
623 - If passed 'columns' will normalize over each column.
624 - If margins is `True`, will also normalize margin values.
625
626 Returns
627 -------
628 DataFrame
629 Cross tabulation of the data.
630
631 See Also
632 --------
633 DataFrame.pivot : Reshape data based on column values.
634 pivot_table : Create a pivot table as a DataFrame.
635
636 Notes
637 -----
638 Any Series passed will have their name attributes used unless row or column
639 names for the cross-tabulation are specified.
640
641 Any input passed containing Categorical data will have **all** of its
642 categories included in the cross-tabulation, even if the actual data does
643 not contain any instances of a particular category.
644
645 In the event that there aren't overlapping indexes an empty DataFrame will
646 be returned.
647
648 Reference :ref:`the user guide <reshaping.crosstabulations>` for more examples.
649
650 Examples
651 --------
652 >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
653 ... "bar", "bar", "foo", "foo", "foo"], dtype=object)
654 >>> b = np.array(["one", "one", "one", "two", "one", "one",
655 ... "one", "two", "two", "two", "one"], dtype=object)
656 >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
657 ... "shiny", "dull", "shiny", "shiny", "shiny"],
658 ... dtype=object)
659 >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
660 b one two
661 c dull shiny dull shiny
662 a
663 bar 1 2 1 0
664 foo 2 2 1 2
665
666 Here 'c' and 'f' are not represented in the data and will not be
667 shown in the output because dropna is True by default. Set
668 dropna=False to preserve categories with no data.
669
670 >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])
671 >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f'])
672 >>> pd.crosstab(foo, bar)
673 col_0 d e
674 row_0
675 a 1 0
676 b 0 1
677 >>> pd.crosstab(foo, bar, dropna=False)
678 col_0 d e f
679 row_0
680 a 1 0 0
681 b 0 1 0
682 c 0 0 0
683 """
684 if values is None and aggfunc is not None:
685 raise ValueError("aggfunc cannot be used without values.")
686
687 if values is not None and aggfunc is None:
688 raise ValueError("values cannot be used without an aggfunc.")
689
690 if not is_nested_list_like(index):
691 index = [index]
692 if not is_nested_list_like(columns):
693 columns = [columns]
694
695 common_idx = None
696 pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))]
697 if pass_objs:
698 common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False)
699
700 rownames = _get_names(index, rownames, prefix="row")
701 colnames = _get_names(columns, colnames, prefix="col")
702
703 # duplicate names mapped to unique names for pivot op
704 (
705 rownames_mapper,
706 unique_rownames,
707 colnames_mapper,
708 unique_colnames,
709 ) = _build_names_mapper(rownames, colnames)
710
711 from pandas import DataFrame
712
713 data = {
714 **dict(zip(unique_rownames, index)),
715 **dict(zip(unique_colnames, columns)),
716 }
717 df = DataFrame(data, index=common_idx)
718
719 if values is None:
720 df["__dummy__"] = 0
721 kwargs = {"aggfunc": len, "fill_value": 0}
722 else:
723 df["__dummy__"] = values
724 kwargs = {"aggfunc": aggfunc}
725
726 # error: Argument 7 to "pivot_table" of "DataFrame" has incompatible type
727 # "**Dict[str, object]"; expected "Union[...]"
728 table = df.pivot_table(
729 "__dummy__",
730 index=unique_rownames,
731 columns=unique_colnames,
732 margins=margins,
733 margins_name=margins_name,
734 dropna=dropna,
735 observed=False,
736 **kwargs, # type: ignore[arg-type]
737 )
738
739 # Post-process
740 if normalize is not False:
741 table = _normalize(
742 table, normalize=normalize, margins=margins, margins_name=margins_name
743 )
744
745 table = table.rename_axis(index=rownames_mapper, axis=0)
746 table = table.rename_axis(columns=colnames_mapper, axis=1)
747
748 return table
749
750
751def _normalize(
752 table: DataFrame, normalize, margins: bool, margins_name: Hashable = "All"
753) -> DataFrame:
754 if not isinstance(normalize, (bool, str)):
755 axis_subs = {0: "index", 1: "columns"}
756 try:
757 normalize = axis_subs[normalize]
758 except KeyError as err:
759 raise ValueError("Not a valid normalize argument") from err
760
761 if margins is False:
762 # Actual Normalizations
763 normalizers: dict[bool | str, Callable] = {
764 "all": lambda x: x / x.sum(axis=1).sum(axis=0),
765 "columns": lambda x: x / x.sum(),
766 "index": lambda x: x.div(x.sum(axis=1), axis=0),
767 }
768
769 normalizers[True] = normalizers["all"]
770
771 try:
772 f = normalizers[normalize]
773 except KeyError as err:
774 raise ValueError("Not a valid normalize argument") from err
775
776 table = f(table)
777 table = table.fillna(0)
778
779 elif margins is True:
780 # keep index and column of pivoted table
781 table_index = table.index
782 table_columns = table.columns
783 last_ind_or_col = table.iloc[-1, :].name
784
785 # check if margin name is not in (for MI cases) and not equal to last
786 # index/column and save the column and index margin
787 if (margins_name not in last_ind_or_col) & (margins_name != last_ind_or_col):
788 raise ValueError(f"{margins_name} not in pivoted DataFrame")
789 column_margin = table.iloc[:-1, -1]
790 index_margin = table.iloc[-1, :-1]
791
792 # keep the core table
793 table = table.iloc[:-1, :-1]
794
795 # Normalize core
796 table = _normalize(table, normalize=normalize, margins=False)
797
798 # Fix Margins
799 if normalize == "columns":
800 column_margin = column_margin / column_margin.sum()
801 table = concat([table, column_margin], axis=1)
802 table = table.fillna(0)
803 table.columns = table_columns
804
805 elif normalize == "index":
806 index_margin = index_margin / index_margin.sum()
807 table = table._append(index_margin)
808 table = table.fillna(0)
809 table.index = table_index
810
811 elif normalize == "all" or normalize is True:
812 column_margin = column_margin / column_margin.sum()
813 index_margin = index_margin / index_margin.sum()
814 index_margin.loc[margins_name] = 1
815 table = concat([table, column_margin], axis=1)
816 table = table._append(index_margin)
817
818 table = table.fillna(0)
819 table.index = table_index
820 table.columns = table_columns
821
822 else:
823 raise ValueError("Not a valid normalize argument")
824
825 else:
826 raise ValueError("Not a valid margins argument")
827
828 return table
829
830
831def _get_names(arrs, names, prefix: str = "row"):
832 if names is None:
833 names = []
834 for i, arr in enumerate(arrs):
835 if isinstance(arr, ABCSeries) and arr.name is not None:
836 names.append(arr.name)
837 else:
838 names.append(f"{prefix}_{i}")
839 else:
840 if len(names) != len(arrs):
841 raise AssertionError("arrays and names must have the same length")
842 if not isinstance(names, list):
843 names = list(names)
844
845 return names
846
847
848def _build_names_mapper(
849 rownames: list[str], colnames: list[str]
850) -> tuple[dict[str, str], list[str], dict[str, str], list[str]]:
851 """
852 Given the names of a DataFrame's rows and columns, returns a set of unique row
853 and column names and mappers that convert to original names.
854
855 A row or column name is replaced if it is duplicate among the rows of the inputs,
856 among the columns of the inputs or between the rows and the columns.
857
858 Parameters
859 ----------
860 rownames: list[str]
861 colnames: list[str]
862
863 Returns
864 -------
865 Tuple(Dict[str, str], List[str], Dict[str, str], List[str])
866
867 rownames_mapper: dict[str, str]
868 a dictionary with new row names as keys and original rownames as values
869 unique_rownames: list[str]
870 a list of rownames with duplicate names replaced by dummy names
871 colnames_mapper: dict[str, str]
872 a dictionary with new column names as keys and original column names as values
873 unique_colnames: list[str]
874 a list of column names with duplicate names replaced by dummy names
875
876 """
877
878 def get_duplicates(names):
879 seen: set = set()
880 return {name for name in names if name not in seen}
881
882 shared_names = set(rownames).intersection(set(colnames))
883 dup_names = get_duplicates(rownames) | get_duplicates(colnames) | shared_names
884
885 rownames_mapper = {
886 f"row_{i}": name for i, name in enumerate(rownames) if name in dup_names
887 }
888 unique_rownames = [
889 f"row_{i}" if name in dup_names else name for i, name in enumerate(rownames)
890 ]
891
892 colnames_mapper = {
893 f"col_{i}": name for i, name in enumerate(colnames) if name in dup_names
894 }
895 unique_colnames = [
896 f"col_{i}" if name in dup_names else name for i, name in enumerate(colnames)
897 ]
898
899 return rownames_mapper, unique_rownames, colnames_mapper, unique_colnames