Making Figures in Python

Making Figures in Python

by Jerry Chen

A good data visualization can turn data into a compelling story, which interpret the numbers into understandable figures.

Matplotlib is a plotting library that can help researchers to visualize their data in many different ways including line plots, histograms, bar charts, pie charts, scatter plots, stream plots, simple 3-D plots, etc. Moreover, it supports TeX expressions for mathematical expressions and work with the python scripts, jupyter notebook and web application servers. The interface is similar to MATLAB, so there is a minimum transition. The Matplotlib is well-documented and provides the tutorials and samples on the official website.

Why matplotlib?

  1. It is free, open source. A default plotting library comes with Python.
  2. While coding take more time, but you can make sure you can generate the same plots with the same settings every time.
  3. If you process your data in python, why don’t you just choose to stay in the same environment?

Installing

As described in our tutorial on installing Python packages, we can install Matplotlib with install by entering the following lines into your console (Details):

conda install matplotlib

To test if you install correctly, you can try to import the library to see if you get any error messages. If not, then you are ready to go!

import matplotlib.pyplot as plt

Workflow

  1. Prepare your data: usually save in the a list or Numpy array.
  2. Create the plot/sub-plot.
  3. Customize your plot: adjust the linewidth, color, marker, axes, etc.
  4. Show/Save plot

Simple Example

This example is taken form the article Ten simple rules for better figures while they illustrate the rule of Rule 5: Do Not Trust the Defaults. The example is simple enough and provides examples of common tweaks for making professional figures for your research publications.

Default plot

Step 0. Import libraries

import numpy as np
import matplotlib.pyplot as plt

Step 1. Prepare your data

# Create a mesh with 256 points from -Pi to Pi
x = np.linspace(-np.pi, np.pi, 256)

# Find the y values accordingly
y1 = np.sin(x)
y2 = np.cos(x)

Step 2. Create the plot

plt.plot(x,y1)
plt.plot(x,y2)

Step 3. Customize your plot – since it is default, you don’t have to do anything now.

Step 4. Show/Save plot

plt.show()

Professional Plot

Step 0. Import libraries

import numpy as np
import matplotlib.pyplot as plt

Step 1. Prepare your data

# Create a mesh with 256 points from -Pi to Pi
x = np.linspace(-np.pi, np.pi, 256)

# Find the y values accordingly - you can also stack the y1, y2 in python
y1, y2 = np.sin(x), np.cos(x)

Step 3. Customize your plot

# Setting the figure size and resolution
fig = plt.figure(figsize=(10, 6), dpi=300)

# Setting the color, linewidth, and linestyle
plt.plot(x, y1, color="red",  linewidth=2.5, linestyle="-")
plt.plot(x, y2, color="blue", linewidth=2.5, linestyle="-")

# Setting the boundaries of the figure
plt.xlim(x.min()*1.1, x.max()*1.1)
plt.ylim(y2.min()*1.1, y2.max()*1.1)

# Changing spine style
ax = plt.gca()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data', 0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data', 0))

# Use Latex to set tick labels
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
           [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])
plt.yticks([-1, 0, +1], [r'$-1$', r'$0$', r'$+1$'])
plt.xticks(fontsize=14, rotation=0)
plt.yticks(fontsize=14, rotation=0)

# Adding legends
plt.plot(x, y1, color="red",  linewidth=2.5, linestyle="-", label="sin")
plt.plot(x, y2, color="blue", linewidth=2.5, linestyle="-", label="cos")
plt.legend(loc='upper left', prop={'size': 16}, frameon=False)

# Annotate the points
t = 2*np.pi/3
plt.plot([t, t], [0, np.cos(t)], color='blue',
         linewidth=1.5, linestyle="--")  # line
plt.scatter([t], [np.cos(t)], 50, color='blue')  # point

