Source code for hail.expr.aggregators.aggregators

import difflib
from functools import wraps, update_wrapper

import hail as hl
from hail.expr.expressions import *
from hail.expr.types import *
from hail.ir import *
from hail.typecheck import *
from hail.utils import wrap_to_list
from hail.utils.java import Env


class AggregableChecker(TypeChecker):
    def __init__(self, coercer):
        self.coercer = coercer
        super(AggregableChecker, self).__init__()

    def expects(self):
        return self.coercer.expects()

    def format(self, arg):
        if isinstance(arg, Aggregable):
            return f'<aggregable Expression of type {repr(arg.dtype)}>'
        else:
            return self.coercer.format(arg)

    def check(self, x, caller, param):
        coercer = self.coercer
        if isinstance(x, Aggregable):
            if coercer.can_coerce(x.dtype):
                if coercer.requires_conversion(x.dtype):
                    return x._map(lambda x_: coercer.coerce(x_))
                else:
                    return x
            else:
                raise TypecheckFailure
        else:
            x = coercer.check(x, caller, param)
            return _to_agg(x)


def _to_agg(x):
    return Aggregable(x._ir, x._type, x._indices, x._aggregations)


agg_expr = AggregableChecker


class AggFunc(object):
    def __init__(self):
        self._as_scan = False

    @typecheck_method(name=str,
                      aggregable=Aggregable,
                      ret_type=hail_type,
                      constructor_args=sequenceof(expr_any),
                      init_op_args=nullable(sequenceof(expr_any)),
                      seq_op_args=nullable(sequenceof(oneof(expr_any, func_spec(1, expr_any)))))
    def __call__(self, name, aggregable, ret_type, constructor_args=(), init_op_args=None, seq_op_args=None):
        args = constructor_args if init_op_args is None else constructor_args + init_op_args
        if seq_op_args is None:
            seq_op_args = [lambda x: x]
        indices, aggregations = unify_all(aggregable, *args)
        if aggregations:
            raise ExpressionException('Cannot aggregate an already-aggregated expression')
        for a in args:
            _check_agg_bindings(a)
        _check_agg_bindings(aggregable)

        uid = Env.get_uid()

        def get_type(expr):
            return expr.dtype

        def get_ir(expr):
            return expr._ir

        def apply_seq_ops(seq_op_args, agg):
            ref = construct_variable(uid, get_type(agg), agg._indices)
            return [x(ref) if callable(x) else x for x in seq_op_args]

        def agg_sig(applied_seq_ops):
            return AggSignature(name,
                                list(map(get_type, constructor_args)),
                                None if init_op_args is None else list(map(get_type, init_op_args)),
                                list(map(get_type, applied_seq_ops)))

        def make_seq_op(agg):
            applied_seq_ops = apply_seq_ops(seq_op_args, agg)
            return construct_expr(
                Let(uid, get_ir(agg), SeqOp(I32(0), [get_ir(x) for x in applied_seq_ops], agg_sig(applied_seq_ops))), None)

        seq_op = aggregable._transformations(aggregable, make_seq_op)
        applied_seq_ops = apply_seq_ops(seq_op_args, aggregable)
        signature = agg_sig(applied_seq_ops)

        if self._as_scan:
            ir = ApplyScanOp(seq_op._ir,
                             list(map(get_ir, constructor_args)),
                             None if init_op_args is None else list(map(get_ir, init_op_args)),
                             signature)
            aggs = aggregations
        else:
            ir = ApplyAggOp(seq_op._ir,
                            list(map(get_ir, constructor_args)),
                            None if init_op_args is None else list(map(get_ir, init_op_args)),
                            signature)
            aggs = aggregations.push(Aggregation(aggregable, *args))
        return construct_expr(ir, ret_type, Indices(indices.source, set()), aggs)

    def _group_by(self, group, agg_expr):
        if group._aggregations:
            raise ExpressionException("'group_by' does not support an already-aggregated expression as the argument to 'group'")

        if isinstance(agg_expr._ir, ApplyScanOp):
            if not self._as_scan:
                raise TypeError("'agg.group_by' requires a non-scan aggregation expression (agg.*) as the argument to 'agg_expr'")
        elif isinstance(agg_expr._ir, ApplyAggOp):
            if self._as_scan:
                raise TypeError("'scan.group_by' requires a scan aggregation expression (scan.*) as the argument to 'agg_expr'")
        elif not isinstance(agg_expr._ir, ApplyAggOp) and not isinstance(agg_expr._ir, ApplyScanOp):
            raise TypeError("'group_by' requires an aggregation expression as the argument to 'agg_expr'")

        ir = agg_expr._ir
        agg_sig = ir.agg_sig
        a = ir.a

        new_agg_sig = AggSignature(f'Keyed({agg_sig.op})',
                                    agg_sig.ctor_arg_types,
                                    agg_sig.initop_arg_types,
                                    [group.dtype] + agg_sig.seqop_arg_types)

        def rewrite_a(ir):
            if isinstance(ir, SeqOp):
                return SeqOp(ir.i, [group._ir] + ir.args, new_agg_sig)
            else:
                return ir.map_ir(rewrite_a)

        new_a = rewrite_a(a)

        if isinstance(agg_expr._ir, ApplyAggOp):
            ir = ApplyAggOp(new_a,
                            ir.constructor_args,
                            ir.init_op_args,
                            new_agg_sig)
        else:
            assert isinstance(agg_expr._ir, ApplyScanOp)
            ir = ApplyScanOp(new_a,
                             ir.constructor_args,
                             ir.init_op_args,
                             new_agg_sig)

        return construct_expr(ir,
                              hl.tdict(group.dtype, agg_expr.dtype),
                              agg_expr._indices,
                              agg_expr._aggregations)


