3. Customizing plots with Matplotlib#

Although seaborn generates good-looking figures out of the box, most of the time we need to introduce some adjustments. You saw already in the previous chapters that we could relabel axes and adjust their range or ticks by calling matplotlib functions directly to modify figures that we generated with seaborn. Here, we will review some basic concepts of matplotlib figures and learn how to adjust some of their elements to create custom figures.

The two most important concepts to be aware of when using matplotlib are the figure and axes objects:

  • axes (also referred to as subplots): the area where we plot the data; has an x- and y-axis, which contain ticks, tick locations, labels, and other elements.

  • figure: the overall window/page where everything is drawn; can contain multiple axes (subplots) organized in a form of a grid.

Before you begin: Be sure to check out the official matplotlib cheat sheet (available here) - it outlines all of the most useful functionalities discussed below and more!

3.1. Environment setup#

# for Google Colab
import os
if 'COLAB_JUPYTER_IP' in os.environ:
    !git clone https://github.com/bokulich-lab/DataVisualizationBook.git book

    from book.utils import utils
    utils.ensure_packages('book/requirements.txt')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

%config InlineBackend.figure_format='svg'

3.2. Creating figures and subplots#

Let’s see now how to create a figure and its axes. Note: the following examples don’t have any data shown on the (sub)plots - this is to only show you how to create and refer to figures with one or more axes.

Documentation you might want to check for further information:

3.2.1. One Subplot#

fig, ax = plt.subplots(nrows=1, ncols=1);
../_images/03_matplotlib_manipulations_8_0.svg

3.2.2. Many Subplots - Horizontal Layout#

We can create multiple subplots by using the function subplot of the pyplot module in matplotlib. Here we specify the size of the figure by passing a tuple in the form (width, height) to the figsize parameter. In addition, the nrows and ncols specify how many rows and how many columns the figure should have.

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4));
../_images/03_matplotlib_manipulations_11_0.svg

3.2.3. Multiple Subplots - Grid Layout#

Here we can see an example where, instead of specifying the width and height of the figure, we specify the aspect ratio by setting figsize = plt.figaspect(aspect_ratio).

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=plt.figaspect(0.4));
../_images/03_matplotlib_manipulations_14_0.svg

3.2.4. Gridspec#

Sometimes we may need a grid of plots where some subplots should span multiple rows or columns. The code below shows how to do this using a gridspec. Below you will find two examples: a vertical and a horizontal alignment. After creating the GridSpec object (line 5), we remove the subplots corresponding to the location of the subplot that will span some rows and some columns. The removal is done by specifying the correct rows and columns in the for loop (line 7). Then, we need to add the subplot in the place of the removed subplots. We do this using the function fig.add_subplot(gs[rows, columns]), where gs is the Gridspace object created at the beginning. Note that rows and columns passed to the gs should correspond to the locations of the axes that we removed in the for loop. Finally, we use the annotate method to add text to the subplot that spans some rows/columns.

# vertical alignment
fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(10, 4))

# GridSpec object starting at row 1 column 2
gs = axes[1, 2].get_gridspec()

for ax in axes[1:, -1]:
    ax.remove()
big_ax = fig.add_subplot(gs[1:, -1])
big_ax.annotate(
    'Big Axes \nGridSpec[1:, -1]', (0.1, 0.5),
    xycoords='axes fraction', va='center'
)

fig.tight_layout();
../_images/03_matplotlib_manipulations_17_0.svg
# horizontal alignment
fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(10,4))

# GridSpec object starting at row 1 column 2
gs = axes[1, 2].get_gridspec()

for ax in axes[1, 0:-1]:
    ax.remove()
ax_big = fig.add_subplot(gs[1, :-1])
ax_big.annotate(
    'Big Axes \nGridSpec[1:, :-1]', (0.1, 0.5),
    xycoords='axes fraction', va='center'
)

fig.tight_layout();
../_images/03_matplotlib_manipulations_18_0.svg

Note

  • axes: the area where we plot the data; has an x- and y-axis, which contain ticks, tick locations, labels, and other elements.

  • figure: the overall window/page where everything is drawn; can contain multiple axes (subplots) organized in a form of a grid

3.3. Adding data to specific subplots#

When you use seaborn, it will automatically create axes and figures for you so you don’t need to do any of the above. However, when you want to have better control over your plots you may want to first create your own figure with the desired properties and then place specific plots in its axes. To demonstrate, let’s generate some data that we will use in the plots later:

x = np.arange(0.1, 4, 0.1)
df = pd.DataFrame({
    'x': x,
    'y1': np.exp(-1.0 * x),
    'y2': np.exp(-0.5 * x)
})

df.head()
x y1 y2
0 0.1 0.904837 0.951229
1 0.2 0.818731 0.904837
2 0.3 0.740818 0.860708
3 0.4 0.670320 0.818731
4 0.5 0.606531 0.778801

