[1]:
import numpy as np
import pandas as pd
import dynamo as dyn
import scipy.sparse as sp
from random import uniform
from scipy.stats import ttest_ind
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
/opt/anaconda3/envs/graphvelo-release/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[1]:
from graphvelo.graph_velocity import tangent_space_projection
from graphvelo.tangent_space import corr_kernel, cos_corr, density_corrected_transition_matrix
[3]:
tau=1
simulator = dyn.sim.BifurcationTwoGenes(dyn.sim.bifur2genes_params, tau=tau)
|-----> The model contains 2 genes and 2 species
|-----> Adjusting parameters based on `r_aug` and `tau`...
|-----> 10 initial conditions have been created by augmentation.
[4]:
simulator.simulate([0, 40], n_cells=2000)
adata = simulator.generate_anndata()
adata.obsm['X_raw'] = adata.layers['total'].copy()
adata.obsm['velocity_raw'] = adata.layers['velocity_T'].copy()
|-----> Sampling 2000 from 35794 simulated data points.
|-----> 2000 cell with 2 genes stored in AnnData.
/opt/anaconda3/envs/graphvelo/lib/python3.8/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
[5]:
ax = dyn.pl.zscatter(adata, basis='raw', color='time', cmap='viridis', save_show_or_return='return')
dyn.pl.zstreamline(adata, basis='raw')
../../_images/graphvelo_notebooks_tutorials_figure2_analytical_manifold_4_0.png
[6]:
# adata.write('bif_sim.h5ad')
[7]:
adata = dyn.read('bif_sim.h5ad')
adata
[7]:
AnnData object with n_obs × n_vars = 2000 × 2
    obs: 'trajectory', 'time'
    var: 'a', 'b', 'S', 'K', 'm', 'n', 'gamma'
    obsm: 'X_raw', 'velocity_raw'
    layers: 'total', 'velocity_T'

Mapping 2d vector field to 3d

[8]:
def mapping(X, center_x, center_y, radius):
    x_2d = X[:, 0]
    y_2d = X[:, 1]
    x_3d = x_2d
    y_3d = y_2d
    z_3d = np.sqrt(np.maximum(0, radius**2 - (x_2d - center_x)**2 - (y_2d - center_y)**2))
    XYZ_3d = np.column_stack((x_3d, y_3d, z_3d))
    return XYZ_3d

def mapping_with_velocity(X_t0, X_t1, dt, center_x, center_y, radius):
    ### (Y-X)/delta_t
    XYZ_3d = mapping(X_t0, center_x, center_y, radius)
    XYZ_3d_ = mapping(X_t1, center_x, center_y, radius)

    return (XYZ_3d_ - XYZ_3d) / dt

def add_noise(X_3d, V_tangent, epsilon_range=(0, 1), center_x=60, center_y=60):
    center = np.array([center_x, center_y, 0])
    noisy_V_tangent = np.zeros_like(V_tangent)

    for i in range(X_3d.shape[0]):
        point = X_3d[i]
        normal_vector = point - center
        normal_vector /= np.linalg.norm(normal_vector)

        tangent_magnitude = np.linalg.norm(V_tangent[i])

        epsilon = uniform(*epsilon_range)
        noise_vector = epsilon * tangent_magnitude * normal_vector

        noisy_V_tangent[i] = V_tangent[i] + noise_vector

    return noisy_V_tangent
[9]:
center_x, center_y = 60, 60
r  = 70
dt = 1
X = adata.layers['total']
V = adata.layers['velocity_T']
X_3d = mapping(X, center_x, center_y, radius=r)
V_tangent = mapping_with_velocity(X, X+V*dt, dt, center_x, center_y, r)
V_noisy = add_noise(X_3d, V_tangent, epsilon_range=(0, 1), center_x=center_x, center_y=center_y)
[10]:
normals = X_3d - np.array([center_x, center_y, 0])
V_N = np.zeros_like(V_tangent)
V_N_noise = np.zeros_like(V_tangent)
for i in range(V_N.shape[0]):
    V_N[i] = (np.dot(V_tangent[i], normals[i]) / np.dot(normals[i], normals[i])) * normals[i]
    V_N_noise[i] = (np.dot(V_noisy[i], normals[i]) / np.dot(normals[i], normals[i])) * normals[i]
