'How to customize the colorbar of a heatmap in seaborn?

Background: I compared the performance of 13 models by using each of them for prediction over four data sets. Now I have 4 * 13 R-Squared values which indicate the goodness of fit. The problem is that some large negative R-Squared values exist, making the visualization not so good. the heatmap with a color bar of linear scale

The positive R-Squared values are hard to differentiate because of those negative values like -11 or -9.7. How can I extend the positive range and squeeze the negative range by customizing the color bar? The code and data is as follows.

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
data = np.array([[  0.9848,   0.    ,   0.9504,  -0.8198,   0.9501,   0.9071,
          0.8598,   0.9348,   0.    ,   0.713 ,   0.    ,   0.669 ,
          0.6184,   0.    ],
       [  0.9733,   0.    ,   0.0566,  -9.654 ,   0.1291,  -0.0926,
         -0.0661,  -2.3085,   0.    , -10.63  ,   0.    ,  -3.797 ,
         -7.592 ,   0.    ],
       [  0.9676,   0.    ,   0.9331,   0.9177,   0.9401,   0.9352,
          0.9251,   0.7987,   0.    ,   0.5635,   0.    ,   0.5924,
          0.2456,   0.    ],
       [  0.9759,   0.    ,  -0.114 ,   0.1566,   0.0412,   0.3588,
          0.2605,  -0.5471,   0.    ,   0.2534,   0.    ,   0.5216,
          0.3784,   0.    ]])
def comp_heatmap(ax):
    with sns.axes_style('white'):
        ax = sns.heatmap(
            data, ax=ax, vmax=.3,
            annot=True,
            xticklabels=np.arange(14),
            yticklabels=np.arange(4),
        )
    ax.set_xlabel('Model', fontdict=font_text)
    ax.set_ylabel(r'$R^2$', fontproperties=font_formula, labelpad=5)
    ax.figure.colorbar(ax.collections[0])
    # set tick labels
    xticks = ax.get_xticks()
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks.astype(int))
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels(['lnr, fit', 'lg, fit', 'lnr, test', 'lg, test'])

comp_heatmap(ax)


Solution 1:[1]

I've used a FuncNorm method to resolve it. result

from matplotlib import pyplot as plt, font_manager as fm, colors


def forward(x):
    x = base ** x - 1
    return x


def inverse(x):
    x = np.log(x + 1) / np.log(base)
    return x


def comp_heatmap(ax):
    plt.rc('font', family='Times New Roman', size=15)
    plt.subplots_adjust(left=0.05, right=1)
    norm = colors.FuncNorm((forward, inverse), vmin=-11, vmax=1)
    mask = np.zeros_like(data)
    mask[:, [1, 8, 10, 13]] = 1
    mask = mask.astype(np.bool)
    with sns.axes_style('white'):
        ax = sns.heatmap(
            data, ax=ax, vmax=.3,
            mask=mask,
            annot=True, fmt='.4',
            annot_kws=font_annot,
            norm=norm,
            xticklabels=np.arange(14),
            yticklabels=np.arange(4),
            cbar=False,
            cmap='rainbow'
        )    
    cbar = ax.figure.colorbar(ax.collections[0])
    cbar.set_ticks([-11, -0.5, 0, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    # set tick labels
    xticks = ax.get_xticks()
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks.astype(int), **font_tick)
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels(['', '', '', ''])
    return ax


font_formula = fm.FontProperties(
    math_fontfamily='cm', size=22
)
font_text = {'size': 22, 'fontfamily': 'Times New Roman'}
font_annot = {'size': 17, 'fontfamily': 'Times New Roman'}
font_tick = {'size': 18, 'fontfamily': 'Times New Roman'}
fig, axes = plt.subplots()
base = 5
ax = comp_heatmap(axes)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Jasmine