Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / machine-learning

Data Visualization – Insights with Matplotlib

5.00/5 (14 votes)
3 Oct 2020CPOL8 min read 16.4K  
A detailed look at how to deduce insights using Matplotlib with real world examples.
While working on a machine learning problem, Matplotlib is the most popular Python library used for visualization that helps in representing & analyzing the data and work through insights.

matplotlib-machine-learning

Generally, it’s difficult to interpret much about data, just by looking at it. But, a presentation of the data in any visual form, helps a great deal to peek into it. It becomes easy to deduce correlations, identify patterns & parameters of importance.

In data science world, data visualization plays an important role around data pre-processing stage. It helps in picking appropriate features and apply appropriate machine learning algorithm. Later, it helps in representing the data in a meaningful way.

Data Insights via Various Plots

If needed, we will use these datasets for plot examples and discussions. Based on the need, following are the common plots that are used:

Line Chart | ax.plot(x,y)

It helps in representing series of data points against a given range of defined parameter. Real benefit is to plot multiple line charts in a single plot to compare and track changes.

Quote:

Points next to each other are related that helps to identify repeated or a defined pattern

Python
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0, 1, 0.05)
y1 = x**2
y2 = x**3

plt.plot(x, y1,
    linewidth=0.5,
    linestyle='--',
    color='b',
    marker='o',
    markersize=10,
    markerfacecolor='red')

plt.plot(x, y2,
    linewidth=0.5,
    linestyle='dotted',
    color='g',
    marker='^',
    markersize=10,
    markerfacecolor='yellow')

plt.title('x Vs f(x)')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend(['f(x)=x^2', 'f(x)=x^3'])
plt.xticks(np.arange(0, 1.1,0.2),
    ['0','0.2','0.4','0.6','0.8','1.0'])

plt.grid(True)
plt.show()

line-chart

Real World Example

We will work with dataset created from collating historical data for few stocks downloaded from here.

Python
import pandas as pd
import matplotlib.pyplot as plt

stocksdf1 = pd.read_csv('data-files/stock-INTU.csv') 
stocksdf2 = pd.read_csv('data-files/stock-AAPL.csv') 
stocksdf3 = pd.read_csv('data-files/stock-ADBE.csv') 

stocksdf = pd.DataFrame()
stocksdf['date'] = pd.to_datetime(stocksdf1['Date'])
stocksdf['INTU'] = stocksdf1['Open']
stocksdf['AAPL'] = stocksdf2['Open']
stocksdf['ADBE'] = stocksdf3['Open']

plt.plot(stocksdf['date'], stocksdf['INTU'])
plt.plot(stocksdf['date'], stocksdf['AAPL'])
plt.plot(stocksdf['date'], stocksdf['ADBE'])

plt.legend(labels=['INTU','AAPL','ADBE'])
plt.grid(True)

plt.show()

line-chart-stocks

With the above, we have couple of quick assessments:
Q: How a particular stock fared over last year?
A: Stocks were roughly rising till Feb 2020 and then took a dip in April and then back up since then.

Q: How the three stocks behaved during the same period?
A: Stock price of ADBE was more sensitive and AAPL being least sensitive to the change during the same period.

Histogram | ax.hist(data, n_bins)

It helps in showing distributions of variables where it plots quantitative data with range of the data grouped into intervals.

Quote:

We can use Log scale if the data range is across several orders of magnitude.

Python
import numpy as np
import matplotlib.pyplot as plt

mean = [0, 0]
cov = [[2,4], [5, 9]]
xn, yn = np.random.multivariate_normal(
                                mean, cov, 100).T

plt.hist(xn,bins=25,label="Distribution on x-axis"); 

plt.xlabel('x')
plt.ylabel('frequency')
plt.grid(True)
plt.legend()

Image 4

Real World Example

We will work with dataset of Indian Census data downloaded from here.

Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

populationdf = pd.read_csv(
    "./data-files/census-population.csv")

mask1 = populationdf['Level']=='STATE'
mask2 = populationdf['TRU']=='Total'
df = populationdf[mask1 & mask2]

plt.hist(df['TOT_P'], label='Distribution')

plt.xlabel('Total Population')
plt.ylabel('State Count')
plt.yticks(np.arange(0,20,2))

plt.grid(True)
plt.legend()

histogram-state-pop

With the above, couple of quick assessments about population in states of India:
Q: What’s the general population distribution of states in India?
A: More than 50% of states have population less than 2 crores (20 million)

Q: How many states are having population more than 10 crores (100 million)?
A: Only 3 states have that high a population.

Bar Chart | ax.bar(x_pos, heights)

It helps in comparing two or more variables by displaying values associated with categorical data.

Quote:

Most commonly used plot in Media sharing data around surveys displaying every data sample.

Python
import numpy as np
import matplotlib.pyplot as plt

data = [[60, 45, 65, 35],
        [35, 25, 55, 40]]

x_pos = np.arange(4)
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.set_xticks(x_pos)

ax.bar(x_pos - 0.1, data[0], color='b', width=0.2)
ax.bar(x_pos + 0.1, data[1], color='g', width=0.2)

ax.yaxis.grid(True)

bar-chart

Real World Example

We will work with dataset of Indian Census data downloaded from here.

Python
import pandas as pd
import matplotlib.pyplot as plt

populationdf = pd.read_csv(
    "./data-files/census-population.csv")

mask1 = populationdf['Level']=='STATE'
mask2 = populationdf['TRU']=='Total'
statesdf = populationdf.loc[mask1].loc[mask2]
statesdf = statesdf.sort_values('TOT_P')

plt.figure(figsize=(10,8))
plt.barh(range(len(statesdf)), 
    statesdf['TOT_P'], tick_label=statesdf['Name'])
plt.grid(True)
plt.title('Total Population')
plt.show()

bar-chart-state-pop

With the above, couple of quick assessments about population in states of India:

  • Uttar Pradesh has the highest total population and Lakshadeep has lowest
  • Relative popluation across states with Uttar Pradesh almost double the second most populated state

Pie Chart | ax.pie(sizes, labels=[labels])

It helps in showing the percentage (or proportional) distribution of categories at a certain point of time. Usually, it works well if it’s limited to single digit categories.

Quote:

A circular statistical graphic where the arc length of each slice is proportional to the quantity it represents.

Python
import numpy as np
import matplotlib.pyplot as plt

# Slices will be ordered n plotted counter-clockwise
labels = ['Audi','BMW','LandRover','Tesla','Ferrari']
sizes = [90, 70, 35, 20, 25]

fig, ax = plt.subplots()
ax.pie(sizes,labels=labels, autopct='%1.1f%%')
ax.set_title('Car Sales')
plt.show()

pie-chart

Real World Example

We will work with dataset of Alcohol Consumption downloaded from here.

Python
import panda as pd
import matplotlib.pyplot as plt

drinksdf = pd.read_csv('data-files/drinks.csv', 
    skiprows=1, 
    names = ['country', 'beer', 'spirit', 
             'wine', 'alcohol', 'continent']) 

labels = ['Beer', 'Spirit', 'Wine']
sizes = [drinksdf['beer'].sum(), 
         drinksdf['spirit'].sum(), 
         drinksdf['wine'].sum()]

fig, ax = plt.subplots()
explode = [0.05,0.05,0.2]
ax.pie(sizes,explode=explode,
    labels=labels, autopct='%1.1f%%')

ax.set_title('Alcohol Consumption')
plt.show()

pie-chart-drinks

With the above, we can have a quick assessment that alcohol consumption is distributed overall. This view helps if we have less number of slices (categories).

Scatter Plot | ax.scatter(x_points, y_points)

It helps representing paired numerical data either to compare how one variable is affected by another or to see how multiple dependent variables value is spread for each value of independent variable.

Quote:

Sometimes the data points in a scatter plot form distinct groups and are called as clusters.

Python
import numpy as np
import matplotlib.pyplot as plt