[11]:
def project_velocity(X_embedding, T=None) -> np.ndarray:
    n = T.shape[0]
    delta_X = np.zeros((n, X_embedding.shape[1]))
    for i in range(n):
        idx = T[i].indices
        diff_emb = X_embedding[idx] - X_embedding[i, None]
        if np.isnan(diff_emb).sum() != 0:
            diff_emb[np.isnan(diff_emb)] = 0
        T_i = T[i].data
        delta_X[i] = T_i.dot(diff_emb)

    return delta_X
[12]:
nbrs = NearestNeighbors(n_neighbors=15).fit(X_3d)
dist, ind = nbrs.kneighbors(X_3d)
[13]:
# GraphVelo implementation
P = corr_kernel(X_3d, V_noisy, ind, corr_func=cos_corr)
P_dc = density_corrected_transition_matrix(P).A
T = tangent_space_projection(X_3d, V_noisy, P_dc, ind, b=0)
T = sp.csr_matrix(T)
V_p = project_velocity(X_3d, T)
V_cos = project_velocity(X_3d, sp.csr_matrix(P_dc))
Learning Phi in tangent space projection.: 100%|██████████| 2000/2000 [00:02<00:00, 989.53it/s]
[14]:
def add_significance(ax, left: int, right: int, significance: str, level: int = 0, max_y = None, **kwargs):
    bracket_level = kwargs.pop("bracket_level", 0.8)
    bracket_height = kwargs.pop("bracket_height", 0.02)
    text_height = kwargs.pop("text_height", 0.005)

    bottom, top = ax.get_ylim()
    y_axis_range = top - bottom if max_y is None else max_y - bottom

    bracket_level = (y_axis_range * 0.07 * level) + top * bracket_level
    bracket_height = bracket_level - (y_axis_range * bracket_height)

    ax.plot(
        [left, left, right, right],
        [bracket_height, bracket_level, bracket_level, bracket_height], **kwargs
    )

    ax.text(
        (left + right) * 0.5,
        bracket_level + (y_axis_range * text_height),
        significance,
        ha='center',
        va='bottom',
        c='k'
    )

def get_significance(pvalue):
    if pvalue < 0.001:
        return "***"
    elif pvalue < 0.01:
        return "**"
    elif pvalue < 0.1:
        return "*"
    else:
        return "n.s."
[15]:
palette = {"GraphVelo": "#f15a24", "cosine kernel": "#ab99e7"}
[16]:
V_Np = np.zeros_like(V_p)
V_Ncos = np.zeros_like(V_p)
for i in range(V_Np.shape[0]):
    V_Np[i] = (np.dot(V_p[i], normals[i]) / np.dot(normals[i], normals[i])) * normals[i]
    V_Ncos[i] = (np.dot(V_cos[i], normals[i]) / np.dot(normals[i], normals[i])) * normals[i]
df = pd.DataFrame({
    'Normal components': list(np.linalg.norm(V_Np, axis=1)) + list(np.linalg.norm(V_Ncos, axis=1)),
    'Velocity': ['GraphVelo']*V_p.shape[0] + ['cosine kernel']*V_p.shape[0]
    })
