done
This commit is contained in:
@ -0,0 +1,170 @@
|
||||
"""
|
||||
The `trendline_functions` module contains functions which are called by Plotly Express
|
||||
when the `trendline` argument is used. Valid values for `trendline` are the names of the
|
||||
functions in this module, and the value of the `trendline_options` argument to PX
|
||||
functions is passed in as the first argument to these functions when called.
|
||||
|
||||
Note that the functions in this module are not meant to be called directly, and are
|
||||
exposed as part of the public API for documentation purposes.
|
||||
"""
|
||||
|
||||
__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"]
|
||||
|
||||
|
||||
def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
|
||||
"""Ordinary Least Squares (OLS) trendline function
|
||||
|
||||
Requires `statsmodels` to be installed.
|
||||
|
||||
This trendline function causes fit results to be stored within the figure,
|
||||
accessible via the `plotly.express.get_trendline_results` function. The fit results
|
||||
are the output of the `statsmodels.api.OLS` function.
|
||||
|
||||
Valid keys for the `trendline_options` dict are:
|
||||
|
||||
- `add_constant` (`bool`, default `True`): if `False`, the trendline passes through
|
||||
the origin but if `True` a y-intercept is fitted.
|
||||
|
||||
- `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
|
||||
respect to the base 10 logarithm of the input. Note that this means no zeros can
|
||||
be present in the input.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
valid_options = ["add_constant", "log_x", "log_y"]
|
||||
for k in trendline_options.keys():
|
||||
if k not in valid_options:
|
||||
raise ValueError(
|
||||
"OLS trendline_options keys must be one of [%s] but got '%s'"
|
||||
% (", ".join(valid_options), k)
|
||||
)
|
||||
|
||||
import statsmodels.api as sm
|
||||
|
||||
add_constant = trendline_options.get("add_constant", True)
|
||||
log_x = trendline_options.get("log_x", False)
|
||||
log_y = trendline_options.get("log_y", False)
|
||||
|
||||
if log_y:
|
||||
if np.any(y <= 0):
|
||||
raise ValueError(
|
||||
"Can't do OLS trendline with `log_y=True` when `y` contains non-positive values."
|
||||
)
|
||||
y = np.log10(y)
|
||||
y_label = "log10(%s)" % y_label
|
||||
if log_x:
|
||||
if np.any(x <= 0):
|
||||
raise ValueError(
|
||||
"Can't do OLS trendline with `log_x=True` when `x` contains non-positive values."
|
||||
)
|
||||
x = np.log10(x)
|
||||
x_label = "log10(%s)" % x_label
|
||||
if add_constant:
|
||||
x = sm.add_constant(x)
|
||||
fit_results = sm.OLS(y, x, missing="drop").fit()
|
||||
y_out = fit_results.predict()
|
||||
if log_y:
|
||||
y_out = np.power(10, y_out)
|
||||
hover_header = "<b>OLS trendline</b><br>"
|
||||
if len(fit_results.params) == 2:
|
||||
hover_header += "%s = %g * %s + %g<br>" % (
|
||||
y_label,
|
||||
fit_results.params[1],
|
||||
x_label,
|
||||
fit_results.params[0],
|
||||
)
|
||||
elif not add_constant:
|
||||
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
|
||||
else:
|
||||
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
|
||||
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
|
||||
return y_out, hover_header, fit_results
|
||||
|
||||
|
||||
def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
|
||||
"""LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function
|
||||
|
||||
Requires `statsmodels` to be installed.
|
||||
|
||||
Valid keys for the `trendline_options` dict are:
|
||||
|
||||
- `frac` (`float`, default `0.6666666`): the `frac` parameter from the
|
||||
`statsmodels.api.nonparametric.lowess` function
|
||||
"""
|
||||
|
||||
valid_options = ["frac"]
|
||||
for k in trendline_options.keys():
|
||||
if k not in valid_options:
|
||||
raise ValueError(
|
||||
"LOWESS trendline_options keys must be one of [%s] but got '%s'"
|
||||
% (", ".join(valid_options), k)
|
||||
)
|
||||
|
||||
import statsmodels.api as sm
|
||||
|
||||
frac = trendline_options.get("frac", 0.6666666)
|
||||
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
|
||||
hover_header = "<b>LOWESS trendline</b><br><br>"
|
||||
return y_out, hover_header, None
|
||||
|
||||
|
||||
def _pandas(mode, trendline_options, x_raw, y, non_missing):
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
msg = "Trendline requires pandas to be installed"
|
||||
raise ImportError(msg)
|
||||
|
||||
modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
|
||||
trendline_options = trendline_options.copy()
|
||||
function_name = trendline_options.pop("function", "mean")
|
||||
function_args = trendline_options.pop("function_args", dict())
|
||||
|
||||
series = pd.Series(np.copy(y), index=x_raw.to_pandas())
|
||||
|
||||
# TODO: Narwhals Series/DataFrame do not support rolling, ewm nor expanding, therefore
|
||||
# it fallbacks to pandas Series independently of the original type.
|
||||
# Plotly issue: https://github.com/plotly/plotly.py/issues/4834
|
||||
# Narwhals issue: https://github.com/narwhals-dev/narwhals/issues/1254
|
||||
agg = getattr(series, mode) # e.g. series.rolling
|
||||
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
|
||||
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
|
||||
y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
|
||||
y_out = y_out[non_missing]
|
||||
hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
|
||||
return y_out, hover_header, None
|
||||
|
||||
|
||||
def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
|
||||
"""Rolling trendline function
|
||||
|
||||
The value of the `function` key of the `trendline_options` dict is the function to
|
||||
use (defaults to `mean`) and the value of the `function_args` key are taken to be
|
||||
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
|
||||
keyword arguments into the `pandas.Series.rolling` function.
|
||||
"""
|
||||
return _pandas("rolling", trendline_options, x_raw, y, non_missing)
|
||||
|
||||
|
||||
def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
|
||||
"""Expanding trendline function
|
||||
|
||||
The value of the `function` key of the `trendline_options` dict is the function to
|
||||
use (defaults to `mean`) and the value of the `function_args` key are taken to be
|
||||
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
|
||||
keyword arguments into the `pandas.Series.expanding` function.
|
||||
"""
|
||||
return _pandas("expanding", trendline_options, x_raw, y, non_missing)
|
||||
|
||||
|
||||
def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
|
||||
"""Exponentially Weighted Moment (EWM) trendline function
|
||||
|
||||
The value of the `function` key of the `trendline_options` dict is the function to
|
||||
use (defaults to `mean`) and the value of the `function_args` key are taken to be
|
||||
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
|
||||
keyword arguments into the `pandas.Series.ewm` function.
|
||||
"""
|
||||
return _pandas("ewm", trendline_options, x_raw, y, non_missing)
|
Reference in New Issue
Block a user