# random but focused cluster data
x1 = np.random.randn(100) + 8
y1 = np.random.randn(100) + 8
x2 = np.random.randn(100) + 3
y2 = np.random.randn(100) + 3

x = np.append(x1,x2)
y = np.append(y1,y2)

plt.scatter(x,y, label="xy distribution")
plt.legend()

scatter-plot

Real World Example
  1. We will work with dataset of Alcohol Consumption downloaded from here.
    Python
    import pandas as pd
    import matplotlib.pyplot as plt
    
    drinksdf = pd.read_csv('data-files/drinks.csv', 
        skiprows=1, 
        names = ['country', 'beer', 'spirit', 
                 'wine', 'alcohol', 'continent']) 
    
    drinksdf['total'] = drinksdf['beer'] 
    + drinksdf['spirit'] 
    + drinksdf['wine'] 
    + drinksdf['alcohol']
    
    # drinksdf.corr() tells beer and alcochol 
    # are highly corelated
    fig = plt.figure()
    
    # Compare beet and alcohol consumption
    # Use color to show a third variable.
    # Can also use size (s) to show a third variable.
    scat = plt.scatter(drinksdf['beer'], 
                       drinksdf['alcohol'], 
                       c=drinksdf['total'], 
                       cmap=plt.cm.rainbow)
    
    # colorbar to explain the color scheme
    fig.colorbar(scat, label='Total drinks')
    
    plt.xlabel('Beer')
    plt.ylabel('Alcohol')
    plt.title('Comparing beer and alcohol consumption')
    plt.grid(True)
    plt.show()

    scatter-plot-drinks

    With the above, we can have a quick assessment that beer and alcohol consumption have strong positive correlation which would suggest a large overlap of people who drink beer and alcohol.

  2. We will work with dataset of Mall Customers downloaded from here.
    Python
    import pandas as pd
    import matplotlib.pyplot as plt
    
    malldf = pd.read_csv('data-files/mall-customers.csv',
                    skiprows=1, 
                    names = ['customerid', 'genre', 
                             'age', 'annualincome', 
                             'spendingscore'])
    
    plt.scatter(malldf['annualincome'], 
                malldf['spendingscore'], 
                marker='p', s=40, 
                facecolor='r', edgecolor='b', 
                linewidth=2, alpha=0.4)
    
    plt.xlabel("Annual Income")
    plt.ylabel("Spending Score (1-100)")
    plt.grid(True)

    scatter-plot-mall

    With the above, we can have a quick assessment that there are five clusters there and thus five segments or types of customers one can make plan for.

Box Plot | ax.boxplot([data list])

A statistical plot that helps in comparing distributions of variables because the center, spread and range are immediately visible. It only shows the summary statistics like mean, median and interquartile range.

Quote:

Easy to identify if data is symmetrical, how tightly it is grouped, and if and how data is skewed

Python
import numpy as np
import matplotlib.pyplot as plt

# some random data
data1 = np.random.normal(0, 2, 100)
data2 = np.random.normal(0, 4, 100)
data3 = np.random.normal(0, 3, 100)
data4 = np.random.normal(0, 5, 100)
data = list([data1, data2, data3, data4])

fig, ax = plt.subplots()
bx = ax.boxplot(data, patch_artist=True)

ax.set_title('Box Plot Sample')
ax.set_ylabel('Spread')
xticklabels=['category A', 
             'category B', 
             'category B', 
             'category D']

colors = ['pink','lightblue','lightgreen','yellow']
for patch, color in zip(bx['boxes'], colors):
    patch.set_facecolor(color)

ax.set_xticklabels(xticklabels)
ax.yaxis.grid(True)
plt.show()

box-plot

Real World Example

We will work with dataset of Tips downloaded from here.

Python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

tipsdf = pd.read_csv('data-files/tips.csv') 
sns.boxplot(x="time", y="tip", 
            hue='sex', data=tipsdf, 
            order=["Dinner", "Lunch"],
            palette='coolwarm')

box-plot-tips

With the above, we can have a quick couple of assessments:

  • Male gender gives more tip compared to females.
  • Tips during dinner time can vary a lot (more) by males mean tip.