3.3.1. A single plot#

As you can see from the code below, we can create a figure and the corresponding axes and then we connect the seaborn plot with the figure and axes created via the ax parameter inside the plot method of the seaborn library.

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(4,4))
sns.scatterplot(data=df, x='x', y='y1', ax=axes);
../_images/03_matplotlib_manipulations_25_0.svg

3.3.2. Multiple plots#

When plotting more than one subplot, axes becomes a numpy array with the shape (nrows, ncols). In this case, we need to select which subplot you want to use and pass this subplot to the plotting function.

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

sns.scatterplot(data=df, x='x', y='y1', ax=axes[0])
sns.lineplot(data=df, x='x', y='y2', ax=axes[1])

fig.tight_layout();
../_images/03_matplotlib_manipulations_28_0.svg

Below is an example when axes is a 2D array.

with sns.axes_style("darkgrid"):
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8, 8))
    sns.scatterplot(data=df, x='x', y='y1', ax=axes[0, 0])
    sns.lineplot(data=df, x='x', y='y2', ax=axes[0, 1])
    sns.lineplot(x=df['x'], y=df['y1']**2, ax=axes[1, 0])
    sns.scatterplot(x=df['x'], y=df['y2']**2, ax=axes[1, 1])

fig.tight_layout();
../_images/03_matplotlib_manipulations_30_0.svg

3.4. Modifying elements of a plot#

As you could see so far, we often need to adjust certain elements of a plot. Most of the attributes that we are usually interested in can be modified directly on the Axes object (the one that you passed to or received from a seaborn plotting function). Adjustments that concern, e.g., how subplots relate to one another (like spacing between them) or other ones that have to do with the figure itself (like a figure title) can be set on the Figure object directly.

3.4.1. Axes’ labels and title#

fig, ax = plt.subplots()
sns.scatterplot(data=df, x='x', y='y1', ax=ax)
ax.set_xlabel('This is the new x-label', fontsize=14)
ax.set_ylabel('Y')
ax.set_title('Plot title goes here', fontsize=20);
../_images/03_matplotlib_manipulations_35_0.svg

3.4.2. Axes’ ranges, ticks and tick labels#

fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))

# add space between subplots
plt.tight_layout(pad=4.0)

for i, ax in enumerate(axes):
    sns.scatterplot(data=df, x='x', y='y1', ax=axes[i])

    axes[i].set_xlim((0.5, 2.0))
    axes[i].set_ylim((0.2, 0.6))

    xticks = np.arange(0.5, 2.1, 0.5)
    axes[i].set_xticks(xticks)
    axes[i].tick_params(axis='x', which='major', bottom=True)
    axes[i].tick_params(axis='y', which='major', left=True)
    xtick_labels = [f'x={x}' for x in xticks]

    if i == 0:
        axes[i].set_xticklabels(xtick_labels, rotation='horizontal', fontsize=12)
        axes[i].set_title("Horizontal Labels")
    else:
        axes[i].set_xticklabels(xtick_labels, rotation=45, fontsize=12)
        axes[i].set_title("Rotated Labels")
../_images/03_matplotlib_manipulations_38_0.svg

3.4.3. Axis’ scale#

fig, ax = plt.subplots()
sns.scatterplot(data=df, x='x', y='y1', ax=ax)

# set x-axis to logarithmic scale
ax.set_xscale('log');
../_images/03_matplotlib_manipulations_41_0.svg

3.4.4. Zooming in/out in a plot#

# define the function that will be plotted
def f(x):
    return np.sin(2*np.pi*x) + np.cos(3*np.pi*x)

x = np.arange(0.0, 5.0, 0.01)
axis1 = plt.subplot(212)
axis1.margins() # Default margin is 0.05
axis1.plot(x, f(x))
axis1.set_title('Normal')

axis2 = plt.subplot(221)
# zoom out on both coordinates by a factor of 0.25
axis2.margins(0.25, 0.25)
axis2.plot(x, f(x))
axis2.set_title('Zoomed out')

axis3 = plt.subplot(222)
# zoom in on both coordinates by a factor of 0.25
axis3.margins(-0.25, -0.25)
axis3.plot(x, f(x))
axis3.set_title('Zoomed in')

plt.tight_layout(pad=2.0);
../_images/03_matplotlib_manipulations_44_0.svg

3.4.5. Color Maps#

There are many ways to specify colors when plotting. You can either use color palettes from the seaborn package or color maps from the matplotlib library. Below you will find some examples of how to apply both. Please consult respective documentation for more details on both approaches (see links below).

# we can use seaborn's built in color palettes 
# https://seaborn.pydata.org/tutorial/color_palettes.html

