Skip to content

Commit

Permalink
transcripts finally working
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed Sep 6, 2023
1 parent f63e1f1 commit 421a8a3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ show_small.py
archive/
workflow/.snakemake
workflow/logs
*.ignore.*
*ignore*
explore

# OS related
Expand Down
42 changes: 27 additions & 15 deletions sopa/io/explorer/transcripts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
from math import ceil
from pathlib import Path

import numpy as np
import pandas as pd
import zarr

from ...utils.tiling import Tiles2D


def subsample_indices(n_samples, factor: int = 4):
n_sub = n_samples // factor
Expand All @@ -27,8 +26,10 @@ def write_transcripts(
location = df[[x, y]]
location = np.concatenate([location, np.zeros((num_transcripts, 1))], axis=1)

xmin, xmax = location[:, 0].min(), location[:, 0].max()
ymin, ymax = location[:, 1].min(), location[:, 1].max()
xmax, ymax = location[:, :2].max(axis=0)

assert location[:, 0].min() >= 0
assert location[:, 1].min() >= 0

gene_names = list(df[gene].cat.categories)
num_genes = len(gene_names)
Expand All @@ -39,6 +40,9 @@ def write_transcripts(
uuid = np.stack(
[np.arange(num_transcripts), np.full(num_transcripts, 65535)], axis=1
)
transcript_id = np.stack(
[np.arange(num_transcripts), np.full(num_transcripts, 65535)], axis=1
)
gene_identity = df[gene].cat.codes.values[:, None]
codeword_identity = np.stack(
[gene_identity[:, 0], np.full(num_transcripts, 65535)], axis=1
Expand Down Expand Up @@ -87,21 +91,20 @@ def write_transcripts(
level_group = grids.create_group(level)

tile_size = GRID_SIZE * 2**level
tiles = Tiles2D(xmin, xmax, ymin, ymax, tile_size, 0)

print(f"Level {level}: {len(tiles)} n_tiles, {len(location)} transcripts")
print(f"Level {level}: {len(location)} transcripts")

indices = tiles.coords_to_indices(location[:, :2])
tiles_str_indices = np.apply_along_axis(
lambda l: ",".join(l), 1, indices.astype(str)
)
indices = np.floor(location[:, :2] / tile_size).clip(0).astype(int)
tiles_str_indices = np.array([f"{tx},{ty}" for (tx, ty) in indices])

GRIDS_ATTRS["grid_array_shapes"].append([{}] * len(tiles))
GRIDS_ATTRS["grid_number_objects"].append(len(location))
GRIDS_ATTRS["grid_array_shapes"].append([])
GRIDS_ATTRS["grid_number_objects"].append([])
GRIDS_ATTRS["grid_keys"].append([])

for tx in range(tiles.tile_x.count):
for ty in range(tiles.tile_y.count):
n_tiles_x, n_tiles_y = ceil(xmax / tile_size), ceil(ymax / tile_size)

for tx in range(n_tiles_x):
for ty in range(n_tiles_y):
str_index = f"{tx},{ty}"
loc = np.where(tiles_str_indices == str_index)[0]

Expand All @@ -111,7 +114,9 @@ def write_transcripts(
if n_points_tile == 0:
continue

GRIDS_ATTRS["grid_array_shapes"][-1].append({})
GRIDS_ATTRS["grid_keys"][-1].append(str_index)
GRIDS_ATTRS["grid_number_objects"][-1].append(n_points_tile)

tile_group = level_group.create_group(str_index)
tile_group.array(
Expand Down Expand Up @@ -156,8 +161,14 @@ def write_transcripts(
dtype="uint32",
chunks=chunks,
)
tile_group.array(
"id",
transcript_id[loc],
dtype="uint32",
chunks=chunks,
)

if len(tiles) == 1:
if n_tiles_x * n_tiles_y == 1:
GRIDS_ATTRS["number_levels"] = level + 1
break

Expand All @@ -170,6 +181,7 @@ def write_transcripts(
quality_score = quality_score[sub_indices]
codeword_identity = codeword_identity[sub_indices]
uuid = uuid[sub_indices]
transcript_id = transcript_id[sub_indices]

grids.attrs.put(GRIDS_ATTRS)

Expand Down

0 comments on commit 421a8a3

Please sign in to comment.