This is a minimal Colab/Jupyter Notebook for regenerating all figures in the paper "Progress towards Deep Learning Genome Assembly: de novo ML-based Read-overlap Graphs and Partial Layout", which was submitted to the ICML Computational Biology Workshop.
Anonymized preprint: https://storage.googleapis.com/progress-towards-assembly/deep_learning_assembly.pdf
Permalink to the source notebook: https://storage.googleapis.com/progress-towards-assembly/deep_learning_assembly.ipynb
All of the required files, as well as model checkpoints, are in the public google cloud bucket gs://progress-towards-assembly/
If you make any modifications that require writing out log files, the scripts will fail unless you copy the bucket to somewhere you have write permissions. Without modifications, read-only access to the public bucket is sufficient.
The scripts below will load the most recent model checkpoint, as well as precomputed edit distances (edit distances take around 6 hours to compute, and the model was trained for 27 hours).
If you want to turn on exporting of Plotly figures, you will have to restart the Colab after installing Kaleido.
As of 2023-05-17, it will run on a GPU-A100-High RAM colab with no modifications.
# This code is only needed for writing.
#from google.colab import auth
#auth.authenticate_user()
#!pip install --quiet nbconvert
#!jupyter nbconvert --to html --embed-images /content/colab.ipynb
import tensorflow as tf
# Tensorboard extension (for visualization purposes later)
from torch.utils.tensorboard import SummaryWriter
# This will allow viewing the training logs. Note that it also breaks HTML output,
# so we have disabled it here.
#%load_ext tensorboard
#%tensorboard --reload_multifile True --logdir=gs://progress-towards-assembly/dim3large_20230512a/e200
import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
#from jax import random
import os
## Flax (NN in JAX)
try:
import flax
except ModuleNotFoundError: # Install flax if missing
!pip install --quiet flax
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
## Optax (Optimizers in JAX)
try:
import optax
except ModuleNotFoundError: # Install optax if missing
!pip install --quiet optax
import optax
import torch
from tqdm.auto import tqdm
# Tensorboard extension (for visualization purposes later)
from torch.utils.tensorboard import SummaryWriter
#%load_ext tensorboard
import plotly
!pip install --quiet -U kaleido
!pip install --quiet levenshtein
# Path to the folder where the datasets are
DATASET_PATH = "gs://progress-towards-assembly/"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = DATASET_PATH + "dim3large_20230512a"
!pip install --quiet mpl-scatter-density
import matplotlib.pyplot as plt
import mpl_scatter_density
from matplotlib.colors import LinearSegmentedColormap
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.9/79.9 MB 22.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 174.1/174.1 kB 464.1 kB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 98.1 MB/s eta 0:00:00 Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting mpl-scatter-density Downloading mpl_scatter_density-0.7-py3-none-any.whl (655 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 655.5/655.5 kB 44.0 MB/s eta 0:00:00 Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from mpl-scatter-density) (1.22.4) Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.10/dist-packages (from mpl-scatter-density) (3.7.1) Collecting fast-histogram>=0.3 (from mpl-scatter-density) Downloading fast_histogram-0.11-cp36-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.7/52.7 kB 7.1 MB/s eta 0:00:00 Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (1.0.7) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (0.11.0) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (4.39.3) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (1.4.4) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (23.1) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (8.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (3.0.9) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->mpl-scatter-density) (2.8.2) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.0->mpl-scatter-density) (1.16.0) Installing collected packages: fast-histogram, mpl-scatter-density Successfully installed fast-histogram-0.11 mpl-scatter-density-0.7
# Otherwise Matplotlib is bad about releasing memory
import gc
from matplotlib import pyplot as plt
def collect_garbage():
plt.figure().clear()
plt.close('all')
plt.cla()
plt.clf()
gc.collect(generation=2)
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
def logarithmic_vector(x, num_bits=16):
'''Encodes a number x revseribly as a logarithmic vector probability
For example:
0 = [1, 0, 0, ...]
1 = [0, 1, 0 ,...]
2 = [0, 0, 1, 0, ...]
3 = [0, 0, 0.42, 0.58, 0, ...]
4 = [0, 0, 0, 1, 0, ...]
Based on the mantissa, so that the distance in log-space is preserved
Note that the first bit is reserved for distance 0.
'''
ans = np.zeros(num_bits)
if x >= 1:
xl = np.log2(x)
exp = int(xl)
mantissa = xl % 1
ans[int(xl)+1] = 1 - mantissa
ans[int(xl)+2] = mantissa
elif x >= 0:
ans[0] = 1 - x
ans[1] = x
else:
raise ValueError('x must be >= 0')
return ans
def coarse_onehot(x, num_bits=16, range=256):
ans = np.zeros(num_bits)
loc = min(x//(range // num_bits), num_bits-1)
ans[loc] = 1
return ans
def adjacent_read(batch, expansion=1, theta=0, prng=np.random.RandomState(None)):
'''Generates many adjacent reads for each read in batch '''
def adjacent(x):
'''First shifts the entire k-mer by 'roll' to simulate overlap.
Then, puts in theta percent substitutions into a k-mer x (as array of 0,1,2,3)
Note that the actual edit distance for the returned string may not be d
'''
length = len(x)
roll = prng.randint(-length, length)
x = np.concatenate([x, prng.randint(4, size=length)])
x = np.roll(x, roll)[:length]
mask = (prng.uniform(size=length) < theta)
adds = prng.randint(4, size=length) * mask
x = (x + adds) % 4
return x.astype(np.uint8), np.absolute(roll)
read_array = []
for x in batch:
read_array.extend([(x,) + adjacent(x) for _ in range(expansion)] )
return read_array
import tensorflow as tf
import jax
from jax import numpy as jnp
import numpy as np
def convert_fasta_to_numpy(string):
'''Converts ACGT or acgt to 3021'''
return ((np.frombuffer(string.upper().encode(), dtype='uint8') % 7 +1) % 5)%4
with tf.io.gfile.GFile(DATASET_PATH + "NC_045512.2.fasta", "r") as f:
f.read(11)
data = f.read().replace('\n', '')
data = convert_fasta_to_numpy(data)
from tqdm.auto import tqdm
def sample_reads(k, fasta, size=1, theta=0.01, seed=1729):
'''Samples reads from a fasta'''
seq_length = len(fasta) - k
prng = np.random.RandomState(seed)
read_array = []
positions = []
for _ in range(size):
loc = prng.randint(seq_length)
positions.append(loc)
read = fasta[loc:loc+k]
num_mutations = sum(prng.uniform(size=k)<theta)
for _ in range(num_mutations):
if prng.uniform() < 0.33: # insertion
rloc = prng.randint(0, k)
roll = prng.randint(0, k)
read[rloc:] = np.roll(read[rloc:], roll)
read[rloc] = (read[rloc] + prng.randint(1,4)) % 4
elif prng.uniform() <0.66: # deletion
rloc = prng.randint(0, k)
roll = prng.randint(0, k)
read[0] = (read[0] + prng.randint(1,4)) % 4
read[:rloc] = np.roll(read[:rloc], -roll)
else:
rloc = prng.randint(0, k)
read[rloc] = (read[rloc] + prng.randint(1,4)) % 4
read_array.append(read.reshape((1,-1)))
return np.concatenate(read_array, axis=0), positions
# Generate actual sampled reads for assembly
k = 256
depth_of_coverage = 200
reads, locations = sample_reads(k, data, size=(len(data)*depth_of_coverage//k), theta=0.01)
sorted_reads = np.concatenate([x.reshape((1,-1)) for x, _ in sorted(zip(reads, locations), key=lambda x: x[1])], axis=0)
sorted_locations = sorted(locations)
train_batch = adjacent_read(reads, expansion=1, theta=0.01)
train_loader = torch.utils.data.DataLoader(train_batch, batch_size=512, shuffle=True,
num_workers=1, collate_fn=numpy_collate)
test_dataset = adjacent_read(reads, expansion=1, theta=0.01, prng=np.random.RandomState(1))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=True,
num_workers=1, collate_fn=numpy_collate)
val_dataset = adjacent_read(reads, expansion=1, theta=0.01, prng=np.random.RandomState(0))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=True,
num_workers=1, collate_fn=numpy_collate)
import torch
from torch.utils.data import Dataset, DataLoader
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
import jax
import jax.numpy as jnp
class Encoder(nn.Module):
@nn.compact
def __call__(self, x):
x = x.reshape([x.shape[0], -1, 1])
y1 = nn.Conv(features=4, kernel_size=(4,), strides=4, name='1mer')(x)
y2 = nn.Conv(features=4**2, kernel_size=(4*2,), strides=4, name='2mer')(x)
y3 = nn.Conv(features=4**3, kernel_size=(4*3,), strides=4, name='3mer')(x)
y4 = nn.Conv(features=4**4 // 4, kernel_size=(4*4,), strides=4, name='4mer')(x)
y5 = nn.Conv(features=4**5 // 8, kernel_size=(4*5,), strides=4, name='5mer')(x)
y6 = nn.Conv(features=4**6 // 16, kernel_size=(4*6,), strides=4, name='6mer')(x)
y7 = nn.Conv(features=4**7 // 32, kernel_size=(4*7,), strides=4, name='7mer')(x)
y8 = nn.Conv(features=4**8 // 64, kernel_size=(4*8,), strides=4, name='8mer')(x)
y9 = nn.Conv(features=4**9 // 128, kernel_size=(4*9,), strides=4, name='9mer')(x)
y10 = nn.Conv(features=4**10 // 256, kernel_size=(4*10,), strides=4, name='10mer')(x)
y4 = nn.max_pool(y4, window_shape=(4,), strides=(4,))
y5 = nn.max_pool(y5, window_shape=(5,), strides=(5,))
y6 = nn.max_pool(y6, window_shape=(6,), strides=(6,))
y7 = nn.max_pool(y7, window_shape=(7,), strides=(7,))
y8 = nn.max_pool(y8, window_shape=(8,), strides=(8,))
y9 = nn.max_pool(y9, window_shape=(9,), strides=(9,))
y10 = nn.max_pool(y10, window_shape=(10,), strides=(10,))
y1 = y1.reshape(y1.shape[0], -1)
y2 = y2.reshape(y2.shape[0], -1)
y3 = y3.reshape(y3.shape[0], -1)
y4 = y4.reshape(y4.shape[0], -1)
y5 = y5.reshape(y5.shape[0], -1)
y6 = y6.reshape(y6.shape[0], -1)
y7 = y7.reshape(y7.shape[0], -1)
y8 = y8.reshape(y8.shape[0], -1)
y9 = y9.reshape(y9.shape[0], -1)
y10 = y10.reshape(y10.shape[0], -1)
y = jnp.concatenate([y1, y2, y3, y4, y5, y6, y7, y8, y9, y10], axis=1)
y = nn.gelu(y)
y = nn.Dense(512)(y)
y = nn.gelu(y)
y = nn.Dense(256)(y)
y = nn.gelu(y)
y = nn.Dense(128)(y)
y = nn.gelu(y)
y = nn.Dense(64)(y)
y = nn.gelu(y)
y = nn.Dense(32)(y)
y = nn.gelu(y)
y = nn.Dense(16)(y)
y = nn.gelu(y)
y = nn.Dense(8)(y)
y = nn.gelu(y)
y = nn.Dense(3)(y)
return y
class PredictDistance(nn.Module):
@nn.compact
def __call__(self, lx, ly):
def approx_sqrt(a):
a = a + 0.01
def update(x):
x = x*(x*x + 3*a)/(3*x*x + a)
return x
x = a/2
for _ in range(10):
x = update(x)
return x
z = ((lx - ly)**2).sum(axis=1)
z = approx_sqrt(z)
return z
class EditEmbedding(nn.Module):
def setup(self):
self.encoder = Encoder()
self.predictdistance = PredictDistance()
def __call__(self, triple):
x, y, d = triple
x = nn.one_hot(x, 4)
y = nn.one_hot(y, 4)
lx = self.encoder(x)
ly = self.encoder(y)
z = self.predictdistance(lx, ly)
return z, d
rng = jax.random.PRNGKey(0)
batch = next(iter(test_loader))
encoder = EditEmbedding()
params = encoder.init(rng, batch)['params']
recon_x, y = encoder.apply({'params': params}, batch)
del rng, batch, encoder, params, recon_x, y
class TrainerModule:
def __init__(self, lr=1e-3, seed=42, log_local=False):
super().__init__()
self.lr = lr
self.seed = seed
# Create empty model (with no parameters)
self.model = EditEmbedding()
# Prepare logging
self.exmp = next(iter(test_loader))
self.log_dir = os.path.join(CHECKPOINT_PATH, f'e200')
if log_local:
self.logger = SummaryWriter(log_dir = '.')
else:
self.logger = SummaryWriter(log_dir = self.log_dir)
self.create_functions() # Create jitted training and eval functions
self.init_model() # Initialize model
def loss_fn(self, params, batch):
# Reconstruction loss
pred_d, d = self.model.apply({'params': params}, batch)
d = d.reshape([d.shape[0], 1])
pred_d = pred_d.reshape([pred_d.shape[0], 1])
concat_both = jnp.concatenate([d, pred_d], axis=1)
weight = 1-jax.nn.sigmoid((jnp.min(concat_both, axis=1)-128.0)/16)
base_loss = optax.huber_loss(pred_d, d).reshape((d.shape[0],)) * weight
base_loss = base_loss.mean()
return base_loss
def create_functions(self):
# Training function
def train_step(state, batch):
loss, grads = jax.value_and_grad(self.loss_fn)(state.params, batch) # Get loss and gradients for loss
state = state.apply_gradients(grads=grads) # Optimizer update step
return state, loss
self.train_step = jax.jit(train_step)
def eval_step(state, batch):
return self.loss_fn(state.params, batch)
self.eval_step = jax.jit(eval_step)
def init_model(self):
#Initialize model
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
params = self.model.init(init_rng, self.exmp)['params']
# Initialize learning rate schedule and optimizer
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=500,
decay_steps=2500*len(train_loader),
end_value=1e-6
)
optimizer = optax.chain(
optax.clip(1.0), # Clip gradients at 1
optax.adam(lr_schedule)
)
# Initialize training state
self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=optimizer)
def train_model(self, num_epochs=500, starting_epoch=0):
# Train model until num_epochs
best_eval = 1e6
with tqdm(range(starting_epoch+1, num_epochs+1)) as t:
for epoch_idx in t:
self.train_epoch(epoch=epoch_idx, tqdm_out=t)
if epoch_idx % 50 == 0:
eval_loss = self.eval_model(val_loader)
self.logger.add_scalar('val/loss', eval_loss, global_step=epoch_idx)
if eval_loss < best_eval:
best_eval = eval_loss
self.save_model(step=epoch_idx)
self.logger.flush()
def train_epoch(self, epoch, tqdm_out=None):
# Train model for one epoch, and log avg loss
losses = []
train_dataset = adjacent_read(reads, expansion=1, theta=0.01, prng=np.random.RandomState(1000+epoch))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True,
num_workers=4, collate_fn=numpy_collate)
for batch in train_loader:
self.state, loss = self.train_step(self.state, batch)
losses.append(loss)
losses_np = np.stack(jax.device_get(losses))
avg_loss = losses_np.mean()
self.logger.add_scalar('train/loss', avg_loss, global_step=epoch)
if tqdm_out is not None:
tqdm_out.set_description('loss: ' + str(avg_loss))
def eval_model(self, data_loader):
# Test model on all images of a data loader and return avg loss
losses = []
batch_sizes = []
for batch in data_loader:
loss = self.eval_step(self.state, batch)
losses.append(loss)
batch_sizes.append(batch[0].shape[0])
losses_np = np.stack(jax.device_get(losses))
batch_sizes_np = np.stack(batch_sizes)
avg_loss = (losses_np * batch_sizes_np).sum() / batch_sizes_np.sum()
return avg_loss
def save_model(self, step=0):
# Save current model at certain training iteration
checkpoints.save_checkpoint(ckpt_dir=self.log_dir, target=self.state.params,
prefix=f'e200_',
step=step, overwrite=False, keep=10)
def load_model(self, step=-1):
prefix = f'e200_'
checkpoint_file = checkpoints.latest_checkpoint(self.log_dir, prefix)
if step < 0:
if checkpoint_file == None:
step = None
else:
step = int(checkpoint_file.split('_')[-1])
params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, prefix=prefix, step=step)
self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)
if step == None:
step = 0
return step
def load_pretrained_model(self, model_path):
params = checkpoints.restore_checkpoint(ckpt_dir=model_path, target=self.state.params)
self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)
step = 0
return step
def checkpoint_exists(self):
# Check whether a model already exists for this autoencoder
return tf.io.gfile.exists(self.log_dir)
def load_aligner(total_epochs=10, training_bool=True):
# Create a trainer module with specified hyperparameters
trainer = TrainerModule(log_local=(not training_bool))
step = trainer.load_model()
if training_bool:
trainer.train_model(num_epochs=total_epochs, starting_epoch=step)
trainer.model_bd = trainer.model.bind({'params': trainer.state.params})
return trainer
training_bool = False # Setting this to True requires write permissions
trainer_ld = load_aligner(total_epochs=(16100), training_bool=training_bool)
num_params = 0
for k1 in trainer_ld.state.params['encoder'].keys():
for k2 in trainer_ld.state.params['encoder'][k1]:
num_params += (np.prod(trainer_ld.state.params['encoder'][k1][k2].shape))
print("Number of trainable parameters: " + str(num_params))
Number of trainable parameters: 130437175
scatter_loader = torch.utils.data.DataLoader(sorted_reads, batch_size=512, shuffle=False,
num_workers=1, collate_fn=numpy_collate)
coords_list = []
for r in iter(scatter_loader):
r = nn.one_hot(r, 4)
coords = (trainer_ld.model_bd.encoder(np.asarray(r)))
coords_list.append(coords)
import pandas as pd
aug_coords_list = []
for i, coord in enumerate(coords_list):
coord2 = i*np.ones((coord.shape[0], 1))
coord3 = np.concatenate([coord, coord2], axis=1)
aug_coords_list.append(coord3)
aug_coords = np.concatenate(aug_coords_list)
df = pd.DataFrame(aug_coords, columns = ['x', 'y', 'z', 'c'])
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from matplotlib import pyplot as plt
import plotly.express as px
df['Position'] = df['c']/df['c'].max()
fig = px.scatter_3d(df.loc[df['c'] >-1], x='x', y='y', z='z', color='Position',
color_continuous_scale=px.colors.sequential.Rainbow)
fig.update_traces(marker_size = 2)
camera = dict(
eye=dict(x=0.7, y=2.1, z=0.8)
)
fig.update_layout(scene_camera=camera)
if False: # Set this to true to write output
plotly.io.write_image(fig, '3D.png', scale=5)
fig.show()
collect_garbage()
<Figure size 640x480 with 0 Axes>
df2 = df.copy()
center = df.mean(axis=0)
df2['c'] /= df2['c'].max() # Normalize the distance along the genome
# find the centre of the sphere
from scipy.optimize import least_squares
def f(sphere, df):
'''Sphere is [x,y,z,r]'''
xyz = df[['x', 'y', 'z']].to_numpy()
centre = sphere[:3]
ans = ((xyz - centre)**2).sum(axis=1) - sphere[3]**2
return ans
centre = least_squares(f, 10*np.ones(4), args=(df2,))
df2['x'] = (df2['x'] - centre.x[0]) / centre.x[3]
df2['y'] = (df2['y'] - centre.x[1]) / centre.x[3]
df2['z'] = (df2['z'] - centre.x[2]) / centre.x[3]
def g(cylinder, df):
'''Cylinder axis is [x, y, z]'''
xyz = df[['x', 'y', 'z']].to_numpy()
cylinder = cylinder / np.sqrt(np.sum((cylinder**2)))
ans = (xyz * cylinder).sum(axis=1)
return ans
cylinder = least_squares(g, np.ones(3), args=(df2,))
cylinder_axis = cylinder.x / np.sqrt(np.sum((cylinder.x**2)))
#print(cylinder_axis)
def rot(cylinder_axis, df):
print(cylinder_axis)
dfx = df['x']
dfy = df['y']
dfz = df['z']
x, y, z = cylinder_axis
thet1 = -np.arctan(y/x)
x_ = np.cos(thet1)*x - np.sin(thet1)*y
y_ = np.sin(thet1)*x + np.cos(thet1)*y
cylinder_axis = (x_, y_, z)
x, y, z = cylinder_axis
thet2 = -np.arctan(x/z)
z_ = np.cos(thet2)*z - np.sin(thet2)*x
x_ = np.sin(thet2)*z + np.cos(thet2)*x
cylinder_axis = (x_, y, z_)
#print(cylinder_axis)
dfx_ = np.cos(thet1)*dfx - np.sin(thet1)*dfy
dfy_ = np.sin(thet1)*dfx + np.cos(thet1)*dfy
dfx = dfx_
dfy = dfy_
dfz_ = np.cos(thet2)*dfz - np.sin(thet2)*dfx
dfx_ = np.sin(thet2)*dfz + np.cos(thet2)*dfx
dfz = dfz_
dfx = dfx_
df2 = df.copy()
df2['x'] = dfx
df2['y'] = dfy
df2['z'] = dfz
return df2
df2 = rot(cylinder_axis, df2)
def unroll(df):
'''Unrolls a cylinder around the z axis with radius 1'''
df2 = df.copy()
dfx = df['x']
dfy = df['y']
dfz = df['z']
df0 = np.arctan(dfx/dfy) + (dfy < 0)*3.14 + 3.14/2
df2['theta'] = df0
return df2
df2 = unroll(df2)
df2['Position'] = df2['c']
df2['angle'] = df2['theta']
fig = px.scatter(df2.loc[df2['c'] >-1], x='angle', y='z', color='Position',
color_continuous_scale=px.colors.sequential.Rainbow)
fig.update_traces(marker_size = 3)
fig.update_xaxes(range=[0, 6.28], row=1, col=1)
fig.update_layout(
plot_bgcolor='white'
)
fig.show()
if False: # Set this to true to write output
plotly.io.write_image(fig, '2D.png', scale=5)
collect_garbage()
[-0.5340782 0.81368279 -0.22952296]
<Figure size 640x480 with 0 Axes>
from sklearn.metrics.pairwise import euclidean_distances
XYZ = df[['x', 'y', 'z']].to_numpy()
ed = euclidean_distances(XYZ)
od = euclidean_distances(np.asarray(sorted_locations).reshape([-1, 1]))
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
# "Viridis-like" colormap with white background
white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
(0, '#ffffff'),
(1e-20, '#440053'),
(0.2, '#404388'),
(0.4, '#2a788e'),
(0.6, '#21a784'),
(0.8, '#78d151'),
(1, '#fde624'),
], N=256)
fig = plt.figure(figsize=(6, 6), dpi=90)
ax = fig.add_subplot(1,1,1, projection='scatter_density')
density = ax.scatter_density(od.flatten(), ed.flatten(), cmap=white_viridis)
ax.set_ylim([0,410])
ax.set_xlim([0,410])
ax.set_xlabel('Distance along sequence')
ax.set_ylabel('Embedding space Euclidean distance')
ax.set_box_aspect(1)
fig.colorbar(density, label='Number of points per pixel')
plt.show()
collect_garbage()
/usr/local/lib/python3.10/dist-packages/mpl_scatter_density/generic_density_artist.py:77: RuntimeWarning: All-NaN slice encountered /usr/local/lib/python3.10/dist-packages/mpl_scatter_density/generic_density_artist.py:82: RuntimeWarning: All-NaN slice encountered
<Figure size 640x480 with 0 Axes>
fig = plt.figure(figsize=(20, 6), dpi=90)
ax = fig.add_subplot(1,1,1, projection='scatter_density')
density = ax.scatter_density(od.flatten(), ed.flatten(), cmap=white_viridis)
ax.set_ylim([0,410])
ax.set_xlim([0,30000])
ax.set_xlabel('Distance along sequence')
ax.set_ylabel('Embedding space Euclidean distance')
ax.set_box_aspect(1/4)
fig.colorbar(density, label='Number of points per pixel')
plt.show()
collect_garbage()
<Figure size 640x480 with 0 Axes>
from Levenshtein import distance
from sklearn.metrics.pairwise import pairwise_distances, paired_distances
from tqdm.auto import tqdm
# This function takes upwards of 6 hours to compute, so we've cached the answers here.
cache_path = DATASET_PATH + "dim3large_20230512a/e200/edit_distances.npy"
try:
with tf.io.gfile.GFile(cache_path, "rb") as f:
edit_distances = np.load(f)
except FileNotFoundError:
edit_distances = pairwise_distances(sorted_reads, sorted_reads, metric=distance)
if False:
with tf.io.gfile.GFile(cache_path, "w") as f:
np.save(f, edit_distances)
fig = plt.figure(figsize=(6, 6), dpi=90)
ax = fig.add_subplot(1,1,1, projection='scatter_density')
density = ax.scatter_density(edit_distances.flatten(), ed.flatten(), cmap=white_viridis)
ax.set_ylim([0,410])
ax.set_xlim([0,410])
ax.set_xlabel('Actual edit distance')
ax.set_ylabel('Embedding space Euclidean distance')
ax.set_box_aspect(1)
fig.colorbar(density, label='Number of points per pixel')
plt.show()
collect_garbage()
<Figure size 640x480 with 0 Axes>