fig, ax = plt.subplots(figsize=(1.5, 2.5))
sns.violinplot(df, x='Velocity', y='Normal components', palette=palette, ax=ax,)
plt.xticks(rotation = 45, ha='right')
y_min, y_max = ax.get_ylim()
ax.set_ylim([y_min, y_max + 0.02])
# plt.ylim(0, 0.5)
ttest_res = ttest_ind(np.linalg.norm(V_Np, axis=1), np.linalg.norm(V_Ncos, axis=1), equal_var=False, alternative="less")
significance = get_significance(ttest_res.pvalue)
add_significance(
    ax=ax, left=0, right=1, significance=significance, lw=1, bracket_level=1.1, c="k", level=0
)
# plt.savefig(dpi=300, transparent=True, fname=f'figures/normal_components.pdf', bbox_inches = "tight")
plt.show()
/var/folders/zr/17m1_r_x57q5mxr1zs8px9rh0000gn/T/ipykernel_96564/1376135066.py:11: FutureWarning:

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(df, x='Velocity', y='Normal components', palette=palette, ax=ax,)
../../_images/graphvelo_notebooks_tutorials_figure2_analytical_manifold_16_1.png
[17]:
sim = np.diag(cosine_similarity(V_p, V_tangent))
sim_cos = np.diag(cosine_similarity(V_cos, V_tangent))
rmse = np.array([np.linalg.norm(V_p[i] - V_tangent[i])/np.sqrt(len(V_p[i])) for i in range(V_p.shape[0])])
rmse_cos = np.array([np.linalg.norm(V_cos[i] - V_tangent[i])/np.sqrt(len(V_p[i])) for i in range(V_p.shape[0])])
[18]:
df = pd.DataFrame(
    {
        "cosine similarity": sim.tolist()+sim_cos.tolist(),
        "Model": ["GraphVelo"] * len(sim) + ["cosine kernel"] * len(sim_cos)
    }
)
sns.set_style(style="whitegrid")
fig, ax = plt.subplots(figsize=(1, 1.8), dpi=300)

sns.violinplot(data=df, x="Model", y="cosine similarity", palette="colorblind", ax=ax)

ttest_res = ttest_ind(sim, sim_cos, equal_var=False, alternative="greater")
significance = get_significance(ttest_res.pvalue)
add_significance(
    ax=ax, left=0, right=1, significance=significance, lw=1, bracket_level=1.1, c="k", level=0
)

