1"""
2transforms.py is for shape-preserving functions.
3"""
4
5from __future__ import annotations
6
7from typing import TYPE_CHECKING
8
9import numpy as np
10
11if TYPE_CHECKING:
12 from pandas._typing import (
13 AxisInt,
14 Scalar,
15 )
16
17
18def shift(
19 values: np.ndarray, periods: int, axis: AxisInt, fill_value: Scalar
20) -> np.ndarray:
21 new_values = values
22
23 if periods == 0 or values.size == 0:
24 return new_values.copy()
25
26 # make sure array sent to np.roll is c_contiguous
27 f_ordered = values.flags.f_contiguous
28 if f_ordered:
29 new_values = new_values.T
30 axis = new_values.ndim - axis - 1
31
32 if new_values.size:
33 new_values = np.roll(
34 new_values,
35 np.intp(periods),
36 axis=axis,
37 )
38
39 axis_indexer = [slice(None)] * values.ndim
40 if periods > 0:
41 axis_indexer[axis] = slice(None, periods)
42 else:
43 axis_indexer[axis] = slice(periods, None)
44 new_values[tuple(axis_indexer)] = fill_value
45
46 # restore original order
47 if f_ordered:
48 new_values = new_values.T
49
50 return new_values