Skip to content

Instantly share code, notes, and snippets.

@tomron
Last active August 6, 2024 13:33
Show Gist options
  • Save tomron/8798256fcee5438edd58c17654adf443 to your computer and use it in GitHub Desktop.
Save tomron/8798256fcee5438edd58c17654adf443 to your computer and use it in GitHub Desktop.
A nicer seasonal decompose chart using plotly.
from statsmodels.tsa.seasonal import seasonal_decompose
import plotly.tools as tls
def plotSeasonalDecompose(
x,
model='additive',
filt=None,
period=None,
two_sided=True,
extrapolate_trend=0,
title="Seasonal Decomposition"):
"""
Plot time series decomposition
:param x: Time series.
See documentation of the remaining models here -
https://www.statsmodels.org/stable/generated/statsmodels.tsa.seasonal.seasonal_decompose.html
Example -
import pandas as pd
from datetime import datetime
import PlotTimeSeries
s = pd.DataFrame(list(range(1, 11))*10,
index=pd.date_range(start=datetime(2010, 1, 1), periods=100))
fig = PlotTimeSeries.plotSeasonalDecompose(s)
fig.show()
"""
result = seasonal_decompose(
x, model=model, filt=filt, period=period,
two_sided=two_sided, extrapolate_trend=extrapolate_trend)
fig = make_subplots(
rows=4, cols=1,
subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"])
fig.add_trace(
go.Scatter(x=result.seasonal.index, y=result.observed, mode='lines'),
row=1, col=1,
)
fig.add_trace(
go.Scatter(x=result.trend.index, y=result.trend, mode='lines'),
row=2, col=1,
)
fig.add_trace(
go.Scatter(x=result.seasonal.index, y=result.seasonal, mode='lines'),
row=3, col=1,
)
fig.add_trace(
go.Scatter(x=result.resid.index, y=result.resid, mode='lines'),
row=4, col=1,
)
return fig
@joslinmartinez
Copy link

Nice !!

@till90
Copy link

till90 commented Oct 11, 2022

@chrimaho
i have modified your code that it takes a dataframe to plot multiple columns decompostion side by side



def plot_seasonal_decompose(title:str="Seasonal Decomposition", df:pd.DataFrame=None):
    
    
    fig = make_subplots(rows=4, cols=len(df.columns), subplot_titles=df.columns)
    
    for n, (column_name, Series) in enumerate(df.iteritems(), start=1):
        decomposition = seasonal_decompose(Series, model='additive', period=12)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.observed, mode="lines", name='Observed'),
            row=1,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.trend, mode="lines", name='Observed'),
            row=2,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.seasonal, mode="lines", name='Seasonal'),
            row=3,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.resid, mode="lines", name='Residual'),
            row=4,
            col=n,
        )
    [fig.update_yaxes(title_text = x, row = n, col=1) for n,x in enumerate(["Observed", "Trend", "Seasonal", "Residuals"], start=1)]
    fig.update_layout(
        height=900, title=f'<b>{title}</b>', margin={'t':100}, title_x=0.5, showlegend=False
    )
    
    return fig
fig = plot_seasonal_decompose(df=df)
fig.show()

newplot(5)

@saravanansaminathan
Copy link

saravanansaminathan commented Aug 6, 2024

from plotly.subplots import make_subplots
import plotly.graph_objects as go
from statsmodels.tsa.seasonal import DecomposeResult, seasonal_decompose

def plot_seasonal_decompose(result:DecomposeResult, dates:pd.Series=None, title:str="Seasonal Decomposition"):
x_values = dates if dates is not None else np.arange(len(result.observed))
return (
make_subplots(
rows=4,
cols=1,
subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"],
)
.add_trace(
go.Scatter(x=x_values, y=result.observed, mode="lines", name='Observed'),
row=1,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.trend, mode="lines", name='Trend'),
row=2,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.seasonal, mode="lines", name='Seasonal'),
row=3,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.resid, mode="lines", name='Residual'),
row=4,
col=1,
)
.update_layout(
height=900, title=f'{title}', margin={'t':100}, title_x=0.5, showlegend=False
)
)

import pandas as pd
from statsmodels.tsa.seasonal import seasonal_decompose
data = pd.read_csv("https://raw.githubusercontent.com/swilsonmfc/pandas/main/AirPassengers.csv")
decomposition = seasonal_decompose(data['#Passengers'], model='additive', period=12)
fig = plot_seasonal_decompose(decomposition, dates=data['Month'])
fig.show()

Updated code combining above comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment