Source code for hail.methods.misc

import hail as hl
from hail.expr.expr_ast import VariableReference
from hail.expr.expressions import *
from hail.expr.types import *
from hail.matrixtable import MatrixTable
from hail.table import Table
from hail.typecheck import *
from hail.utils import Interval, Struct
from hail.utils.java import Env, joption


[docs]@typecheck(i=Expression, j=Expression, keep=bool, tie_breaker=nullable(func_spec(2, expr_numeric))) def maximal_independent_set(i, j, keep=True, tie_breaker=None) -> Table: """Return a table containing the vertices in a near `maximal independent set <https://en.wikipedia.org/wiki/Maximal_independent_set>`_ of an undirected graph whose edges are given by a two-column table. Examples -------- Run PC-relate and compute pairs of closely related individuals: >>> pc_rel = hl.pc_relate(dataset.GT, 0.001, k=2, statistics='kin') >>> pairs = pc_rel.filter(pc_rel['kin'] > 0.125) Starting from the above pairs, prune individuals from a dataset until no close relationships remain: >>> related_samples_to_remove = hl.maximal_independent_set(pairs.i, pairs.j, False) >>> result = dataset.filter_cols( ... hl.is_defined(related_samples_to_remove[dataset.col_key]), keep=False) Starting from the above pairs, prune individuals from a dataset until no close relationships remain, preferring to keep cases over controls: >>> samples = dataset.cols() >>> pairs_with_case = pairs.key_by( ... i=hl.struct(id=pairs.i, is_case=samples[pairs.i].is_case), ... j=hl.struct(id=pairs.j, is_case=samples[pairs.j].is_case)) >>> def tie_breaker(l, r): ... return hl.cond(l.is_case & ~r.is_case, -1, ... hl.cond(~l.is_case & r.is_case, 1, 0)) >>> related_samples_to_remove = hl.maximal_independent_set( ... pairs_with_case.i, pairs_with_case.j, False, tie_breaker) >>> result = dataset.filter_cols(hl.is_defined( ... related_samples_to_remove.key_by( ... s = related_samples_to_remove.node.id.s)[dataset.col_key]), keep=False) Notes ----- The vertex set of the graph is implicitly all the values realized by `i` and `j` on the rows of this table. Each row of the table corresponds to an undirected edge between the vertices given by evaluating `i` and `j` on that row. An undirected edge may appear multiple times in the table and will not affect the output. Vertices with self-edges are removed as they are not independent of themselves. The expressions for `i` and `j` must have the same type. The value of `keep` determines whether the vertices returned are those in the maximal independent set, or those in the complement of this set. This is useful if you need to filter a table without removing vertices that don't appear in the graph at all. This method implements a greedy algorithm which iteratively removes a vertex of highest degree until the graph contains no edges. The greedy algorithm always returns an independent set, but the set may not always be perfectly maximal. `tie_breaker` is a Python function taking two arguments---say `l` and `r`---each of which is an :class:`Expression` of the same type as `i` and `j`. `tie_breaker` returns a :class:`NumericExpression`, which defines an ordering on nodes. A pair of nodes can be ordered in one of three ways, and `tie_breaker` must encode the relationship as follows: - if ``l < r`` then ``tie_breaker`` evaluates to some negative integer - if ``l == r`` then ``tie_breaker`` evaluates to 0 - if ``l > r`` then ``tie_breaker`` evaluates to some positive integer For example, the usual ordering on the integers is defined by: ``l - r``. The `tie_breaker` function must satisfy the following property: ``tie_breaker(l, r) == -tie_breaker(r, l)``. When multiple nodes have the same degree, this algorithm will order the nodes according to ``tie_breaker`` and remove the *largest* node. Parameters ---------- i : :class:`.Expression` Expression to compute one endpoint of an edge. j : :class:`.Expression` Expression to compute another endpoint of an edge. keep : :obj:`bool` If ``True``, return vertices in set. If ``False``, return vertices removed. tie_breaker : function Function used to order nodes with equal degree. Returns ------- :class:`.Table` Table with the set of independent vertices. The table schema is one row field `node` which has the same type as input expressions `i` and `j`. """ if i.dtype != j.dtype: raise ValueError("'maximal_independent_set' expects arguments `i` and `j` to have same type. " "Found {} and {}.".format(i.dtype, j.dtype)) source = i._indices.source if not isinstance(source, Table): raise ValueError("'maximal_independent_set' expects an expression of 'Table'. Found {}".format( "expression of '{}'".format( source.__class__) if source is not None else 'scalar expression')) if i._indices.source != j._indices.source: raise ValueError( "'maximal_independent_set' expects arguments `i` and `j` to be expressions of the same Table. " "Found\n{}\n{}".format(i, j)) node_t = i.dtype if tie_breaker: wrapped_node_t = ttuple(node_t) l = construct_expr(VariableReference('l'), wrapped_node_t) r = construct_expr(VariableReference('r'), wrapped_node_t) tie_breaker_expr = hl.int64(tie_breaker(l[0], r[0])) t, _ = source._process_joins(i, j, tie_breaker_expr) tie_breaker_hql = tie_breaker_expr._ast.to_hql() else: t, _ = source._process_joins(i, j) tie_breaker_hql = None nodes = (t.select(node=[i, j]) .explode('node') .key_by('node') .select()) edges = t.key_by(None).select('i', 'j') nodes_in_set = Env.hail().utils.Graph.maximalIndependentSet(edges._jt.collect(), node_t._jtype, joption(tie_breaker_hql)) nt = Table(nodes._jt.annotateGlobal(nodes_in_set, hl.tset(node_t)._jtype, 'nodes_in_set')) nt = (nt .filter(nt.nodes_in_set.contains(nt.node), keep) .drop('nodes_in_set'))
return nt def require_col_key_str(dataset: MatrixTable, method: str): if not len(dataset.col_key) == 1 or dataset[next(iter(dataset.col_key))].dtype != hl.tstr: raise ValueError(f"Method '{method}' requires column key to be one field of type 'str', found " f"{list(str(x.dtype) for x in dataset.col_key.values())}") def require_row_key_variant(dataset, method): if (list(dataset.row_key) != ['locus', 'alleles'] or not isinstance(dataset['locus'].dtype, tlocus) or not dataset['alleles'].dtype == tarray(tstr)): raise ValueError("Method '{}' requires row key to be two fields 'locus' (type 'locus<any>') and " "'alleles' (type 'array<str>')\n" " Found:{}".format(method, ''.join( "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in dataset.row_key))) def require_row_key_variant_w_struct_locus(dataset, method): if (list(dataset.row_key) != ['locus', 'alleles'] or not dataset['alleles'].dtype == tarray(tstr) or (not isinstance(dataset['locus'].dtype, tlocus) and dataset['locus'].dtype != hl.dtype('struct{contig: str, position: int32}'))): raise ValueError("Method '{}' requires row key to be two fields 'locus'" " (type 'locus<any>' or 'struct{contig: str, position: int32}') and " "'alleles' (type 'array<str>')\n" " Found:{}".format(method, ''.join( "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in dataset.row_key))) def require_partition_key_locus(dataset, method): if (len(dataset.partition_key) != 1 or not isinstance(dataset.partition_key[0].dtype, tlocus)): raise ValueError("Method '{}' requires partition key to be one field of type 'locus<any>'.\n" " Found:{}".format(method, ''.join( "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in dataset.partition_key))) @typecheck(table=Table, method=str) def require_key(table, method): if table.key is None: raise ValueError("Method '{}' requires keyed table".format(method)) @typecheck(dataset=MatrixTable, method=str) def require_biallelic(dataset, method) -> MatrixTable: require_row_key_variant(dataset, method) dataset = MatrixTable(Env.hail().methods.VerifyBiallelic.apply(dataset._jvds, method)) return dataset
[docs]@typecheck(dataset=MatrixTable, name=str) def rename_duplicates(dataset, name='unique_id') -> MatrixTable: """Rename duplicate column keys. .. include:: ../_templates/req_tstring.rst Examples -------- >>> renamed = hl.rename_duplicates(dataset).cols() >>> duplicate_samples = (renamed.filter(renamed.s != renamed.unique_id) ... .select() ... .collect()) Notes ----- This method produces a new column field from the string column key by appending a unique suffix ``_N`` as necessary. For example, if the column key "NA12878" appears three times in the dataset, the first will produce "NA12878", the second will produce "NA12878_1", and the third will produce "NA12878_2". The name of this new field is parameterized by `name`. Parameters ---------- dataset : :class:`.MatrixTable` Dataset. name : :obj:`str` Name of new field. Returns ------- :class:`.MatrixTable` """
return MatrixTable(dataset._jvds.renameDuplicates(name))
[docs]@typecheck(ds=MatrixTable, intervals=expr_array(expr_interval(expr_any)), keep=bool) def filter_intervals(ds, intervals, keep=True) -> MatrixTable: """Filter rows with a list of intervals. Examples -------- Filter to loci falling within one interval: >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')]) Remove all loci within list of intervals: >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']] >>> ds_result = hl.filter_intervals(dataset, intervals) Notes ----- Based on the ``keep`` argument, this method will either restrict to points in the supplied interval ranges, or remove all rows in those ranges. When ``keep=True``, partitions that don't overlap any supplied interval will not be loaded at all. This enables :func:`.filter_intervals` to be used for reasonably low-latency queries of small ranges of the dataset, even on large datasets. Parameters ---------- ds : :class:`.MatrixTable` Dataset. intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval` Intervals to filter on. If there is only one row partition key, the point type of the interval can be the type of the first partition key. Otherwise, the interval point type must be a :class:`.Struct` matching the row partition key schema. keep : :obj:`bool` If ``True``, keep only rows that fall within any interval in `intervals`. If ``False``, keep only rows that fall outside all intervals in `intervals`. Returns ------- :class:`.MatrixTable` """ n_pk = len(ds.partition_key) pk_type = ds.partition_key.dtype point_type = intervals.dtype.element_type.point_type if point_type == pk_type: needs_wrapper = False elif n_pk == 1 and point_type == ds.partition_key[0].dtype: needs_wrapper = True else: raise TypeError("The point type does not match the row partition key type of the dataset ('{}', '{}')".format(repr(point_type), repr(pk_type))) def wrap_input(interval): if interval is None: raise TypeError("'filter_intervals' does not allow missing values in 'intervals'.") elif needs_wrapper: return Interval(Struct(foo=interval.start), Struct(foo=interval.end), interval.includes_start, interval.includes_end) else: return interval intervals = [wrap_input(x)._jrep for x in intervals.value] jmt = Env.hail().methods.FilterIntervals.apply(ds._jvds, intervals, keep)
return MatrixTable(jmt)
[docs]@typecheck(mt=MatrixTable, bp_window_size=int) def window_by_locus(mt: MatrixTable, bp_window_size: int) -> MatrixTable: """Collect arrays of row and entry values from preceding loci. .. include:: ../_templates/req_tlocus.rst .. include:: ../_templates/experimental.rst Examples -------- >>> ds_result = hl.window_by_locus(ds, 3) Notes ----- This method groups each row (variant) with the previous rows in a window of `bp_window_size` base pairs, putting the row values from the previous variants into `prev_rows` (row field of type ``array<struct>``) and entry values from those variants into `prev_entries` (entry field of type ``array<struct>``). The `bp_window_size` argument is inclusive; if `base_pairs` is 2 and the loci are .. code-block:: text 1:100 1:100 1:102 1:102 1:103 2:100 2:101 then the size of `prev_rows` is 0, 1, 2, 3, 2, 0, and 1, respectively (and same for the size of prev_entries). Parameters ---------- mt : :class:`.MatrixTable` Input dataset. bp_window_size : :obj:`int` Base pairs to include in the backwards window (inclusive). Returns ------- :class:`.MatrixTable` """ require_partition_key_locus(mt, 'window_by_locus')
return MatrixTable(mt._jvds.windowVariants(bp_window_size))