import hail
from hail.utils.java import Env, joption, error
from hail.typecheck import enumeration, typecheck, nullable
import difflib
from collections import defaultdict, Counter, OrderedDict
import atexit
import shutil
import tempfile
[docs]@typecheck(n_rows=int, n_cols=int, n_partitions=nullable(int))
def range_matrix_table(n_rows, n_cols, n_partitions=None) -> 'hail.MatrixTable':
"""Construct a matrix table with row and column indices and no entry fields.
Examples
--------
>>> range_ds = hl.utils.range_matrix_table(n_rows=100, n_cols=10)
>>> range_ds.count_rows()
100
>>> range_ds.count_cols()
10
Notes
-----
The resulting matrix table contains the following fields:
- `row_idx` (:py:data:`.tint32`) - Row index (row key).
- `col_idx` (:py:data:`.tint32`) - Column index (column key).
It contains no entry fields.
This method is meant for testing and learning, and is not optimized for
production performance.
Parameters
----------
n_rows : :obj:`int`
Number of rows.
n_cols : :obj:`int`
Number of columns.
n_partitions : int, optional
Number of partitions (uses Spark default parallelism if None).
Returns
-------
:class:`.MatrixTable`
"""
check_positive_and_in_range('range_matrix_table', 'n_rows', n_rows)
check_positive_and_in_range('range_matrix_table', 'n_cols', n_cols)
if n_partitions is not None:
check_positive_and_in_range('range_matrix_table', 'n_partitions', n_partitions)
return hail.MatrixTable(Env.hail().variant.MatrixTable.range(Env.hc()._jhc, n_rows, n_cols, joption(n_partitions)))
[docs]@typecheck(n=int, n_partitions=nullable(int))
def range_table(n, n_partitions=None) -> 'hail.Table':
"""Construct a table with the row index and no other fields.
Examples
--------
>>> df = hl.utils.range_table(100)
>>> df.count()
100
Notes
-----
The resulting table contains one field:
- `idx` (:py:data:`.tint32`) - Row index (key).
This method is meant for testing and learning, and is not optimized for
production performance.
Parameters
----------
n : int
Number of rows.
n_partitions : int, optional
Number of partitions (uses Spark default parallelism if None).
Returns
-------
:class:`.Table`
"""
check_positive_and_in_range('range_table', 'n', n)
if n_partitions is not None:
check_positive_and_in_range('range_table', 'n_partitions', n_partitions)
return hail.Table(Env.hail().table.Table.range(Env.hc()._jhc, n, joption(n_partitions)))
def check_positive_and_in_range(caller, name, value):
if value <= 0:
raise ValueError(f"'{caller}': parameter '{name}' must be positive, found {value}")
elif value > hail.tint32.max_value:
raise ValueError(f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, "
f"found {value}")
def wrap_to_list(s):
if isinstance(s, list):
return s
else:
return [s]
def wrap_to_tuple(x):
if isinstance(x, tuple):
return x
else:
return x,
def wrap_to_sequence(x):
if isinstance(x, tuple):
return x
if isinstance(x, list):
return tuple(x)
else:
return x,
def get_env_or_default(maybe, envvar, default):
import os
return maybe or os.environ.get(envvar) or default
def uri_path(uri):
return Env.jutils().uriPath(uri)
def local_path_uri(path):
return 'file://' + path
def new_temp_file(suffix=None, prefix=None, n_char=10):
return Env.hc()._jhc.getTemporaryFile(n_char, joption(prefix), joption(suffix))
def new_local_temp_dir(suffix=None, prefix=None, dir=None):
local_temp_dir = tempfile.mkdtemp(suffix, prefix, dir)
atexit.register(shutil.rmtree, local_temp_dir)
return local_temp_dir
def new_local_temp_file(filename="temp"):
local_temp_dir = new_local_temp_dir()
path = local_temp_dir + "/" + filename
return path
storage_level = enumeration('NONE', 'DISK_ONLY', 'DISK_ONLY_2', 'MEMORY_ONLY',
'MEMORY_ONLY_2', 'MEMORY_ONLY_SER', 'MEMORY_ONLY_SER_2',
'MEMORY_AND_DISK', 'MEMORY_AND_DISK_2', 'MEMORY_AND_DISK_SER',
'MEMORY_AND_DISK_SER_2', 'OFF_HEAP')
def run_command(args):
import subprocess as sp
try:
sp.check_output(args, stderr=sp.STDOUT)
except sp.CalledProcessError as e:
print(e.output)
raise e
def plural(orig, n, alternate=None):
if n == 1:
return orig
elif alternate:
return alternate
else:
return orig + 's'
def get_obj_metadata(obj):
from hail.matrixtable import MatrixTable, GroupedMatrixTable
from hail.table import Table, GroupedTable
from hail.utils import Struct
from hail.expr.expressions import StructExpression
def table_error(index_obj):
def fmt_field(field):
assert field in index_obj._fields
inds = index_obj[field]._indices
if inds == index_obj._global_indices:
return "'{}' [globals]".format(field)
elif inds == index_obj._row_indices:
return "'{}' [row]".format(field)
elif inds == index_obj._col_indices: # Table will never get here
return "'{}' [col]".format(field)
else:
assert inds == index_obj._entry_indices
return "'{}' [entry]".format(field)
return fmt_field
def struct_error(s):
def fmt_field(field):
assert field in s._fields
return "'{}'".format(field)
return fmt_field
if isinstance(obj, MatrixTable):
return 'MatrixTable', MatrixTable, table_error(obj), True
elif isinstance(obj, GroupedMatrixTable):
return 'GroupedMatrixTable', GroupedMatrixTable, table_error(obj._parent), True
elif isinstance(obj, Table):
return 'Table', Table, table_error(obj), True
elif isinstance(obj, GroupedTable):
return 'GroupedTable', GroupedTable, table_error(obj), False
elif isinstance(obj, Struct):
return 'Struct', Struct, struct_error(obj), False
elif isinstance(obj, StructExpression):
return 'StructExpression', StructExpression, struct_error(obj), True
else:
raise NotImplementedError(obj)
def get_nice_attr_error(obj, item):
class_name, cls, handler, has_describe = get_obj_metadata(obj)
if item.startswith('_'):
# don't handle 'private' attribute access
return "{} instance has no attribute '{}'".format(class_name, item)
else:
field_names = obj._fields.keys()
field_dict = defaultdict(lambda: [])
for f in field_names:
field_dict[f.lower()].append(f)
obj_namespace = {x for x in dir(cls) if not x.startswith('_')}
inherited = {x for x in obj_namespace if x not in cls.__dict__}
methods = {x for x in obj_namespace if x in cls.__dict__ and callable(cls.__dict__[x])}
props = obj_namespace - methods - inherited
item_lower = item.lower()
field_matches = difflib.get_close_matches(item_lower, field_dict, n=5)
inherited_matches = difflib.get_close_matches(item_lower, inherited, n=5)
method_matches = difflib.get_close_matches(item_lower, methods, n=5)
prop_matches = difflib.get_close_matches(item_lower, props, n=5)
s = ["{} instance has no field, method, or property '{}'".format(class_name, item)]
if any([field_matches, method_matches, prop_matches, inherited_matches]):
s.append('\n Did you mean:')
if field_matches:
l = []
for f in field_matches:
l.extend(field_dict[f])
word = plural('field', len(l))
s.append('\n Data {}: {}'.format(word, ', '.join(handler(f) for f in l)))
if method_matches:
word = plural('method', len(method_matches))
s.append('\n {} {}: {}'.format(class_name, word,
', '.join("'{}'".format(m) for m in method_matches)))
if prop_matches:
word = plural('property', len(prop_matches), 'properties')
s.append('\n {} {}: {}'.format(class_name, word,
', '.join("'{}'".format(p) for p in prop_matches)))
if inherited_matches:
word = plural('inherited method', len(inherited_matches))
s.append('\n {} {}: {}'.format(class_name, word,
', '.join("'{}'".format(m) for m in inherited_matches)))
elif has_describe:
s.append("\n Hint: use 'describe()' to show the names of all data fields.")
return ''.join(s)
def get_nice_field_error(obj, item):
class_name, _, handler, has_describe = get_obj_metadata(obj)
field_names = obj._fields.keys()
dd = defaultdict(lambda: [])
for f in field_names:
dd[f.lower()].append(f)
item_lower = item.lower()
field_matches = difflib.get_close_matches(item_lower, dd, n=5)
s = ["{} instance has no field '{}'".format(class_name, item)]
if field_matches:
s.append('\n Did you mean:')
for f in field_matches:
for orig_f in dd[f]:
s.append("\n {}".format(handler(orig_f)))
if has_describe:
s.append("\n Hint: use 'describe()' to show the names of all data fields.")
return ''.join(s)
def check_collisions(fields, name, indices):
from hail.expr.expressions import ExpressionException
if name in fields and not fields[name]._indices == indices:
msg = "name collision with field indexed by {}: {}".format(list(fields[name]._indices.axes), repr(name))
error('Analysis exception: {}'.format(msg))
raise ExpressionException(msg)
def check_field_uniqueness(fields):
for k, v in Counter(fields).items():
if v > 1:
from hail.expr.expressions import ExpressionException
raise ExpressionException("selection would produce duplicate field '{}'".format(repr(k)))
def check_keys(name, indices):
from hail.expr.expressions import ExpressionException
if indices.key is None:
return
if name in set(indices.key):
msg = "cannot overwrite key field {} with annotate, select or drop; use key_by to modify keys.".format(repr(name))
error('Analysis exception: {}'.format(msg))
raise ExpressionException(msg)
def get_select_exprs(caller, exprs, named_exprs, indices, protect_keys=True):
from hail.expr.expressions import to_expr, ExpressionException, TopLevelReference, Select
exprs = [to_expr(e) if not isinstance(e, str) else indices.source[e] for e in exprs]
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
assignments = OrderedDict()
for e in exprs:
if not e._indices == indices:
raise ExpressionException("method '{}' parameter 'exprs' expects {}-indexed fields,"
" found indices {}".format(caller, list(indices.axes), list(e._indices.axes)))
if not e._ast.is_nested_field:
raise ExpressionException("method '{}' expects keyword arguments for complex expressions".format(caller))
if protect_keys:
check_keys(e._ast.name, indices)
assignments[e._ast.name] = e
for k, e in named_exprs.items():
if protect_keys:
check_keys(k, indices)
check_collisions(indices.source._fields, k, indices)
assignments[k] = e
check_field_uniqueness(assignments.keys())
return assignments
def get_annotate_exprs(caller, named_exprs, indices):
from hail.expr.expressions import to_expr, ExpressionException
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
for k, v in named_exprs.items():
check_keys(k, indices)
if indices.key and k in indices.key.keys():
raise ExpressionException("'{}' cannot overwrite key field: {}"
.format(caller, repr(k)))
check_collisions(indices.source._fields, k, indices)
return named_exprs
def process_joins(obj, exprs, broadcast_f):
all_uids = []
left = obj
used_joins = set()
broadcasts = []
for e in exprs:
joins = e._ast.search(lambda a: isinstance(a, hail.expr.expr_ast.Join))
for j in sorted(joins, key=lambda j: j.idx): # Make sure joins happen in order
if j not in used_joins:
left = j.join_func(left)
all_uids.extend(j.temp_vars)
used_joins.add(j)
broadcasts.extend(e._ast.search(lambda a: isinstance(a, hail.expr.expr_ast.Broadcast)))
if broadcasts:
t = hail.tstruct(**{b.uid: b.dtype for b in broadcasts})
all_uids.extend(list(t))
data = hail.Struct(**{b.uid: b.value for b in broadcasts})
data_json = t._to_json(data)
left = broadcast_f(left, data_json, t._jtype)
def cleanup(table):
remaining_uids = [uid for uid in all_uids if uid in table._fields]
return table.drop(*remaining_uids)
return left, cleanup