'How to customize the colorbar of a heatmap in seaborn?
Background: I compared the performance of 13 models by using each of them for prediction over four data sets. Now I have 4 * 13 R-Squared values which indicate the goodness of fit. The problem is that some large negative R-Squared values exist, making the visualization not so good.

The positive R-Squared values are hard to differentiate because of those negative values like -11 or -9.7. How can I extend the positive range and squeeze the negative range by customizing the color bar? The code and data is as follows.
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
data = np.array([[ 0.9848, 0. , 0.9504, -0.8198, 0.9501, 0.9071,
0.8598, 0.9348, 0. , 0.713 , 0. , 0.669 ,
0.6184, 0. ],
[ 0.9733, 0. , 0.0566, -9.654 , 0.1291, -0.0926,
-0.0661, -2.3085, 0. , -10.63 , 0. , -3.797 ,
-7.592 , 0. ],
[ 0.9676, 0. , 0.9331, 0.9177, 0.9401, 0.9352,
0.9251, 0.7987, 0. , 0.5635, 0. , 0.5924,
0.2456, 0. ],
[ 0.9759, 0. , -0.114 , 0.1566, 0.0412, 0.3588,
0.2605, -0.5471, 0. , 0.2534, 0. , 0.5216,
0.3784, 0. ]])
def comp_heatmap(ax):
with sns.axes_style('white'):
ax = sns.heatmap(
data, ax=ax, vmax=.3,
annot=True,
xticklabels=np.arange(14),
yticklabels=np.arange(4),
)
ax.set_xlabel('Model', fontdict=font_text)
ax.set_ylabel(r'$R^2$', fontproperties=font_formula, labelpad=5)
ax.figure.colorbar(ax.collections[0])
# set tick labels
xticks = ax.get_xticks()
ax.set_xticks(xticks)
ax.set_xticklabels(xticks.astype(int))
yticks = ax.get_yticks()
ax.set_yticks(yticks)
ax.set_yticklabels(['lnr, fit', 'lg, fit', 'lnr, test', 'lg, test'])
comp_heatmap(ax)
Solution 1:[1]
I've used a FuncNorm method to resolve it.

from matplotlib import pyplot as plt, font_manager as fm, colors
def forward(x):
x = base ** x - 1
return x
def inverse(x):
x = np.log(x + 1) / np.log(base)
return x
def comp_heatmap(ax):
plt.rc('font', family='Times New Roman', size=15)
plt.subplots_adjust(left=0.05, right=1)
norm = colors.FuncNorm((forward, inverse), vmin=-11, vmax=1)
mask = np.zeros_like(data)
mask[:, [1, 8, 10, 13]] = 1
mask = mask.astype(np.bool)
with sns.axes_style('white'):
ax = sns.heatmap(
data, ax=ax, vmax=.3,
mask=mask,
annot=True, fmt='.4',
annot_kws=font_annot,
norm=norm,
xticklabels=np.arange(14),
yticklabels=np.arange(4),
cbar=False,
cmap='rainbow'
)
cbar = ax.figure.colorbar(ax.collections[0])
cbar.set_ticks([-11, -0.5, 0, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
# set tick labels
xticks = ax.get_xticks()
ax.set_xticks(xticks)
ax.set_xticklabels(xticks.astype(int), **font_tick)
yticks = ax.get_yticks()
ax.set_yticks(yticks)
ax.set_yticklabels(['', '', '', ''])
return ax
font_formula = fm.FontProperties(
math_fontfamily='cm', size=22
)
font_text = {'size': 22, 'fontfamily': 'Times New Roman'}
font_annot = {'size': 17, 'fontfamily': 'Times New Roman'}
font_tick = {'size': 18, 'fontfamily': 'Times New Roman'}
fig, axes = plt.subplots()
base = 5
ax = comp_heatmap(axes)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Jasmine |