Violen Plot | ax.violinplot([data list])

A statistical plot that helps in comparing distributions of variables because the center, spread and range are immediately visible. It shows the full distribution of data.

Quote:

A quick way to compare distributions across multiple variables

Python
import numpy as np
import matplotlib.pyplot as plt

data = [np.random.normal(0, std, size=100) 
        for std in range(2, 6)]

fig, ax = plt.subplots()
bx = ax.violinplot(data)

ax.set_title('Violin Plot Sample')
ax.set_ylabel('Spread')
xticklabels=['category A', 
             'category B', 
             'category B', 
             'category D']

ax.set_xticks([1,2,3,4])
ax.set_xticklabels(xticklabels)

ax.yaxis.grid(True)
plt.show()

violin-plot

Real World Example
  1. We will work with dataset of Tips downloaded from here.
    Python
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    tipsdf = pd.read_csv('data-files/tips.csv') 
    sns.violinplot(x="day", y="tip", 
                   split="True", data=tipsdf)

    violin-plot-tips

    With the above, we can have a quick assessment that the tips on Saturday has more relaxed distribution whereas Friday has much narrower distribution in comparison.

  2. We will work with dataset of Indian Census data downloaded from here.
    Python
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    populationdf = pd.read_csv(
        "./data-files/census-population.csv")
    
    mask1 = populationdf['Level']=='DISTRICT'
    mask2 = populationdf['TRU']!='Total'
    statesdf = populationdf[mask1 & mask2]
    
    maskUP = statesdf['State']==9
    maskM = statesdf['State']==27
    data = statesdf.loc[maskUP | maskM]
    
    sns.violinplot( x='State', y='P_06', 
    inner='quartile', hue='TRU',  
    palette={'Rural':'green','Urban':'blue'}, 
    scale='count', split=True, 
    data=data, size=6)
    
    plt.title('In districts of UP and Maharashtra')
    plt.show()

    violin-plot-child

    With the above, we can have couple of quick assessments:

    • Uttar Pradesh has high volume and distribution of rural child population.
    • Maharashtra has almost equal spread of rural and urban child population.

Heatmap

It helps in representing a 2-D matrix form of data using variation of color for different values. Variation of color maybe hue or intensity.

Quote:

Generally used to visualize correlation matrix which in turn helps in features (variables) selection.

Python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# create 2D array
array_2d = np.random.rand(4, 6)
sns.heatmap(array_2d, annot=True)

heatmap

Real World Example
  1. We will work with dataset of Alcohol Consumption downloaded from here.
Python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

drinksdf = pd.read_csv('data-files/drinks.csv', 
    skiprows=1, 
    names = ['country', 'beer', 'spirit', 
             'wine', 'alcohol', 'continent']) 

sns.heatmap(drinksdf.corr(),annot=True,cmap='YlGnBu')

heatmap-drinks

With the above, we can have a quick couple of assessments:

  • There is a strong correlation between beer and alcohol and thus a strong overlap there.
  • Wine and spirit are almost not correlated and thus it would be rare to have a place where wine and spirit consumption equally high. One would be preferred over the other.

If we notice, upper and lower halves along the diagonal are same. Correlation of A is to B is same as B is to A. Further, A correlation with A will always be 1. In such case, we can make a small tweak to make it more presentable and avoid any correlation confusion.

Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

drinksdf = pd.read_csv(
    'data-files/drinks.csv', 
    skiprows=1, 
    names = ['country', 'beer', 'spirit', 
             'wine', 'alcohol', 'continent']) 

# correlation and masks
drinks_cr = drinksdf.corr()
drinks_mask = np.triu(drinks_cr)

# remove the last ones on both axes
drinks_cr = drinks_cr.iloc[1:,:-1]
drinks_mask = drinks_mask[1:, :-1]

sns.heatmap(drinks_cr, 
        mask=drinks_mask,
        annot=True,
        cmap='coolwarm')

heatmap-masked

It is the same correlation data but just the needed one is represented.

