from math import log, isnan
import numpy as np
from bokeh.models import *
from bokeh.plotting import figure
from itertools import cycle
from hail.expr import aggregators
from hail.expr.expressions import *
from hail.expr.expressions import Expression
from hail.typecheck import *
from hail import Table
import hail
palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
[docs]@typecheck(data=oneof(hail.utils.struct.Struct, expr_float64), range=nullable(sized_tupleof(numeric, numeric)),
bins=int, legend=nullable(str), title=nullable(str))
def histogram(data, range=None, bins=50, legend=None, title=None):
"""Create a histogram.
Parameters
----------
data : :class:`.Struct` or :class:`.Float64Expression`
Sequence of data to plot.
range : Tuple[float]
Range of x values in the histogram.
bins : int
Number of bins in the histogram.
legend : str
Label of data on the x-axis.
title : str
Title of the histogram.
Returns
-------
:class:`bokeh.plotting.figure.Figure`
"""
if isinstance(data, Expression):
if data._indices.source is not None:
agg_f = data._aggregation_method()
if range is not None:
start = range[0]
end = range[1]
else:
start, end = agg_f((aggregators.min(data), aggregators.max(data)))
data = agg_f(aggregators.hist(data, start, end, bins))
else:
return ValueError('Invalid input')
p = figure(title=title, x_axis_label=legend, y_axis_label='Frequency', background_fill_color='#EEEEEE')
p.quad(
bottom=0, top=data.bin_freq,
left=data.bin_edges[:-1], right=data.bin_edges[1:],
legend=legend, line_color='black')
if data.n_larger > 0:
p.quad(
bottom=0, top=data.n_larger,
left=data.bin_edges[-1], right=(data.bin_edges[-1] + (data.bin_edges[1] - data.bin_edges[0])),
line_color='black', fill_color='green', legend='Outliers Above')
if data.n_smaller > 0:
p.quad(
bottom=0, top=data.n_smaller,
left=data.bin_edges[0] - (data.bin_edges[1] - data.bin_edges[0]), right=data.bin_edges[0],
line_color='black', fill_color='red', legend='Outliers Below')
return p
[docs]@typecheck(data=oneof(hail.utils.struct.Struct, expr_float64), range=nullable(sized_tupleof(numeric, numeric)),
bins=int, legend=nullable(str), title=nullable(str), normalize=bool, log=bool)
def cumulative_histogram(data, range=None, bins=50, legend=None, title=None, normalize=True, log=False):
"""Create a cumulative histogram.
Parameters
----------
data : :class:`.Struct` or :class:`.Float64Expression`
Sequence of data to plot.
range : Tuple[float]
Range of x values in the histogram.
bins : int
Number of bins in the histogram.
legend : str
Label of data on the x-axis.
title : str
Title of the histogram.
normalize: bool
Whether or not the cumulative data should be normalized.
log: bool
Whether or not the y-axis should be of type log.
Returns
-------
:class:`bokeh.plotting.figure.Figure`
"""
if isinstance(data, Expression):
if data._indices.source is not None:
agg_f = data._aggregation_method()
if range is not None:
start = range[0]
end = range[1]
else:
start, end = agg_f((aggregators.min(data), aggregators.max(data)))
data = agg_f(aggregators.hist(data, start, end, bins))
else:
return ValueError('Invalid input')
cumulative_data = np.cumsum(data.bin_freq) + data.n_smaller
np.append(cumulative_data, [cumulative_data[-1] + data.n_larger])
num_data_points = max(cumulative_data)
if normalize:
cumulative_data = cumulative_data / num_data_points
if title is not None:
title = f'{title} ({num_data_points:,} data points)'
if log:
p = figure(title=title, x_axis_label=legend, y_axis_label='Frequency',
background_fill_color='#EEEEEE', y_axis_type='log')
else:
p = figure(title=title, x_axis_label=legend, y_axis_label='Frequency', background_fill_color='#EEEEEE')
p.line(data.bin_edges[:-1], cumulative_data, line_color='#036564', line_width=3)
return p
[docs]@typecheck(x=oneof(sequenceof(numeric), expr_float64), y=oneof(sequenceof(numeric), expr_float64),
label=oneof(nullable(str), expr_str, sequenceof(str)), title=nullable(str),
xlabel=nullable(str), ylabel=nullable(str), size=int, legend=bool,
source_fields=nullable(dictof(str, sequenceof(anytype))), collect_all=nullable(bool), n_divisions=int)
def scatter(x, y, label=None, title=None, xlabel=None, ylabel=None, size=4, legend=True,
collect_all=False, n_divisions=500, source_fields=None):
"""Create a scatterplot.
Parameters
----------
x : List[float] or :class:`.Float64Expression`
List of x-values to be plotted.
y : List[float] or :class:`.Float64Expression`
List of y-values to be plotted.
label : List[str] or :class:`.StringExpression`
List of labels for x and y values, used to assign each point a label (e.g. population)
title : str
Title of the scatterplot.
xlabel : str
X-axis label.
ylabel : str
Y-axis label.
size : int
Size of markers in screen space units.
legend : bool
Whether or not to show the legend in the resulting figure.
collect_all : bool
Whether to collect all values or downsample before plotting.
This parameter will be ignored if x and y are Python objects.
n_divisions : int
Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
source_fields : Dict[str, List[Any]]
Extra fields for the ColumnDataSource of the plot.
Returns
-------
:class:`bokeh.plotting.figure.Figure`
"""
if isinstance(x, Expression) and isinstance(y, Expression):
agg_f = x._aggregation_method()
if isinstance(label, Expression):
if collect_all:
res = hail.tuple([x, y, label]).collect()
label = [point[2] for point in res]
else:
res = agg_f(aggregators.downsample(x, y, label=label, n_divisions=n_divisions))
label = [point[2][0] for point in res]
x = [point[0] for point in res]
y = [point[1] for point in res]
else:
if collect_all:
res = hail.tuple([x, y]).collect()
else:
res = agg_f(aggregators.downsample(x, y, n_divisions=n_divisions))
x = [point[0] for point in res]
y = [point[1] for point in res]
elif isinstance(x, Expression) or isinstance(y, Expression):
raise TypeError('Invalid input: x and y must both be either Expressions or Python Lists.')
else:
if isinstance(label, Expression):
label = label.collect()
p = figure(title=title, x_axis_label=xlabel, y_axis_label=ylabel, background_fill_color='#EEEEEE')
if label is not None:
fields = dict(x=x, y=y, label=label)
if source_fields is not None:
for key, values in source_fields.items():
fields[key] = values
source = ColumnDataSource(fields)
if legend:
leg = 'label'
else:
leg = None
factors = list(set(label))
if len(factors) > len(palette):
color_gen = cycle(palette)
colors = []
for i in range(0, len(factors)):
colors.append(next(color_gen))
else:
colors = palette[0:len(factors)]
color_mapper = CategoricalColorMapper(factors=factors, palette=colors)
p.circle('x', 'y', alpha=0.5, source=source, size=size,
color={'field': 'label', 'transform': color_mapper}, legend=leg)
else:
p.circle(x, y, alpha=0.5, size=size)
return p
[docs]@typecheck(pvals=oneof(sequenceof(numeric), expr_float64), collect_all=bool, n_divisions=int)
def qq(pvals, collect_all=False, n_divisions=500):
"""Create a Quantile-Quantile plot. (https://en.wikipedia.org/wiki/Q-Q_plot)
Parameters
----------
pvals : List[float] or :class:`.Float64Expression`
P-values to be plotted.
collect_all : bool
Whether to collect all values or downsample before plotting.
This parameter will be ignored if pvals is a Python object.
n_divisions : int
Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
Returns
-------
:class:`bokeh.plotting.figure.Figure`
"""
if isinstance(pvals, Expression):
source = pvals._indices.source
if source is not None:
if collect_all:
pvals = pvals.collect()
spvals = sorted(filter(lambda x: x and not(isnan(x)), pvals))
exp = [-log(float(i) / len(spvals), 10) for i in np.arange(1, len(spvals) + 1, 1)]
obs = [-log(p, 10) for p in spvals]
else:
if isinstance(source, Table):
ht = source.select(pval=pvals).key_by().persist().key_by('pval')
else:
ht = source.select_rows(pval=pvals).rows().key_by().select('pval').persist().key_by('pval')
n = ht.count()
ht = ht.select(idx=hail.scan.count())
ht = ht.annotate(expected_p=(ht.idx + 1) / n)
pvals = ht.aggregate(
aggregators.downsample(-hail.log10(ht.expected_p), -hail.log10(ht.pval), n_divisions=n_divisions))
exp = [point[0] for point in pvals if not isnan(point[1])]
obs = [point[1] for point in pvals if not isnan(point[1])]
else:
return ValueError('Invalid input: expression has no source')
else:
spvals = sorted(filter(lambda x: x and not(isnan(x)), pvals))
exp = [-log(float(i) / len(spvals), 10) for i in np.arange(1, len(spvals) + 1, 1)]
obs = [-log(p, 10) for p in spvals]
p = figure(
title='Q-Q Plot',
x_axis_label='Expected p-value (-log10 scale)',
y_axis_label='Observed p-value (-log10 scale)')
p.scatter(x=exp, y=obs, color='black')
bound = max(max(exp), max(obs)) * 1.1
p.line([0, bound], [0, bound], color='red')
return p
[docs]@typecheck(pvals=expr_float64, locus=nullable(expr_locus()), title=nullable(str),
size=int, hover_fields=nullable(dictof(str, expr_any)), collect_all=bool, n_divisions=int)
def manhattan(pvals, locus=None, title=None, size=4, hover_fields=None, collect_all=False, n_divisions=500):
"""Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot)
Parameters
----------
pvals : :class:`.Float64Expression`
P-values to be plotted.
locus : :class:`.LocusExpression`
Locus values to be plotted.
title : str
Title of the plot.
size : int
Size of markers in screen space units.
hover_fields : Dict[str, :class:`.Expression`]
Dictionary of field names and values to be shown in the HoverTool of the plot.
collect_all : bool
Whether to collect all values or downsample before plotting.
n_divisions : int
Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
Returns
-------
:class:`bokeh.plotting.figure.Figure`
"""
def get_contig_index(x, starts):
left = 0
right = len(starts) - 1
while left <= right:
mid = (left + right) // 2
if x < starts[mid]:
if x >= starts[mid - 1]:
return mid - 1
right = mid
elif x >= starts[mid+1]:
left = mid + 1
else:
return mid
pvals = -hail.log10(pvals)
if locus is None:
locus = pvals._indices.source.locus
if hover_fields is None:
hover_fields = {}
hover_fields['locus'] = hail.str(locus)
if collect_all:
res = hail.tuple([locus.global_position(), pvals, hail.struct(**hover_fields)]).collect()
hf_struct = [point[2] for point in res]
for key in hover_fields:
hover_fields[key] = [item[key] for item in hf_struct]
else:
agg_f = pvals._aggregation_method()
res = agg_f(aggregators.downsample(locus.global_position(), pvals,
label=hail.array([hail.str(x) for x in hover_fields.values()]),
n_divisions=n_divisions))
fields = [point[2] for point in res]
for idx, key in enumerate(list(hover_fields.keys())):
hover_fields[key] = [field[idx] for field in fields]
x = [point[0] for point in res]
y = [point[1] for point in res]
ref = locus.dtype.reference_genome
total_pos = 0
start_points = []
for i in range(0, len(ref.contigs)):
start_points.append(total_pos)
total_pos += ref.lengths.get(ref.contigs[i])
start_points.append(total_pos) # end point of all contigs
observed_contigs = set()
label = []
for element in x:
contig_index = get_contig_index(element, start_points)
label.append(str(contig_index % 2))
observed_contigs.add(ref.contigs[contig_index])
labels = ref.contigs.copy()
num_deleted = 0
mid_points = []
for i in range(0, len(ref.contigs)):
if ref.contigs[i] in observed_contigs:
length = ref.lengths.get(ref.contigs[i])
mid = start_points[i] + length / 2
if mid % 1 == 0:
mid += 0.5
mid_points.append(mid)
else:
del labels[i - num_deleted]
num_deleted += 1
p = scatter(x, y, label=label, title=title, xlabel='Chromosome', ylabel='P-value (-log10 scale)',
size=size, legend=False, source_fields=hover_fields)
p.xaxis.ticker = mid_points
p.xaxis.major_label_overrides = dict(zip(mid_points, labels))
p.width = 1000
tooltips = [(key, "@{}".format(key)) for key in hover_fields]
tooltips.append(tuple(('p-value', "$y")))
p.add_tools(HoverTool(
tooltips=tooltips
))
return p