y_min, y_max = ax.get_ylim()
ax.set_ylim([y_min, y_max + 0.02])

ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels([0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks(rotation = 45, ha='right') # Rotates X-Axis Ticks by 45-degrees
# plt.savefig(dpi=300, transparent=True, fname=f'./figures/cosine_sim.pdf', bbox_inches = "tight")
plt.show()
/var/folders/zr/17m1_r_x57q5mxr1zs8px9rh0000gn/T/ipykernel_96564/1978391663.py:10: FutureWarning:

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(data=df, x="Model", y="cosine similarity", palette="colorblind", ax=ax)
../../_images/graphvelo_notebooks_tutorials_figure2_analytical_manifold_18_1.png
[19]:
df_rmse = pd.DataFrame(
    {
        "RMSE": rmse.tolist()+rmse_cos.tolist(),
        "Model": ["GraphVelo"] * len(rmse) + ["cosine kernel"] * len(rmse_cos)
    }
)
sns.set_style(style="whitegrid")
fig, ax = plt.subplots(figsize=(1, 1.8), dpi=300)

sns.violinplot(data=df_rmse, x="Model", y="RMSE", palette="colorblind", ax=ax)

ttest_res = ttest_ind(rmse, rmse_cos, equal_var=False, alternative="less")
significance = get_significance(ttest_res.pvalue)
add_significance(
    ax=ax, left=0, right=1, significance=significance, lw=1, bracket_level=1.1, c="k", level=0
)

y_min, y_max = ax.get_ylim()
ax.set_ylim([y_min, y_max + 0.02])

plt.xticks(rotation = 45, ha='right') # Rotates X-Axis Ticks by 45-degrees
# plt.savefig(dpi=300, transparent=True, fname=f'./figures/rmse.pdf', bbox_inches = "tight")
plt.show()
/var/folders/zr/17m1_r_x57q5mxr1zs8px9rh0000gn/T/ipykernel_96564/2143150486.py:10: FutureWarning:

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(data=df_rmse, x="Model", y="RMSE", palette="colorblind", ax=ax)
../../_images/graphvelo_notebooks_tutorials_figure2_analytical_manifold_19_1.png
[20]:
def plot_3d_vector_field(X, V, c, radius, downsample, scale=0.1, norm_v=True, size=5, line_width=4, cone_size=1, add_surface=True, surface_alpha=0.6, camera_view=[1, 1, 1]):
    import matplotlib
    import matplotlib.colors as mcolors

    norm = mcolors.Normalize(vmin=np.min(c), vmax=np.max(c))
    viridis = matplotlib.colormaps['plasma']
    colors = viridis(norm(c))
    colors_hex = [mcolors.to_hex(c) for c in colors]

    if norm_v:
        from sklearn.preprocessing import normalize
        V = normalize(V, axis=0, norm="l2")
        V = normalize(V, axis=1, norm="l2")
    V = scale * V
    fig = go.Figure()

    # Surface
    x_min, x_max = X[:, 0].min(), X[:, 0].max()
    y_min, y_max = X[:, 1].min(), X[:, 1].max()
    x = np.linspace(x_min, x_max, 100)
    y = np.linspace(y_min, y_max, 100)
    x_grid, y_grid = np.meshgrid(x, y)
    z_grid = np.sqrt(radius**2 - (x_grid - center_x)**2 - (y_grid - center_y)**2)

    if add_surface:
        fig.add_surface(x=x_grid, y=y_grid, z=z_grid, colorscale=[[0, '#99BADF'], [1, '#99BADF']], showscale=False, opacity=surface_alpha)

    fig.add_trace(go.Scatter3d(
        x=X[:, 0], y=X[:, 1], z=X[:, 2],
        mode='markers',
        marker=dict(size=size, color=colors_hex, opacity=0.5),
    ))
    end_points = X + V
    X_downsample = X[::downsample]
    end_points_downsample = end_points[::downsample]

    for i in range(len(X_downsample)):
        fig.add_trace(go.Scatter3d(
            x=[X_downsample[i, 0], end_points_downsample[i, 0]],
            y=[X_downsample[i, 1], end_points_downsample[i, 1]],
            z=[X_downsample[i, 2], end_points_downsample[i, 2]],
            mode='lines',
            line=dict(color='black', width=line_width),
            showlegend=False
        ))

    fig.add_trace(go.Cone(
        x=end_points_downsample[:, 0],
        y=end_points_downsample[:, 1],
        z=end_points_downsample[:, 2],
        u=V[::downsample, 0],
        v=V[::downsample, 1],
        w=V[::downsample, 2],
        colorscale=[[0, 'black'], [1, 'black']],
        sizeref=cone_size,
        # opacity=1.0,
        showscale=False
    ))

    eye = dict(x=camera_view[0], y=camera_view[1], z=camera_view[2])

    fig.update_layout(
        showlegend=False,
        scene=dict(
            xaxis=dict(
                title='',
                showticklabels=False,
            ),
            yaxis=dict(
                title='',
                showticklabels=False,
            ),
            zaxis=dict(
                title='',
                showticklabels=False,
            ),
            aspectratio=dict(
                x=1,
                y=1,
                z=0.5,),
            camera=dict(
                eye=eye)
            )
    )
    return fig
[21]:
fig = plot_3d_vector_field(X_3d, V_noisy, adata.obs['time'].values, radius=r, size=5, scale=5, downsample=2, line_width=5, cone_size=3, camera_view=(0.2, 0.5, 0.5)) # (0.3, 0.8, 0.6)
fig.show()
/var/folders/zr/17m1_r_x57q5mxr1zs8px9rh0000gn/T/ipykernel_96564/1669440842.py:23: RuntimeWarning:

invalid value encountered in sqrt

Data type cannot be displayed: application/vnd.plotly.v1+json

[28]:
fig = plot_3d_vector_field(X_3d, V_p, adata.obs['time'].values, radius=r, size=5, scale=5, downsample=2, line_width=5, cone_size=5, camera_view=(0.2, 0.5, 0.5))
fig.show()
/var/folders/zr/17m1_r_x57q5mxr1zs8px9rh0000gn/T/ipykernel_96564/1669440842.py:23: RuntimeWarning:

invalid value encountered in sqrt

Data type cannot be displayed: application/vnd.plotly.v1+json