Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scikit_learn-1.4.dev0-py3.8-linux-x86_64.egg/sklearn/utils/_mask.py: 28%

25 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1from contextlib import suppress 

2 

3import numpy as np 

4from scipy import sparse as sp 

5 

6from . import is_scalar_nan 

7from .fixes import _object_dtype_isnan 

8 

9 

10def _get_dense_mask(X, value_to_mask): 

11 with suppress(ImportError, AttributeError): 

12 # We also suppress `AttributeError` because older versions of pandas do 

13 # not have `NA`. 

14 import pandas 

15 

16 if value_to_mask is pandas.NA: 

17 return pandas.isna(X) 

18 

19 if is_scalar_nan(value_to_mask): 

20 if X.dtype.kind == "f": 

21 Xt = np.isnan(X) 

22 elif X.dtype.kind in ("i", "u"): 

23 # can't have NaNs in integer array. 

24 Xt = np.zeros(X.shape, dtype=bool) 

25 else: 

26 # np.isnan does not work on object dtypes. 

27 Xt = _object_dtype_isnan(X) 

28 else: 

29 Xt = X == value_to_mask 

30 

31 return Xt 

32 

33 

34def _get_mask(X, value_to_mask): 

35 """Compute the boolean mask X == value_to_mask. 

36 

37 Parameters 

38 ---------- 

39 X : {ndarray, sparse matrix} of shape (n_samples, n_features) 

40 Input data, where ``n_samples`` is the number of samples and 

41 ``n_features`` is the number of features. 

42 

43 value_to_mask : {int, float} 

44 The value which is to be masked in X. 

45 

46 Returns 

47 ------- 

48 X_mask : {ndarray, sparse matrix} of shape (n_samples, n_features) 

49 Missing mask. 

50 """ 

51 if not sp.issparse(X): 

52 # For all cases apart of a sparse input where we need to reconstruct 

53 # a sparse output 

54 return _get_dense_mask(X, value_to_mask) 

55 

56 Xt = _get_dense_mask(X.data, value_to_mask) 

57 

58 sparse_constructor = sp.csr_matrix if X.format == "csr" else sp.csc_matrix 

59 Xt_sparse = sparse_constructor( 

60 (Xt, X.indices.copy(), X.indptr.copy()), shape=X.shape, dtype=bool 

61 ) 

62 

63 return Xt_sparse