145 lines
5.0 KiB
Python
145 lines
5.0 KiB
Python
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib import colors
|
|
from matplotlib.ticker import LogFormatterSciNotation, SymmetricalLogLocator, LogLocator
|
|
|
|
|
|
def heatmap(data, row_labels, col_labels, ax=None,
|
|
cbar_kw={}, cbarlabel="", logcolor=False, sym_logcolor=False, xlabel=None, ylabel=None, **kwargs):
|
|
"""
|
|
Create a heatmap from a numpy array and two lists of labels.
|
|
|
|
Parameters
|
|
----------
|
|
data
|
|
A 2D numpy array of shape (N, M).
|
|
row_labels
|
|
A list or array of length N with the labels for the rows.
|
|
col_labels
|
|
A list or array of length M with the labels for the columns.
|
|
ax
|
|
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
|
|
not provided, use current axes or create a new one. Optional.
|
|
cbar_kw
|
|
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
|
|
cbarlabel
|
|
The label for the colorbar. Optional.
|
|
**kwargs
|
|
All other arguments are forwarded to `imshow`.
|
|
"""
|
|
data = np.ma.masked_where(data == 0, data)
|
|
|
|
if not ax:
|
|
ax = plt.gca()
|
|
|
|
# Plot the heatmap
|
|
im = ax.imshow(data, **kwargs)
|
|
|
|
# Create colorbar
|
|
if logcolor:
|
|
pcm = ax.pcolor(data,
|
|
norm=colors.LogNorm(vmin=data.min(), vmax=data.max()),
|
|
cmap='Reds')
|
|
cbar = ax.figure.colorbar(pcm, ax=ax, extend="max", ticks=LogLocator(base=2), format=LogFormatterSciNotation(base=2))
|
|
elif sym_logcolor:
|
|
linthresh = 1.0
|
|
pcm = ax.pcolor(data,
|
|
norm=colors.SymLogNorm(
|
|
linthresh=linthresh,
|
|
linscale=1.0,
|
|
vmin=data.min(),
|
|
vmax=data.max()
|
|
),
|
|
cmap='RdBu_r')
|
|
cbar = ax.figure.colorbar(pcm, ax=ax, extend="both", ticks=SymmetricalLogLocator(base=2, linthresh=linthresh), format=LogFormatterSciNotation(base=2))
|
|
else:
|
|
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
|
|
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
|
|
|
|
# We want to show all ticks...
|
|
ax.set_xticks(np.arange(data.shape[1]))
|
|
ax.set_yticks(np.arange(data.shape[0]))
|
|
# ... and label them with the respective list entries.
|
|
ax.set_xticklabels(col_labels)
|
|
ax.set_yticklabels(row_labels)
|
|
|
|
plt.tick_params(labelsize=6)
|
|
|
|
if xlabel is not None:
|
|
ax.set_xlabel(xlabel)
|
|
if ylabel is not None:
|
|
ax.set_ylabel(ylabel)
|
|
|
|
# Turn spines off and create white grid.
|
|
for edge, spine in ax.spines.items():
|
|
spine.set_visible(False)
|
|
|
|
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
|
|
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
|
|
ax.tick_params(which="minor", bottom=False, left=False)
|
|
|
|
ax.patch.set(hatch="xx", edgecolor="gray")
|
|
|
|
ax.grid(which="minor", color="w", linestyle='-', linewidth=0) # set linewidth=0.1 if annotating
|
|
#annotate_heatmap(im, data, fontsize=2)
|
|
return im, cbar
|
|
|
|
|
|
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
|
|
textcolors=["black", "white"],
|
|
threshold=None, **textkw):
|
|
"""
|
|
A function to annotate a heatmap.
|
|
|
|
Parameters
|
|
----------
|
|
im
|
|
The AxesImage to be labeled.
|
|
data
|
|
Data used to annotate. If None, the image's data is used. Optional.
|
|
valfmt
|
|
The format of the annotations inside the heatmap. This should either
|
|
use the string format method, e.g. "$ {x:.2f}", or be a
|
|
`matplotlib.ticker.Formatter`. Optional.
|
|
textcolors
|
|
A list or array of two color specifications. The first is used for
|
|
values below a threshold, the second for those above. Optional.
|
|
threshold
|
|
Value in data units according to which the colors from textcolors are
|
|
applied. If None (the default) uses the middle of the colormap as
|
|
separation. Optional.
|
|
**kwargs
|
|
All other arguments are forwarded to each call to `text` used to create
|
|
the text labels.
|
|
"""
|
|
|
|
if not isinstance(data, (list, np.ndarray)):
|
|
data = im.get_array()
|
|
|
|
# Normalize the threshold to the images color range.
|
|
if threshold is not None:
|
|
threshold = im.norm(threshold)
|
|
else:
|
|
threshold = im.norm(data.max())/2.
|
|
|
|
# Set default alignment to center, but allow it to be
|
|
# overwritten by textkw.
|
|
kw = dict(horizontalalignment="center",
|
|
verticalalignment="center")
|
|
kw.update(textkw)
|
|
|
|
# Get the formatter in case a string is supplied
|
|
if isinstance(valfmt, str):
|
|
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
|
|
|
|
# Loop over the data and create a `Text` for each "pixel".
|
|
# Change the text's color depending on the data.
|
|
texts = []
|
|
for i in range(data.shape[0]):
|
|
for j in range(data.shape[1]):
|
|
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
|
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
|
|
texts.append(text)
|
|
|
|
return texts |