Skip to content

Commit

Permalink
Merge pull request rasbt#1022 from rasbt/heatmap-param
Browse files Browse the repository at this point in the history
add heatmap font color threshold parameter
  • Loading branch information
rasbt authored Apr 2, 2023
2 parents 77d0696 + 67b9615 commit 45096c4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Version 0.22.0dev (TBD)
- Internal `fpmax` code improvement that avoids casting a sparse DataFrame into a dense NumPy array. ([#1000](https://github.com/rasbt/mlxtend/pull/1000) via [Tim Kellogg](https://github.com/tkellogg))
- The `plot_decision_regions` function now has a `n_jobs` parameter to parallelize the computation. (In a particular use case, on a small dataset, there was a 21x speed-up (449 seconds vs 21 seconds on local HPC instance of 36 cores). ([#998](https://github.com/rasbt/mlxtend/pull/998) via [Khalid ElHaj](https://github.com/Ne-oL))
- Added `mlxtend.frequent_patterns.hmine` algorithm and documentation for mining frequent itemsets using the H-Mine algorithm. ([#1020](https://github.com/rasbt/mlxtend/pull/1020) via [Fatih Sen](https://github.com/fatihsen20))
- Adds a new `font_color_threshold` parameter to the `mlxtend.plotting.heatmap` function to manually override the threshold for the text annotation color. ([#1022](https://github.com/rasbt/mlxtend/pull/1022))

### Version 0.21.0 (09/17/2022)

Expand Down
35 changes: 18 additions & 17 deletions docs/sources/user_guide/plotting/heatmap.ipynb

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions mlxtend/plotting/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def heatmap(
cell_values=True,
cell_fmt=".2f",
cell_font_size=None,
text_color_threshold=None,
):
"""Plot a heatmap via matplotlib.
Expand Down Expand Up @@ -64,6 +65,11 @@ def heatmap(
cell_font_size : int (default: None)
Font size for cell values (if `cell_values=True`)
text_color_threshold : float (default: None)
Threshold for the black/white text threshold of the text
annotation. Default (None) tried to infer a good threshold
automatically using `np.max(normed_matrix) / 2`.
Returns
-----------
fig, ax : matplotlib.pyplot subplot objects
Expand Down Expand Up @@ -108,6 +114,10 @@ def heatmap(
fig.colorbar(matshow)

normed_matrix = matrix.astype("float") / matrix.max()
if text_color_threshold is None:
thres_value = np.max(normed_matrix) / 2
else:
thres_value = text_color_threshold

if cell_values:
for i in range(matrix.shape[0]):
Expand All @@ -121,9 +131,7 @@ def heatmap(
s=cell_text,
va="center",
ha="center",
color="white"
if normed_matrix[i, j] > np.max(normed_matrix) / 2
else "black",
color="white" if normed_matrix[i, j] > thres_value else "black",
)

if row_names is not None:
Expand Down

0 comments on commit 45096c4

Please sign in to comment.