Skip to content

Commit

Permalink
Merge branch 'main' of github.com:theislab/scib
Browse files Browse the repository at this point in the history
  • Loading branch information
mumichae committed Feb 1, 2023
2 parents d53b17e + d15ffc4 commit 98ee9a2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
24 changes: 18 additions & 6 deletions scib/metrics/isolated_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def isolated_labels_asw(
)
"""

return isolated_labels(
adata,
label_key=label_key,
Expand Down Expand Up @@ -156,9 +157,18 @@ def isolated_labels(

# 2. compute isolated label score for each isolated label
scores = {}
if not cluster:
adata.obs["silhouette_temp"] = silhouette_samples(
adata.obsm[embed], adata.obs[label_key]
)
for label in isolated_labels:
score = score_isolated_label(
adata, label_key, label, embed, cluster, verbose=verbose
adata,
label_key,
label,
embed,
cluster,
verbose=verbose,
)
scores[label] = score
scores = pd.Series(scores)
Expand Down Expand Up @@ -225,11 +235,13 @@ def max_f1(adata, label_key, cluster_key, label, argmax=False):
score = max_f1(adata, label_key, iso_label_key, isolated_label, argmax=False)
else:
# AWS score between isolated label vs rest
adata.obs[iso_label_key] = adata.obs[label_key] == isolated_label
adata.obs["silhouette_temp"] = silhouette_samples(
adata.obsm[embed], adata.obs[iso_label_key]
)
score = adata.obs[adata.obs[iso_label_key]].silhouette_temp.mean()

if "silhouette_temp" not in adata.obs:
adata.obs["silhouette_temp"] = silhouette_samples(
adata.obsm[embed], adata.obs[label_key]
)
# aggregate silhouette scores for isolated label only
score = adata.obs[adata.obs[label_key] == isolated_label].silhouette_temp.mean()

if verbose:
print(f"{isolated_label}: {score}")
Expand Down
19 changes: 16 additions & 3 deletions tests/metrics/test_isolated_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _random_embedding(partition):
embedding = OneHotEncoder().fit_transform(
LabelEncoder().fit_transform(partition)[:, None]
)
embedding = embedding + np.random.uniform(-0.1, 0.1, embedding.shape)
embedding = embedding + np.random.uniform(-0.001, 0.001, embedding.shape)
# convert to numpy array
embedding = np.asarray(embedding)
return embedding
Expand All @@ -43,10 +43,10 @@ def test_isolated_labels_ASW(adata_pca):
verbose=True,
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.1938440054655075, diff=1e-3)
assert_near_exact(score, 0.14066337049007416, diff=1e-3)


def test_isolated_labels_perfect(adata_pca):
def test_isolated_labels_f1_perfect(adata_pca):
adata_pca.obsm["X_emb"] = _random_embedding(partition=adata_pca.obs["celltype"])
score = scib.me.isolated_labels_f1(
adata_pca,
Expand All @@ -57,3 +57,16 @@ def test_isolated_labels_perfect(adata_pca):
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 1, diff=1e-12)


def test_isolated_labels_asw_perfect(adata_pca):
adata_pca.obsm["X_emb"] = _random_embedding(partition=adata_pca.obs["celltype"])
score = scib.me.isolated_labels_asw(
adata_pca,
label_key="celltype",
batch_key="batch",
embed="X_emb",
verbose=True,
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 1, diff=1e-2)
2 changes: 1 addition & 1 deletion tests/metrics/test_silhouette_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def test_isolated_labels_silhouette(adata_pca):
verbose=True,
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.1938440054655075, diff=1e-3)
assert_near_exact(score, 0.14066337049007416, diff=1e-3)

0 comments on commit 98ee9a2

Please sign in to comment.