Skip to content

Commit

Permalink
use pandas assert to compare data frames for var_df tests and add a…
Browse files Browse the repository at this point in the history
… couple of new tests.
  • Loading branch information
fidelram authored and ivirshup committed Dec 4, 2020
1 parent 081065b commit 2bb76b8
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions scanpy/tests/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,36 +89,45 @@ def test_obs_df(adata):


def test_var_df(adata):
adata.varm["eye"] = np.eye(2)
adata.varm["sparse"] = sparse.csr_matrix(np.eye(2))
adata.varm["eye"] = np.eye(2, dtype=int)
adata.varm["sparse"] = sparse.csr_matrix(np.eye(2), dtype='float64')

assert np.all(
np.equal(
sc.get.var_df(
adata,
keys=["cell2", "gene_symbols"],
varm_keys=[("eye", 0), ("sparse", 1)],
),
pd.DataFrame(
{
"cell2": [1, 1],
"gene_symbols": ["genesymbol1", "genesymbol2"],
"eye-0": [1, 0],
"sparse-1": [0, 1],
},
index=adata.obs_names,
),
)
pd.testing.assert_frame_equal(
sc.get.var_df(
adata,
keys=["cell2", "gene_symbols"],
varm_keys=[("eye", 0), ("sparse", 1)],
),
pd.DataFrame(
{
"cell2": [1, 1],
"gene_symbols": ["genesymbol1", "genesymbol2"],
"eye-0": [1, 0],
"sparse-1": [0.0, 1.0],
},
index=adata.var_names,
),
)
assert np.all(
np.equal(
sc.get.var_df(adata, keys=["cell1", "gene_symbols"], layer="double"),
pd.DataFrame(
{"cell1": [2, 2], "gene_symbols": ["genesymbol1", "genesymbol2"]},
index=adata.obs_names,
),
)
pd.testing.assert_frame_equal(
sc.get.var_df(adata, keys=["cell1", "gene_symbols"], layer="double"),
pd.DataFrame(
{"cell1": [2, 2], "gene_symbols": ["genesymbol1", "genesymbol2"]},
index=adata.var_names,
),
)
# test only cells
pd.testing.assert_frame_equal(
sc.get.var_df(adata, keys=["cell1", "cell2"]),
pd.DataFrame({"cell1": [1, 1], "cell2": [1, 1]}, index=adata.var_names,),
)
# test only var columns
pd.testing.assert_frame_equal(
sc.get.var_df(adata, keys=["gene_symbols"]),
pd.DataFrame(
{"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=adata.var_names,
),
)

badkeys = ["badkey1", "badkey2"]
with pytest.raises(KeyError) as badkey_err:
sc.get.var_df(adata, keys=badkeys)
Expand Down

0 comments on commit 2bb76b8

Please sign in to comment.