'Reordering the high-level clusters from seaborn clustermap results

Is there a way to get from a to b in the following figure with scripting? I am using seaborn.clustermap() to get to a (i.e. the order of the rows are preserved. However, columns order change only at second highest level).

I was wondering whether it is possible to use the seaborn.matrix.ClusterGrid that is returned by seaborn.clustermap(), modify it and plot the modified results. enter image description here b P.S. The reason I am asking this is that the order has a meaning (first comes blue, next green, and finally red).

Update: Here is a small data set to generate the situation:

df = pd.DataFrame([[1, 1.1, 0.9, 1.9, 2, 2.1, 2.8, 3, 3.1], 
                   [1.8, 2, 2.1, 0.7, 1, 1.1, 2.7, 3, 3.3]],
              columns = ['d1', 'd2', 'd3', 
                         'l3', 'l2', 'l1', 
                         'b1', 'b2', 'b3'],
              index = ['p1', 'p2'])

cg = sns.clustermap(df); ## returns a ClusterGrid

The output is this:

enter image description here

We can think of columns starting with b as breakfast, l as lunch and d as dinner. Now, the order, is breakfast -> dinner -> lunch. I want to get to breakfast -> lunch -> dinner.



Solution 1:[1]

This is how I solved my problem. It works but it might not be as elegant as one would hope for!

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram

# set the desired order of groups eg: breakfast, lunch, dinner
groups = ['b', 'l', 'd'] 

# reorder indexes/indices besed on the desired order
new_order = []
for group in groups:
    indexes = cg.data2d.columns.str.startswith(group)
    indexes_locs = np.where(indexes)[0].tolist()
    new_order += indexes_locs
    
## reorder df based on the new order
ordered_df = cg.data2d.iloc[:, new_order]

## Run clustermap on the reordered dataframe by disabling 
## the clustering for both rows and columns
ocg = sns.clustermap(ordered_df, 
                     row_cluster=False, 
                     col_cluster=False,
                    );

## draw dendrogram x-axis
axx = ocg.ax_col_dendrogram.axes
axx.clear()

with plt.rc_context({'lines.linewidth': 0.5}):
    
    link = cg.dendrogram_col.linkage ## extract the linkage information

    ## manualy inspect the linkage and determine the new desired order
    link[[4, 2]] = link[[2, 4]]  ## swaping the two groups of higher hierarchy
    
    ## draw the the dendrogram on the x-axis
    dendrogram(link, 
           color_threshold=0, 
           ax=axx,
           truncate_mode='lastp',
           orientation='top',
           link_color_func=lambda x: 'k'
          );

axx.set_yticklabels(['']*len(axx.get_yticklabels()))
axx.tick_params(color='w')
    
## draw dendrogram y-axis (no chage here)
axy = ocg.ax_row_dendrogram.axes
axy.clear()

with plt.rc_context({'lines.linewidth': 0.5}):
    
    ## draw the the dendrogram on the y-axis
    dendrogram(cg.dendrogram_row.linkage, 
           color_threshold=0, 
           ax=axy,
           truncate_mode='lastp',
           orientation='left',
           link_color_func=lambda x: 'k',
          );

axy.set_xticklabels(['']*len(axy.get_yticklabels()))
axy.tick_params(color='w')
# axy.invert_yaxis() # we might need to invert y-axis

The output looks like this: enter image description here

Solution 2:[2]

This solution should be simpler and more elegant. The trick is to understand what is saved in the link. The key is that, if you have 9 samples, and you merge 0 & 1, then you gain a new "sample"/"group" 10. If you then merge 2 and 10, then you gain a new "sample"/"group" 11. You just need to find the group id and sway its position during merging.

from scipy.cluster.hierarchy import linkage, dendrogram

link = linkage(df.T)
link
_ = dendrogram(link)

array([[ 4.        ,  5.        ,  0.14142136,  2.        ],
       [ 0.        ,  1.        ,  0.2236068 ,  2.        ],
       [ 2.        , 10.        ,  0.2236068 ,  3.        ],
       [ 7.        ,  8.        ,  0.31622777,  2.        ],
       [ 3.        ,  9.        ,  0.31622777,  3.        ],
       [ 6.        , 12.        ,  0.36055513,  3.        ],
       [11.        , 13.        ,  1.28062485,  6.        ],
       [14.        , 15.        ,  1.74642492,  9.        ]])

enter image description here

    link[6][[0, 1]] = link[6][[1, 0]]
    link[7][[0, 1]] = link[7][[1, 0]]
    _ = dendrogram(link)

enter image description here

    cg = sns.clustermap(df, col_linkage=link)

enter image description here

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Shidan