Show Sidebar Hide Sidebar

2d Density Plots in Pandas

How to make a 2d density plot in pandas. Examples of density plots with kernel density estimations, custom color-scales, and smoothing.

# Learn about API authentication here: https://plot.ly/pandas/getting-started
# Find your api_key here: https://plot.ly/settings/api

import plotly.plotly as py
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import colorlover as cl
from scipy.stats import gaussian_kde

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris.csv')
df.head()

scl = cl.scales['9']['seq']['Blues']
colorscale = [ [ float(i)/float(len(scl)-1), scl[i] ] for i in range(len(scl)) ]
colorscale

def kde_scipy(x, x_grid, bandwidth=0.2 ):
    kde = gaussian_kde(x, bw_method=bandwidth / x.std(ddof=1) )
    return kde.evaluate(x_grid)

x_grid = np.linspace(df['SepalWidth'].min(), df['SepalWidth'].max(), 100)
y_grid = np.linspace(df['PetalLength'].min(), df['PetalLength'].max(), 100)

trace1 = go.Histogram2dContour(
    x=df['SepalWidth'],
    y=df['PetalLength'],
    name='density',
    ncontours=20,
    colorscale=colorscale,
    showscale=False
)
trace2 = go.Histogram(
    x=df['SepalWidth'],
    name='x density',
    yaxis='y2',
    histnorm='probability density',
    marker=dict(color='rgb(217, 217, 217)'),
    nbinsx=25
)
trace2s = go.Scatter(
    x=x_grid,
    y=kde_scipy( df['SepalWidth'].as_matrix(), x_grid ),
    yaxis='y2',
    line = dict( color='rgb(31, 119, 180)' ),
    fill='tonexty',
)
trace3 = go.Histogram(
    y=df['PetalLength'],
    name='y density',
    xaxis='x2',
    histnorm='probability density',
    marker=dict(color='rgb(217, 217, 217)'),
    nbinsy=50
)
trace3s = go.Scatter(
    y=y_grid,
    x=kde_scipy( df['PetalLength'].as_matrix(), y_grid ),
    xaxis='x2',
    line = dict( color='rgb(31, 119, 180)' ),
    fill='tonextx',
)

data = [trace1, trace2, trace2s, trace3, trace3s]

layout = go.Layout(
    showlegend=False,
    autosize=False,
    width=700,
    height=700,
    hovermode='closest',
    bargap=0,

    xaxis=dict(domain=[0, 0.746], linewidth=2, linecolor='#444', title='SepalWidth',
               showgrid=False, zeroline=False, ticks='', showline=True, mirror=True),

    yaxis=dict(domain=[0, 0.746],linewidth=2,linecolor='#444', title='PetalLength',
               showgrid=False, zeroline=False, ticks='', showline=True, mirror=True),

    xaxis2=dict(domain=[0.75, 1], showgrid=False, zeroline=False, ticks='',
                showticklabels=False ),

    yaxis2=dict(domain=[0.75, 1], showgrid=False, zeroline=False, ticks='',
                showticklabels=False ),
)

fig = go.Figure(data=data, layout=layout)

# IPython notebook
# py.iplot(fig, filename='pandas-2d-density-plot', height=700)

url = py.plot(fig, filename='pandas-2d-density-plot')
Still need help?
Contact Us

For guaranteed 24 hour response turnarounds, upgrade to a Developer Support Plan.