with sns.axes_style("darkgrid"):
    fig, ax = plt.subplots()
    
    colors = sns.color_palette("rocket")
    for i in range(1, 5):
        sns.scatterplot(
            x=df['x'], y=df['y1']**i, ax=ax, 
            label=f'i={i}', color=colors[-i]
        )
../_images/03_matplotlib_manipulations_47_0.svg
# alternatively, we can use matplotlib's color maps 
# https://matplotlib.org/stable/tutorials/colors/colormaps.html

from matplotlib import cm

with sns.axes_style("darkgrid"):
    fig, ax = plt.subplots()
    
    colors = cm.get_cmap('plasma', 8).colors
    for i in range(1, 5):
        sns.scatterplot(
            x=df['x'], y=df['y1']**i, ax=ax, 
            label=f'i={i}', color=colors[i]
        )
/tmp/ipykernel_2864/382733206.py:9: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
  colors = cm.get_cmap('plasma', 8).colors
../_images/03_matplotlib_manipulations_48_1.svg

3.4.6. Legend position and title#

with sns.axes_style("whitegrid"):
    fig, ax = plt.subplots()
    
    for i in range(1, 5):
        sns.scatterplot(x=df['x'], y=df['y1']**i, ax=ax, label=f'i={i}')
        
    ax.legend(
        loc='center left', bbox_to_anchor=(1, 0.5), 
        fontsize=14, title_fontsize=14
    )
    legend = ax.get_legend()
    legend.set_title('Magical parameter')
../_images/03_matplotlib_manipulations_51_0.svg

3.4.7. Figure Title#

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

sns.scatterplot(data=df, x='x', y='y1', ax=axes[0])
sns.lineplot(data=df, x='x', y='y2', ax=axes[1])

# we can add title to individual subplots
for i, ax in enumerate(axes):
    axes[i].set_title(f'Measurement {i+1}', fontsize=12)

# but also to the entire figure
fig.suptitle('Important measurements', fontsize=16);
../_images/03_matplotlib_manipulations_54_0.svg

3.4.8. Layout of Subplots#

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

sns.scatterplot(data=df, x='x', y='y1', ax=axes[0])
sns.lineplot(data=df, x='x', y='y2', ax=axes[1])

# we can adjust spacing between individual subplots
fig.tight_layout(w_pad=10);
../_images/03_matplotlib_manipulations_57_0.svg

3.4.9. Axis label position and colorbars#

Here we create a scatterplot out of a dataset with two points (1,1), (2,2). The argument c encodes the color of each of the points.

fig, ax = plt.subplots()

sc = ax.scatter([1, 2, 3, 4], [1, 2, 3, 4], c=[2, 2.2, 2.7, 3])
ax.set_ylabel('y')
ax.set_xlabel('x')
cbar = fig.colorbar(sc)
cbar.set_label("z");
../_images/03_matplotlib_manipulations_60_0.svg

3.5. Little Plotting Exercise#

Time for a small exercise! It will allow you to practice some of the concepts we introduced in the chapters. We will use one of seaborn’s built-in datasets to create a simple visualization and customize it to our liking. The dataset contains diamond prices for approx. 54k diamonds with different properties.

For the sake of this exercise we are only interested in three columns of this dataset: carat, price and cut. Your task is to:

  • create a square scatter plot depicting the dependence of the price on the carat value

  • adjust the axes labels to be capitalized and with an appropriate font size

  • give the plot an appropriate title

  • if not present, add major ticks to both axes

  • change the transparency of the points to 0.2 (look out for an alpha parameter)

  • change the color of the points to your favourite one

Bonus: Create one plot where diamonds of ‘cut’ == ‘ideal’ are plotted with a different colour then all the other ones. Hint: you can create two dataframes with the two sub-datasets and plot both on the same ax (you can just call the plotting function twice, passing the same ax to both).

diamonds = sns.load_dataset("diamonds")
diamonds.head()
carat cut color clarity depth table price x y z
0 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
1 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
2 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
3 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
4 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75

To reveal the code solution and see what the plot could look like, unfold the cells below:

ideal = diamonds[diamonds['cut'] == 'Ideal']
other = diamonds[diamonds['cut'] != 'Ideal']

with sns.axes_style("white"):
    fig, ax = plt.subplots(figsize=(10, 10))

    sns.scatterplot(data=other, x='carat', y='price', ax=ax, color='darkgrey', alpha=0.2)
    sns.scatterplot(data=ideal, x='carat', y='price', ax=ax, color='royalblue', alpha=0.2)

    ax.set_xlabel('Carat', fontsize=16)
    ax.set_ylabel('Price [USD]', fontsize=16)

    ax.set_title('Diamond prices', fontsize=18)

    ax.tick_params(axis='both', which='major', bottom=True, left=True, labelsize=12)
../_images/03_matplotlib_manipulations_66_0.svg