_agg_func = AggFunc()


def _check_agg_bindings(expr):
    bound_references = {ref.name for ref in expr._ir.search(lambda ir: isinstance(ir, Ref) and not isinstance(ir, TopLevelReference))}
    free_variables = bound_references - expr._ir.bound_variables
    if free_variables:
        raise ExpressionException("dynamic variables created by 'hl.bind' or lambda methods like 'hl.map' may not be aggregated")


[docs]@typecheck(expr=agg_expr(expr_any)) def collect(expr) -> ArrayExpression: """Collect records into an array. Examples -------- Collect the `ID` field where `HT` is greater than 68: >>> table1.aggregate(agg.collect(agg.filter(table1.HT > 68, table1.ID))) [2, 3] Notes ----- The element order of the resulting array is not guaranteed, and in some cases is non-deterministic. Use :meth:`collect_as_set` to collect unique items. Warning ------- Collecting a large number of items can cause out-of-memory exceptions. Parameters ---------- expr : :class:`.Expression` Expression to collect. Returns ------- :class:`.ArrayExpression` Array of all `expr` records. """ return _agg_func('Collect', expr, tarray(expr.dtype))
[docs]@typecheck(expr=agg_expr(expr_any)) def collect_as_set(expr) -> SetExpression: """Collect records into a set. Examples -------- Collect the unique `ID` field where `HT` is greater than 68: >>> table1.aggregate(agg.collect_as_set(agg.filter(table1.HT > 68, table1.ID))) set([2, 3] Warning ------- Collecting a large number of items can cause out-of-memory exceptions. Parameters ---------- expr : :class:`.Expression` Expression to collect. Returns ------- :class:`.SetExpression` Set of unique `expr` records. """ return _agg_func('CollectAsSet', expr, tset(expr.dtype))
[docs]@typecheck(expr=nullable(agg_expr(expr_any))) def count(expr=None) -> Int64Expression: """Count the number of records. Examples -------- Group by the `SEX` field and count the number of rows in each category: >>> (table1.group_by(table1.SEX) ... .aggregate(n=agg.count()) ... .show()) +-----+-------+ | SEX | n | +-----+-------+ | str | int64 | +-----+-------+ | M | 2 | | F | 2 | +-----+-------+ Notes ----- If `expr` is not provided, then this method will count the number of records aggregated. If `expr` is provided, then the result should make use of :meth:`filter` or :meth:`explode` so that the number of records aggregated changes. Parameters ---------- expr : :class:`.Expression`, or :obj:`None` Expression to count. Returns ------- :class:`.Expression` of type :py:data:`.tint64` Total number of records. """ if expr is not None: return _agg_func('Count', expr, tint64, seq_op_args=()) else: return _agg_func('Count', _to_agg(hl.int32(0)), tint64, seq_op_args=())
[docs]@typecheck(condition=expr_bool) def count_where(condition) -> Int64Expression: """Count the number of records where a predicate is ``True``. Examples -------- Count the number of individuals with `HT` greater than 68: >>> table1.aggregate(agg.count_where(table1.HT > 68)) 2 Parameters ---------- condition : :class:`.BooleanExpression` Criteria for inclusion. Returns ------- :class:`.Expression` of type :py:data:`.tint64` Total number of records where `condition` is ``True``. """ return _agg_func('Count', filter(condition, 0), tint64, seq_op_args=())
[docs]@typecheck(condition=agg_expr(expr_bool)) def any(condition) -> BooleanExpression: """Returns ``True`` if `condition` is ``True`` for any record. Examples -------- >>> (table1.group_by(table1.SEX) ... .aggregate(any_over_70 = agg.any(table1.HT > 70)) ... .show()) +-----+-------------+ | SEX | any_over_70 | +-----+-------------+ | str | bool | +-----+-------------+ | M | true | | F | false | +-----+-------------+ Notes ----- If there are no records to aggregate, the result is ``False``. Missing records are not considered. If every record is missing, the result is also ``False``. Parameters ---------- condition : :class:`.BooleanExpression` Condition to test. Returns ------- :class:`.BooleanExpression` """ return count(filter(lambda x: x, condition)) > 0
[docs]@typecheck(condition=agg_expr(expr_bool)) def all(condition) -> BooleanExpression: """Returns ``True`` if `condition` is ``True`` for every record. Examples -------- >>> (table1.group_by(table1.SEX) ... .aggregate(all_under_70 = agg.all(table1.HT < 70)) ... .show()) +-----+--------------+ | SEX | all_under_70 | +-----+--------------+ | str | bool | +-----+--------------+ | M | false | | F | false | +-----+--------------+ Notes ----- If there are no records to aggregate, the result is ``True``. Missing records are not considered. If every record is missing, the result is also ``True``. Parameters ---------- condition : :class:`.BooleanExpression` Condition to test. Returns ------- :class:`.BooleanExpression` """ n_defined = count(filter(lambda x: hl.is_defined(x), condition)) n_true = count(filter(lambda x: hl.is_defined(x) & x, condition)) return n_defined == n_true
[docs]@typecheck(expr=agg_expr(expr_any)) def counter(expr) -> DictExpression: """Count the occurrences of each unique record and return a dictionary. Examples -------- Count the number of individuals for each unique `SEX` value: >>> table1.aggregate(agg.counter(table1.SEX)) {'M': 2L, 'F': 2L} Notes ----- This aggregator method returns a dict expression whose key type is the same type as `expr` and whose value type is :class:`.Expression` of type :py:data:`.tint64`. This dict contains a key for each unique value of `expr`, and the value is the number of times that key was observed. Ensure that the result can be stored in memory on a single machine. Warning ------- Using :meth:`counter` with a large number of unique items can cause out-of-memory exceptions. Parameters ---------- expr : :class:`.Expression` Expression to count by key. Returns ------- :class:`.DictExpression` Dictionary with the number of occurrences of each unique record. """ return _agg_func('Counter', expr, tdict(expr.dtype, tint64))
[docs]@typecheck(expr=agg_expr(expr_any), n=int, ordering=nullable(oneof(expr_any, func_spec(1, expr_any)))) def take(expr, n, ordering=None) -> ArrayExpression: """Take `n` records of `expr`, optionally ordered by `ordering`. Examples -------- Take 3 elements of field `X`: >>> table1.aggregate(agg.take(table1.X, 3)) [5, 6, 7] Take the `ID` and `HT` fields, ordered by `HT` (descending): >>> table1.aggregate(agg.take(hl.struct(ID=table1.ID, HT=table1.HT), ... 3, ... ordering=-table1.HT)) [Struct(ID=2, HT=72), Struct(ID=3, HT=70), Struct(ID=1, HT=65)] Notes ----- The resulting array can include fewer than `n` elements if there are fewer than `n` total records. The `ordering` argument may be an expression, a function, or ``None``. If `ordering` is an expression, this expression's type should be one with a natural ordering (e.g. numeric). If `ordering` is a function, it will be evaluated on each record of `expr` to compute the value used for ordering. In the above example, ``ordering=-table1.HT`` and ``ordering=lambda x: -x.HT`` would be equivalent. If `ordering` is ``None``, then there is no guaranteed ordering on the elements taken, and and the results may be non-deterministic. Missing values are always sorted **last**. Parameters ---------- expr : :class:`.Expression` Expression to store. n : :class:`.Expression` of type :py:data:`.tint32` Number of records to take. ordering : :class:`.Expression` or function ((arg) -> :class:`.Expression`) or None Optional ordering on records. Returns ------- :class:`.ArrayExpression` Array of up to `n` records of `expr`. """ n = to_expr(n) if ordering is None: return _agg_func('Take', expr, tarray(expr.dtype), [n]) else: return _agg_func('TakeBy', expr, tarray(expr.dtype), [n], seq_op_args=[lambda expr: expr, ordering])
[docs]@typecheck(expr=agg_expr(expr_numeric)) def min(expr) -> NumericExpression: """Compute the minimum `expr`. Examples -------- Compute the minimum value of `HT`: >>> table1.aggregate(agg.min(table1.HT)) min_ht=60 Notes ----- This method returns the minimum non-missing value. If there are no non-missing values, then the result is missing. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.NumericExpression` Minimum value of all `expr` records, same type as `expr`. """ return _agg_func('Min', expr, expr.dtype)
[docs]@typecheck(expr=agg_expr(expr_numeric)) def max(expr) -> NumericExpression: """Compute the maximum `expr`. Examples -------- Compute the maximum value of `HT`: >>> table1.aggregate(agg.max(table1.HT)) max_ht=72 Notes ----- This method returns the maximum non-missing value. If there are no non-missing values, then the result is missing. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.NumericExpression` Maximum value of all `expr` records, same type as `expr`. """ return _agg_func('Max', expr, expr.dtype)
[docs]@typecheck(expr=agg_expr(expr_oneof(expr_int64, expr_float64))) def sum(expr): """Compute the sum of all records of `expr`. Examples -------- Compute the sum of field `C1`: >>> table1.aggregate(agg.sum(table1.C1)) 25 Notes ----- Missing values are ignored (treated as zero). If `expr` is an expression of type :py:data:`.tint32`, :py:data:`.tint64`, or :py:data:`.tbool`, then the result is an expression of type :py:data:`.tint64`. If `expr` is an expression of type :py:data:`.tfloat32` or :py:data:`.tfloat64`, then the result is an expression of type :py:data:`.tfloat64`. Warning ------- Boolean values are cast to integers before computing the sum. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.Expression` of type :py:data:`.tint64` or :py:data:`.tfloat64` Sum of records of `expr`. """ return _agg_func('Sum', expr, expr.dtype)
[docs]@typecheck(expr=agg_expr(expr_array(expr_oneof(expr_int64, expr_float64)))) def array_sum(expr) -> ArrayExpression: """Compute the coordinate-wise sum of all records of `expr`. Examples -------- Compute the sum of `C1` and `C2`: >>> table1.aggregate(agg.array_sum([table1.C1, table1.C2])) [25, 46] Notes ------ All records must have the same length. Each coordinate is summed independently as described in :func:`sum`. Parameters ---------- expr : :class:`.ArrayNumericExpression` Returns ------- :class:`.ArrayExpression` with element type :py:data:`.tint64` or :py:data:`.tfloat64` """ return _agg_func('Sum', expr, expr.dtype)
[docs]@typecheck(expr=agg_expr(expr_float64)) def mean(expr) -> Float64Expression: """Compute the mean value of records of `expr`. Examples -------- Compute the mean of field `HT`: >>> table1.aggregate(agg.mean(table1.HT)) 66.75 Notes ----- Missing values are ignored. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.Expression` of type :py:data:`.tfloat64` Mean value of records of `expr`. """ return sum(expr)/count(filter(lambda x: hl.is_defined(x), expr))
[docs]@typecheck(expr=agg_expr(expr_float64)) def stats(expr) -> StructExpression: """Compute a number of useful statistics about `expr`. Examples -------- Compute statistics about field `HT`: >>> table1.aggregate(agg.stats(table1.HT)) Struct(min=60.0, max=72.0, sum=267.0, stdev=4.65698400255, n=4, mean=66.75) Notes ----- Computes a struct with the following fields: - `min` (:py:data:`.tfloat64`) - Minimum value. - `max` (:py:data:`.tfloat64`) - Maximum value. - `mean` (:py:data:`.tfloat64`) - Mean value, - `stdev` (:py:data:`.tfloat64`) - Standard deviation. - `n` (:py:data:`.tfloat64`) - Number of non-missing records. - `sum` (:py:data:`.tfloat64`) - Sum. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.StructExpression` Struct expression with fields `mean`, `stdev`, `min`, `max`, `n`, and `sum`. """ return _agg_func('Statistics', expr, tstruct(mean=tfloat64, stdev=tfloat64, min=tfloat64, max=tfloat64, n=tint64, sum=tfloat64))
[docs]@typecheck(expr=agg_expr(expr_oneof(expr_int64, expr_float64))) def product(expr): """Compute the product of all records of `expr`. Examples -------- Compute the product of field `C1`: >>> table1.aggregate(agg.product(table1.C1)) 440 Notes ----- Missing values are ignored (treated as one). If `expr` is an expression of type :py:data:`.tint32`, :py:data:`.tint64` or :py:data:`.tbool`, then the result is an expression of type :py:data:`.tint64`. If `expr` is an expression of type :py:data:`.tfloat32` or :py:data:`.tfloat64`, then the result is an expression of type :py:data:`.tfloat64`. Warning ------- Boolean values are cast to integers before computing the product. Parameters ---------- expr : :class:`.NumericExpression` Numeric expression. Returns ------- :class:`.Expression` of type :py:data:`.tint64` or :py:data:`.tfloat64` Product of records of `expr`. """ return _agg_func('Product', expr, expr.dtype)
[docs]@typecheck(predicate=agg_expr(expr_bool)) def fraction(predicate) -> Float64Expression: """Compute the fraction of records where `predicate` is ``True``. Examples -------- Compute the fraction of rows where `SEX` is "F" and `HT` > 65: >>> table1.aggregate(agg.fraction((table1.SEX == 'F') & (table1.HT > 65))) 0.25 Notes ----- Missing values for `predicate` are treated as ``False``. Parameters ---------- predicate : :class:`.BooleanExpression` Boolean predicate. Returns ------- :class:`.Expression` of type :py:data:`.tfloat64` Fraction of records where `predicate` is ``True``. """ return _agg_func("Fraction", predicate, tfloat64)
[docs]@typecheck(expr=agg_expr(expr_call)) def hardy_weinberg_test(expr) -> StructExpression: """Performs test of Hardy-Weinberg equilibrium. Examples -------- Test each row of a dataset: >>> dataset_result = dataset.annotate_rows(hwe = agg.hardy_weinberg_test(dataset.GT)) Test each row on a sub-population: >>> dataset_result = dataset.annotate_rows( ... hwe_eas = agg.hardy_weinberg_test(agg.filter(dataset.pop == 'EAS', dataset.GT))) Notes ----- This method performs the test described in :func:`.functions.hardy_weinberg_test` based solely on the counts of homozygous reference, heterozygous, and homozygous variant calls. The resulting struct expression has two fields: - `het_freq_hwe` (:py:data:`.tfloat64`) - Expected frequency of heterozygous calls under Hardy-Weinberg equilibrium. - `p_value` (:py:data:`.tfloat64`) - p-value from test of Hardy-Weinberg equilibrium. Hail computes the exact p-value with mid-p-value correction, i.e. the probability of a less-likely outcome plus one-half the probability of an equally-likely outcome. See this `document <LeveneHaldane.pdf>`__ for details on the Levene-Haldane distribution and references. Warning ------- Non-diploid calls (``ploidy != 2``) are ignored in the counts. While the counts are defined for multiallelic variants, this test is only statistically rigorous in the biallelic setting; use :func:`~hail.methods.split_multi` to split multiallelic variants beforehand. Parameters ---------- expr : :class:`.CallExpression` Call to test for Hardy-Weinberg equilibrium. Returns ------- :class:`.StructExpression` Struct expression with fields `het_freq_hwe` and `p_value`. """ t = tstruct(het_freq_hwe=tfloat64, p_value=tfloat64) return _agg_func('HardyWeinberg', expr, t)
[docs]@typecheck(expr=agg_expr(expr_oneof(expr_array(), expr_set()))) def explode(expr) -> Aggregable: """Explode an array or set expression to aggregate the elements of all records. Examples -------- Compute the mean of all elements in fields `C1`, `C2`, and `C3`: >>> table1.aggregate(agg.mean(agg.explode([table1.C1, table1.C2, table1.C3]))) 24.8333333333 Compute the set of all observed elements in the `filters` field (``Set[String]``): >>> dataset.aggregate_rows(agg.collect_as_set(agg.explode(dataset.filters))) set([u'VQSRTrancheSNP99.80to99.90', u'VQSRTrancheINDEL99.95to100.00', u'VQSRTrancheINDEL99.00to99.50', u'VQSRTrancheINDEL97.00to99.00', u'VQSRTrancheSNP99.95to100.00', u'VQSRTrancheSNP99.60to99.80', u'VQSRTrancheINDEL99.50to99.90', u'VQSRTrancheSNP99.90to99.95', u'VQSRTrancheINDEL96.00to97.00'])) Notes ----- This method can be used with aggregator functions to aggregate the elements of collection types (:class:`.tarray` and :class:`.tset`). The result of the :meth:`explode` and :meth:`filter` methods is an :class:`.Aggregable` expression which can be used only in aggregator methods. Parameters ---------- expr : :class:`.CollectionExpression` Expression of type :class:`.tarray` or :class:`.tset`. Returns ------- :class:`.Aggregable` Aggregable expression. """ return expr._flatmap(identity)
[docs]@typecheck(condition=oneof(func_spec(1, expr_bool), expr_bool), expr=agg_expr(expr_any)) def filter(condition, expr) -> Aggregable: """Filter records according to a predicate. Examples -------- Collect the `ID` field where `HT` >= 70: >>> table1.aggregate(agg.collect(agg.filter(table1.HT >= 70, table1.ID))) [2, 3] Notes ----- This method can be used with aggregator functions to remove records from aggregation. The result of the :meth:`explode` and :meth:`filter` methods is an :class:`.Aggregable` expression which can be used only in aggregator methods. Parameters ---------- condition : :class:`.BooleanExpression` or function ( (arg) -> :class:`.BooleanExpression`) Filter expression, or a function to evaluate for each record. expr : :class:`.Expression` Expression to filter. Returns ------- :class:`.Aggregable` Aggregable expression. """ f = condition if callable(condition) else lambda x: condition return expr._filter(f)
@typecheck(f=oneof(func_spec(1, expr_any), expr_any), expr=agg_expr(expr_any)) def _map(f, expr) -> Aggregable: f2 = f if callable(f) else lambda x: f return expr._map(f2) @typecheck(f=oneof(func_spec(1, expr_array()), expr_array()), expr=agg_expr(expr_any)) def _flatmap(f, expr) -> Aggregable: f2 = f if callable(f) else lambda x: f return expr._flatmap(f2)
[docs]@typecheck(expr=agg_expr(expr_call), prior=expr_float64) def inbreeding(expr, prior) -> StructExpression: """Compute inbreeding statistics on calls. Examples -------- Compute inbreeding statistics per column: >>> dataset_result = dataset.annotate_cols(IB = agg.inbreeding(dataset.GT, dataset.variant_qc.AF[1])) >>> dataset_result.cols().show() +----------------+--------------+-------------+------------------+------------------+ | s | IB.f_stat | IB.n_called | IB.expected_homs | IB.observed_homs | +----------------+--------------+-------------+------------------+------------------+ | str | float64 | int64 | float64 | int64 | +----------------+--------------+-------------+------------------+------------------+ | C1046::HG02024 | -1.23867e-01 | 338 | 2.96180e+02 | 291 | | C1046::HG02025 | 2.02944e-02 | 339 | 2.97151e+02 | 298 | | C1046::HG02026 | 5.47269e-02 | 336 | 2.94742e+02 | 297 | | C1047::HG00731 | -1.89046e-02 | 337 | 2.95779e+02 | 295 | | C1047::HG00732 | 1.38718e-01 | 337 | 2.95202e+02 | 301 | | C1047::HG00733 | 3.50684e-01 | 338 | 2.96418e+02 | 311 | | C1048::HG02024 | -1.95603e-01 | 338 | 2.96180e+02 | 288 | | C1048::HG02025 | 2.02944e-02 | 339 | 2.97151e+02 | 298 | | C1048::HG02026 | 6.74296e-02 | 338 | 2.96180e+02 | 299 | | C1049::HG00731 | -1.00467e-02 | 337 | 2.95418e+02 | 295 | +----------------+--------------+-------------+------------------+------------------+ Notes ----- ``E`` is total number of expected homozygous calls, given by the sum of ``1 - 2.0 * prior * (1 - prior)`` across records. ``O`` is the observed number of homozygous calls across records. ``N`` is the number of non-missing calls. ``F`` is the inbreeding coefficient, and is computed by: ``(O - E) / (N - E)``. This method returns a struct expression with four fields: - `f_stat` (:py:data:`.tfloat64`): ``F``, the inbreeding coefficient. - `n_called` (:py:data:`.tint64`): ``N``, the number of non-missing calls. - `expected_homs` (:py:data:`.tfloat64`): ``E``, the expected number of homozygotes. - `observed_homs` (:py:data:`.tint64`): ``O``, the number of observed homozygotes. Parameters ---------- expr : :class:`.CallExpression` Call expression. prior : :class:`.Expression` of type :py:data:`.tfloat64` Alternate allele frequency prior. Returns ------- :class:`.StructExpression` Struct expression with fields `f_stat`, `n_called`, `expected_homs`, `observed_homs`. """ t = tstruct(f_stat=tfloat64, n_called=tint64, expected_homs=tfloat64, observed_homs=tint64) return _agg_func('Inbreeding', expr, t, seq_op_args=[lambda expr: expr, prior])
[docs]@typecheck(call=agg_expr(expr_call), alleles=expr_array(expr_str)) def call_stats(call, alleles) -> StructExpression: """Compute useful call statistics. Examples -------- Compute call statistics per row: >>> dataset_result = dataset.annotate_rows(gt_stats = agg.call_stats(dataset.GT, dataset.alleles)) >>> dataset_result.rows().key_by('locus').select('gt_stats').show() +---------------+--------------+----------------+-------------+---------------------------+ | locus | gt_stats.AC | gt_stats.AF | gt_stats.AN | gt_stats.homozygote_count | +---------------+--------------+----------------+-------------+---------------------------+ | locus<GRCh37> | array<int32> | array<float64> | int32 | array<int32> | +---------------+--------------+----------------+-------------+---------------------------+ | 20:10579373 | [199,1] | [0.995,0.005] | 200 | [99,0] | | 20:13695607 | [177,23] | [0.885,0.115] | 200 | [77,0] | | 20:13698129 | [198,2] | [0.99,0.01] | 200 | [98,0] | | 20:14306896 | [142,58] | [0.71,0.29] | 200 | [51,9] | | 20:14306953 | [121,79] | [0.605,0.395] | 200 | [38,17] | | 20:15948325 | [172,2] | [0.989,0.012] | 174 | [85,0] | | 20:15948326 | [174,8] | [0.956,0.043] | 182 | [83,0] | | 20:17479423 | [199,1] | [0.995,0.005] | 200 | [99,0] | | 20:17600357 | [79,121] | [0.395,0.605] | 200 | [24,45] | | 20:17640833 | [193,3] | [0.985,0.015] | 196 | [95,0] | +---------------+--------------+----------------+-------------+---------------------------+ Notes ----- This method is meaningful for computing call metrics per variant, but not especially meaningful for computing metrics per sample. This method returns a struct expression with three fields: - `AC` (:class:`.tarray` of :py:data:`.tint32`) - Allele counts. One element for each allele, including the reference. - `AF` (:class:`.tarray` of :py:data:`.tfloat64`) - Allele frequencies. One element for each allele, including the reference. - `AN` (:py:data:`.tint32`) - Allele number. The total number of called alleles, or the number of non-missing calls * 2. - `homozygote_count` (:class:`.tarray` of :py:data:`.tint32`) - Homozygote genotype counts for each allele, including the reference. Only **diploid** genotype calls are counted. Parameters ---------- call : :class:`.CallExpression` alleles : :class:`.ArrayStringExpression` Variant alleles. Returns ------- :class:`.StructExpression` Struct expression with fields `AC`, `AF`, `AN`, and `homozygote_count`. """ n_alleles = hl.len(alleles) t = tstruct(AC=tarray(tint32), AF=tarray(tfloat64), AN=tint32, homozygote_count=tarray(tint32)) return _agg_func('CallStats', call, t, [], init_op_args=[n_alleles])
[docs]@typecheck(expr=agg_expr(expr_float64), start=expr_float64, end=expr_float64, bins=expr_int32) def hist(expr, start, end, bins) -> StructExpression: """Compute binned counts of a numeric expression. Examples -------- Compute a histogram of field `GQ`: >>> dataset.aggregate_entries(agg.hist(dataset.GQ, 0, 100, 10)) Struct(bin_edges=[0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0], bin_freq=[2194L, 637L, 2450L, 1081L, 518L, 402L, 11168L, 1918L, 1379L, 11973L]), n_smaller=0, n_greater=0) Notes ----- This method returns a struct expression with four fields: - `bin_edges` (:class:`.tarray` of :py:data:`.tfloat64`): Bin edges. Bin `i` contains values in the left-inclusive, right-exclusive range ``[ bin_edges[i], bin_edges[i+1] )``. - `bin_freq` (:class:`.tarray` of :py:data:`.tint64`): Bin frequencies. The number of records found in each bin. - `n_smaller` (:py:data:`.tint64`): The number of records smaller than the start of the first bin. - `n_larger` (:py:data:`.tint64`): The number of records larger than the end of the last bin. Parameters ---------- expr : :class:`.NumericExpression` Target numeric expression. start : :obj:`int` or :obj:`float` Start of histogram range. end : :obj:`int` or :obj:`float` End of histogram range. bins : :obj:`int` or :obj:`float` Number of bins. Returns ------- :class:`.StructExpression` Struct expression with fields `bin_edges`, `bin_freq`, `n_smaller`, and `n_larger`. """ t = tstruct(bin_edges=tarray(tfloat64), bin_freq=tarray(tint64), n_smaller=tint64, n_larger=tint64) return _agg_func('Histogram', expr, t, constructor_args=[start, end, bins])
[docs]@typecheck(x=expr_float64, y=expr_float64, label=nullable(oneof(expr_str, expr_array(expr_str))), n_divisions=int) def downsample(x, y, label=None, n_divisions=500) -> ArrayExpression: """Downsample (x, y) coordinate datapoints. Parameters --------- x : :class:`.NumericExpression` X-values to be downsampled. y : :class:`.NumericExpression` Y-values to be downsampled. label : :class:`.StringExpression` or :class:`.ArrayExpression` Additional data for each (x, y) coordinate. Can pass in multiple fields in an :class:`.ArrayExpression`. n_divisions : :obj:`int` Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints. Returns ------- :class:`.ArrayExpression` Expression for downsampled coordinate points (x, y). The element type of the array is :py:data:`.ttuple` of :py:data:`.tfloat64`, :py:data:`.tfloat64`, and :py:data:`.tarray` of :py:data:`.tstring` """ if label is None: label = hl.null(hl.tarray(hl.tstr)) elif isinstance(label, StringExpression): label = hl.array([label]) return _agg_func('downsample', _to_agg(x), tarray(ttuple(tfloat64, tfloat64, tarray(tstr))), constructor_args=[n_divisions], seq_op_args=[lambda x: x, y, label])
[docs]@typecheck(gp=agg_expr(expr_array(expr_float64))) def info_score(gp) -> StructExpression: r"""Compute the IMPUTE information score. Examples -------- Calculate the info score per variant: >>> gen_mt = hl.import_gen('data/example.gen', sample_file='data/example.sample') >>> gen_mt = gen_mt.annotate_rows(info_score = hl.agg.info_score(gen_mt.GP)) Calculate group-specific info scores per variant: >>> gen_mt = hl.import_gen('data/example.gen', sample_file='data/example.sample') >>> gen_mt = gen_mt.annotate_cols(is_case = hl.rand_bool(0.5)) >>> gen_mt = gen_mt.annotate_rows(info_score = hl.agg.group_by(gen_mt.is_case, hl.agg.info_score(gen_mt.GP))) Notes ----- The result of :func:`.info_score` is a struct with two fields: - `score` (``float64``) -- Info score. - `n_included` (``int32``) -- Number of non-missing samples included in the calculation. We implemented the IMPUTE info measure as described in the supplementary information from `Marchini & Howie. Genotype imputation for genome-wide association studies. Nature Reviews Genetics (2010) <http://www.nature.com/nrg/journal/v11/n7/extref/nrg2796-s3.pdf>`__. To calculate the info score :math:`I_{A}` for one SNP: .. math:: I_{A} = \begin{cases} 1 - \frac{\sum_{i=1}^{N}(f_{i} - e_{i}^2)}{2N\hat{\theta}(1 - \hat{\theta})} & \text{when } \hat{\theta} \in (0, 1) \\ 1 & \text{when } \hat{\theta} = 0, \hat{\theta} = 1\\ \end{cases} - :math:`N` is the number of samples with imputed genotype probabilities [:math:`p_{ik} = P(G_{i} = k)` where :math:`k \in \{0, 1, 2\}`] - :math:`e_{i} = p_{i1} + 2p_{i2}` is the expected genotype per sample - :math:`f_{i} = p_{i1} + 4p_{i2}` - :math:`\hat{\theta} = \frac{\sum_{i=1}^{N}e_{i}}{2N}` is the MLE for the population minor allele frequency Hail will not generate identical results to `QCTOOL <http://www.well.ox.ac.uk/~gav/qctool/#overview>`__ for the following reasons: - Hail automatically removes genotype probability distributions that do not meet certain requirements on data import with :func:`.import_gen` and :func:`.import_bgen`. - Hail does not use the population frequency to impute genotype probabilities when a genotype probability distribution has been set to missing. - Hail calculates the same statistic for sex chromosomes as autosomes while QCTOOL incorporates sex information. - The floating point number Hail stores for each genotype probability is slightly different than the original data due to rounding and normalization of probabilities. Warning ------- - The info score Hail reports will be extremely different from QCTOOL when a SNP has a high missing rate. - If the `gp` array must contain 3 elements, and its elements may not be missing. - If the genotype data was not imported using the :func:`.import_gen` or :func:`.import_bgen` functions, then the results for all variants will be ``score = NA`` and ``n_included = 0``. - It only makes semantic sense to compute the info score per variant. While the aggregator will run in any context if its arguments are the right type, the results are only meaningful in a narrow context. Parameters ---------- gp : :class:`.ArrayNumericExpression` Genotype probability array. Must have 3 elements, all of which must be defined. Returns ------- :class:`.StructExpression` Struct with fields `score` and `n_included`. """ t = hl.tstruct(score=hl.tfloat64, n_included=hl.tint32) return _agg_func('InfoScore', gp, t)
[docs]@typecheck(y=expr_float64, x=oneof(expr_float64, sequenceof(expr_float64)), nested_dim=int, weight=nullable(expr_float64)) def linreg(y, x, nested_dim=1, weight=None) -> StructExpression: """Compute multivariate linear regression statistics. Examples -------- Regress HT against an intercept (1), SEX, and C1: >>> table1.aggregate(agg.linreg(table1.HT, [1, table1.SEX == 'F', table1.C1])) Struct(beta=[88.50000000000014, 81.50000000000057, -10.000000000000068], standard_error=[14.430869689661844, 59.70552738231206, 7.000000000000016], t_stat=[6.132686518775844, 1.365032746099571, -1.428571428571435], p_value=[0.10290201427537926, 0.40250974549499974, 0.3888002244284281], multiple_standard_error=4.949747468305833, multiple_r_squared=0.7175792507204611, adjusted_r_squared=0.1527377521613834, f_stat=1.2704081632653061, multiple_p_value=0.5314327326007864, n=4) Regress blood pressure against an intercept (1), genotype, age, and the interaction of genotype and age: >>> ds_ann = ds.annotate_rows(linreg = ... hl.agg.linreg(ds.pheno.blood_pressure, ... [1, ... ds.GT.n_alt_alleles(), ... ds.pheno.age, ... ds.GT.n_alt_alleles() * ds.pheno.age])) Warning ------- As in the example, the intercept covariate ``1`` must be included **explicitly** if desired. Notes ----- In relation to `lm.summary <https://stat.ethz.ch/R-manual/R-devel/library/stats/html/summary.lm.html>`__ in R, ``linreg(y, x = [1, mt.x1, mt.x2])`` computes ``summary(lm(y ~ x1 + x2))`` and ``linreg(y, x = [mt.x1, mt.x2], nested_dim=0)`` computes ``summary(lm(y ~ x1 + x2 - 1))``. More generally, `nested_dim` defines the number of effects to fit in the nested (null) model, with the effects on the remaining covariates fixed to zero. The returned struct has ten fields: - `beta` (:class:`.tarray` of :py:data:`.tfloat64`): Estimated regression coefficient for each covariate. - `standard_error` (:class:`.tarray` of :py:data:`.tfloat64`): Estimated standard error for each covariate. - `t_stat` (:class:`.tarray` of :py:data:`.tfloat64`): t-statistic for each covariate. - `p_value` (:class:`.tarray` of :py:data:`.tfloat64`): p-value for each covariate. - `multiple_standard_error` (:py:data:`.tfloat64`): Estimated standard deviation of the random error. - `multiple_r_squared` (:py:data:`.tfloat64`): Coefficient of determination for nested models. - `adjusted_r_squared` (:py:data:`.tfloat64`): Adjusted `multiple_r_squared` taking into account degrees of freedom. - `f_stat` (:py:data:`.tfloat64`): F-statistic for nested models. - `multiple_p_value` (:py:data:`.tfloat64`): p-value for the `F-test <https://en.wikipedia.org/wiki/F-test#Regression_problems>`__ of nested models. - `n` (:py:data:`.tint64`): Number of samples included in the regression. A sample is included if and only if `y`, all elements of `x`, and `weight` (if set) are non-missing. All but the last field are missing if `n` is less than or equal to the number of covariates or if the covariates are linearly dependent. If set, the `weight` parameter generalizes the model to `weighted least squares <https://en.wikipedia.org/wiki/Weighted_least_squares>`__, useful for heteroscedastic (diagonal but non-constant) variance. Warning ------- If any weight is negative, the resulting statistics will be ``nan``. Parameters ---------- y : :class:`.Float64Expression` Response (dependent variable). x : :class:`.Float64Expression` or :obj:`list` of :class:`.Float64Expression` Covariates (independent variables). nested_dim : :obj:`int` The null model includes the first `nested_dim` covariates. Must be between 0 and `k` (the length of `x`). weight : :class:`.Float64Expression`, optional Non-negative weight for weighted least squares. Returns ------- :class:`.StructExpression` Struct of regression results. """ x = wrap_to_list(x) if len(x) == 0: raise ValueError("linreg: must have at least one covariate in `x`") hl.methods.statgen._warn_if_no_intercept('linreg', x) if weight is None: return _linreg(y, x, nested_dim) else: return _linreg(hl.sqrt(weight) * y, [hl.sqrt(weight) * xi for xi in x], nested_dim)
def _linreg(y, x, nested_dim): k = len(x) k0 = nested_dim if k0 < 0 or k0 > k: raise ValueError("linreg: `nested_dim` must be between 0 and the number " f"of covariates ({k}), inclusive") t = hl.tstruct(beta=hl.tarray(hl.tfloat64), standard_error=hl.tarray(hl.tfloat64), t_stat=hl.tarray(hl.tfloat64), p_value=hl.tarray(hl.tfloat64), multiple_standard_error=hl.tfloat64, multiple_r_squared=hl.tfloat64, adjusted_r_squared=hl.tfloat64, f_stat=hl.tfloat64, multiple_p_value=hl.tfloat64, n=hl.tint64) x = hl.array(x) k = hl.int32(k) k0 = hl.int32(k0) return _agg_func('LinearRegression', _to_agg(y), t, [k, k0], seq_op_args=[lambda y: y, x])
[docs]@typecheck(x=expr_float64, y=expr_float64) def corr(x, y) -> Float64Expression: """Computes the `Pearson correlation coefficient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`__ between `x` and `y`. Examples -------- >>> ds.aggregate_cols(hl.agg.corr(ds.pheno.age, ds.pheno.blood_pressure)) 0.159882536301 Notes ----- Only records where both `x` and `y` are non-missing will be included in the calculation. In the case that there are no non-missing pairs, the result will be missing. See Also -------- :func:`linreg` Parameters ---------- x : :class:`.Expression` of type ``tfloat64`` y : :class:`.Expression` of type ``tfloat64`` Returns ------- :class:`.Float64Expression` """ return _agg_func('corr', _to_agg(x), tfloat64, seq_op_args=[lambda x: x, y])
[docs]@typecheck(group=expr_any, agg_expr=expr_any) def group_by(group, agg_expr) -> DictExpression: """Compute aggregation statistics stratified by one or more groups. .. include:: _templates/experimental.rst Examples -------- Compute linear regression statistics stratified by SEX: >>> table1.aggregate(agg.group_by(table1.SEX, ... agg.linreg(table1.HT, table1.C1, nested_dim=0))) { 'F': Struct(beta=[6.153846153846154], standard_error=[0.7692307692307685], t_stat=[8.000000000000009], p_value=[0.07916684832113098], multiple_standard_error=11.4354374979373, multiple_r_squared=0.9846153846153847, adjusted_r_squared=0.9692307692307693, f_stat=64.00000000000014, multiple_p_value=0.07916684832113098, n=2), 'M': Struct(beta=[34.25], standard_error=[1.75], t_stat=[19.571428571428573], p_value=[0.03249975499062629], multiple_standard_error=4.949747468305833, multiple_r_squared=0.9973961101073441, adjusted_r_squared=0.9947922202146882, f_stat=383.0408163265306, multiple_p_value=0.03249975499062629, n=2) } Compute call statistics stratified by population group and case status: >>> ann = ds.annotate_rows(call_stats=hl.agg.group_by(hl.struct(pop=ds.pop, is_case=ds.is_case), ... hl.agg.call_stats(ds.GT, ds.alleles))) Parameters ---------- group : :class:`.Expression` or :obj:`list` of :class:`.Expression` Group to stratify the result by. agg_expr : :class:`.Expression` Aggregation or scan expression to compute per grouping. Returns ------- :class:`.DictExpression` Dictionary where the keys are `group` and the values are the result of computing `agg_expr` for each unique value of `group`. """ return _agg_func._group_by(group, agg_expr)
class ScanFunctions(object): def __init__(self, scope): self._functions = {name: self._scan_decorator(f) for name, f in scope.items()} def _scan_decorator(self, f): @wraps(f) def wrapper(*args, **kwargs): func = getattr(f, '__wrapped__') af = func.__globals__['_agg_func'] setattr(af, '_as_scan', True) res = f(*args, **kwargs) setattr(af, '_as_scan', False) return res update_wrapper(wrapper, f) return wrapper def __getattr__(self, field): if field in self._functions: return self._functions[field] else: field_matches = difflib.get_close_matches(field, self._functions.keys(), n=5) raise AttributeError("hl.scan.{} does not exist. Did you mean:\n {}".format( field, "\n ".join(field_matches)))