Data Image

It helps in displaying data as an image, i.e. on a 2D regular raster.

Quote:

Images are internally just arrays. Any 2D numpy array can be displayed as an image.

Python
import pandas as pd
import matplotlib.pyplot as plt

M,N = 25,30
data = np.random.random((M,N)) 
plt.imshow(data)

data-image

Real World Example
  1. Let’s read an image and then try to display it back to see how it looks:
    Python
    import cv2
    import matplotlib.pyplot as plt
    
    img = cv2.imread('data-files/babygroot.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # print(img.shape)
    # output => (500, 359, 3)
    
    plt.imshow(img)

baby-groot

It read the image as an array of matrix and then drew it as plot that turned to be same as the image. Since images are like any other plots, we can plot other objects (like annotations) on top of it.

SubPlots | fig, (ax1,ax2,ax3, ax4) = plt.subplots(2,2)

Generally, it is used in comparing multiple variables (in pairs) against each other. With multiple plots stacked against each other in the same figure, it helps in quick assessment for correlation and distribution for a pair.

Quote:

Parameters are: number of rows, number of columns, the index of the subplot

(Index are counted row wise starting with 1)

The widths of the different subplots may be different with use of GridSpec.

Python
import numpy as np
import matplotlib.pyplot as plt
import math

# data setup
x = np.arange(1, 100, 5)
y1 = x**2
y2 = 2*x+4
y3 = [ math.sqrt(i) for i in x]  
y4 = [ math.log(j) for j in x] 

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

ax1.plot(x, y1)
ax1.set_title('f(x) = quadratic')
ax1.grid()

ax2.plot(x, y2)
ax2.set_title('f(x) = linear')
ax2.grid()

ax3.plot(x, y3)
ax3.set_title('f(x) = sqareroot')
ax3.grid()

ax4.plot(x, y4)
ax4.set_title('f(x) = log')
ax4.grid()

fig.tight_layout()
plt.show()

sub-plot

We can stack up m x n view of the variables and have a quick look on how they are correlated. With the above, we can quickly assess that second graph parameters are linearly correlated.

Data Representation

Plot Anatomy

The below picture will help with plots terminology and representation:

matplotlib-plot-anatomy

Credit: matplotlib.org

The figure above is the base space where the entire plot happens. Most of the parameters can be customized for better representation. For specific details, look here.

Plot Annotations

It helps in highlighting few key findings or indicators on a plot. For advanced annotations, look here.

Python
import numpy as np
import matplotlib.pyplot as plt

# A simple parabolic data
x = np.arange(-4, 4, 0.02)
y = x**2

# Setup plot with data
fig, ax = plt.subplots()
ax.plot(x, y)

# Setup axes
ax.set_xlim(-4,4)
ax.set_ylim(-1,8)

# Visual titles
ax.set_title('Annotation Sample')
ax.set_xlabel('X-values')
ax.set_ylabel('Parabolic values')

# Annotation
# 1. Highlighting specific data on the x,y data
ax.annotate('local minima of \n the parabola',
            xy=(0, 0),
            xycoords='data',
            xytext=(2, 3),
            arrowprops=
                dict(facecolor='red', shrink=0.04),
                horizontalalignment='left',
                verticalalignment='top')

# 2. Highlighting specific data on the x/y axis
bbox_yproperties = dict(
    boxstyle="round,pad=0.4", fc="w", ec="k", lw=2)
ax.annotate('Covers 70% of y-plot range',
            xy=(0, 0.7),
            xycoords='axes fraction',
            xytext=(0.2, 0.7),
            bbox=bbox_yproperties,
            arrowprops=
                dict(facecolor='green', shrink=0.04),
                horizontalalignment='left',
                verticalalignment='center')

bbox_xproperties = dict(
    boxstyle="round,pad=0.4", fc="w", ec="k", lw=2)
ax.annotate('Covers 40% of x-plot range',
            xy=(0.3, 0),
            xycoords='axes fraction',
            xytext=(0.1, 0.4),
            bbox=bbox_xproperties,
            arrowprops=
                dict(facecolor='blue', shrink=0.04),
                horizontalalignment='left',
                verticalalignment='center')

plt.show()

matplotlib-annotation

Plot Style | plt.style.use('style')

It helps in customizing representation of a plot, like color, fonts, line thickness, etc. Default styles get applied if the customization is not defined. Apart from adhoc customization, we can also choose one of the already defined template styles and apply them.

Python
# To know all existing styles with package
for style in plt.style.available:
    print(style)
Quote:

Solarize_Light2, _classic_test_patch, bmh, classic, dark_background, fast, fivethirtyeight, ggplot, grayscale, seaborn, seaborn-bright, seaborn-colorblind, seaborn-dark, seaborn-dark-palette, seaborn-darkgrid, seaborn-deep, seaborn-muted, seaborn-notebook, seaborn-paper, seaborn-pastel, seaborn-poster, seaborn-talk, seaborn-ticks, seaborn-white, seaborn-whitegrid, tableau-colorblind10

- predefined styles available for use

More details around customization are here.

Python
# To use a defined style for plot
plt.style.use('seaborn')

# OR
with plt.style.context('Solarize_Light2'):
    plt.plot(np.sin(np.linspace(0, 2 * np.pi)), 'r-o')
plt.show()

matplotlib-style-ex

Saving Plots | ax.savefig()

It helps in saving figure with plot as an image file of defined parameters. Parameters details are here. It will save the image file to the current directory by default.

Python
plt.savefig('plot.png', dpi=300, bbox_inches='tight')

Additional Usages of plots

Data Imputation

It helps in filling missing data with some reasonable data as many statistical or machine learning packages do not work with data containing null values.

Quote:

Data interpolation can be defined to use pre-defined functions such as linear, quadratic or cubic

Python
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.randn(20,1))
df = df.where(df<0.5)

fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.plot(df)
ax1.set_title('f(x) = data missing')
ax1.grid()

ax2.plot(df.interpolate())
ax2.set_title('f(x) = data interpolated')
ax2.grid()

fig.tight_layout()
plt.show()

data-interpolate

With the above, we see all the missing data replaced with some probably interpolation supported by dataframe based on valid previous and next data.

Animation

At times, it helps in presenting the data as an animation. On a high level, it would need data to be plugged in a loop with delta changes translating into a moving view.

Python
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import animation

fig = plt.figure()

def f(x, y):
    return np.sin(x) + np.cos(y)

x = np.linspace(0, 2 * np.pi, 80)
y = np.linspace(0, 2 * np.pi, 70).reshape(-1, 1)

im = plt.imshow(f(x, y), animated=True)

def updatefig(*args):
    global x, y
    x += np.pi / 5.
    y += np.pi / 10.
    im.set_array(f(x, y))
    return im,

ani = animation.FuncAnimation(
    fig, updatefig, interval=100, blit=True)
plt.show()

animation

3-D Plotting

If needed, we can also have an interactive 3-D plot though it might be slow with large datasets.

Python
import numpy as np
import matplotlib.pyplot as plt

def randrange(n, vmin, vmax):
     return (vmax-vmin)*np.random.rand(n) + vmin

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
n = 200
for c, m, zl in [('g', 'o', +1), ('r', '^', -1)]:
    xs = randrange(n, 0, 50)
    ys = randrange(n, 0, 100)
    zs = xs+zl*ys  
    ax.scatter(xs, ys, zs, c=c, marker=m)

ax.set_xlabel('X data')
ax.set_ylabel('Y data')
ax.set_zlabel('Z data')
plt.show()

3d-plot

Cheat Sheet

A page representation of the key features for quick lookup or revision:

matplotlib-cheatsheet

Credit: DataCamp

Download the PDF version of cheatsheet from here.
For overall reference & more details, check out at https://matplotlib.org/.

Entire Jupyter notebook with more samples can be downloaded or forked from my GitHub to look or play around at https://github.com/sandeep-mewara/data-visualization.

Keep learning!.

Download

 

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)