Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/io/_mmio.py: 16%
422 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""
2 Matrix Market I/O in Python.
3 See http://math.nist.gov/MatrixMarket/formats.html
4 for information about the Matrix Market format.
5"""
6#
7# Author: Pearu Peterson <pearu@cens.ioc.ee>
8# Created: October, 2004
9#
10# References:
11# http://math.nist.gov/MatrixMarket/
12#
13import os
15import numpy as np
16from numpy import (asarray, real, imag, conj, zeros, ndarray, concatenate,
17 ones, can_cast)
19from scipy.sparse import coo_matrix, issparse
21__all__ = ['mminfo', 'mmread', 'mmwrite', 'MMFile']
24# -----------------------------------------------------------------------------
25def asstr(s):
26 if isinstance(s, bytes):
27 return s.decode('latin1')
28 return str(s)
31def mminfo(source):
32 """
33 Return size and storage parameters from Matrix Market file-like 'source'.
35 Parameters
36 ----------
37 source : str or file-like
38 Matrix Market filename (extension .mtx) or open file-like object
40 Returns
41 -------
42 rows : int
43 Number of matrix rows.
44 cols : int
45 Number of matrix columns.
46 entries : int
47 Number of non-zero entries of a sparse matrix
48 or rows*cols for a dense matrix.
49 format : str
50 Either 'coordinate' or 'array'.
51 field : str
52 Either 'real', 'complex', 'pattern', or 'integer'.
53 symmetry : str
54 Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
56 Examples
57 --------
58 >>> from io import StringIO
59 >>> from scipy.io import mminfo
61 >>> text = '''%%MatrixMarket matrix coordinate real general
62 ... 5 5 7
63 ... 2 3 1.0
64 ... 3 4 2.0
65 ... 3 5 3.0
66 ... 4 1 4.0
67 ... 4 2 5.0
68 ... 4 3 6.0
69 ... 4 4 7.0
70 ... '''
73 ``mminfo(source)`` returns the number of rows, number of columns,
74 format, field type and symmetry attribute of the source file.
76 >>> mminfo(StringIO(text))
77 (5, 5, 7, 'coordinate', 'real', 'general')
78 """
79 return MMFile.info(source)
81# -----------------------------------------------------------------------------
84def mmread(source):
85 """
86 Reads the contents of a Matrix Market file-like 'source' into a matrix.
88 Parameters
89 ----------
90 source : str or file-like
91 Matrix Market filename (extensions .mtx, .mtz.gz)
92 or open file-like object.
94 Returns
95 -------
96 a : ndarray or coo_matrix
97 Dense or sparse matrix depending on the matrix format in the
98 Matrix Market file.
100 Examples
101 --------
102 >>> from io import StringIO
103 >>> from scipy.io import mmread
105 >>> text = '''%%MatrixMarket matrix coordinate real general
106 ... 5 5 7
107 ... 2 3 1.0
108 ... 3 4 2.0
109 ... 3 5 3.0
110 ... 4 1 4.0
111 ... 4 2 5.0
112 ... 4 3 6.0
113 ... 4 4 7.0
114 ... '''
116 ``mmread(source)`` returns the data as sparse matrix in COO format.
118 >>> m = mmread(StringIO(text))
119 >>> m
120 <5x5 sparse matrix of type '<class 'numpy.float64'>'
121 with 7 stored elements in COOrdinate format>
122 >>> m.A
123 array([[0., 0., 0., 0., 0.],
124 [0., 0., 1., 0., 0.],
125 [0., 0., 0., 2., 3.],
126 [4., 5., 6., 7., 0.],
127 [0., 0., 0., 0., 0.]])
128 """
129 return MMFile().read(source)
131# -----------------------------------------------------------------------------
134def mmwrite(target, a, comment='', field=None, precision=None, symmetry=None):
135 r"""
136 Writes the sparse or dense array `a` to Matrix Market file-like `target`.
138 Parameters
139 ----------
140 target : str or file-like
141 Matrix Market filename (extension .mtx) or open file-like object.
142 a : array like
143 Sparse or dense 2-D array.
144 comment : str, optional
145 Comments to be prepended to the Matrix Market file.
146 field : None or str, optional
147 Either 'real', 'complex', 'pattern', or 'integer'.
148 precision : None or int, optional
149 Number of digits to display for real or complex values.
150 symmetry : None or str, optional
151 Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
152 If symmetry is None the symmetry type of 'a' is determined by its
153 values.
155 Returns
156 -------
157 None
159 Examples
160 --------
161 >>> from io import BytesIO
162 >>> import numpy as np
163 >>> from scipy.sparse import coo_matrix
164 >>> from scipy.io import mmwrite
166 Write a small NumPy array to a matrix market file. The file will be
167 written in the ``'array'`` format.
169 >>> a = np.array([[1.0, 0, 0, 0], [0, 2.5, 0, 6.25]])
170 >>> target = BytesIO()
171 >>> mmwrite(target, a)
172 >>> print(target.getvalue().decode('latin1'))
173 %%MatrixMarket matrix array real general
174 %
175 2 4
176 1.0000000000000000e+00
177 0.0000000000000000e+00
178 0.0000000000000000e+00
179 2.5000000000000000e+00
180 0.0000000000000000e+00
181 0.0000000000000000e+00
182 0.0000000000000000e+00
183 6.2500000000000000e+00
185 Add a comment to the output file, and set the precision to 3.
187 >>> target = BytesIO()
188 >>> mmwrite(target, a, comment='\n Some test data.\n', precision=3)
189 >>> print(target.getvalue().decode('latin1'))
190 %%MatrixMarket matrix array real general
191 %
192 % Some test data.
193 %
194 2 4
195 1.000e+00
196 0.000e+00
197 0.000e+00
198 2.500e+00
199 0.000e+00
200 0.000e+00
201 0.000e+00
202 6.250e+00
204 Convert to a sparse matrix before calling ``mmwrite``. This will
205 result in the output format being ``'coordinate'`` rather than
206 ``'array'``.
208 >>> target = BytesIO()
209 >>> mmwrite(target, coo_matrix(a), precision=3)
210 >>> print(target.getvalue().decode('latin1'))
211 %%MatrixMarket matrix coordinate real general
212 %
213 2 4 3
214 1 1 1.00e+00
215 2 2 2.50e+00
216 2 4 6.25e+00
218 Write a complex Hermitian array to a matrix market file. Note that
219 only six values are actually written to the file; the other values
220 are implied by the symmetry.
222 >>> z = np.array([[3, 1+2j, 4-3j], [1-2j, 1, -5j], [4+3j, 5j, 2.5]])
223 >>> z
224 array([[ 3. +0.j, 1. +2.j, 4. -3.j],
225 [ 1. -2.j, 1. +0.j, -0. -5.j],
226 [ 4. +3.j, 0. +5.j, 2.5+0.j]])
228 >>> target = BytesIO()
229 >>> mmwrite(target, z, precision=2)
230 >>> print(target.getvalue().decode('latin1'))
231 %%MatrixMarket matrix array complex hermitian
232 %
233 3 3
234 3.00e+00 0.00e+00
235 1.00e+00 -2.00e+00
236 4.00e+00 3.00e+00
237 1.00e+00 0.00e+00
238 0.00e+00 5.00e+00
239 2.50e+00 0.00e+00
241 """
242 MMFile().write(target, a, comment, field, precision, symmetry)
245###############################################################################
246class MMFile:
247 __slots__ = ('_rows',
248 '_cols',
249 '_entries',
250 '_format',
251 '_field',
252 '_symmetry')
254 @property
255 def rows(self):
256 return self._rows
258 @property
259 def cols(self):
260 return self._cols
262 @property
263 def entries(self):
264 return self._entries
266 @property
267 def format(self):
268 return self._format
270 @property
271 def field(self):
272 return self._field
274 @property
275 def symmetry(self):
276 return self._symmetry
278 @property
279 def has_symmetry(self):
280 return self._symmetry in (self.SYMMETRY_SYMMETRIC,
281 self.SYMMETRY_SKEW_SYMMETRIC,
282 self.SYMMETRY_HERMITIAN)
284 # format values
285 FORMAT_COORDINATE = 'coordinate'
286 FORMAT_ARRAY = 'array'
287 FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
289 @classmethod
290 def _validate_format(self, format):
291 if format not in self.FORMAT_VALUES:
292 raise ValueError('unknown format type %s, must be one of %s' %
293 (format, self.FORMAT_VALUES))
295 # field values
296 FIELD_INTEGER = 'integer'
297 FIELD_UNSIGNED = 'unsigned-integer'
298 FIELD_REAL = 'real'
299 FIELD_COMPLEX = 'complex'
300 FIELD_PATTERN = 'pattern'
301 FIELD_VALUES = (FIELD_INTEGER, FIELD_UNSIGNED, FIELD_REAL, FIELD_COMPLEX,
302 FIELD_PATTERN)
304 @classmethod
305 def _validate_field(self, field):
306 if field not in self.FIELD_VALUES:
307 raise ValueError('unknown field type %s, must be one of %s' %
308 (field, self.FIELD_VALUES))
310 # symmetry values
311 SYMMETRY_GENERAL = 'general'
312 SYMMETRY_SYMMETRIC = 'symmetric'
313 SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
314 SYMMETRY_HERMITIAN = 'hermitian'
315 SYMMETRY_VALUES = (SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC,
316 SYMMETRY_SKEW_SYMMETRIC, SYMMETRY_HERMITIAN)
318 @classmethod
319 def _validate_symmetry(self, symmetry):
320 if symmetry not in self.SYMMETRY_VALUES:
321 raise ValueError('unknown symmetry type %s, must be one of %s' %
322 (symmetry, self.SYMMETRY_VALUES))
324 DTYPES_BY_FIELD = {FIELD_INTEGER: 'intp',
325 FIELD_UNSIGNED: 'uint64',
326 FIELD_REAL: 'd',
327 FIELD_COMPLEX: 'D',
328 FIELD_PATTERN: 'd'}
330 # -------------------------------------------------------------------------
331 @staticmethod
332 def reader():
333 pass
335 # -------------------------------------------------------------------------
336 @staticmethod
337 def writer():
338 pass
340 # -------------------------------------------------------------------------
341 @classmethod
342 def info(self, source):
343 """
344 Return size, storage parameters from Matrix Market file-like 'source'.
346 Parameters
347 ----------
348 source : str or file-like
349 Matrix Market filename (extension .mtx) or open file-like object
351 Returns
352 -------
353 rows : int
354 Number of matrix rows.
355 cols : int
356 Number of matrix columns.
357 entries : int
358 Number of non-zero entries of a sparse matrix
359 or rows*cols for a dense matrix.
360 format : str
361 Either 'coordinate' or 'array'.
362 field : str
363 Either 'real', 'complex', 'pattern', or 'integer'.
364 symmetry : str
365 Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
366 """
368 stream, close_it = self._open(source)
370 try:
372 # read and validate header line
373 line = stream.readline()
374 mmid, matrix, format, field, symmetry = \
375 (asstr(part.strip()) for part in line.split())
376 if not mmid.startswith('%%MatrixMarket'):
377 raise ValueError('source is not in Matrix Market format')
378 if not matrix.lower() == 'matrix':
379 raise ValueError("Problem reading file header: " + line)
381 # http://math.nist.gov/MatrixMarket/formats.html
382 if format.lower() == 'array':
383 format = self.FORMAT_ARRAY
384 elif format.lower() == 'coordinate':
385 format = self.FORMAT_COORDINATE
387 # skip comments
388 # line.startswith('%')
389 while line:
390 if line.lstrip() and line.lstrip()[0] in ['%', 37]:
391 line = stream.readline()
392 else:
393 break
395 # skip empty lines
396 while not line.strip():
397 line = stream.readline()
399 split_line = line.split()
400 if format == self.FORMAT_ARRAY:
401 if not len(split_line) == 2:
402 raise ValueError("Header line not of length 2: " +
403 line.decode('ascii'))
404 rows, cols = map(int, split_line)
405 entries = rows * cols
406 else:
407 if not len(split_line) == 3:
408 raise ValueError("Header line not of length 3: " +
409 line.decode('ascii'))
410 rows, cols, entries = map(int, split_line)
412 return (rows, cols, entries, format, field.lower(),
413 symmetry.lower())
415 finally:
416 if close_it:
417 stream.close()
419 # -------------------------------------------------------------------------
420 @staticmethod
421 def _open(filespec, mode='rb'):
422 """ Return an open file stream for reading based on source.
424 If source is a file name, open it (after trying to find it with mtx and
425 gzipped mtx extensions). Otherwise, just return source.
427 Parameters
428 ----------
429 filespec : str or file-like
430 String giving file name or file-like object
431 mode : str, optional
432 Mode with which to open file, if `filespec` is a file name.
434 Returns
435 -------
436 fobj : file-like
437 Open file-like object.
438 close_it : bool
439 True if the calling function should close this file when done,
440 false otherwise.
441 """
442 # If 'filespec' is path-like (str, pathlib.Path, os.DirEntry, other class
443 # implementing a '__fspath__' method), try to convert it to str. If this
444 # fails by throwing a 'TypeError', assume it's an open file handle and
445 # return it as-is.
446 try:
447 filespec = os.fspath(filespec)
448 except TypeError:
449 return filespec, False
451 # 'filespec' is definitely a str now
453 # open for reading
454 if mode[0] == 'r':
456 # determine filename plus extension
457 if not os.path.isfile(filespec):
458 if os.path.isfile(filespec+'.mtx'):
459 filespec = filespec + '.mtx'
460 elif os.path.isfile(filespec+'.mtx.gz'):
461 filespec = filespec + '.mtx.gz'
462 elif os.path.isfile(filespec+'.mtx.bz2'):
463 filespec = filespec + '.mtx.bz2'
464 # open filename
465 if filespec.endswith('.gz'):
466 import gzip
467 stream = gzip.open(filespec, mode)
468 elif filespec.endswith('.bz2'):
469 import bz2
470 stream = bz2.BZ2File(filespec, 'rb')
471 else:
472 stream = open(filespec, mode)
474 # open for writing
475 else:
476 if filespec[-4:] != '.mtx':
477 filespec = filespec + '.mtx'
478 stream = open(filespec, mode)
480 return stream, True
482 # -------------------------------------------------------------------------
483 @staticmethod
484 def _get_symmetry(a):
485 m, n = a.shape
486 if m != n:
487 return MMFile.SYMMETRY_GENERAL
488 issymm = True
489 isskew = True
490 isherm = a.dtype.char in 'FD'
492 # sparse input
493 if issparse(a):
494 # check if number of nonzero entries of lower and upper triangle
495 # matrix are equal
496 a = a.tocoo()
497 (row, col) = a.nonzero()
498 if (row < col).sum() != (row > col).sum():
499 return MMFile.SYMMETRY_GENERAL
501 # define iterator over symmetric pair entries
502 a = a.todok()
504 def symm_iterator():
505 for ((i, j), aij) in a.items():
506 if i > j:
507 aji = a[j, i]
508 yield (aij, aji, False)
509 elif i == j:
510 yield (aij, aij, True)
512 # non-sparse input
513 else:
514 # define iterator over symmetric pair entries
515 def symm_iterator():
516 for j in range(n):
517 for i in range(j, n):
518 aij, aji = a[i][j], a[j][i]
519 yield (aij, aji, i == j)
521 # check for symmetry
522 # yields aij, aji, is_diagonal
523 for (aij, aji, is_diagonal) in symm_iterator():
524 if isskew and is_diagonal and aij != 0:
525 isskew = False
526 else:
527 if issymm and aij != aji:
528 issymm = False
529 with np.errstate(over="ignore"):
530 # This can give a warning for uint dtypes, so silence that
531 if isskew and aij != -aji:
532 isskew = False
533 if isherm and aij != conj(aji):
534 isherm = False
535 if not (issymm or isskew or isherm):
536 break
538 # return symmetry value
539 if issymm:
540 return MMFile.SYMMETRY_SYMMETRIC
541 if isskew:
542 return MMFile.SYMMETRY_SKEW_SYMMETRIC
543 if isherm:
544 return MMFile.SYMMETRY_HERMITIAN
545 return MMFile.SYMMETRY_GENERAL
547 # -------------------------------------------------------------------------
548 @staticmethod
549 def _field_template(field, precision):
550 return {MMFile.FIELD_REAL: '%%.%ie\n' % precision,
551 MMFile.FIELD_INTEGER: '%i\n',
552 MMFile.FIELD_UNSIGNED: '%u\n',
553 MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' %
554 (precision, precision)
555 }.get(field, None)
557 # -------------------------------------------------------------------------
558 def __init__(self, **kwargs):
559 self._init_attrs(**kwargs)
561 # -------------------------------------------------------------------------
562 def read(self, source):
563 """
564 Reads the contents of a Matrix Market file-like 'source' into a matrix.
566 Parameters
567 ----------
568 source : str or file-like
569 Matrix Market filename (extensions .mtx, .mtz.gz)
570 or open file object.
572 Returns
573 -------
574 a : ndarray or coo_matrix
575 Dense or sparse matrix depending on the matrix format in the
576 Matrix Market file.
577 """
578 stream, close_it = self._open(source)
580 try:
581 self._parse_header(stream)
582 return self._parse_body(stream)
584 finally:
585 if close_it:
586 stream.close()
588 # -------------------------------------------------------------------------
589 def write(self, target, a, comment='', field=None, precision=None,
590 symmetry=None):
591 """
592 Writes sparse or dense array `a` to Matrix Market file-like `target`.
594 Parameters
595 ----------
596 target : str or file-like
597 Matrix Market filename (extension .mtx) or open file-like object.
598 a : array like
599 Sparse or dense 2-D array.
600 comment : str, optional
601 Comments to be prepended to the Matrix Market file.
602 field : None or str, optional
603 Either 'real', 'complex', 'pattern', or 'integer'.
604 precision : None or int, optional
605 Number of digits to display for real or complex values.
606 symmetry : None or str, optional
607 Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
608 If symmetry is None the symmetry type of 'a' is determined by its
609 values.
610 """
612 stream, close_it = self._open(target, 'wb')
614 try:
615 self._write(stream, a, comment, field, precision, symmetry)
617 finally:
618 if close_it:
619 stream.close()
620 else:
621 stream.flush()
623 # -------------------------------------------------------------------------
624 def _init_attrs(self, **kwargs):
625 """
626 Initialize each attributes with the corresponding keyword arg value
627 or a default of None
628 """
630 attrs = self.__class__.__slots__
631 public_attrs = [attr[1:] for attr in attrs]
632 invalid_keys = set(kwargs.keys()) - set(public_attrs)
634 if invalid_keys:
635 raise ValueError('''found {} invalid keyword arguments, please only
636 use {}'''.format(tuple(invalid_keys),
637 public_attrs))
639 for attr in attrs:
640 setattr(self, attr, kwargs.get(attr[1:], None))
642 # -------------------------------------------------------------------------
643 def _parse_header(self, stream):
644 rows, cols, entries, format, field, symmetry = \
645 self.__class__.info(stream)
646 self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
647 field=field, symmetry=symmetry)
649 # -------------------------------------------------------------------------
650 def _parse_body(self, stream):
651 rows, cols, entries, format, field, symm = (self.rows, self.cols,
652 self.entries, self.format,
653 self.field, self.symmetry)
655 dtype = self.DTYPES_BY_FIELD.get(field, None)
657 has_symmetry = self.has_symmetry
658 is_integer = field == self.FIELD_INTEGER
659 is_unsigned_integer = field == self.FIELD_UNSIGNED
660 is_complex = field == self.FIELD_COMPLEX
661 is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
662 is_herm = symm == self.SYMMETRY_HERMITIAN
663 is_pattern = field == self.FIELD_PATTERN
665 if format == self.FORMAT_ARRAY:
666 a = zeros((rows, cols), dtype=dtype)
667 line = 1
668 i, j = 0, 0
669 if is_skew:
670 a[i, j] = 0
671 if i < rows - 1:
672 i += 1
673 while line:
674 line = stream.readline()
675 # line.startswith('%')
676 if not line or line[0] in ['%', 37] or not line.strip():
677 continue
678 if is_integer:
679 aij = int(line)
680 elif is_unsigned_integer:
681 aij = int(line)
682 elif is_complex:
683 aij = complex(*map(float, line.split()))
684 else:
685 aij = float(line)
686 a[i, j] = aij
687 if has_symmetry and i != j:
688 if is_skew:
689 a[j, i] = -aij
690 elif is_herm:
691 a[j, i] = conj(aij)
692 else:
693 a[j, i] = aij
694 if i < rows-1:
695 i = i + 1
696 else:
697 j = j + 1
698 if not has_symmetry:
699 i = 0
700 else:
701 i = j
702 if is_skew:
703 a[i, j] = 0
704 if i < rows-1:
705 i += 1
707 if is_skew:
708 if not (i in [0, j] and j == cols - 1):
709 raise ValueError("Parse error, did not read all lines.")
710 else:
711 if not (i in [0, j] and j == cols):
712 raise ValueError("Parse error, did not read all lines.")
714 elif format == self.FORMAT_COORDINATE:
715 # Read sparse COOrdinate format
717 if entries == 0:
718 # empty matrix
719 return coo_matrix((rows, cols), dtype=dtype)
721 I = zeros(entries, dtype='intc')
722 J = zeros(entries, dtype='intc')
723 if is_pattern:
724 V = ones(entries, dtype='int8')
725 elif is_integer:
726 V = zeros(entries, dtype='intp')
727 elif is_unsigned_integer:
728 V = zeros(entries, dtype='uint64')
729 elif is_complex:
730 V = zeros(entries, dtype='complex')
731 else:
732 V = zeros(entries, dtype='float')
734 entry_number = 0
735 for line in stream:
736 # line.startswith('%')
737 if not line or line[0] in ['%', 37] or not line.strip():
738 continue
740 if entry_number+1 > entries:
741 raise ValueError("'entries' in header is smaller than "
742 "number of entries")
743 l = line.split()
744 I[entry_number], J[entry_number] = map(int, l[:2])
746 if not is_pattern:
747 if is_integer:
748 V[entry_number] = int(l[2])
749 elif is_unsigned_integer:
750 V[entry_number] = int(l[2])
751 elif is_complex:
752 V[entry_number] = complex(*map(float, l[2:]))
753 else:
754 V[entry_number] = float(l[2])
755 entry_number += 1
756 if entry_number < entries:
757 raise ValueError("'entries' in header is larger than "
758 "number of entries")
760 I -= 1 # adjust indices (base 1 -> base 0)
761 J -= 1
763 if has_symmetry:
764 mask = (I != J) # off diagonal mask
765 od_I = I[mask]
766 od_J = J[mask]
767 od_V = V[mask]
769 I = concatenate((I, od_J))
770 J = concatenate((J, od_I))
772 if is_skew:
773 od_V *= -1
774 elif is_herm:
775 od_V = od_V.conjugate()
777 V = concatenate((V, od_V))
779 a = coo_matrix((V, (I, J)), shape=(rows, cols), dtype=dtype)
780 else:
781 raise NotImplementedError(format)
783 return a
785 # ------------------------------------------------------------------------
786 def _write(self, stream, a, comment='', field=None, precision=None,
787 symmetry=None):
788 if isinstance(a, list) or isinstance(a, ndarray) or \
789 isinstance(a, tuple) or hasattr(a, '__array__'):
790 rep = self.FORMAT_ARRAY
791 a = asarray(a)
792 if len(a.shape) != 2:
793 raise ValueError('Expected 2 dimensional array')
794 rows, cols = a.shape
796 if field is not None:
798 if field == self.FIELD_INTEGER:
799 if not can_cast(a.dtype, 'intp'):
800 raise OverflowError("mmwrite does not support integer "
801 "dtypes larger than native 'intp'.")
802 a = a.astype('intp')
803 elif field == self.FIELD_REAL:
804 if a.dtype.char not in 'fd':
805 a = a.astype('d')
806 elif field == self.FIELD_COMPLEX:
807 if a.dtype.char not in 'FD':
808 a = a.astype('D')
810 else:
811 if not issparse(a):
812 raise ValueError('unknown matrix type: %s' % type(a))
814 rep = 'coordinate'
815 rows, cols = a.shape
817 typecode = a.dtype.char
819 if precision is None:
820 if typecode in 'fF':
821 precision = 8
822 else:
823 precision = 16
824 if field is None:
825 kind = a.dtype.kind
826 if kind == 'i':
827 if not can_cast(a.dtype, 'intp'):
828 raise OverflowError("mmwrite does not support integer "
829 "dtypes larger than native 'intp'.")
830 field = 'integer'
831 elif kind == 'f':
832 field = 'real'
833 elif kind == 'c':
834 field = 'complex'
835 elif kind == 'u':
836 field = 'unsigned-integer'
837 else:
838 raise TypeError('unexpected dtype kind ' + kind)
840 if symmetry is None:
841 symmetry = self._get_symmetry(a)
843 # validate rep, field, and symmetry
844 self.__class__._validate_format(rep)
845 self.__class__._validate_field(field)
846 self.__class__._validate_symmetry(symmetry)
848 # write initial header line
849 data = f'%%MatrixMarket matrix {rep} {field} {symmetry}\n'
850 stream.write(data.encode('latin1'))
852 # write comments
853 for line in comment.split('\n'):
854 data = '%%%s\n' % (line)
855 stream.write(data.encode('latin1'))
857 template = self._field_template(field, precision)
858 # write dense format
859 if rep == self.FORMAT_ARRAY:
860 # write shape spec
861 data = '%i %i\n' % (rows, cols)
862 stream.write(data.encode('latin1'))
864 if field in (self.FIELD_INTEGER, self.FIELD_REAL,
865 self.FIELD_UNSIGNED):
866 if symmetry == self.SYMMETRY_GENERAL:
867 for j in range(cols):
868 for i in range(rows):
869 data = template % a[i, j]
870 stream.write(data.encode('latin1'))
872 elif symmetry == self.SYMMETRY_SKEW_SYMMETRIC:
873 for j in range(cols):
874 for i in range(j + 1, rows):
875 data = template % a[i, j]
876 stream.write(data.encode('latin1'))
878 else:
879 for j in range(cols):
880 for i in range(j, rows):
881 data = template % a[i, j]
882 stream.write(data.encode('latin1'))
884 elif field == self.FIELD_COMPLEX:
886 if symmetry == self.SYMMETRY_GENERAL:
887 for j in range(cols):
888 for i in range(rows):
889 aij = a[i, j]
890 data = template % (real(aij), imag(aij))
891 stream.write(data.encode('latin1'))
892 else:
893 for j in range(cols):
894 for i in range(j, rows):
895 aij = a[i, j]
896 data = template % (real(aij), imag(aij))
897 stream.write(data.encode('latin1'))
899 elif field == self.FIELD_PATTERN:
900 raise ValueError('pattern type inconsisted with dense format')
902 else:
903 raise TypeError('Unknown field type %s' % field)
905 # write sparse format
906 else:
907 coo = a.tocoo() # convert to COOrdinate format
909 # if symmetry format used, remove values above main diagonal
910 if symmetry != self.SYMMETRY_GENERAL:
911 lower_triangle_mask = coo.row >= coo.col
912 coo = coo_matrix((coo.data[lower_triangle_mask],
913 (coo.row[lower_triangle_mask],
914 coo.col[lower_triangle_mask])),
915 shape=coo.shape)
917 # write shape spec
918 data = '%i %i %i\n' % (rows, cols, coo.nnz)
919 stream.write(data.encode('latin1'))
921 template = self._field_template(field, precision-1)
923 if field == self.FIELD_PATTERN:
924 for r, c in zip(coo.row+1, coo.col+1):
925 data = "%i %i\n" % (r, c)
926 stream.write(data.encode('latin1'))
927 elif field in (self.FIELD_INTEGER, self.FIELD_REAL,
928 self.FIELD_UNSIGNED):
929 for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
930 data = ("%i %i " % (r, c)) + (template % d)
931 stream.write(data.encode('latin1'))
932 elif field == self.FIELD_COMPLEX:
933 for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
934 data = ("%i %i " % (r, c)) + (template % (d.real, d.imag))
935 stream.write(data.encode('latin1'))
936 else:
937 raise TypeError('Unknown field type %s' % field)
940def _is_fromfile_compatible(stream):
941 """
942 Check whether `stream` is compatible with numpy.fromfile.
944 Passing a gzipped file object to ``fromfile/fromstring`` doesn't work with
945 Python 3.
946 """
948 bad_cls = []
949 try:
950 import gzip
951 bad_cls.append(gzip.GzipFile)
952 except ImportError:
953 pass
954 try:
955 import bz2
956 bad_cls.append(bz2.BZ2File)
957 except ImportError:
958 pass
960 bad_cls = tuple(bad_cls)
961 return not isinstance(stream, bad_cls)