Skip to content

Commit

Permalink
Fix n_top_genes sorting problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Gökcen Eraslan authored and gokceneraslan committed Apr 30, 2019
1 parent a2f1689 commit 3db4715
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
9 changes: 7 additions & 2 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def highly_variable_genes(
df.append(hvg)

df = pd.concat(df, axis=0)
df['highly_variable'] = df['highly_variable'].astype(int)
df = df.groupby('gene').agg({'means': np.nanmean,
'dispersions': np.nanmean,
'dispersions_norm': np.nanmean,
'highly_variable': np.nansum})
df.rename(columns={'highly_variable': 'highly_variable_nbatches'}, inplace=True)
df['highly_variable_intersection'] = df['highly_variable_nbatches'] == len(batches)

if n_top_genes is not None:
# sort genes by how often they selected as hvg within each batch and
Expand Down Expand Up @@ -251,6 +253,7 @@ def highly_variable_genes(
adata.var['dispersions_norm'] = df['dispersions_norm'].values.astype('float32', copy=False)
if batch_key is not None:
adata.var['highly_variable_nbatches'] = df['highly_variable_nbatches'].values
adata.var['highly_variable_intersection'] = df['highly_variable_intersection'].values
if subset:
adata._inplace_subset_var(gene_subset)
else:
Expand All @@ -267,6 +270,8 @@ def highly_variable_genes(
('dispersions_norm', 'float32'),
]
if batch_key is not None:
arrays.append(df['highly_variable_nbatches'].values)
dtypes.append(('highly_variable_nbatches', int))
arrays.extend([df['highly_variable_nbatches'].values,
df['highly_variable_intersection'].values])
dtypes.append([('highly_variable_nbatches', int),
('highly_variable_intersection', np.bool_)])
return np.rec.fromarrays(arrays, dtype=dtypes)
1 change: 1 addition & 0 deletions scanpy/tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_highly_variable_genes_basic():
adata = sc.datasets.blobs()
sc.pp.highly_variable_genes(adata, batch_key='blobs')
assert 'highly_variable_nbatches' in adata.var.columns
assert 'highly_variable_intersection' in adata.var.columns

adata = sc.datasets.blobs()
sc.pp.highly_variable_genes(adata, batch_key='blobs', n_top_genes=3)
Expand Down

0 comments on commit 3db4715

Please sign in to comment.