# Annotate the equations
plt.annotate(r'$\sin(\frac{2\pi}{3})=\frac{\sqrt{3}}{2}$',
             xy=(t, np.sin(t)), xycoords='data',
             xytext=(+10, +30), textcoords='offset points', fontsize=16,
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

# Annotate the points
plt.plot([t, t], [0, np.sin(t)], color='red',
         linewidth=1.5, linestyle="--")  # line
plt.scatter([t], [np.sin(t)], 50, color='red')  # point

# Annotate the equations
plt.annotate(r'$\cos(\frac{2\pi}{3})=-\frac{1}{2}$',
             xy=(t, np.cos(t)), xycoords='data',
             xytext=(-90, -50), textcoords='offset points', fontsize=16,
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

plt.show() # show figure

Step 4. Save plot

fig.savefig("sincos.png", dpi = 300) # save figure
Take away

Now you can easily tell that generating the plot using matplotlib is not hard while using the default setting. It makes you quickly get a the sense of the change of the dependent variables, i.e. fast prototyping. However, if you want to utilize matplotlib to generate professional figures for your publications, you might need to put some efforts to tweak each single element of the figure. Once you are done, you can process the similar type of data by using the same code, and it will ensure your figures always have the same style.

This figure is taken from here. You can tweak every element you see on the figure.


Frequently-used plots for researchers

Here we want to use simple dataset to demonstrate some frequent-used plots. We make sure the dataset is simple enough for you to understand and modify.

1. Showing 2D data

2. Showing Statistical results

3. Showing 3D data

Scatter plot
# Step 0
import numpy as np
import matplotlib.pyplot as plt

# Step 1
# Note: data can be a list of numpy array
x = [1, 2, 3, 4, 5]
y1 = [2, 3, 4, 5, 7]
y2 = [3, 4, 5, 5.5, 7.5]

# Step 2 & 3
plt.scatter(x, y1, color='blue')
plt.scatter(x, y2, color='red')

# Step 4
plt.show()

Filled line plot

# Step 0
import numpy as np
import matplotlib.pyplot as plt

# Step 1
x = np.linspace(-np.pi, np.pi, 256)
y1, y2 = np.sin(x), np.cos(x)

# Step 2
plt.plot(x, y1)
plt.plot(x, y2)

# Step 3
plt.fill_between(x, y1, y2, facecolor='yellow')

# Step 4
plt.show()

Bar chart with error bar
# Step 0
import numpy as np
import matplotlib.pyplot as plt

# Step 1
n_groups = 2

mean_y = [4, 6]
std_y = [1, 2]

mean_n = [3, 4]
std_n = [2, 1]

# Step 2
# Return evenly spaced values within a given interval eg. np.arange(3) >>> array([0, 1, 2])
index = np.arange(n_groups)
bar_width = 0.3

plt.bar(index, mean_y, bar_width,
        color='b',
        yerr=std_y,  # err bar
        label='Yes')

plt.bar(index + bar_width, mean_n, bar_width,
        color='r',
        yerr=std_n,  # err bar
        label='No')

# Step 3
plt.xlabel('Question')  # add x-label
plt.ylabel('Votes')  # add y-label
plt.title('Votes on Questions')  # add title
# index + bar_width / 2 define ticks location
plt.xticks(index + bar_width / 2, ('A', 'B'))
plt.legend()  # add legend

# Step 4
plt.show()

Histogram

# Step 0
import numpy as np
import matplotlib.pyplot as plt

def normpdf(x, mu, sigma):
    '''Finding norm PDF'''
    u = (x-mu)/abs(sigma)
    return (1/(np.sqrt(2*np.pi)*abs(sigma)))*np.exp(-u*u/2)

# Step 1
x = [1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5,
     5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 8, 9]
mean = np.mean(x)
std = np.std(x)

# Step 2
n, bins, patches = plt.hist(x, 9, normed=1, facecolor='blue', edgecolor='cyan')

y = normpdf(bins, mean, std)
plt.plot(bins, y, 'r--', linewidth=1) # add a 'best fit' line

# Step 3
plt.xlabel('Ball Number')  # add x-label
plt.ylabel('Probability')  # add y-label
plt.title('Lottery')  # add title
plt.axis([0, 10, 0, 0.4])  # difine the axes display range
plt.grid(True)  # show the grid

# Step 4
plt.show()

Contour plot with color bar

# Step 0
import numpy as np
import matplotlib.pyplot as plt

def f(x, y):
    '''use "function" to get f(x,y) and return the value'''
    return np.sin(x) + np.sin(y)

# Step 1
n = 256
x = np.linspace(-np.pi, np.pi, n)
y = np.linspace(-np.pi, np.pi, n)
X, Y = np.meshgrid(x, y) # generate the grid
Z = f(X, Y)

# Step 2
Contour = plt.contour(X, Y, f(X, Y))

# Step 3
plt.clabel(Contour, fontsize=10)
plt.colorbar(Contour)

# Step 4
plt.show()

Color grid (Imgshow) with color bar

# Step 0
import numpy as np
import matplotlib.pyplot as plt

def f(x, y):
    '''use "function" to get f(x,y) and return the value'''
    return np.sin(x) + np.sin(y)

# Step 1
n = 256
x = np.linspace(-np.pi, np.pi, n)
y = np.linspace(-np.pi, np.pi, n)
X, Y = np.meshgrid(x, y) # generate the grid
Z = f(X, Y)

# Step 2
plt.imshow(Z, origin='lower')  # origin at the bottom-left

# Step 3
plt.colorbar()

# Step 4
plt.show()

Simple 3D plot

# Step 0
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # import 3D axes
from matplotlib import cm  # import colormap

def f(x, y):
    '''use "function" to get f(x,y) and return the value'''
    return np.sin(x) + np.sin(y)


# Step 1
n = 16
x = np.linspace(-np.pi, np.pi, n)
y = np.linspace(-np.pi, np.pi, n)
X, Y = np.meshgrid(x, y) # generate the grid
Z = f(X, Y)

# Step 2 & 3
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.viridis)

# Step 4
plt.show()


Other Data Visualization Packages in Python

  1. Seaborn: is built on top of matplotlib library for statistical plotting.
  2. Bokeh: is a Python interactive visualization library for web browsers.
  3. Geoplotlib: is a Python toolbox for visualizing geographical data.