This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,155 @@
from numbers import Number
import plotly.exceptions
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
def make_linear_colorscale(colors):
"""
Makes a list of colors into a colorscale-acceptable form
For documentation regarding to the form of the output, see
https://plot.ly/python/reference/#mesh3d-colorscale
"""
scale = 1.0 / (len(colors) - 1)
return [[i * scale, color] for i, color in enumerate(colors)]
def create_2d_density(
x,
y,
colorscale="Earth",
ncontours=20,
hist_color=(0, 0, 0.5),
point_color=(0, 0, 0.5),
point_size=2,
title="2D Density Plot",
height=600,
width=600,
):
"""
**deprecated**, use instead
:func:`plotly.express.density_heatmap`.
:param (list|array) x: x-axis data for plot generation
:param (list|array) y: y-axis data for plot generation
:param (str|tuple|list) colorscale: either a plotly scale name, an rgb
or hex color, a color tuple or a list or tuple of colors. An rgb
color is of the form 'rgb(x, y, z)' where x, y, z belong to the
interval [0, 255] and a color tuple is a tuple of the form
(a, b, c) where a, b and c belong to [0, 1]. If colormap is a
list, it must contain the valid color types aforementioned as its
members.
:param (int) ncontours: the number of 2D contours to draw on the plot
:param (str) hist_color: the color of the plotted histograms
:param (str) point_color: the color of the scatter points
:param (str) point_size: the color of the scatter points
:param (str) title: set the title for the plot
:param (float) height: the height of the chart
:param (float) width: the width of the chart
Examples
--------
Example 1: Simple 2D Density Plot
>>> from plotly.figure_factory import create_2d_density
>>> import numpy as np
>>> # Make data points
>>> t = np.linspace(-1,1.2,2000)
>>> x = (t**3)+(0.3*np.random.randn(2000))
>>> y = (t**6)+(0.3*np.random.randn(2000))
>>> # Create a figure
>>> fig = create_2d_density(x, y)
>>> # Plot the data
>>> fig.show()
Example 2: Using Parameters
>>> from plotly.figure_factory import create_2d_density
>>> import numpy as np
>>> # Make data points
>>> t = np.linspace(-1,1.2,2000)
>>> x = (t**3)+(0.3*np.random.randn(2000))
>>> y = (t**6)+(0.3*np.random.randn(2000))
>>> # Create custom colorscale
>>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
... (1, 1, 0.2), (0.98,0.98,0.98)]
>>> # Create a figure
>>> fig = create_2d_density(x, y, colorscale=colorscale,
... hist_color='rgb(255, 237, 222)', point_size=3)
>>> # Plot the data
>>> fig.show()
"""
# validate x and y are filled with numbers only
for array in [x, y]:
if not all(isinstance(element, Number) for element in array):
raise plotly.exceptions.PlotlyError(
"All elements of your 'x' and 'y' lists must be numbers."
)
# validate x and y are the same length
if len(x) != len(y):
raise plotly.exceptions.PlotlyError(
"Both lists 'x' and 'y' must be the same length."
)
colorscale = clrs.validate_colors(colorscale, "rgb")
colorscale = make_linear_colorscale(colorscale)
# validate hist_color and point_color
hist_color = clrs.validate_colors(hist_color, "rgb")
point_color = clrs.validate_colors(point_color, "rgb")
trace1 = graph_objs.Scatter(
x=x,
y=y,
mode="markers",
name="points",
marker=dict(color=point_color[0], size=point_size, opacity=0.4),
)
trace2 = graph_objs.Histogram2dContour(
x=x,
y=y,
name="density",
ncontours=ncontours,
colorscale=colorscale,
reversescale=True,
showscale=False,
)
trace3 = graph_objs.Histogram(
x=x, name="x density", marker=dict(color=hist_color[0]), yaxis="y2"
)
trace4 = graph_objs.Histogram(
y=y, name="y density", marker=dict(color=hist_color[0]), xaxis="x2"
)
data = [trace1, trace2, trace3, trace4]
layout = graph_objs.Layout(
showlegend=False,
autosize=False,
title=title,
height=height,
width=width,
xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
margin=dict(t=50),
hovermode="closest",
bargap=0,
xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
)
fig = graph_objs.Figure(data=data, layout=layout)
return fig

View File

@ -0,0 +1,69 @@
# ruff: noqa: E402
from plotly import optional_imports
# Require that numpy exists for figure_factory
np = optional_imports.get_module("numpy")
if np is None:
raise ImportError(
"""\
The figure factory module requires the numpy package"""
)
from plotly.figure_factory._2d_density import create_2d_density
from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
from plotly.figure_factory._bullet import create_bullet
from plotly.figure_factory._candlestick import create_candlestick
from plotly.figure_factory._dendrogram import create_dendrogram
from plotly.figure_factory._distplot import create_distplot
from plotly.figure_factory._facet_grid import create_facet_grid
from plotly.figure_factory._gantt import create_gantt
from plotly.figure_factory._ohlc import create_ohlc
from plotly.figure_factory._quiver import create_quiver
from plotly.figure_factory._scatterplot import create_scatterplotmatrix
from plotly.figure_factory._streamline import create_streamline
from plotly.figure_factory._table import create_table
from plotly.figure_factory._trisurf import create_trisurf
from plotly.figure_factory._violin import create_violin
if optional_imports.get_module("pandas") is not None:
from plotly.figure_factory._county_choropleth import create_choropleth
from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox
else:
def create_choropleth(*args, **kwargs):
raise ImportError("Please install pandas to use `create_choropleth`")
def create_hexbin_mapbox(*args, **kwargs):
raise ImportError("Please install pandas to use `create_hexbin_mapbox`")
if optional_imports.get_module("skimage") is not None:
from plotly.figure_factory._ternary_contour import create_ternary_contour
else:
def create_ternary_contour(*args, **kwargs):
raise ImportError("Please install scikit-image to use `create_ternary_contour`")
__all__ = [
"create_2d_density",
"create_annotated_heatmap",
"create_bullet",
"create_candlestick",
"create_choropleth",
"create_dendrogram",
"create_distplot",
"create_facet_grid",
"create_gantt",
"create_hexbin_mapbox",
"create_ohlc",
"create_quiver",
"create_scatterplotmatrix",
"create_streamline",
"create_table",
"create_ternary_contour",
"create_trisurf",
"create_violin",
]

View File

@ -0,0 +1,307 @@
import plotly.colors as clrs
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
from plotly.validator_cache import ValidatorCache
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
def validate_annotated_heatmap(z, x, y, annotation_text):
"""
Annotated-heatmap-specific validations
Check that if a text matrix is supplied, it has the same
dimensions as the z matrix.
See FigureFactory.create_annotated_heatmap() for params
:raises: (PlotlyError) If z and text matrices do not have the same
dimensions.
"""
if annotation_text is not None and isinstance(annotation_text, list):
utils.validate_equal_length(z, annotation_text)
for lst in range(len(z)):
if len(z[lst]) != len(annotation_text[lst]):
raise exceptions.PlotlyError(
"z and text should have the same dimensions"
)
if x:
if len(x) != len(z[0]):
raise exceptions.PlotlyError(
"oops, the x list that you "
"provided does not match the "
"width of your z matrix "
)
if y:
if len(y) != len(z):
raise exceptions.PlotlyError(
"oops, the y list that you "
"provided does not match the "
"length of your z matrix "
)
def create_annotated_heatmap(
z,
x=None,
y=None,
annotation_text=None,
colorscale="Plasma",
font_colors=None,
showscale=False,
reversescale=False,
**kwargs,
):
"""
**deprecated**, use instead
:func:`plotly.express.imshow`.
Function that creates annotated heatmaps
This function adds annotations to each cell of the heatmap.
:param (list[list]|ndarray) z: z matrix to create heatmap.
:param (list) x: x axis labels.
:param (list) y: y axis labels.
:param (list[list]|ndarray) annotation_text: Text strings for
annotations. Should have the same dimensions as the z matrix. If no
text is added, the values of the z matrix are annotated. Default =
z matrix values.
:param (list|str) colorscale: heatmap colorscale.
:param (list) font_colors: List of two color strings: [min_text_color,
max_text_color] where min_text_color is applied to annotations for
heatmap values < (max_value - min_value)/2. If font_colors is not
defined, the colors are defined logically as black or white
depending on the heatmap's colorscale.
:param (bool) showscale: Display colorscale. Default = False
:param (bool) reversescale: Reverse colorscale. Default = False
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
These kwargs describe other attributes about the annotated Heatmap
trace such as the colorscale. For more information on valid kwargs
call help(plotly.graph_objs.Heatmap)
Example 1: Simple annotated heatmap with default configuration
>>> import plotly.figure_factory as ff
>>> z = [[0.300000, 0.00000, 0.65, 0.300000],
... [1, 0.100005, 0.45, 0.4300],
... [0.300000, 0.00000, 0.65, 0.300000],
... [1, 0.100005, 0.45, 0.00000]]
>>> fig = ff.create_annotated_heatmap(z)
>>> fig.show()
"""
# Avoiding mutables in the call signature
font_colors = font_colors if font_colors is not None else []
validate_annotated_heatmap(z, x, y, annotation_text)
# validate colorscale
colorscale_validator = ValidatorCache.get_validator("heatmap", "colorscale")
colorscale = colorscale_validator.validate_coerce(colorscale)
annotations = _AnnotatedHeatmap(
z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
).make_annotations()
if x or y:
trace = dict(
type="heatmap",
z=z,
x=x,
y=y,
colorscale=colorscale,
showscale=showscale,
reversescale=reversescale,
**kwargs,
)
layout = dict(
annotations=annotations,
xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"),
yaxis=dict(ticks="", dtick=1, ticksuffix=" "),
)
else:
trace = dict(
type="heatmap",
z=z,
colorscale=colorscale,
showscale=showscale,
reversescale=reversescale,
**kwargs,
)
layout = dict(
annotations=annotations,
xaxis=dict(
ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False
),
yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False),
)
data = [trace]
return graph_objs.Figure(data=data, layout=layout)
def to_rgb_color_list(color_str, default):
color_str = color_str.strip()
if color_str.startswith("rgb"):
return [int(v) for v in color_str.strip("rgba()").split(",")]
elif color_str.startswith("#"):
return clrs.hex_to_rgb(color_str)
else:
return default
def should_use_black_text(background_color):
return (
background_color[0] * 0.299
+ background_color[1] * 0.587
+ background_color[2] * 0.114
) > 186
class _AnnotatedHeatmap(object):
"""
Refer to TraceFactory.create_annotated_heatmap() for docstring
"""
def __init__(
self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
):
self.z = z
if x:
self.x = x
else:
self.x = range(len(z[0]))
if y:
self.y = y
else:
self.y = range(len(z))
if annotation_text is not None:
self.annotation_text = annotation_text
else:
self.annotation_text = self.z
self.colorscale = colorscale
self.reversescale = reversescale
self.font_colors = font_colors
if np and isinstance(self.z, np.ndarray):
self.zmin = np.amin(self.z)
self.zmax = np.amax(self.z)
else:
self.zmin = min([v for row in self.z for v in row])
self.zmax = max([v for row in self.z for v in row])
if kwargs.get("zmin", None) is not None:
self.zmin = kwargs["zmin"]
if kwargs.get("zmax", None) is not None:
self.zmax = kwargs["zmax"]
self.zmid = (self.zmax + self.zmin) / 2
if kwargs.get("zmid", None) is not None:
self.zmid = kwargs["zmid"]
def get_text_color(self):
"""
Get font color for annotations.
The annotated heatmap can feature two text colors: min_text_color and
max_text_color. The min_text_color is applied to annotations for
heatmap values < (max_value - min_value)/2. The user can define these
two colors. Otherwise the colors are defined logically as black or
white depending on the heatmap's colorscale.
:rtype (string, string) min_text_color, max_text_color: text
color for annotations for heatmap values <
(max_value - min_value)/2 and text color for annotations for
heatmap values >= (max_value - min_value)/2
"""
# Plotly colorscales ranging from a lighter shade to a darker shade
colorscales = [
"Greys",
"Greens",
"Blues",
"YIGnBu",
"YIOrRd",
"RdBu",
"Picnic",
"Jet",
"Hot",
"Blackbody",
"Earth",
"Electric",
"Viridis",
"Cividis",
]
# Plotly colorscales ranging from a darker shade to a lighter shade
colorscales_reverse = ["Reds"]
white = "#FFFFFF"
black = "#000000"
if self.font_colors:
min_text_color = self.font_colors[0]
max_text_color = self.font_colors[-1]
elif self.colorscale in colorscales and self.reversescale:
min_text_color = black
max_text_color = white
elif self.colorscale in colorscales:
min_text_color = white
max_text_color = black
elif self.colorscale in colorscales_reverse and self.reversescale:
min_text_color = white
max_text_color = black
elif self.colorscale in colorscales_reverse:
min_text_color = black
max_text_color = white
elif isinstance(self.colorscale, list):
min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255])
max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255])
# swap min/max colors if reverse scale
if self.reversescale:
min_col, max_col = max_col, min_col
if should_use_black_text(min_col):
min_text_color = black
else:
min_text_color = white
if should_use_black_text(max_col):
max_text_color = black
else:
max_text_color = white
else:
min_text_color = black
max_text_color = black
return min_text_color, max_text_color
def make_annotations(self):
"""
Get annotations for each cell of the heatmap with graph_objs.Annotation
:rtype (list[dict]) annotations: list of annotations for each cell of
the heatmap
"""
min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
annotations = []
for n, row in enumerate(self.z):
for m, val in enumerate(row):
font_color = min_text_color if val < self.zmid else max_text_color
annotations.append(
graph_objs.layout.Annotation(
text=str(self.annotation_text[n][m]),
x=self.x[m],
y=self.y[n],
xref="x1",
yref="y1",
font=dict(color=font_color),
showarrow=False,
)
)
return annotations

View File

@ -0,0 +1,366 @@
import math
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.figure_factory import utils
import plotly
import plotly.graph_objs as go
pd = optional_imports.get_module("pandas")
def _bullet(
df,
markers,
measures,
ranges,
subtitles,
titles,
orientation,
range_colors,
measure_colors,
horizontal_spacing,
vertical_spacing,
scatter_options,
layout_options,
):
num_of_lanes = len(df)
num_of_rows = num_of_lanes if orientation == "h" else 1
num_of_cols = 1 if orientation == "h" else num_of_lanes
if not horizontal_spacing:
horizontal_spacing = 1.0 / num_of_lanes
if not vertical_spacing:
vertical_spacing = 1.0 / num_of_lanes
fig = plotly.subplots.make_subplots(
num_of_rows,
num_of_cols,
print_grid=False,
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
)
# layout
fig["layout"].update(
dict(shapes=[]),
title="Bullet Chart",
height=600,
width=1000,
showlegend=False,
barmode="stack",
annotations=[],
margin=dict(l=120 if orientation == "h" else 80),
)
# update layout
fig["layout"].update(layout_options)
if orientation == "h":
width_axis = "yaxis"
length_axis = "xaxis"
else:
width_axis = "xaxis"
length_axis = "yaxis"
for key in fig["layout"]:
if "xaxis" in key or "yaxis" in key:
fig["layout"][key]["showgrid"] = False
fig["layout"][key]["zeroline"] = False
if length_axis in key:
fig["layout"][key]["tickwidth"] = 1
if width_axis in key:
fig["layout"][key]["showticklabels"] = False
fig["layout"][key]["range"] = [0, 1]
# narrow domain if 1 bar
if num_of_lanes <= 1:
fig["layout"][width_axis + "1"]["domain"] = [0.4, 0.6]
if not range_colors:
range_colors = ["rgb(200, 200, 200)", "rgb(245, 245, 245)"]
if not measure_colors:
measure_colors = ["rgb(31, 119, 180)", "rgb(176, 196, 221)"]
for row in range(num_of_lanes):
# ranges bars
for idx in range(len(df.iloc[row]["ranges"])):
inter_colors = clrs.n_colors(
range_colors[0], range_colors[1], len(df.iloc[row]["ranges"]), "rgb"
)
x = (
[sorted(df.iloc[row]["ranges"])[-1 - idx]]
if orientation == "h"
else [0]
)
y = (
[0]
if orientation == "h"
else [sorted(df.iloc[row]["ranges"])[-1 - idx]]
)
bar = go.Bar(
x=x,
y=y,
marker=dict(color=inter_colors[-1 - idx]),
name="ranges",
hoverinfo="x" if orientation == "h" else "y",
orientation=orientation,
width=2,
base=0,
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
)
fig.add_trace(bar)
# measures bars
for idx in range(len(df.iloc[row]["measures"])):
inter_colors = clrs.n_colors(
measure_colors[0],
measure_colors[1],
len(df.iloc[row]["measures"]),
"rgb",
)
x = (
[sorted(df.iloc[row]["measures"])[-1 - idx]]
if orientation == "h"
else [0.5]
)
y = (
[0.5]
if orientation == "h"
else [sorted(df.iloc[row]["measures"])[-1 - idx]]
)
bar = go.Bar(
x=x,
y=y,
marker=dict(color=inter_colors[-1 - idx]),
name="measures",
hoverinfo="x" if orientation == "h" else "y",
orientation=orientation,
width=0.4,
base=0,
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
)
fig.add_trace(bar)
# markers
x = df.iloc[row]["markers"] if orientation == "h" else [0.5]
y = [0.5] if orientation == "h" else df.iloc[row]["markers"]
markers = go.Scatter(
x=x,
y=y,
name="markers",
hoverinfo="x" if orientation == "h" else "y",
xaxis="x{}".format(row + 1),
yaxis="y{}".format(row + 1),
**scatter_options,
)
fig.add_trace(markers)
# titles and subtitles
title = df.iloc[row]["titles"]
if "subtitles" in df:
subtitle = "<br>{}".format(df.iloc[row]["subtitles"])
else:
subtitle = ""
label = "<b>{}</b>".format(title) + subtitle
annot = utils.annotation_dict_for_label(
label,
(num_of_lanes - row if orientation == "h" else row + 1),
num_of_lanes,
vertical_spacing if orientation == "h" else horizontal_spacing,
"row" if orientation == "h" else "col",
True if orientation == "h" else False,
False,
)
fig["layout"]["annotations"] += (annot,)
return fig
def create_bullet(
data,
markers=None,
measures=None,
ranges=None,
subtitles=None,
titles=None,
orientation="h",
range_colors=("rgb(200, 200, 200)", "rgb(245, 245, 245)"),
measure_colors=("rgb(31, 119, 180)", "rgb(176, 196, 221)"),
horizontal_spacing=None,
vertical_spacing=None,
scatter_options={},
**layout_options,
):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Indicator`.
:param (pd.DataFrame | list | tuple) data: either a list/tuple of
dictionaries or a pandas DataFrame.
:param (str) markers: the column name or dictionary key for the markers in
each subplot.
:param (str) measures: the column name or dictionary key for the measure
bars in each subplot. This bar usually represents the quantitative
measure of performance, usually a list of two values [a, b] and are
the blue bars in the foreground of each subplot by default.
:param (str) ranges: the column name or dictionary key for the qualitative
ranges of performance, usually a 3-item list [bad, okay, good]. They
correspond to the grey bars in the background of each chart.
:param (str) subtitles: the column name or dictionary key for the subtitle
of each subplot chart. The subplots are displayed right underneath
each title.
:param (str) titles: the column name or dictionary key for the main label
of each subplot chart.
:param (bool) orientation: if 'h', the bars are placed horizontally as
rows. If 'v' the bars are placed vertically in the chart.
:param (list) range_colors: a tuple of two colors between which all
the rectangles for the range are drawn. These rectangles are meant to
be qualitative indicators against which the marker and measure bars
are compared.
Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)')
:param (list) measure_colors: a tuple of two colors which is used to color
the thin quantitative bars in the bullet chart.
Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)')
:param (float) horizontal_spacing: see the 'horizontal_spacing' param in
plotly.tools.make_subplots. Ranges between 0 and 1.
:param (float) vertical_spacing: see the 'vertical_spacing' param in
plotly.tools.make_subplots. Ranges between 0 and 1.
:param (dict) scatter_options: describes attributes for the scatter trace
in each subplot such as name and marker size. Call
help(plotly.graph_objs.Scatter) for more information on valid params.
:param layout_options: describes attributes for the layout of the figure
such as title, height and width. Call help(plotly.graph_objs.Layout)
for more information on valid params.
Example 1: Use a Dictionary
>>> import plotly.figure_factory as ff
>>> data = [
... {"label": "revenue", "sublabel": "us$, in thousands",
... "range": [150, 225, 300], "performance": [220,270], "point": [250]},
... {"label": "Profit", "sublabel": "%", "range": [20, 25, 30],
... "performance": [21, 23], "point": [26]},
... {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600],
... "performance": [100,320],"point": [550]},
... {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500],
... "performance": [1000, 1650],"point": [2100]},
... {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5],
... "performance": [3.2, 4.7], "point": [4.4]}
... ]
>>> fig = ff.create_bullet(
... data, titles='label', subtitles='sublabel', markers='point',
... measures='performance', ranges='range', orientation='h',
... title='my simple bullet chart'
... )
>>> fig.show()
Example 2: Use a DataFrame with Custom Colors
>>> import plotly.figure_factory as ff
>>> import pandas as pd
>>> data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json')
>>> fig = ff.create_bullet(
... data, titles='title', markers='markers', measures='measures',
... orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'],
... scatter_options={'marker': {'symbol': 'circle'}}, width=700)
>>> fig.show()
"""
# validate df
if not pd:
raise ImportError("'pandas' must be installed for this figure factory.")
if utils.is_sequence(data):
if not all(isinstance(item, dict) for item in data):
raise exceptions.PlotlyError(
"Every entry of the data argument list, tuple, etc must "
"be a dictionary."
)
elif not isinstance(data, pd.DataFrame):
raise exceptions.PlotlyError(
"You must input a pandas DataFrame, or a list of dictionaries."
)
# make DataFrame from data with correct column headers
col_names = ["titles", "subtitle", "markers", "measures", "ranges"]
if utils.is_sequence(data):
df = pd.DataFrame(
[
[d[titles] for d in data] if titles else [""] * len(data),
[d[subtitles] for d in data] if subtitles else [""] * len(data),
[d[markers] for d in data] if markers else [[]] * len(data),
[d[measures] for d in data] if measures else [[]] * len(data),
[d[ranges] for d in data] if ranges else [[]] * len(data),
],
index=col_names,
)
elif isinstance(data, pd.DataFrame):
df = pd.DataFrame(
[
data[titles].tolist() if titles else [""] * len(data),
data[subtitles].tolist() if subtitles else [""] * len(data),
data[markers].tolist() if markers else [[]] * len(data),
data[measures].tolist() if measures else [[]] * len(data),
data[ranges].tolist() if ranges else [[]] * len(data),
],
index=col_names,
)
df = pd.DataFrame.transpose(df)
# make sure ranges, measures, 'markers' are not NAN or NONE
for needed_key in ["ranges", "measures", "markers"]:
for idx, r in enumerate(df[needed_key]):
try:
r_is_nan = math.isnan(r)
if r_is_nan or r is None:
df[needed_key][idx] = []
except TypeError:
pass
# validate custom colors
for colors_list in [range_colors, measure_colors]:
if colors_list:
if len(colors_list) != 2:
raise exceptions.PlotlyError(
"Both 'range_colors' or 'measure_colors' must be a list "
"of two valid colors."
)
clrs.validate_colors(colors_list)
colors_list = clrs.convert_colors_to_same_type(colors_list, "rgb")[0]
# default scatter options
default_scatter = {
"marker": {"size": 12, "symbol": "diamond-tall", "color": "rgb(0, 0, 0)"}
}
if scatter_options == {}:
scatter_options.update(default_scatter)
else:
# add default options to scatter_options if they are not present
for k in default_scatter["marker"]:
if k not in scatter_options["marker"]:
scatter_options["marker"][k] = default_scatter["marker"][k]
fig = _bullet(
df,
markers,
measures,
ranges,
subtitles,
titles,
orientation,
range_colors,
measure_colors,
horizontal_spacing,
vertical_spacing,
scatter_options,
layout_options,
)
return fig

View File

@ -0,0 +1,277 @@
from plotly.figure_factory import utils
from plotly.figure_factory._ohlc import (
_DEFAULT_INCREASING_COLOR,
_DEFAULT_DECREASING_COLOR,
validate_ohlc,
)
from plotly.graph_objs import graph_objs
def make_increasing_candle(open, high, low, close, dates, **kwargs):
"""
Makes boxplot trace for increasing candlesticks
_make_increasing_candle() and _make_decreasing_candle separate the
increasing traces from the decreasing traces so kwargs (such as
color) can be passed separately to increasing or decreasing traces
when direction is set to 'increasing' or 'decreasing' in
FigureFactory.create_candlestick()
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (list) candle_incr_data: list of the box trace for
increasing candlesticks.
"""
increase_x, increase_y = _Candlestick(
open, high, low, close, dates, **kwargs
).get_candle_increase()
if "line" in kwargs:
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
else:
kwargs.setdefault("fillcolor", _DEFAULT_INCREASING_COLOR)
if "name" in kwargs:
kwargs.setdefault("showlegend", True)
else:
kwargs.setdefault("showlegend", False)
kwargs.setdefault("name", "Increasing")
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR))
candle_incr_data = dict(
type="box",
x=increase_x,
y=increase_y,
whiskerwidth=0,
boxpoints=False,
**kwargs,
)
return [candle_incr_data]
def make_decreasing_candle(open, high, low, close, dates, **kwargs):
"""
Makes boxplot trace for decreasing candlesticks
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to decreasing trace via
plotly.graph_objs.Scatter.
:rtype (list) candle_decr_data: list of the box trace for
decreasing candlesticks.
"""
decrease_x, decrease_y = _Candlestick(
open, high, low, close, dates, **kwargs
).get_candle_decrease()
if "line" in kwargs:
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
else:
kwargs.setdefault("fillcolor", _DEFAULT_DECREASING_COLOR)
kwargs.setdefault("showlegend", False)
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR))
kwargs.setdefault("name", "Decreasing")
candle_decr_data = dict(
type="box",
x=decrease_x,
y=decrease_y,
whiskerwidth=0,
boxpoints=False,
**kwargs,
)
return [candle_decr_data]
def create_candlestick(open, high, low, close, dates=None, direction="both", **kwargs):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Candlestick`
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param (string) direction: direction can be 'increasing', 'decreasing',
or 'both'. When the direction is 'increasing', the returned figure
consists of all candlesticks where the close value is greater than
the corresponding open value, and when the direction is
'decreasing', the returned figure consists of all candlesticks
where the close value is less than or equal to the corresponding
open value. When the direction is 'both', both increasing and
decreasing candlesticks are returned. Default: 'both'
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
These kwargs describe other attributes about the ohlc Scatter trace
such as the color or the legend name. For more information on valid
kwargs call help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of candlestick chart figure.
Example 1: Simple candlestick chart from a Pandas DataFrame
>>> from plotly.figure_factory import create_candlestick
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index)
>>> fig.show()
Example 2: Customize the candlestick colors
>>> from plotly.figure_factory import create_candlestick
>>> from plotly.graph_objs import Line, Marker
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> # Make increasing candlesticks and customize their color and name
>>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index,
... direction='increasing', name='AAPL',
... marker=Marker(color='rgb(150, 200, 250)'),
... line=Line(color='rgb(150, 200, 250)'))
>>> # Make decreasing candlesticks and customize their color and name
>>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
... dates=df.index,
... direction='decreasing',
... marker=Marker(color='rgb(128, 128, 128)'),
... line=Line(color='rgb(128, 128, 128)'))
>>> # Initialize the figure
>>> fig = fig_increasing
>>> # Add decreasing data with .extend()
>>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP
>>> fig.show()
Example 3: Candlestick chart with datetime objects
>>> from plotly.figure_factory import create_candlestick
>>> from datetime import datetime
>>> # Add data
>>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
>>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
>>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
>>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
>>> dates = [datetime(year=2013, month=10, day=10),
... datetime(year=2013, month=11, day=10),
... datetime(year=2013, month=12, day=10),
... datetime(year=2014, month=1, day=10),
... datetime(year=2014, month=2, day=10)]
>>> # Create ohlc
>>> fig = create_candlestick(open_data, high_data,
... low_data, close_data, dates=dates)
>>> fig.show()
"""
if dates is not None:
utils.validate_equal_length(open, high, low, close, dates)
else:
utils.validate_equal_length(open, high, low, close)
validate_ohlc(open, high, low, close, direction, **kwargs)
if direction == "increasing":
candle_incr_data = make_increasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_incr_data
elif direction == "decreasing":
candle_decr_data = make_decreasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_decr_data
else:
candle_incr_data = make_increasing_candle(
open, high, low, close, dates, **kwargs
)
candle_decr_data = make_decreasing_candle(
open, high, low, close, dates, **kwargs
)
data = candle_incr_data + candle_decr_data
layout = graph_objs.Layout()
return graph_objs.Figure(data=data, layout=layout)
class _Candlestick(object):
"""
Refer to FigureFactory.create_candlestick() for docstring.
"""
def __init__(self, open, high, low, close, dates, **kwargs):
self.open = open
self.high = high
self.low = low
self.close = close
if dates is not None:
self.x = dates
else:
self.x = [x for x in range(len(self.open))]
self.get_candle_increase()
def get_candle_increase(self):
"""
Separate increasing data from decreasing data.
The data is increasing when close value > open value
and decreasing when the close value <= open value.
"""
increase_y = []
increase_x = []
for index in range(len(self.open)):
if self.close[index] > self.open[index]:
increase_y.append(self.low[index])
increase_y.append(self.open[index])
increase_y.append(self.close[index])
increase_y.append(self.close[index])
increase_y.append(self.close[index])
increase_y.append(self.high[index])
increase_x.append(self.x[index])
increase_x = [[x, x, x, x, x, x] for x in increase_x]
increase_x = utils.flatten(increase_x)
return increase_x, increase_y
def get_candle_decrease(self):
"""
Separate increasing data from decreasing data.
The data is increasing when close value > open value
and decreasing when the close value <= open value.
"""
decrease_y = []
decrease_x = []
for index in range(len(self.open)):
if self.close[index] <= self.open[index]:
decrease_y.append(self.low[index])
decrease_y.append(self.open[index])
decrease_y.append(self.close[index])
decrease_y.append(self.close[index])
decrease_y.append(self.close[index])
decrease_y.append(self.high[index])
decrease_x.append(self.x[index])
decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
decrease_x = utils.flatten(decrease_x)
return decrease_x, decrease_y

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,395 @@
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
def create_dendrogram(
X,
orientation="bottom",
labels=None,
colorscale=None,
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
):
"""
Function that returns a dendrogram Plotly figure object. This is a thin
wrapper around scipy.cluster.hierarchy.dendrogram.
See also https://dash.plot.ly/dash-bio/clustergram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for the dendrogram tree.
Requires 8 colors to be specified, the 7th of
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
and 6th are used twice as often as the others.
Given a shorter list, the missing values are
replaced with defaults and with a longer list the
extra values are ignored.
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
clusters
:param (double) color_threshold: Value at which the separation of clusters will be made
Example 1: Simple bottom oriented dendrogram
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(10,10)
>>> fig = create_dendrogram(X)
>>> fig.show()
Example 2: Dendrogram to put on the left of the heatmap
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(5,5)
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
>>> dendro.show()
Example 3: Dendrogram with Pandas
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> import pandas as pd
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
>>> fig = create_dendrogram(df, labels=Index)
>>> fig.show()
"""
if not scp or not scs or not sch:
raise ImportError(
"FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy"
)
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(
X,
orientation,
labels,
colorscale,
distfun=distfun,
linkagefun=linkagefun,
hovertext=hovertext,
color_threshold=color_threshold,
)
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(
self,
X,
orientation="bottom",
labels=None,
colorscale=None,
width=np.inf,
height=np.inf,
xaxis="xaxis",
yaxis="yaxis",
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ["left", "bottom"]:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ["right", "bottom"]:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
X, colorscale, distfun, linkagefun, hovertext, color_threshold
)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
if len(self.zero_vals) > len(yvals) + 1:
# If the length of zero_vals is larger than the length of yvals,
# it means that there are wrong vals because of the identicial samples.
# Three and more identicial samples will make the yvals of spliting
# center into 0 and it will accidentally take it as leaves.
l_border = int(min(self.zero_vals))
r_border = int(max(self.zero_vals))
correct_leaves_pos = range(
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
)
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
self.zero_vals = [v for v in correct_leaves_pos]
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = dd_traces
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
# This list is the colors that can be used by dendrogram, which were
# determined as the combination of the default above_threshold_color and
# the default color palette (see scipy/cluster/hierarchy.py)
d = {
"r": "red",
"g": "green",
"b": "blue",
"c": "cyan",
"m": "magenta",
"y": "yellow",
"k": "black",
# TODO: 'w' doesn't seem to be in the default color
# palette in scipy/cluster/hierarchy.py
"w": "white",
}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
rgb_colorscale = [
"rgb(0,116,217)", # blue
"rgb(35,205,205)", # cyan
"rgb(61,153,112)", # green
"rgb(40,35,35)", # black
"rgb(133,20,75)", # magenta
"rgb(255,65,54)", # red
"rgb(255,255,255)", # white
"rgb(255,220,0)", # yellow
]
else:
rgb_colorscale = colorscale
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(rgb_colorscale):
default_colors[k] = rgb_colorscale[i]
# add support for cyclic format colors as introduced in scipy===1.5.0
# before this, the colors were named 'r', 'b', 'y' etc., now they are
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
# scipy version, we try as much as possible to map the new colors to the
# old colors
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
# comment above).
new_old_color_map = [
("C0", "b"),
("C1", "g"),
("C2", "r"),
("C3", "c"),
("C4", "m"),
("C5", "y"),
("C6", "k"),
("C7", "g"),
("C8", "r"),
("C9", "c"),
]
for nc, oc in new_old_color_map:
try:
default_colors[nc] = default_colors[oc]
except KeyError:
# it could happen that the old color isn't found (if a custom
# colorscale was specified), in this case we set it to an
# arbitrary default.
default_colors[nc] = "rgb(0,116,217)"
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
"type": "linear",
"ticks": "outside",
"mirror": "allticks",
"rangemode": "tozero",
"showticklabels": True,
"zeroline": False,
"showgrid": False,
"showline": True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ["left", "right"]:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]["tickvals"] = [
zv * self.sign[axis_key] for zv in self.zero_vals
]
self.layout[axis_key_labels]["ticktext"] = self.labels
self.layout[axis_key_labels]["tickmode"] = "array"
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update(
{
"showlegend": False,
"autosize": False,
"hovermode": "closest",
"width": width,
"height": height,
}
)
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
)
icoord = np.array(P["icoord"])
dcoord = np.array(P["dcoord"])
ordered_labels = np.array(P["ivl"])
color_list = np.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = f"x{x_index}"
trace["yaxis"] = f"y{y_index}"
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]

View File

@ -0,0 +1,441 @@
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
pd = optional_imports.get_module("pandas")
scipy = optional_imports.get_module("scipy")
scipy_stats = optional_imports.get_module("scipy.stats")
DEFAULT_HISTNORM = "probability density"
ALTERNATIVE_HISTNORM = "probability"
def validate_distplot(hist_data, curve_type):
"""
Distplot-specific validations
:raises: (PlotlyError) If hist_data is not a list of lists
:raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
'normal').
"""
hist_data_types = (list,)
if np:
hist_data_types += (np.ndarray,)
if pd:
hist_data_types += (pd.core.series.Series,)
if not isinstance(hist_data[0], hist_data_types):
raise exceptions.PlotlyError(
"Oops, this function was written "
"to handle multiple datasets, if "
"you want to plot just one, make "
"sure your hist_data variable is "
"still a list of lists, i.e. x = "
"[1, 2, 3] -> x = [[1, 2, 3]]"
)
curve_opts = ("kde", "normal")
if curve_type not in curve_opts:
raise exceptions.PlotlyError("curve_type must be defined as 'kde' or 'normal'")
if not scipy:
raise ImportError("FigureFactory.create_distplot requires scipy")
def create_distplot(
hist_data,
group_labels,
bin_size=1.0,
curve_type="kde",
colors=None,
rug_text=None,
histnorm=DEFAULT_HISTNORM,
show_hist=True,
show_curve=True,
show_rug=True,
):
"""
Function that creates a distplot similar to seaborn.distplot;
**this function is deprecated**, use instead :mod:`plotly.express`
functions, for example
>>> import plotly.express as px
>>> tips = px.data.tips()
>>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
... hover_data=tips.columns)
>>> fig.show()
The distplot can be composed of all or any combination of the following
3 components: (1) histogram, (2) curve: (a) kernel density estimation
or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
(from multiple datasets) can be created in the same plot.
:param (list[list]) hist_data: Use list of lists to plot multiple data
sets on the same plot.
:param (list[str]) group_labels: Names for each data set.
:param (list[float]|float) bin_size: Size of histogram bins.
Default = 1.
:param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
:param (str) histnorm: 'probability density' or 'probability'
Default = 'probability density'
:param (bool) show_hist: Add histogram to distplot? Default = True
:param (bool) show_curve: Add curve to distplot? Default = True
:param (bool) show_rug: Add rug to distplot? Default = True
:param (list[str]) colors: Colors for traces.
:param (list[list]) rug_text: Hovertext values for rug_plot,
:return (dict): Representation of a distplot figure.
Example 1: Simple distplot of 1 data set
>>> from plotly.figure_factory import create_distplot
>>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
... 3.5, 4.1, 4.4, 4.5, 4.5,
... 5.0, 5.0, 5.2, 5.5, 5.5,
... 5.5, 5.5, 5.5, 6.1, 7.0]]
>>> group_labels = ['distplot example']
>>> fig = create_distplot(hist_data, group_labels)
>>> fig.show()
Example 2: Two data sets and added rug text
>>> from plotly.figure_factory import create_distplot
>>> # Add histogram data
>>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
... -0.9, -0.07, 1.95, 0.9, -0.2,
... -0.5, 0.3, 0.4, -0.37, 0.6]
>>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
... 1.0, 0.8, 1.7, 0.5, 0.8,
... -0.3, 1.2, 0.56, 0.3, 2.2]
>>> # Group data together
>>> hist_data = [hist1_x, hist2_x]
>>> group_labels = ['2012', '2013']
>>> # Add text
>>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
... 'f1', 'g1', 'h1', 'i1', 'j1',
... 'k1', 'l1', 'm1', 'n1', 'o1']
>>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
... 'f2', 'g2', 'h2', 'i2', 'j2',
... 'k2', 'l2', 'm2', 'n2', 'o2']
>>> # Group text together
>>> rug_text_all = [rug_text_1, rug_text_2]
>>> # Create distplot
>>> fig = create_distplot(
... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
>>> # Add title
>>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
>>> fig.show()
Example 3: Plot with normal curve and hide rug plot
>>> from plotly.figure_factory import create_distplot
>>> import numpy as np
>>> x1 = np.random.randn(190)
>>> x2 = np.random.randn(200)+1
>>> x3 = np.random.randn(200)-1
>>> x4 = np.random.randn(210)+2
>>> hist_data = [x1, x2, x3, x4]
>>> group_labels = ['2012', '2013', '2014', '2015']
>>> fig = create_distplot(
... hist_data, group_labels, curve_type='normal',
... show_rug=False, bin_size=.4)
Example 4: Distplot with Pandas
>>> from plotly.figure_factory import create_distplot
>>> import numpy as np
>>> import pandas as pd
>>> df = pd.DataFrame({'2012': np.random.randn(200),
... '2013': np.random.randn(200)+1})
>>> fig = create_distplot([df[c] for c in df.columns], df.columns)
>>> fig.show()
"""
if colors is None:
colors = []
if rug_text is None:
rug_text = []
validate_distplot(hist_data, curve_type)
utils.validate_equal_length(hist_data, group_labels)
if isinstance(bin_size, (float, int)):
bin_size = [bin_size] * len(hist_data)
data = []
if show_hist:
hist = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_hist()
data.append(hist)
if show_curve:
if curve_type == "normal":
curve = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_normal()
else:
curve = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_kde()
data.append(curve)
if show_rug:
rug = _Distplot(
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
).make_rug()
data.append(rug)
layout = graph_objs.Layout(
barmode="overlay",
hovermode="closest",
legend=dict(traceorder="reversed"),
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
)
else:
layout = graph_objs.Layout(
barmode="overlay",
hovermode="closest",
legend=dict(traceorder="reversed"),
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
)
data = sum(data, [])
return graph_objs.Figure(data=data, layout=layout)
class _Distplot(object):
"""
Refer to TraceFactory.create_distplot() for docstring
"""
def __init__(
self,
hist_data,
histnorm,
group_labels,
bin_size,
curve_type,
colors,
rug_text,
show_hist,
show_curve,
):
self.hist_data = hist_data
self.histnorm = histnorm
self.group_labels = group_labels
self.bin_size = bin_size
self.show_hist = show_hist
self.show_curve = show_curve
self.trace_number = len(hist_data)
if rug_text:
self.rug_text = rug_text
else:
self.rug_text = [None] * self.trace_number
self.start = []
self.end = []
if colors:
self.colors = colors
else:
self.colors = [
"rgb(31, 119, 180)",
"rgb(255, 127, 14)",
"rgb(44, 160, 44)",
"rgb(214, 39, 40)",
"rgb(148, 103, 189)",
"rgb(140, 86, 75)",
"rgb(227, 119, 194)",
"rgb(127, 127, 127)",
"rgb(188, 189, 34)",
"rgb(23, 190, 207)",
]
self.curve_x = [None] * self.trace_number
self.curve_y = [None] * self.trace_number
for trace in self.hist_data:
self.start.append(min(trace) * 1.0)
self.end.append(max(trace) * 1.0)
def make_hist(self):
"""
Makes the histogram(s) for FigureFactory.create_distplot().
:rtype (list) hist: list of histogram representations
"""
hist = [None] * self.trace_number
for index in range(self.trace_number):
hist[index] = dict(
type="histogram",
x=self.hist_data[index],
xaxis="x1",
yaxis="y1",
histnorm=self.histnorm,
name=self.group_labels[index],
legendgroup=self.group_labels[index],
marker=dict(color=self.colors[index % len(self.colors)]),
autobinx=False,
xbins=dict(
start=self.start[index],
end=self.end[index],
size=self.bin_size[index],
),
opacity=0.7,
)
return hist
def make_kde(self):
"""
Makes the kernel density estimation(s) for create_distplot().
This is called when curve_type = 'kde' in create_distplot().
:rtype (list) curve: list of kde representations
"""
curve = [None] * self.trace_number
for index in range(self.trace_number):
self.curve_x[index] = [
self.start[index] + x * (self.end[index] - self.start[index]) / 500
for x in range(500)
]
self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
self.curve_x[index]
)
if self.histnorm == ALTERNATIVE_HISTNORM:
self.curve_y[index] *= self.bin_size[index]
for index in range(self.trace_number):
curve[index] = dict(
type="scatter",
x=self.curve_x[index],
y=self.curve_y[index],
xaxis="x1",
yaxis="y1",
mode="lines",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=False if self.show_hist else True,
marker=dict(color=self.colors[index % len(self.colors)]),
)
return curve
def make_normal(self):
"""
Makes the normal curve(s) for create_distplot().
This is called when curve_type = 'normal' in create_distplot().
:rtype (list) curve: list of normal curve representations
"""
curve = [None] * self.trace_number
mean = [None] * self.trace_number
sd = [None] * self.trace_number
for index in range(self.trace_number):
mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
self.curve_x[index] = [
self.start[index] + x * (self.end[index] - self.start[index]) / 500
for x in range(500)
]
self.curve_y[index] = scipy_stats.norm.pdf(
self.curve_x[index], loc=mean[index], scale=sd[index]
)
if self.histnorm == ALTERNATIVE_HISTNORM:
self.curve_y[index] *= self.bin_size[index]
for index in range(self.trace_number):
curve[index] = dict(
type="scatter",
x=self.curve_x[index],
y=self.curve_y[index],
xaxis="x1",
yaxis="y1",
mode="lines",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=False if self.show_hist else True,
marker=dict(color=self.colors[index % len(self.colors)]),
)
return curve
def make_rug(self):
"""
Makes the rug plot(s) for create_distplot().
:rtype (list) rug: list of rug plot representations
"""
rug = [None] * self.trace_number
for index in range(self.trace_number):
rug[index] = dict(
type="scatter",
x=self.hist_data[index],
y=([self.group_labels[index]] * len(self.hist_data[index])),
xaxis="x1",
yaxis="y2",
mode="markers",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=(False if self.show_hist or self.show_curve else True),
text=self.rug_text[index],
marker=dict(
color=self.colors[index % len(self.colors)], symbol="line-ns-open"
),
)
return rug

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,526 @@
from plotly.express._core import build_dataframe
from plotly.express._doc import make_docstring
from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
import narwhals.stable.v1 as nw
import numpy as np
def _project_latlon_to_wgs84(lat, lon):
"""
Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
"""
x = lon * np.pi / 180
y = np.arctanh(np.sin(lat * np.pi / 180))
return x, y
def _project_wgs84_to_latlon(x, y):
"""
Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
"""
lon = x * 180 / np.pi
lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
return lat, lon
def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim):
"""
Get the mapbox zoom level given bounds and a figure dimension
Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds
"""
scale = (
2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256
)
WORLD_DIM = {"height": 256 * scale, "width": 256 * scale}
ZOOM_MAX = 18
def latRad(lat):
sin = np.sin(lat * np.pi / 180)
radX2 = np.log((1 + sin) / (1 - sin)) / 2
return max(min(radX2, np.pi), -np.pi) / 2
def zoom(mapPx, worldPx, fraction):
return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2)
latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi
lngDiff = lon_max - lon_min
lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360
latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction)
lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction)
return min(latZoom, lngZoom, ZOOM_MAX)
def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count):
"""
Computes the aggregation at hexagonal bin level.
Also defines the coordinates of the hexagons for plotting.
The binning is inspired by matplotlib's implementation.
Parameters
----------
x : np.ndarray
Array of x values (shape N)
y : np.ndarray
Array of y values (shape N)
x_range : np.ndarray
Min and max x (shape 2)
y_range : np.ndarray
Min and max y (shape 2)
color : np.ndarray
Metric to aggregate at hexagon level (shape N)
nx : int
Number of hexagons horizontally
agg_func : function
Numpy compatible aggregator, this function must take a one-dimensional
np.ndarray as input and output a scalar
min_count : int
Minimum number of points in the hexagon for the hexagon to be displayed
Returns
-------
np.ndarray
X coordinates of each hexagon (shape M x 6)
np.ndarray
Y coordinates of each hexagon (shape M x 6)
np.ndarray
Centers of the hexagons (shape M x 2)
np.ndarray
Aggregated value in each hexagon (shape M)
"""
xmin = x_range.min()
xmax = x_range.max()
ymin = y_range.min()
ymax = y_range.max()
# In the x-direction, the hexagons exactly cover the region from
# xmin to xmax. Need some padding to avoid roundoff errors.
padding = 1.0e-9 * (xmax - xmin)
xmin -= padding
xmax += padding
Dx = xmax - xmin
Dy = ymax - ymin
if Dx == 0 and Dy > 0:
dx = Dy / nx
elif Dx == 0 and Dy == 0:
dx, _ = _project_latlon_to_wgs84(1, 1)
else:
dx = Dx / nx
dy = dx * np.sqrt(3)
ny = np.ceil(Dy / dy).astype(int)
# Center the hexagons vertically since we only want regular hexagons
ymin -= (ymin + dy * ny - ymax) / 2
x = (x - xmin) / dx
y = (y - ymin) / dy
ix1 = np.round(x).astype(int)
iy1 = np.round(y).astype(int)
ix2 = np.floor(x).astype(int)
iy2 = np.floor(y).astype(int)
nx1 = nx + 1
ny1 = ny + 1
nx2 = nx
ny2 = ny
n = nx1 * ny1 + nx2 * ny2
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
bdist = d1 < d2
if color is None:
lattice1 = np.zeros((nx1, ny1))
lattice2 = np.zeros((nx2, ny2))
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
if min_count is not None:
lattice1[lattice1 < min_count] = np.nan
lattice2[lattice2 < min_count] = np.nan
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
good_idxs = ~np.isnan(accum)
else:
if min_count is None:
min_count = 1
# create accumulation arrays
lattice1 = np.empty((nx1, ny1), dtype=object)
for i in range(nx1):
for j in range(ny1):
lattice1[i, j] = []
lattice2 = np.empty((nx2, ny2), dtype=object)
for i in range(nx2):
for j in range(ny2):
lattice2[i, j] = []
for i in range(len(x)):
if bdist[i]:
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
lattice1[ix1[i], iy1[i]].append(color[i])
else:
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
lattice2[ix2[i], iy2[i]].append(color[i])
for i in range(nx1):
for j in range(ny1):
vals = lattice1[i, j]
if len(vals) >= min_count:
lattice1[i, j] = agg_func(vals)
else:
lattice1[i, j] = np.nan
for i in range(nx2):
for j in range(ny2):
vals = lattice2[i, j]
if len(vals) >= min_count:
lattice2[i, j] = agg_func(vals)
else:
lattice2[i, j] = np.nan
accum = np.hstack(
(lattice1.astype(float).ravel(), lattice2.astype(float).ravel())
)
good_idxs = ~np.isnan(accum)
agreggated_value = accum[good_idxs]
centers = np.zeros((n, 2), float)
centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5
centers[:, 0] *= dx
centers[:, 1] *= dy
centers[:, 0] += xmin
centers[:, 1] += ymin
centers = centers[good_idxs]
# Define normalised regular hexagon coordinates
hx = [0, 0.5, 0.5, 0, -0.5, -0.5]
hy = [
-0.5 / np.cos(np.pi / 6),
-0.5 * np.tan(np.pi / 6),
0.5 * np.tan(np.pi / 6),
0.5 / np.cos(np.pi / 6),
0.5 * np.tan(np.pi / 6),
-0.5 * np.tan(np.pi / 6),
]
# Number of hexagons needed
m = len(centers)
# Coordinates for all hexagonal patches
hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0])
hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1])
return hxs, hys, centers, agreggated_value
def _compute_wgs84_hexbin(
lat=None,
lon=None,
lat_range=None,
lon_range=None,
color=None,
nx=None,
agg_func=None,
min_count=None,
native_namespace=None,
):
"""
Computes the lat-lon aggregation at hexagonal bin level.
Latitude and longitude need to be projected to WGS84 before aggregating
in order to display regular hexagons on the map.
Parameters
----------
lat : np.ndarray
Array of latitudes (shape N)
lon : np.ndarray
Array of longitudes (shape N)
lat_range : np.ndarray
Min and max latitudes (shape 2)
lon_range : np.ndarray
Min and max longitudes (shape 2)
color : np.ndarray
Metric to aggregate at hexagon level (shape N)
nx : int
Number of hexagons horizontally
agg_func : function
Numpy compatible aggregator, this function must take a one-dimensional
np.ndarray as input and output a scalar
min_count : int
Minimum number of points in the hexagon for the hexagon to be displayed
Returns
-------
np.ndarray
Lat coordinates of each hexagon (shape M x 6)
np.ndarray
Lon coordinates of each hexagon (shape M x 6)
nw.Series
Unique id for each hexagon, to be used in the geojson data (shape M)
np.ndarray
Aggregated value in each hexagon (shape M)
"""
# Project to WGS 84
x, y = _project_latlon_to_wgs84(lat, lon)
if lat_range is None:
lat_range = np.array([lat.min(), lat.max()])
if lon_range is None:
lon_range = np.array([lon.min(), lon.max()])
x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
hxs, hys, centers, agreggated_value = _compute_hexbin(
x, y, x_range, y_range, color, nx, agg_func, min_count
)
# Convert back to lat-lon
hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
# Create unique feature id based on hexagon center
centers = centers.astype(str)
hexagons_ids = (
nw.from_dict(
{"x1": centers[:, 0], "x2": centers[:, 1]},
native_namespace=native_namespace,
)
.select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=","))
.get_column("hexagons_ids")
)
return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value
def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
"""
Creates a geojson of hexagonal features based on the outputs of
_compute_wgs84_hexbin
"""
features = []
if ids is None:
ids = np.arange(len(hexagons_lats))
for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids):
points = np.array([lon, lat]).T.tolist()
points.append(points[0])
features.append(
dict(
type="Feature",
id=idx,
geometry=dict(type="Polygon", coordinates=[points]),
)
)
return dict(type="FeatureCollection", features=features)
def create_hexbin_mapbox(
data_frame=None,
lat=None,
lon=None,
color=None,
nx_hexagon=5,
agg_func=None,
animation_frame=None,
color_discrete_sequence=None,
color_discrete_map={},
labels={},
color_continuous_scale=None,
range_color=None,
color_continuous_midpoint=None,
opacity=None,
zoom=None,
center=None,
mapbox_style=None,
title=None,
template=None,
width=None,
height=None,
min_count=None,
show_original_data=False,
original_data_marker=None,
):
"""
Returns a figure aggregating scattered points into connected hexagons
"""
args = build_dataframe(args=locals(), constructor=None)
native_namespace = nw.get_native_namespace(args["data_frame"])
if agg_func is None:
agg_func = np.mean
lat_range = (
args["data_frame"]
.select(
nw.min(args["lat"]).name.suffix("_min"),
nw.max(args["lat"]).name.suffix("_max"),
)
.to_numpy()
.squeeze()
)
lon_range = (
args["data_frame"]
.select(
nw.min(args["lon"]).name.suffix("_min"),
nw.max(args["lon"]).name.suffix("_max"),
)
.to_numpy()
.squeeze()
)
hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
lat=args["data_frame"].get_column(args["lat"]).to_numpy(),
lon=args["data_frame"].get_column(args["lon"]).to_numpy(),
lat_range=lat_range,
lon_range=lon_range,
color=None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
native_namespace=native_namespace,
)
geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
if zoom is None:
if height is None and width is None:
mapDim = dict(height=450, width=450)
elif height is None and width is not None:
mapDim = dict(height=450, width=width)
elif height is not None and width is None:
mapDim = dict(height=height, width=height)
else:
mapDim = dict(height=height, width=width)
zoom = _getBoundsZoomLevel(
lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim
)
if center is None:
center = dict(lat=lat_range.mean(), lon=lon_range.mean())
if args["animation_frame"] is not None:
groups = dict(
args["data_frame"]
.group_by(args["animation_frame"], drop_null_keys=True)
.__iter__()
)
else:
groups = {(0,): args["data_frame"]}
agg_data_frame_list = []
for key, df in groups.items():
_, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
lat=df.get_column(args["lat"]).to_numpy(),
lon=df.get_column(args["lon"]).to_numpy(),
lat_range=lat_range,
lon_range=lon_range,
color=df.get_column(args["color"]).to_numpy() if args["color"] else None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
native_namespace=native_namespace,
)
agg_data_frame_list.append(
nw.from_dict(
{
"frame": [key[0]] * len(hexagons_ids),
"locations": hexagons_ids,
"color": aggregated_value,
},
native_namespace=native_namespace,
)
)
agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns(
color=nw.col("color").cast(nw.Int64)
)
if range_color is None:
range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]
fig = choropleth_mapbox(
data_frame=agg_data_frame.to_native(),
geojson=geojson,
locations="locations",
color="color",
hover_data={"color": True, "locations": False, "frame": False},
animation_frame=("frame" if args["animation_frame"] is not None else None),
color_discrete_sequence=color_discrete_sequence,
color_discrete_map=color_discrete_map,
labels=labels,
color_continuous_scale=color_continuous_scale,
range_color=range_color,
color_continuous_midpoint=color_continuous_midpoint,
opacity=opacity,
zoom=zoom,
center=center,
mapbox_style=mapbox_style,
title=title,
template=template,
width=width,
height=height,
)
if show_original_data:
original_fig = scatter_mapbox(
data_frame=(
args["data_frame"].sort(
by=args["animation_frame"], descending=False, nulls_last=True
)
if args["animation_frame"] is not None
else args["data_frame"]
).to_native(),
lat=args["lat"],
lon=args["lon"],
animation_frame=args["animation_frame"],
)
original_fig.data[0].hoverinfo = "skip"
original_fig.data[0].hovertemplate = None
original_fig.data[0].marker = original_data_marker
fig.add_trace(original_fig.data[0])
if args["animation_frame"] is not None:
for i in range(len(original_fig.frames)):
original_fig.frames[i].data[0].hoverinfo = "skip"
original_fig.frames[i].data[0].hovertemplate = None
original_fig.frames[i].data[0].marker = original_data_marker
fig.frames[i].data = [
fig.frames[i].data[0],
original_fig.frames[i].data[0],
]
return fig
create_hexbin_mapbox.__doc__ = make_docstring(
create_hexbin_mapbox,
override_dict=dict(
nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
agg_func=[
"function",
"Numpy array aggregator, it must take as input a 1D array",
"and output a scalar value.",
],
min_count=[
"int",
"Minimum number of points in a hexagon for it to be displayed.",
"If None and color is not set, display all hexagons.",
"If None and color is set, only display hexagons that contain points.",
],
show_original_data=[
"bool",
"Whether to show the original data on top of the hexbin aggregation.",
],
original_data_marker=["dict", "Scattermapbox marker options."],
),
)

View File

@ -0,0 +1,295 @@
from plotly import exceptions
from plotly.graph_objs import graph_objs
from plotly.figure_factory import utils
# Default colours for finance charts
_DEFAULT_INCREASING_COLOR = "#3D9970" # http://clrs.cc
_DEFAULT_DECREASING_COLOR = "#FF4136"
def validate_ohlc(open, high, low, close, direction, **kwargs):
"""
ohlc and candlestick specific validations
Specifically, this checks that the high value is the greatest value and
the low value is the lowest value in each unit.
See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
for params
:raises: (PlotlyError) If the high value is not the greatest value in
each unit.
:raises: (PlotlyError) If the low value is not the lowest value in each
unit.
:raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
"""
for lst in [open, low, close]:
for index in range(len(high)):
if high[index] < lst[index]:
raise exceptions.PlotlyError(
"Oops! Looks like some of "
"your high values are less "
"the corresponding open, "
"low, or close values. "
"Double check that your data "
"is entered in O-H-L-C order"
)
for lst in [open, high, close]:
for index in range(len(low)):
if low[index] > lst[index]:
raise exceptions.PlotlyError(
"Oops! Looks like some of "
"your low values are greater "
"than the corresponding high"
", open, or close values. "
"Double check that your data "
"is entered in O-H-L-C order"
)
direction_opts = ("increasing", "decreasing", "both")
if direction not in direction_opts:
raise exceptions.PlotlyError(
"direction must be defined as 'increasing', 'decreasing', or 'both'"
)
def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
"""
Makes increasing ohlc sticks
_make_increasing_ohlc() and _make_decreasing_ohlc separate the
increasing trace from the decreasing trace so kwargs (such as
color) can be passed separately to increasing or decreasing traces
when direction is set to 'increasing' or 'decreasing' in
FigureFactory.create_candlestick()
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
sticks.
"""
(flat_increase_x, flat_increase_y, text_increase) = _OHLC(
open, high, low, close, dates
).get_increase()
if "name" in kwargs:
showlegend = True
else:
kwargs.setdefault("name", "Increasing")
showlegend = False
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR, width=1))
kwargs.setdefault("text", text_increase)
ohlc_incr = dict(
type="scatter",
x=flat_increase_x,
y=flat_increase_y,
mode="lines",
showlegend=showlegend,
**kwargs,
)
return ohlc_incr
def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
"""
Makes decreasing ohlc sticks
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing values
:param (list) dates: list of datetime objects. Default: None
:param kwargs: kwargs to be passed to increasing trace via
plotly.graph_objs.Scatter.
:rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
sticks.
"""
(flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC(
open, high, low, close, dates
).get_decrease()
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR, width=1))
kwargs.setdefault("text", text_decrease)
kwargs.setdefault("showlegend", False)
kwargs.setdefault("name", "Decreasing")
ohlc_decr = dict(
type="scatter", x=flat_decrease_x, y=flat_decrease_y, mode="lines", **kwargs
)
return ohlc_decr
def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Ohlc`
:param (list) open: opening values
:param (list) high: high values
:param (list) low: low values
:param (list) close: closing
:param (list) dates: list of datetime objects. Default: None
:param (string) direction: direction can be 'increasing', 'decreasing',
or 'both'. When the direction is 'increasing', the returned figure
consists of all units where the close value is greater than the
corresponding open value, and when the direction is 'decreasing',
the returned figure consists of all units where the close value is
less than or equal to the corresponding open value. When the
direction is 'both', both increasing and decreasing units are
returned. Default: 'both'
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
These kwargs describe other attributes about the ohlc Scatter trace
such as the color or the legend name. For more information on valid
kwargs call help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of an ohlc chart figure.
Example 1: Simple OHLC chart from a Pandas DataFrame
>>> from plotly.figure_factory import create_ohlc
>>> from datetime import datetime
>>> import pandas as pd
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
>>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index)
>>> fig.show()
"""
if dates is not None:
utils.validate_equal_length(open, high, low, close, dates)
else:
utils.validate_equal_length(open, high, low, close)
validate_ohlc(open, high, low, close, direction, **kwargs)
if direction == "increasing":
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_incr]
elif direction == "decreasing":
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_decr]
else:
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
data = [ohlc_incr, ohlc_decr]
layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")
return graph_objs.Figure(data=data, layout=layout)
class _OHLC(object):
"""
Refer to FigureFactory.create_ohlc_increase() for docstring.
"""
def __init__(self, open, high, low, close, dates, **kwargs):
self.open = open
self.high = high
self.low = low
self.close = close
self.empty = [None] * len(open)
self.dates = dates
self.all_x = []
self.all_y = []
self.increase_x = []
self.increase_y = []
self.decrease_x = []
self.decrease_y = []
self.get_all_xy()
self.separate_increase_decrease()
def get_all_xy(self):
"""
Zip data to create OHLC shape
OHLC shape: low to high vertical bar with
horizontal branches for open and close values.
If dates were added, the smallest date difference is calculated and
multiplied by .2 to get the length of the open and close branches.
If no date data was provided, the x-axis is a list of integers and the
length of the open and close branches is .2.
"""
self.all_y = list(
zip(
self.open,
self.open,
self.high,
self.low,
self.close,
self.close,
self.empty,
)
)
if self.dates is not None:
date_dif = []
for i in range(len(self.dates) - 1):
date_dif.append(self.dates[i + 1] - self.dates[i])
date_dif_min = (min(date_dif)) / 5
self.all_x = [
[x - date_dif_min, x, x, x, x, x + date_dif_min, None]
for x in self.dates
]
else:
self.all_x = [
[x - 0.2, x, x, x, x, x + 0.2, None] for x in range(len(self.open))
]
def separate_increase_decrease(self):
"""
Separate data into two groups: increase and decrease
(1) Increase, where close > open and
(2) Decrease, where close <= open
"""
for index in range(len(self.open)):
if self.close[index] is None:
pass
elif self.close[index] > self.open[index]:
self.increase_x.append(self.all_x[index])
self.increase_y.append(self.all_y[index])
else:
self.decrease_x.append(self.all_x[index])
self.decrease_y.append(self.all_y[index])
def get_increase(self):
"""
Flatten increase data and get increase text
:rtype (list, list, list): flat_increase_x: x-values for the increasing
trace, flat_increase_y: y=values for the increasing trace and
text_increase: hovertext for the increasing trace
"""
flat_increase_x = utils.flatten(self.increase_x)
flat_increase_y = utils.flatten(self.increase_y)
text_increase = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
len(self.increase_x)
)
return flat_increase_x, flat_increase_y, text_increase
def get_decrease(self):
"""
Flatten decrease data and get decrease text
:rtype (list, list, list): flat_decrease_x: x-values for the decreasing
trace, flat_decrease_y: y=values for the decreasing trace and
text_decrease: hovertext for the decreasing trace
"""
flat_decrease_x = utils.flatten(self.decrease_x)
flat_decrease_y = utils.flatten(self.decrease_y)
text_decrease = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
len(self.decrease_x)
)
return flat_decrease_x, flat_decrease_y, text_decrease

View File

@ -0,0 +1,265 @@
import math
from plotly import exceptions
from plotly.graph_objs import graph_objs
from plotly.figure_factory import utils
def create_quiver(
x, y, u, v, scale=0.1, arrow_scale=0.3, angle=math.pi / 9, scaleratio=None, **kwargs
):
"""
Returns data for a quiver plot.
:param (list|ndarray) x: x coordinates of the arrow locations
:param (list|ndarray) y: y coordinates of the arrow locations
:param (list|ndarray) u: x components of the arrow vectors
:param (list|ndarray) v: y components of the arrow vectors
:param (float in [0,1]) scale: scales size of the arrows(ideally to
avoid overlap). Default = .1
:param (float in [0,1]) arrow_scale: value multiplied to length of barb
to get length of arrowhead. Default = .3
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (positive float) scaleratio: the ratio between the scale of the y-axis
and the scale of the x-axis (scale_y / scale_x). Default = None, the
scale ratio is not fixed.
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
for more information on valid kwargs call
help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of quiver figure.
Example 1: Trivial Quiver
>>> from plotly.figure_factory import create_quiver
>>> import math
>>> # 1 Arrow from (0,0) to (1,1)
>>> fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
>>> fig.show()
Example 2: Quiver plot using meshgrid
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> import math
>>> # Add data
>>> x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
>>> u = np.cos(x)*y
>>> v = np.sin(x)*y
>>> #Create quiver
>>> fig = create_quiver(x, y, u, v)
>>> fig.show()
Example 3: Styling the quiver plot
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> import math
>>> # Add data
>>> x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
... np.arange(-math.pi, math.pi, .5))
>>> u = np.cos(x)*y
>>> v = np.sin(x)*y
>>> # Create quiver
>>> fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
... name='Wind Velocity', line=dict(width=1))
>>> # Add title to layout
>>> fig.update_layout(title='Quiver Plot') # doctest: +SKIP
>>> fig.show()
Example 4: Forcing a fix scale ratio to maintain the arrow length
>>> from plotly.figure_factory import create_quiver
>>> import numpy as np
>>> # Add data
>>> x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
>>> u = x
>>> v = y
>>> angle = np.arctan(v / u)
>>> norm = 0.25
>>> u = norm * np.cos(angle)
>>> v = norm * np.sin(angle)
>>> # Create quiver with a fix scale ratio
>>> fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)
>>> fig.show()
"""
utils.validate_equal_length(x, y, u, v)
utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
if scaleratio is None:
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
else:
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)
barb_x, barb_y = quiver_obj.get_barbs()
arrow_x, arrow_y = quiver_obj.get_quiver_arrows()
quiver_plot = graph_objs.Scatter(
x=barb_x + arrow_x, y=barb_y + arrow_y, mode="lines", **kwargs
)
data = [quiver_plot]
if scaleratio is None:
layout = graph_objs.Layout(hovermode="closest")
else:
layout = graph_objs.Layout(
hovermode="closest", yaxis=dict(scaleratio=scaleratio, scaleanchor="x")
)
return graph_objs.Figure(data=data, layout=layout)
class _Quiver(object):
"""
Refer to FigureFactory.create_quiver() for docstring
"""
def __init__(self, x, y, u, v, scale, arrow_scale, angle, scaleratio=1, **kwargs):
try:
x = utils.flatten(x)
except exceptions.PlotlyError:
pass
try:
y = utils.flatten(y)
except exceptions.PlotlyError:
pass
try:
u = utils.flatten(u)
except exceptions.PlotlyError:
pass
try:
v = utils.flatten(v)
except exceptions.PlotlyError:
pass
self.x = x
self.y = y
self.u = u
self.v = v
self.scale = scale
self.scaleratio = scaleratio
self.arrow_scale = arrow_scale
self.angle = angle
self.end_x = []
self.end_y = []
self.scale_uv()
barb_x, barb_y = self.get_barbs()
arrow_x, arrow_y = self.get_quiver_arrows()
def scale_uv(self):
"""
Scales u and v to avoid overlap of the arrows.
u and v are added to x and y to get the
endpoints of the arrows so a smaller scale value will
result in less overlap of arrows.
"""
self.u = [i * self.scale * self.scaleratio for i in self.u]
self.v = [i * self.scale for i in self.v]
def get_barbs(self):
"""
Creates x and y startpoint and endpoint pairs
After finding the endpoint of each barb this zips startpoint and
endpoint pairs to create 2 lists: x_values for barbs and y values
for barbs
:rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
x_value pairs separated by a None to create the barb of the arrow,
and list of startpoint and endpoint y_value pairs separated by a
None to create the barb of the arrow.
"""
self.end_x = [i + j for i, j in zip(self.x, self.u)]
self.end_y = [i + j for i, j in zip(self.y, self.v)]
empty = [None] * len(self.x)
barb_x = utils.flatten(zip(self.x, self.end_x, empty))
barb_y = utils.flatten(zip(self.y, self.end_y, empty))
return barb_x, barb_y
def get_quiver_arrows(self):
"""
Creates lists of x and y values to plot the arrows
Gets length of each barb then calculates the length of each side of
the arrow. Gets angle of barb and applies angle to each side of the
arrowhead. Next uses arrow_scale to scale the length of arrowhead and
creates x and y values for arrowhead point1 and point2. Finally x and y
values for point1, endpoint and point2s for each arrowhead are
separated by a None and zipped to create lists of x and y values for
the arrows.
:rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
x_values separated by a None to create the arrowhead and list of
point1, endpoint, point2 y_values separated by a None to create
the barb of the arrow.
"""
dif_x = [i - j for i, j in zip(self.end_x, self.x)]
dif_y = [i - j for i, j in zip(self.end_y, self.y)]
# Get barb lengths(default arrow length = 30% barb length)
barb_len = [None] * len(self.x)
for index in range(len(barb_len)):
barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index])
# Make arrow lengths
arrow_len = [None] * len(self.x)
arrow_len = [i * self.arrow_scale for i in barb_len]
# Get barb angles
barb_ang = [None] * len(self.x)
for index in range(len(barb_ang)):
barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio)
# Set angles to create arrow
ang1 = [i + self.angle for i in barb_ang]
ang2 = [i - self.angle for i in barb_ang]
cos_ang1 = [None] * len(ang1)
for index in range(len(ang1)):
cos_ang1[index] = math.cos(ang1[index])
seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
sin_ang1 = [None] * len(ang1)
for index in range(len(ang1)):
sin_ang1[index] = math.sin(ang1[index])
seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
cos_ang2 = [None] * len(ang2)
for index in range(len(ang2)):
cos_ang2[index] = math.cos(ang2[index])
seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
sin_ang2 = [None] * len(ang2)
for index in range(len(ang2)):
sin_ang2[index] = math.sin(ang2[index])
seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
# Set coordinates to create arrow
for index in range(len(self.end_x)):
point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)]
point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)]
point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
# Combine lists to create arrow
empty = [None] * len(self.end_x)
arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
return arrow_x, arrow_y

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,406 @@
import math
from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
np = optional_imports.get_module("numpy")
def validate_streamline(x, y):
"""
Streamline-specific validations
Specifically, this checks that x and y are both evenly spaced,
and that the package numpy is available.
See FigureFactory.create_streamline() for params
:raises: (ImportError) If numpy is not available.
:raises: (PlotlyError) If x is not evenly spaced.
:raises: (PlotlyError) If y is not evenly spaced.
"""
if np is False:
raise ImportError("FigureFactory.create_streamline requires numpy")
for index in range(len(x) - 1):
if ((x[index + 1] - x[index]) - (x[1] - x[0])) > 0.0001:
raise exceptions.PlotlyError(
"x must be a 1 dimensional, evenly spaced array"
)
for index in range(len(y) - 1):
if ((y[index + 1] - y[index]) - (y[1] - y[0])) > 0.0001:
raise exceptions.PlotlyError(
"y must be a 1 dimensional, evenly spaced array"
)
def create_streamline(
x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=0.09, **kwargs
):
"""
Returns data for a streamline plot.
:param (list|ndarray) x: 1 dimensional, evenly spaced list or array
:param (list|ndarray) y: 1 dimensional, evenly spaced list or array
:param (ndarray) u: 2 dimensional array
:param (ndarray) v: 2 dimensional array
:param (float|int) density: controls the density of streamlines in
plot. This is multiplied by 30 to scale similiarly to other
available streamline functions such as matplotlib.
Default = 1
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
Default = .09
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
for more information on valid kwargs call
help(plotly.graph_objs.Scatter)
:rtype (dict): returns a representation of streamline figure.
Example 1: Plot simple streamline and increase arrow size
>>> from plotly.figure_factory import create_streamline
>>> import plotly.graph_objects as go
>>> import numpy as np
>>> import math
>>> # Add data
>>> x = np.linspace(-3, 3, 100)
>>> y = np.linspace(-3, 3, 100)
>>> Y, X = np.meshgrid(x, y)
>>> u = -1 - X**2 + Y
>>> v = 1 + X - Y**2
>>> u = u.T # Transpose
>>> v = v.T # Transpose
>>> # Create streamline
>>> fig = create_streamline(x, y, u, v, arrow_scale=.1)
>>> fig.show()
Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
>>> from plotly.figure_factory import create_streamline
>>> import numpy as np
>>> import math
>>> # Add data
>>> N = 50
>>> x_start, x_end = -2.0, 2.0
>>> y_start, y_end = -1.0, 1.0
>>> x = np.linspace(x_start, x_end, N)
>>> y = np.linspace(y_start, y_end, N)
>>> X, Y = np.meshgrid(x, y)
>>> ss = 5.0
>>> x_s, y_s = -1.0, 0.0
>>> # Compute the velocity field on the mesh grid
>>> u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
>>> v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
>>> # Create streamline
>>> fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
>>> # Add source point
>>> point = go.Scatter(x=[x_s], y=[y_s], mode='markers',
... marker_size=14, name='source point')
>>> fig.add_trace(point) # doctest: +SKIP
>>> fig.show()
"""
utils.validate_equal_length(x, y)
utils.validate_equal_length(u, v)
validate_streamline(x, y)
utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
streamline_x, streamline_y = _Streamline(
x, y, u, v, density, angle, arrow_scale
).sum_streamlines()
arrow_x, arrow_y = _Streamline(
x, y, u, v, density, angle, arrow_scale
).get_streamline_arrows()
streamline = graph_objs.Scatter(
x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode="lines", **kwargs
)
data = [streamline]
layout = graph_objs.Layout(hovermode="closest")
return graph_objs.Figure(data=data, layout=layout)
class _Streamline(object):
"""
Refer to FigureFactory.create_streamline() for docstring
"""
def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs):
self.x = np.array(x)
self.y = np.array(y)
self.u = np.array(u)
self.v = np.array(v)
self.angle = angle
self.arrow_scale = arrow_scale
self.density = int(30 * density) # Scale similarly to other functions
self.delta_x = self.x[1] - self.x[0]
self.delta_y = self.y[1] - self.y[0]
self.val_x = self.x
self.val_y = self.y
# Set up spacing
self.blank = np.zeros((self.density, self.density))
self.spacing_x = len(self.x) / float(self.density - 1)
self.spacing_y = len(self.y) / float(self.density - 1)
self.trajectories = []
# Rescale speed onto axes-coordinates
self.u = self.u / (self.x[-1] - self.x[0])
self.v = self.v / (self.y[-1] - self.y[0])
self.speed = np.sqrt(self.u**2 + self.v**2)
# Rescale u and v for integrations.
self.u *= len(self.x)
self.v *= len(self.y)
self.st_x = []
self.st_y = []
self.get_streamlines()
streamline_x, streamline_y = self.sum_streamlines()
arrows_x, arrows_y = self.get_streamline_arrows()
def blank_pos(self, xi, yi):
"""
Set up positions for trajectories to be used with rk4 function.
"""
return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5))
def value_at(self, a, xi, yi):
"""
Set up for RK4 function, based on Bokeh's streamline code
"""
if isinstance(xi, np.ndarray):
self.x = xi.astype(int)
self.y = yi.astype(int)
else:
self.val_x = int(xi)
self.val_y = int(yi)
a00 = a[self.val_y, self.val_x]
a01 = a[self.val_y, self.val_x + 1]
a10 = a[self.val_y + 1, self.val_x]
a11 = a[self.val_y + 1, self.val_x + 1]
xt = xi - self.val_x
yt = yi - self.val_y
a0 = a00 * (1 - xt) + a01 * xt
a1 = a10 * (1 - xt) + a11 * xt
return a0 * (1 - yt) + a1 * yt
def rk4_integrate(self, x0, y0):
"""
RK4 forward and back trajectories from the initial conditions.
Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
x and y trajectories then checks length of traj (s in units of axes)
"""
def f(xi, yi):
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
ui = self.value_at(self.u, xi, yi)
vi = self.value_at(self.v, xi, yi)
return ui * dt_ds, vi * dt_ds
def g(xi, yi):
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
ui = self.value_at(self.u, xi, yi)
vi = self.value_at(self.v, xi, yi)
return -ui * dt_ds, -vi * dt_ds
def check(xi, yi):
return (0 <= xi < len(self.x) - 1) and (0 <= yi < len(self.y) - 1)
xb_changes = []
yb_changes = []
def rk4(x0, y0, f):
ds = 0.01
stotal = 0
xi = x0
yi = y0
xb, yb = self.blank_pos(xi, yi)
xf_traj = []
yf_traj = []
while check(xi, yi):
xf_traj.append(xi)
yf_traj.append(yi)
try:
k1x, k1y = f(xi, yi)
k2x, k2y = f(xi + 0.5 * ds * k1x, yi + 0.5 * ds * k1y)
k3x, k3y = f(xi + 0.5 * ds * k2x, yi + 0.5 * ds * k2y)
k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
except IndexError:
break
xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.0
yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.0
if not check(xi, yi):
break
stotal += ds
new_xb, new_yb = self.blank_pos(xi, yi)
if new_xb != xb or new_yb != yb:
if self.blank[new_yb, new_xb] == 0:
self.blank[new_yb, new_xb] = 1
xb_changes.append(new_xb)
yb_changes.append(new_yb)
xb = new_xb
yb = new_yb
else:
break
if stotal > 2:
break
return stotal, xf_traj, yf_traj
sf, xf_traj, yf_traj = rk4(x0, y0, f)
sb, xb_traj, yb_traj = rk4(x0, y0, g)
stotal = sf + sb
x_traj = xb_traj[::-1] + xf_traj[1:]
y_traj = yb_traj[::-1] + yf_traj[1:]
if len(x_traj) < 1:
return None
if stotal > 0.2:
initxb, inityb = self.blank_pos(x0, y0)
self.blank[inityb, initxb] = 1
return x_traj, y_traj
else:
for xb, yb in zip(xb_changes, yb_changes):
self.blank[yb, xb] = 0
return None
def traj(self, xb, yb):
"""
Integrate trajectories
:param (int) xb: results of passing xi through self.blank_pos
:param (int) xy: results of passing yi through self.blank_pos
Calculate each trajectory based on rk4 integrate method.
"""
if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
return
if self.blank[yb, xb] == 0:
t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
if t is not None:
self.trajectories.append(t)
def get_streamlines(self):
"""
Get streamlines by building trajectory set.
"""
for indent in range(self.density // 2):
for xi in range(self.density - 2 * indent):
self.traj(xi + indent, indent)
self.traj(xi + indent, self.density - 1 - indent)
self.traj(indent, xi + indent)
self.traj(self.density - 1 - indent, xi + indent)
self.st_x = [
np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories
]
self.st_y = [
np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories
]
for index in range(len(self.st_x)):
self.st_x[index] = self.st_x[index].tolist()
self.st_x[index].append(np.nan)
for index in range(len(self.st_y)):
self.st_y[index] = self.st_y[index].tolist()
self.st_y[index].append(np.nan)
def get_streamline_arrows(self):
"""
Makes an arrow for each streamline.
Gets angle of streamline at 1/3 mark and creates arrow coordinates
based off of user defined angle and arrow_scale.
:param (array) st_x: x-values for all streamlines
:param (array) st_y: y-values for all streamlines
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
Default = .09
:rtype (list, list) arrows_x: x-values to create arrowhead and
arrows_y: y-values to create arrowhead
"""
arrow_end_x = np.empty((len(self.st_x)))
arrow_end_y = np.empty((len(self.st_y)))
arrow_start_x = np.empty((len(self.st_x)))
arrow_start_y = np.empty((len(self.st_y)))
for index in range(len(self.st_x)):
arrow_end_x[index] = self.st_x[index][int(len(self.st_x[index]) / 3)]
arrow_start_x[index] = self.st_x[index][
(int(len(self.st_x[index]) / 3)) - 1
]
arrow_end_y[index] = self.st_y[index][int(len(self.st_y[index]) / 3)]
arrow_start_y[index] = self.st_y[index][
(int(len(self.st_y[index]) / 3)) - 1
]
dif_x = arrow_end_x - arrow_start_x
dif_y = arrow_end_y - arrow_start_y
orig_err = np.geterr()
np.seterr(divide="ignore", invalid="ignore")
streamline_ang = np.arctan(dif_y / dif_x)
np.seterr(**orig_err)
ang1 = streamline_ang + (self.angle)
ang2 = streamline_ang - (self.angle)
seg1_x = np.cos(ang1) * self.arrow_scale
seg1_y = np.sin(ang1) * self.arrow_scale
seg2_x = np.cos(ang2) * self.arrow_scale
seg2_y = np.sin(ang2) * self.arrow_scale
point1_x = np.empty((len(dif_x)))
point1_y = np.empty((len(dif_y)))
point2_x = np.empty((len(dif_x)))
point2_y = np.empty((len(dif_y)))
for index in range(len(dif_x)):
if dif_x[index] >= 0:
point1_x[index] = arrow_end_x[index] - seg1_x[index]
point1_y[index] = arrow_end_y[index] - seg1_y[index]
point2_x[index] = arrow_end_x[index] - seg2_x[index]
point2_y[index] = arrow_end_y[index] - seg2_y[index]
else:
point1_x[index] = arrow_end_x[index] + seg1_x[index]
point1_y[index] = arrow_end_y[index] + seg1_y[index]
point2_x[index] = arrow_end_x[index] + seg2_x[index]
point2_y[index] = arrow_end_y[index] + seg2_y[index]
space = np.empty((len(point1_x)))
space[:] = np.nan
# Combine arrays into array
arrows_x = np.array([point1_x, arrow_end_x, point2_x, space])
arrows_x = arrows_x.flatten("F")
arrows_x = arrows_x.tolist()
# Combine arrays into array
arrows_y = np.array([point1_y, arrow_end_y, point2_y, space])
arrows_y = arrows_y.flatten("F")
arrows_y = arrows_y.tolist()
return arrows_x, arrows_y
def sum_streamlines(self):
"""
Makes all streamlines readable as a single trace.
:rtype (list, list): streamline_x: all x values for each streamline
combined into single list and streamline_y: all y values for each
streamline combined into single list
"""
streamline_x = sum(self.st_x, [])
streamline_y = sum(self.st_y, [])
return streamline_x, streamline_y

View File

@ -0,0 +1,280 @@
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
pd = optional_imports.get_module("pandas")
def validate_table(table_text, font_colors):
"""
Table-specific validations
Check that font_colors is supplied correctly (1, 3, or len(text)
colors).
:raises: (PlotlyError) If font_colors is supplied incorretly.
See FigureFactory.create_table() for params
"""
font_colors_len_options = [1, 3, len(table_text)]
if len(font_colors) not in font_colors_len_options:
raise exceptions.PlotlyError(
"Oops, font_colors should be a list of length 1, 3 or len(text)"
)
def create_table(
table_text,
colorscale=None,
font_colors=None,
index=False,
index_title="",
annotation_offset=0.45,
height_constant=30,
hoverinfo="none",
**kwargs,
):
"""
Function that creates data tables.
See also the plotly.graph_objects trace
:class:`plotly.graph_objects.Table`
:param (pandas.Dataframe | list[list]) text: data for table.
:param (str|list[list]) colorscale: Colorscale for table where the
color at value 0 is the header color, .5 is the first table color
and 1 is the second table color. (Set .5 and 1 to avoid the striped
table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
[1, '#ffffff']]
:param (list) font_colors: Color for fonts in table. Can be a single
color, three colors, or a color for each row in the table.
Default=['#000000'] (black text for the entire table)
:param (int) height_constant: Constant multiplied by # of rows to
create table height. Default=30.
:param (bool) index: Create (header-colored) index column index from
Pandas dataframe or list[0] for each list in text. Default=False.
:param (string) index_title: Title for index column. Default=''.
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
These kwargs describe other attributes about the annotated Heatmap
trace such as the colorscale. For more information on valid kwargs
call help(plotly.graph_objs.Heatmap)
Example 1: Simple Plotly Table
>>> from plotly.figure_factory import create_table
>>> text = [['Country', 'Year', 'Population'],
... ['US', 2000, 282200000],
... ['Canada', 2000, 27790000],
... ['US', 2010, 309000000],
... ['Canada', 2010, 34000000]]
>>> table = create_table(text)
>>> table.show()
Example 2: Table with Custom Coloring
>>> from plotly.figure_factory import create_table
>>> text = [['Country', 'Year', 'Population'],
... ['US', 2000, 282200000],
... ['Canada', 2000, 27790000],
... ['US', 2010, 309000000],
... ['Canada', 2010, 34000000]]
>>> table = create_table(text,
... colorscale=[[0, '#000000'],
... [.5, '#80beff'],
... [1, '#cce5ff']],
... font_colors=['#ffffff', '#000000',
... '#000000'])
>>> table.show()
Example 3: Simple Plotly Table with Pandas
>>> from plotly.figure_factory import create_table
>>> import pandas as pd
>>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
>>> df_p = df[0:25]
>>> table_simple = create_table(df_p)
>>> table_simple.show()
"""
# Avoiding mutables in the call signature
colorscale = (
colorscale
if colorscale is not None
else [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]]
)
font_colors = (
font_colors if font_colors is not None else ["#ffffff", "#000000", "#000000"]
)
validate_table(table_text, font_colors)
table_matrix = _Table(
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
).get_table_matrix()
annotations = _Table(
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
).make_table_annotations()
trace = dict(
type="heatmap",
z=table_matrix,
opacity=0.75,
colorscale=colorscale,
showscale=False,
hoverinfo=hoverinfo,
**kwargs,
)
data = [trace]
layout = dict(
annotations=annotations,
height=len(table_matrix) * height_constant + 50,
margin=dict(t=0, b=0, r=0, l=0),
yaxis=dict(
autorange="reversed",
zeroline=False,
gridwidth=2,
ticks="",
dtick=1,
tick0=0.5,
showticklabels=False,
),
xaxis=dict(
zeroline=False,
gridwidth=2,
ticks="",
dtick=1,
tick0=-0.5,
showticklabels=False,
),
)
return graph_objs.Figure(data=data, layout=layout)
class _Table(object):
"""
Refer to TraceFactory.create_table() for docstring
"""
def __init__(
self,
table_text,
colorscale,
font_colors,
index,
index_title,
annotation_offset,
**kwargs,
):
if pd and isinstance(table_text, pd.DataFrame):
headers = table_text.columns.tolist()
table_text_index = table_text.index.tolist()
table_text = table_text.values.tolist()
table_text.insert(0, headers)
if index:
table_text_index.insert(0, index_title)
for i in range(len(table_text)):
table_text[i].insert(0, table_text_index[i])
self.table_text = table_text
self.colorscale = colorscale
self.font_colors = font_colors
self.index = index
self.annotation_offset = annotation_offset
self.x = range(len(table_text[0]))
self.y = range(len(table_text))
def get_table_matrix(self):
"""
Create z matrix to make heatmap with striped table coloring
:rtype (list[list]) table_matrix: z matrix to make heatmap with striped
table coloring.
"""
header = [0] * len(self.table_text[0])
odd_row = [0.5] * len(self.table_text[0])
even_row = [1] * len(self.table_text[0])
table_matrix = [None] * len(self.table_text)
table_matrix[0] = header
for i in range(1, len(self.table_text), 2):
table_matrix[i] = odd_row
for i in range(2, len(self.table_text), 2):
table_matrix[i] = even_row
if self.index:
for array in table_matrix:
array[0] = 0
return table_matrix
def get_table_font_color(self):
"""
Fill font-color array.
Table text color can vary by row so this extends a single color or
creates an array to set a header color and two alternating colors to
create the striped table pattern.
:rtype (list[list]) all_font_colors: list of font colors for each row
in table.
"""
if len(self.font_colors) == 1:
all_font_colors = self.font_colors * len(self.table_text)
elif len(self.font_colors) == 3:
all_font_colors = list(range(len(self.table_text)))
all_font_colors[0] = self.font_colors[0]
for i in range(1, len(self.table_text), 2):
all_font_colors[i] = self.font_colors[1]
for i in range(2, len(self.table_text), 2):
all_font_colors[i] = self.font_colors[2]
elif len(self.font_colors) == len(self.table_text):
all_font_colors = self.font_colors
else:
all_font_colors = ["#000000"] * len(self.table_text)
return all_font_colors
def make_table_annotations(self):
"""
Generate annotations to fill in table text
:rtype (list) annotations: list of annotations for each cell of the
table.
"""
all_font_colors = _Table.get_table_font_color(self)
annotations = []
for n, row in enumerate(self.table_text):
for m, val in enumerate(row):
# Bold text in header and index
format_text = (
"<b>" + str(val) + "</b>"
if n == 0 or self.index and m < 1
else str(val)
)
# Match font color of index to font color of header
font_color = (
self.font_colors[0] if self.index and m == 0 else all_font_colors[n]
)
annotations.append(
graph_objs.layout.Annotation(
text=format_text,
x=self.x[m] - self.annotation_offset,
y=self.y[n],
xref="x1",
yref="y1",
align="left",
xanchor="left",
font=dict(color=font_color),
showarrow=False,
)
)
return annotations

View File

@ -0,0 +1,692 @@
import plotly.colors as clrs
from plotly.graph_objs import graph_objs as go
from plotly import exceptions
from plotly import optional_imports
from skimage import measure
np = optional_imports.get_module("numpy")
scipy_interp = optional_imports.get_module("scipy.interpolate")
# -------------------------- Layout ------------------------------
def _ternary_layout(
title="Ternary contour plot", width=550, height=525, pole_labels=["a", "b", "c"]
):
"""
Layout of ternary contour plot, to be passed to ``go.FigureWidget``
object.
Parameters
==========
title : str or None
Title of ternary plot
width : int
Figure width.
height : int
Figure height.
pole_labels : str, default ['a', 'b', 'c']
Names of the three poles of the triangle.
"""
return dict(
title=title,
width=width,
height=height,
ternary=dict(
sum=1,
aaxis=dict(
title=dict(text=pole_labels[0]), min=0.01, linewidth=2, ticks="outside"
),
baxis=dict(
title=dict(text=pole_labels[1]), min=0.01, linewidth=2, ticks="outside"
),
caxis=dict(
title=dict(text=pole_labels[2]), min=0.01, linewidth=2, ticks="outside"
),
),
showlegend=False,
)
# ------------- Transformations of coordinates -------------------
def _replace_zero_coords(ternary_data, delta=0.0005):
"""
Replaces zero ternary coordinates with delta and normalize the new
triplets (a, b, c).
Parameters
----------
ternary_data : ndarray of shape (N, 3)
delta : float
Small float to regularize logarithm.
Notes
-----
Implements a method
by J. A. Martin-Fernandez, C. Barcelo-Vidal, V. Pawlowsky-Glahn,
Dealing with zeros and missing values in compositional data sets
using nonparametric imputation, Mathematical Geology 35 (2003),
pp 253-278.
"""
zero_mask = ternary_data == 0
is_any_coord_zero = np.any(zero_mask, axis=0)
unity_complement = 1 - delta * is_any_coord_zero
if np.any(unity_complement) < 0:
raise ValueError(
"The provided value of delta led to negative"
"ternary coords.Set a smaller delta"
)
ternary_data = np.where(zero_mask, delta, unity_complement * ternary_data)
return ternary_data
def _ilr_transform(barycentric):
"""
Perform Isometric Log-Ratio on barycentric (compositional) data.
Parameters
----------
barycentric: ndarray of shape (3, N)
Barycentric coordinates.
References
----------
"An algebraic method to compute isometric logratio transformation and
back transformation of compositional data", Jarauta-Bragulat, E.,
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
Intl Assoc for Math Geology, 2003, pp 31-30.
"""
barycentric = np.asarray(barycentric)
x_0 = np.log(barycentric[0] / barycentric[1]) / np.sqrt(2)
x_1 = (
1.0 / np.sqrt(6) * np.log(barycentric[0] * barycentric[1] / barycentric[2] ** 2)
)
ilr_tdata = np.stack((x_0, x_1))
return ilr_tdata
def _ilr_inverse(x):
"""
Perform inverse Isometric Log-Ratio (ILR) transform to retrieve
barycentric (compositional) data.
Parameters
----------
x : array of shape (2, N)
Coordinates in ILR space.
References
----------
"An algebraic method to compute isometric logratio transformation and
back transformation of compositional data", Jarauta-Bragulat, E.,
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
Intl Assoc for Math Geology, 2003, pp 31-30.
"""
x = np.array(x)
matrix = np.array([[0.5, 1, 1.0], [-0.5, 1, 1.0], [0.0, 0.0, 1.0]])
s = np.sqrt(2) / 2
t = np.sqrt(3 / 2)
Sk = np.einsum("ik, kj -> ij", np.array([[s, t], [-s, t]]), x)
Z = -np.log(1 + np.exp(Sk).sum(axis=0))
log_barycentric = np.einsum(
"ik, kj -> ij", matrix, np.stack((2 * s * x[0], t * x[1], Z))
)
iilr_tdata = np.exp(log_barycentric)
return iilr_tdata
def _transform_barycentric_cartesian():
"""
Returns the transformation matrix from barycentric to Cartesian
coordinates and conversely.
"""
# reference triangle
tri_verts = np.array([[0.5, np.sqrt(3) / 2], [0, 0], [1, 0]])
M = np.array([tri_verts[:, 0], tri_verts[:, 1], np.ones(3)])
return M, np.linalg.inv(M)
def _prepare_barycentric_coord(b_coords):
"""
Check ternary coordinates and return the right barycentric coordinates.
"""
if not isinstance(b_coords, (list, np.ndarray)):
raise ValueError(
"Data should be either an array of shape (n,m),"
"or a list of n m-lists, m=2 or 3"
)
b_coords = np.asarray(b_coords)
if b_coords.shape[0] not in (2, 3):
raise ValueError(
"A point should have 2 (a, b) or 3 (a, b, c)barycentric coordinates"
)
if (
(len(b_coords) == 3)
and not np.allclose(b_coords.sum(axis=0), 1, rtol=0.01)
and not np.allclose(b_coords.sum(axis=0), 100, rtol=0.01)
):
msg = "The sum of coordinates should be 1 or 100 for all data points"
raise ValueError(msg)
if len(b_coords) == 2:
A, B = b_coords
C = 1 - (A + B)
else:
A, B, C = b_coords / b_coords.sum(axis=0)
if np.any(np.stack((A, B, C)) < 0):
raise ValueError("Barycentric coordinates should be positive.")
return np.stack((A, B, C))
def _compute_grid(coordinates, values, interp_mode="ilr"):
"""
Transform data points with Cartesian or ILR mapping, then Compute
interpolation on a regular grid.
Parameters
==========
coordinates : array-like
Barycentric coordinates of data points.
values : 1-d array-like
Data points, field to be represented as contours.
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours.
"""
if interp_mode == "cartesian":
M, invM = _transform_barycentric_cartesian()
coord_points = np.einsum("ik, kj -> ij", M, coordinates)
elif interp_mode == "ilr":
coordinates = _replace_zero_coords(coordinates)
coord_points = _ilr_transform(coordinates)
else:
raise ValueError("interp_mode should be cartesian or ilr")
xx, yy = coord_points[:2]
x_min, x_max = xx.min(), xx.max()
y_min, y_max = yy.min(), yy.max()
n_interp = max(200, int(np.sqrt(len(values))))
gr_x = np.linspace(x_min, x_max, n_interp)
gr_y = np.linspace(y_min, y_max, n_interp)
grid_x, grid_y = np.meshgrid(gr_x, gr_y)
# We use cubic interpolation, except outside of the convex hull
# of data points where we use nearest neighbor values.
grid_z = scipy_interp.griddata(
coord_points[:2].T, values, (grid_x, grid_y), method="cubic"
)
return grid_z, gr_x, gr_y
# ----------------------- Contour traces ----------------------
def _polygon_area(x, y):
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
def _colors(ncontours, colormap=None):
"""
Return a list of ``ncontours`` colors from the ``colormap`` colorscale.
"""
if colormap in clrs.PLOTLY_SCALES.keys():
cmap = clrs.PLOTLY_SCALES[colormap]
else:
raise exceptions.PlotlyError(
"Colorscale must be a valid Plotly Colorscale."
"The available colorscale names are {}".format(clrs.PLOTLY_SCALES.keys())
)
values = np.linspace(0, 1, ncontours)
vals_cmap = np.array([pair[0] for pair in cmap])
cols = np.array([pair[1] for pair in cmap])
inds = np.searchsorted(vals_cmap, values)
if "#" in cols[0]: # for Viridis
cols = [clrs.label_rgb(clrs.hex_to_rgb(col)) for col in cols]
colors = [cols[0]]
for ind, val in zip(inds[1:], values[1:]):
val1, val2 = vals_cmap[ind - 1], vals_cmap[ind]
interm = (val - val1) / (val2 - val1)
col = clrs.find_intermediate_color(
cols[ind - 1], cols[ind], interm, colortype="rgb"
)
colors.append(col)
return colors
def _is_invalid_contour(x, y):
"""
Utility function for _contour_trace
Contours with an area of the order as 1 pixel are considered spurious.
"""
too_small = np.all(np.abs(x - x[0]) < 2) and np.all(np.abs(y - y[0]) < 2)
return too_small
def _extract_contours(im, values, colors):
"""
Utility function for _contour_trace.
In ``im`` only one part of the domain has valid values (corresponding
to a subdomain where barycentric coordinates are well defined). When
computing contours, we need to assign values outside of this domain.
We can choose a value either smaller than all the values inside the
valid domain, or larger. This value must be chose with caution so that
no spurious contours are added. For example, if the boundary of the valid
domain has large values and the outer value is set to a small one, all
intermediate contours will be added at the boundary.
Therefore, we compute the two sets of contours (with an outer value
smaller of larger than all values in the valid domain), and choose
the value resulting in a smaller total number of contours. There might
be a faster way to do this, but it works...
"""
mask_nan = np.isnan(im)
im_min, im_max = (
im[np.logical_not(mask_nan)].min(),
im[np.logical_not(mask_nan)].max(),
)
zz_min = np.copy(im)
zz_min[mask_nan] = 2 * im_min
zz_max = np.copy(im)
zz_max[mask_nan] = 2 * im_max
all_contours1, all_values1, all_areas1, all_colors1 = [], [], [], []
all_contours2, all_values2, all_areas2, all_colors2 = [], [], [], []
for i, val in enumerate(values):
contour_level1 = measure.find_contours(zz_min, val)
contour_level2 = measure.find_contours(zz_max, val)
all_contours1.extend(contour_level1)
all_contours2.extend(contour_level2)
all_values1.extend([val] * len(contour_level1))
all_values2.extend([val] * len(contour_level2))
all_areas1.extend(
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level1]
)
all_areas2.extend(
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level2]
)
all_colors1.extend([colors[i]] * len(contour_level1))
all_colors2.extend([colors[i]] * len(contour_level2))
if len(all_contours1) <= len(all_contours2):
return all_contours1, all_values1, all_areas1, all_colors1
else:
return all_contours2, all_values2, all_areas2, all_colors2
def _add_outer_contour(
all_contours,
all_values,
all_areas,
all_colors,
values,
val_outer,
v_min,
v_max,
colors,
color_min,
color_max,
):
"""
Utility function for _contour_trace
Adds the background color to fill gaps outside of computed contours.
To compute the background color, the color of the contour with largest
area (``val_outer``) is used. As background color, we choose the next
color value in the direction of the extrema of the colormap.
Then we add information for the outer contour for the different lists
provided as arguments.
A discrete colormap with all used colors is also returned (to be used
by colorscale trace).
"""
# The exact value of outer contour is not used when defining the trace
outer_contour = 20 * np.array([[0, 0, 1], [0, 1, 0.5]]).T
all_contours = [outer_contour] + all_contours
delta_values = np.diff(values)[0]
values = np.concatenate(
([values[0] - delta_values], values, [values[-1] + delta_values])
)
colors = np.concatenate(([color_min], colors, [color_max]))
index = np.nonzero(values == val_outer)[0][0]
if index < len(values) / 2:
index -= 1
else:
index += 1
all_colors = [colors[index]] + all_colors
all_values = [values[index]] + all_values
all_areas = [0] + all_areas
used_colors = [color for color in colors if color in all_colors]
# Define discrete colorscale
color_number = len(used_colors)
scale = np.linspace(0, 1, color_number + 1)
discrete_cm = []
for i, color in enumerate(used_colors):
discrete_cm.append([scale[i], used_colors[i]])
discrete_cm.append([scale[i + 1], used_colors[i]])
discrete_cm.append([scale[color_number], used_colors[color_number - 1]])
return all_contours, all_values, all_areas, all_colors, discrete_cm
def _contour_trace(
x,
y,
z,
ncontours=None,
colorscale="Electric",
linecolor="rgb(150,150,150)",
interp_mode="llr",
coloring=None,
v_min=0,
v_max=1,
):
"""
Contour trace in Cartesian coordinates.
Parameters
==========
x, y : array-like
Cartesian coordinates
z : array-like
Field to be represented as contours.
ncontours : int or None
Number of contours to display (determined automatically if None).
colorscale : None or str (Plotly colormap)
colorscale of the contours.
linecolor : rgb color
Color used for lines. If ``colorscale`` is not None, line colors are
determined from ``colorscale`` instead.
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours. If 'irl',
ILR (Isometric Log-Ratio) of compositional data is performed. If
'cartesian', contours are determined in Cartesian space.
coloring : None or 'lines'
How to display contour. Filled contours if None, lines if ``lines``.
vmin, vmax : float
Bounds of interval of values used for the colorspace
Notes
=====
"""
# Prepare colors
# We do not take extrema, for example for one single contour
# the color will be the middle point of the colormap
colors = _colors(ncontours + 2, colorscale)
# Values used for contours, extrema are not used
# For example for a binary array [0, 1], the value of
# the contour for ncontours=1 is 0.5.
values = np.linspace(v_min, v_max, ncontours + 2)
color_min, color_max = colors[0], colors[-1]
colors = colors[1:-1]
values = values[1:-1]
# Color of line contours
if linecolor is None:
linecolor = "rgb(150, 150, 150)"
else:
colors = [linecolor] * ncontours
# Retrieve all contours
all_contours, all_values, all_areas, all_colors = _extract_contours(
z, values, colors
)
# Now sort contours by decreasing area
order = np.argsort(all_areas)[::-1]
# Add outer contour
all_contours, all_values, all_areas, all_colors, discrete_cm = _add_outer_contour(
all_contours,
all_values,
all_areas,
all_colors,
values,
all_values[order[0]],
v_min,
v_max,
colors,
color_min,
color_max,
)
order = np.concatenate(([0], order + 1))
# Compute traces, in the order of decreasing area
traces = []
M, invM = _transform_barycentric_cartesian()
dx = (x.max() - x.min()) / x.size
dy = (y.max() - y.min()) / y.size
for index in order:
y_contour, x_contour = all_contours[index].T
val = all_values[index]
if interp_mode == "cartesian":
bar_coords = np.dot(
invM,
np.stack((dx * x_contour, dy * y_contour, np.ones(x_contour.shape))),
)
elif interp_mode == "ilr":
bar_coords = _ilr_inverse(
np.stack((dx * x_contour + x.min(), dy * y_contour + y.min()))
)
if index == 0: # outer triangle
a = np.array([1, 0, 0])
b = np.array([0, 1, 0])
c = np.array([0, 0, 1])
else:
a, b, c = bar_coords
if _is_invalid_contour(x_contour, y_contour):
continue
_col = all_colors[index] if coloring == "lines" else linecolor
trace = dict(
type="scatterternary",
a=a,
b=b,
c=c,
mode="lines",
line=dict(color=_col, shape="spline", width=1),
fill="toself",
fillcolor=all_colors[index],
showlegend=True,
hoverinfo="skip",
name="%.3f" % val,
)
if coloring == "lines":
trace["fill"] = None
traces.append(trace)
return traces, discrete_cm
# -------------------- Figure Factory for ternary contour -------------
def create_ternary_contour(
coordinates,
values,
pole_labels=["a", "b", "c"],
width=500,
height=500,
ncontours=None,
showscale=False,
coloring=None,
colorscale="Bluered",
linecolor=None,
title=None,
interp_mode="ilr",
showmarkers=False,
):
"""
Ternary contour plot.
Parameters
----------
coordinates : list or ndarray
Barycentric coordinates of shape (2, N) or (3, N) where N is the
number of data points. The sum of the 3 coordinates is expected
to be 1 for all data points.
values : array-like
Data points of field to be represented as contours.
pole_labels : str, default ['a', 'b', 'c']
Names of the three poles of the triangle.
width : int
Figure width.
height : int
Figure height.
ncontours : int or None
Number of contours to display (determined automatically if None).
showscale : bool, default False
If True, a colorbar showing the color scale is displayed.
coloring : None or 'lines'
How to display contour. Filled contours if None, lines if ``lines``.
colorscale : None or str (Plotly colormap)
colorscale of the contours.
linecolor : None or rgb color
Color used for lines. ``colorscale`` has to be set to None, otherwise
line colors are determined from ``colorscale``.
title : str or None
Title of ternary plot
interp_mode : 'ilr' (default) or 'cartesian'
Defines how data are interpolated to compute contours. If 'irl',
ILR (Isometric Log-Ratio) of compositional data is performed. If
'cartesian', contours are determined in Cartesian space.
showmarkers : bool, default False
If True, markers corresponding to input compositional points are
superimposed on contours, using the same colorscale.
Examples
========
Example 1: ternary contour plot with filled contours
>>> import plotly.figure_factory as ff
>>> import numpy as np
>>> # Define coordinates
>>> a, b = np.mgrid[0:1:20j, 0:1:20j]
>>> mask = a + b <= 1
>>> a = a[mask].ravel()
>>> b = b[mask].ravel()
>>> c = 1 - a - b
>>> # Values to be displayed as contours
>>> z = a * b * c
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z)
>>> fig.show()
It is also possible to give only two barycentric coordinates for each
point, since the sum of the three coordinates is one:
>>> fig = ff.create_ternary_contour(np.stack((a, b)), z)
Example 2: ternary contour plot with line contours
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines')
Example 3: customize number of contours
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, ncontours=8)
Example 4: superimpose contour plot and original data as markers
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines',
... showmarkers=True)
Example 5: customize title and pole labels
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z,
... title='Ternary plot',
... pole_labels=['clay', 'quartz', 'fledspar'])
"""
if scipy_interp is None:
raise ImportError(
"""\
The create_ternary_contour figure factory requires the scipy package"""
)
sk_measure = optional_imports.get_module("skimage")
if sk_measure is None:
raise ImportError(
"""\
The create_ternary_contour figure factory requires the scikit-image
package"""
)
if colorscale is None:
showscale = False
if ncontours is None:
ncontours = 5
coordinates = _prepare_barycentric_coord(coordinates)
v_min, v_max = values.min(), values.max()
grid_z, gr_x, gr_y = _compute_grid(coordinates, values, interp_mode=interp_mode)
layout = _ternary_layout(
pole_labels=pole_labels, width=width, height=height, title=title
)
contour_trace, discrete_cm = _contour_trace(
gr_x,
gr_y,
grid_z,
ncontours=ncontours,
colorscale=colorscale,
linecolor=linecolor,
interp_mode=interp_mode,
coloring=coloring,
v_min=v_min,
v_max=v_max,
)
fig = go.Figure(data=contour_trace, layout=layout)
opacity = 1 if showmarkers else 0
a, b, c = coordinates
hovertemplate = (
pole_labels[0]
+ ": %{a:.3f}<br>"
+ pole_labels[1]
+ ": %{b:.3f}<br>"
+ pole_labels[2]
+ ": %{c:.3f}<br>"
"z: %{marker.color:.3f}<extra></extra>"
)
fig.add_scatterternary(
a=a,
b=b,
c=c,
mode="markers",
marker={
"color": values,
"colorscale": colorscale,
"line": {"color": "rgb(120, 120, 120)", "width": int(coloring != "lines")},
},
opacity=opacity,
hovertemplate=hovertemplate,
)
if showscale:
if not showmarkers:
colorscale = discrete_cm
colorbar = dict(
{
"type": "scatterternary",
"a": [None],
"b": [None],
"c": [None],
"marker": {
"cmin": values.min(),
"cmax": values.max(),
"colorscale": colorscale,
"showscale": True,
},
"mode": "markers",
}
)
fig.add_trace(colorbar)
return fig

View File

@ -0,0 +1,509 @@
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
np = optional_imports.get_module("numpy")
def map_face2color(face, colormap, scale, vmin, vmax):
"""
Normalize facecolor values by vmin/vmax and return rgb-color strings
This function takes a tuple color along with a colormap and a minimum
(vmin) and maximum (vmax) range of possible mean distances for the
given parametrized surface. It returns an rgb color based on the mean
distance between vmin and vmax
"""
if vmin >= vmax:
raise exceptions.PlotlyError(
"Incorrect relation between vmin "
"and vmax. The vmin value cannot be "
"bigger than or equal to the value "
"of vmax."
)
if len(colormap) == 1:
# color each triangle face with the same color in colormap
face_color = colormap[0]
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
if face == vmax:
# pick last color in colormap
face_color = colormap[-1]
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
else:
if scale is None:
# find the normalized distance t of a triangle face between
# vmin and vmax where the distance is between 0 and 1
t = (face - vmin) / float((vmax - vmin))
low_color_index = int(t / (1.0 / (len(colormap) - 1)))
face_color = clrs.find_intermediate_color(
colormap[low_color_index],
colormap[low_color_index + 1],
t * (len(colormap) - 1) - low_color_index,
)
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
else:
# find the face color for a non-linearly interpolated scale
t = (face - vmin) / float((vmax - vmin))
low_color_index = 0
for k in range(len(scale) - 1):
if scale[k] <= t < scale[k + 1]:
break
low_color_index += 1
low_scale_val = scale[low_color_index]
high_scale_val = scale[low_color_index + 1]
face_color = clrs.find_intermediate_color(
colormap[low_color_index],
colormap[low_color_index + 1],
(t - low_scale_val) / (high_scale_val - low_scale_val),
)
face_color = clrs.convert_to_RGB_255(face_color)
face_color = clrs.label_rgb(face_color)
return face_color
def trisurf(
x,
y,
z,
simplices,
show_colorbar,
edges_color,
scale,
colormap=None,
color_func=None,
plot_edges=False,
x_edge=None,
y_edge=None,
z_edge=None,
facecolor=None,
):
"""
Refer to FigureFactory.create_trisurf() for docstring
"""
# numpy import check
if not np:
raise ImportError("FigureFactory._trisurf() requires numpy imported.")
points3D = np.vstack((x, y, z)).T
simplices = np.atleast_2d(simplices)
# vertices of the surface triangles
tri_vertices = points3D[simplices]
# Define colors for the triangle faces
if color_func is None:
# mean values of z-coordinates of triangle vertices
mean_dists = tri_vertices[:, :, 2].mean(-1)
elif isinstance(color_func, (list, np.ndarray)):
# Pre-computed list / array of values to map onto color
if len(color_func) != len(simplices):
raise ValueError(
"If color_func is a list/array, it must "
"be the same length as simplices."
)
# convert all colors in color_func to rgb
for index in range(len(color_func)):
if isinstance(color_func[index], str):
if "#" in color_func[index]:
foo = clrs.hex_to_rgb(color_func[index])
color_func[index] = clrs.label_rgb(foo)
if isinstance(color_func[index], tuple):
foo = clrs.convert_to_RGB_255(color_func[index])
color_func[index] = clrs.label_rgb(foo)
mean_dists = np.asarray(color_func)
else:
# apply user inputted function to calculate
# custom coloring for triangle vertices
mean_dists = []
for triangle in tri_vertices:
dists = []
for vertex in triangle:
dist = color_func(vertex[0], vertex[1], vertex[2])
dists.append(dist)
mean_dists.append(np.mean(dists))
mean_dists = np.asarray(mean_dists)
# Check if facecolors are already strings and can be skipped
if isinstance(mean_dists[0], str):
facecolor = mean_dists
else:
min_mean_dists = np.min(mean_dists)
max_mean_dists = np.max(mean_dists)
if facecolor is None:
facecolor = []
for index in range(len(mean_dists)):
color = map_face2color(
mean_dists[index], colormap, scale, min_mean_dists, max_mean_dists
)
facecolor.append(color)
# Make sure facecolor is a list so output is consistent across Pythons
facecolor = np.asarray(facecolor)
ii, jj, kk = simplices.T
triangles = graph_objs.Mesh3d(
x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name=""
)
mean_dists_are_numbers = not isinstance(mean_dists[0], str)
if mean_dists_are_numbers and show_colorbar is True:
# make a colorscale from the colors
colorscale = clrs.make_colorscale(colormap, scale)
colorscale = clrs.convert_colorscale_to_rgb(colorscale)
colorbar = graph_objs.Scatter3d(
x=x[:1],
y=y[:1],
z=z[:1],
mode="markers",
marker=dict(
size=0.1,
color=[min_mean_dists, max_mean_dists],
colorscale=colorscale,
showscale=True,
),
hoverinfo="none",
showlegend=False,
)
# the triangle sides are not plotted
if plot_edges is False:
if mean_dists_are_numbers and show_colorbar is True:
return [triangles, colorbar]
else:
return [triangles]
# define the lists x_edge, y_edge and z_edge, of x, y, resp z
# coordinates of edge end points for each triangle
# None separates data corresponding to two consecutive triangles
is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
if any(is_none):
if not all(is_none):
raise ValueError(
"If any (x_edge, y_edge, z_edge) is None, all must be None"
)
else:
x_edge = []
y_edge = []
z_edge = []
# Pull indices we care about, then add a None column to separate tris
ixs_triangles = [0, 1, 2, 0]
pull_edges = tri_vertices[:, ixs_triangles, :]
x_edge_pull = np.hstack(
[pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])]
)
y_edge_pull = np.hstack(
[pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])]
)
z_edge_pull = np.hstack(
[pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])]
)
# Now unravel the edges into a 1-d vector for plotting
x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
if not (len(x_edge) == len(y_edge) == len(z_edge)):
raise exceptions.PlotlyError(
"The lengths of x_edge, y_edge and z_edge are not the same."
)
# define the lines for plotting
lines = graph_objs.Scatter3d(
x=x_edge,
y=y_edge,
z=z_edge,
mode="lines",
line=graph_objs.scatter3d.Line(color=edges_color, width=1.5),
showlegend=False,
)
if mean_dists_are_numbers and show_colorbar is True:
return [triangles, lines, colorbar]
else:
return [triangles, lines]
def create_trisurf(
x,
y,
z,
simplices,
colormap=None,
show_colorbar=True,
scale=None,
color_func=None,
title="Trisurf Plot",
plot_edges=True,
showbackground=True,
backgroundcolor="rgb(230, 230, 230)",
gridcolor="rgb(255, 255, 255)",
zerolinecolor="rgb(255, 255, 255)",
edges_color="rgb(50, 50, 50)",
height=800,
width=800,
aspectratio=None,
):
"""
Returns figure for a triangulated surface plot
:param (array) x: data values of x in a 1D array
:param (array) y: data values of y in a 1D array
:param (array) z: data values of z in a 1D array
:param (array) simplices: an array of shape (ntri, 3) where ntri is
the number of triangles in the triangularization. Each row of the
array contains the indicies of the verticies of each triangle
:param (str|tuple|list) colormap: either a plotly scale name, an rgb
or hex color, a color tuple or a list of colors. An rgb color is
of the form 'rgb(x, y, z)' where x, y, z belong to the interval
[0, 255] and a color tuple is a tuple of the form (a, b, c) where
a, b and c belong to [0, 1]. If colormap is a list, it must
contain the valid color types aforementioned as its members
:param (bool) show_colorbar: determines if colorbar is visible
:param (list|array) scale: sets the scale values to be used if a non-
linearly interpolated colormap is desired. If left as None, a
linear interpolation between the colors will be excecuted
:param (function|list) color_func: The parameter that determines the
coloring of the surface. Takes either a function with 3 arguments
x, y, z or a list/array of color values the same length as
simplices. If None, coloring will only depend on the z axis
:param (str) title: title of the plot
:param (bool) plot_edges: determines if the triangles on the trisurf
are visible
:param (bool) showbackground: makes background in plot visible
:param (str) backgroundcolor: color of background. Takes a string of
the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
:param (str) gridcolor: color of the gridlines besides the axes. Takes
a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
inclusive
:param (str) zerolinecolor: color of the axes. Takes a string of the
form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
:param (str) edges_color: color of the edges, if plot_edges is True
:param (int|float) height: the height of the plot (in pixels)
:param (int|float) width: the width of the plot (in pixels)
:param (dict) aspectratio: a dictionary of the aspect ratio values for
the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
Example 1: Sphere
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 20)
>>> v = np.linspace(0, np.pi, 20)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> x = np.sin(v)*np.cos(u)
>>> y = np.sin(v)*np.sin(u)
>>> z = np.cos(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
... simplices=simplices)
Example 2: Torus
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 20)
>>> v = np.linspace(0, 2*np.pi, 20)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> x = (3 + (np.cos(v)))*np.cos(u)
>>> y = (3 + (np.cos(v)))*np.sin(u)
>>> z = np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
... simplices=simplices)
Example 3: Mobius Band
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u = np.linspace(0, 2*np.pi, 24)
>>> v = np.linspace(-1, 1, 8)
>>> u,v = np.meshgrid(u,v)
>>> u = u.flatten()
>>> v = v.flatten()
>>> tp = 1 + 0.5*v*np.cos(u/2.)
>>> x = tp*np.cos(u)
>>> y = tp*np.sin(u)
>>> z = 0.5*v*np.sin(u/2.)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
... simplices=simplices)
Example 4: Using a Custom Colormap Function with Light Cone
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u=np.linspace(-np.pi, np.pi, 30)
>>> v=np.linspace(-np.pi, np.pi, 30)
>>> u,v=np.meshgrid(u,v)
>>> u=u.flatten()
>>> v=v.flatten()
>>> x = u
>>> y = u*np.cos(v)
>>> z = u*np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> # Define distance function
>>> def dist_origin(x, y, z):
... return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
>>> # Create a figure
>>> fig1 = create_trisurf(x=x, y=y, z=z,
... colormap=['#FFFFFF', '#E4FFFE',
... '#A4F6F9', '#FF99FE',
... '#BA52ED'],
... scale=[0, 0.6, 0.71, 0.89, 1],
... simplices=simplices,
... color_func=dist_origin)
Example 5: Enter color_func as a list of colors
>>> # Necessary Imports for Trisurf
>>> import numpy as np
>>> from scipy.spatial import Delaunay
>>> import random
>>> from plotly.figure_factory import create_trisurf
>>> from plotly.graph_objs import graph_objs
>>> # Make data for plot
>>> u=np.linspace(-np.pi, np.pi, 30)
>>> v=np.linspace(-np.pi, np.pi, 30)
>>> u,v=np.meshgrid(u,v)
>>> u=u.flatten()
>>> v=v.flatten()
>>> x = u
>>> y = u*np.cos(v)
>>> z = u*np.sin(v)
>>> points2D = np.vstack([u,v]).T
>>> tri = Delaunay(points2D)
>>> simplices = tri.simplices
>>> colors = []
>>> color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
>>> for index in range(len(simplices)):
... colors.append(random.choice(color_choices))
>>> fig = create_trisurf(
... x, y, z, simplices,
... color_func=colors,
... show_colorbar=True,
... edges_color='rgb(2, 85, 180)',
... title=' Modern Art'
... )
"""
if aspectratio is None:
aspectratio = {"x": 1, "y": 1, "z": 1}
# Validate colormap
clrs.validate_colors(colormap)
colormap, scale = clrs.convert_colors_to_same_type(
colormap, colortype="tuple", return_default_colors=True, scale=scale
)
data1 = trisurf(
x,
y,
z,
simplices,
show_colorbar=show_colorbar,
color_func=color_func,
colormap=colormap,
scale=scale,
edges_color=edges_color,
plot_edges=plot_edges,
)
axis = dict(
showbackground=showbackground,
backgroundcolor=backgroundcolor,
gridcolor=gridcolor,
zerolinecolor=zerolinecolor,
)
layout = graph_objs.Layout(
title=title,
width=width,
height=height,
scene=graph_objs.layout.Scene(
xaxis=graph_objs.layout.scene.XAxis(**axis),
yaxis=graph_objs.layout.scene.YAxis(**axis),
zaxis=graph_objs.layout.scene.ZAxis(**axis),
aspectratio=dict(
x=aspectratio["x"], y=aspectratio["y"], z=aspectratio["z"]
),
),
)
return graph_objs.Figure(data=data1, layout=layout)

View File

@ -0,0 +1,704 @@
from numbers import Number
from plotly import exceptions, optional_imports
import plotly.colors as clrs
from plotly.graph_objs import graph_objs
from plotly.subplots import make_subplots
pd = optional_imports.get_module("pandas")
np = optional_imports.get_module("numpy")
scipy_stats = optional_imports.get_module("scipy.stats")
def calc_stats(data):
"""
Calculate statistics for use in violin plot.
"""
x = np.asarray(data, float)
vals_min = np.min(x)
vals_max = np.max(x)
q2 = np.percentile(x, 50, interpolation="linear")
q1 = np.percentile(x, 25, interpolation="lower")
q3 = np.percentile(x, 75, interpolation="higher")
iqr = q3 - q1
whisker_dist = 1.5 * iqr
# in order to prevent drawing whiskers outside the interval
# of data one defines the whisker positions as:
d1 = np.min(x[x >= (q1 - whisker_dist)])
d2 = np.max(x[x <= (q3 + whisker_dist)])
return {
"min": vals_min,
"max": vals_max,
"q1": q1,
"q2": q2,
"q3": q3,
"d1": d1,
"d2": d2,
}
def make_half_violin(x, y, fillcolor="#1f77b4", linecolor="rgb(0, 0, 0)"):
"""
Produces a sideways probability distribution fig violin plot.
"""
text = [
"(pdf(y), y)=(" + "{:0.2f}".format(x[i]) + ", " + "{:0.2f}".format(y[i]) + ")"
for i in range(len(x))
]
return graph_objs.Scatter(
x=x,
y=y,
mode="lines",
name="",
text=text,
fill="tonextx",
fillcolor=fillcolor,
line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape="spline"),
hoverinfo="text",
opacity=0.5,
)
def make_violin_rugplot(vals, pdf_max, distance, color="#1f77b4"):
"""
Returns a rugplot fig for a violin plot.
"""
return graph_objs.Scatter(
y=vals,
x=[-pdf_max - distance] * len(vals),
marker=graph_objs.scatter.Marker(color=color, symbol="line-ew-open"),
mode="markers",
name="",
showlegend=False,
hoverinfo="y",
)
def make_non_outlier_interval(d1, d2):
"""
Returns the scatterplot fig of most of a violin plot.
"""
return graph_objs.Scatter(
x=[0, 0],
y=[d1, d2],
name="",
mode="lines",
line=graph_objs.scatter.Line(width=1.5, color="rgb(0,0,0)"),
)
def make_quartiles(q1, q3):
"""
Makes the upper and lower quartiles for a violin plot.
"""
return graph_objs.Scatter(
x=[0, 0],
y=[q1, q3],
text=[
"lower-quartile: " + "{:0.2f}".format(q1),
"upper-quartile: " + "{:0.2f}".format(q3),
],
mode="lines",
line=graph_objs.scatter.Line(width=4, color="rgb(0,0,0)"),
hoverinfo="text",
)
def make_median(q2):
"""
Formats the 'median' hovertext for a violin plot.
"""
return graph_objs.Scatter(
x=[0],
y=[q2],
text=["median: " + "{:0.2f}".format(q2)],
mode="markers",
marker=dict(symbol="square", color="rgb(255,255,255)"),
hoverinfo="text",
)
def make_XAxis(xaxis_title, xaxis_range):
"""
Makes the x-axis for a violin plot.
"""
xaxis = graph_objs.layout.XAxis(
title=xaxis_title,
range=xaxis_range,
showgrid=False,
zeroline=False,
showline=False,
mirror=False,
ticks="",
showticklabels=False,
)
return xaxis
def make_YAxis(yaxis_title):
"""
Makes the y-axis for a violin plot.
"""
yaxis = graph_objs.layout.YAxis(
title=yaxis_title,
showticklabels=True,
autorange=True,
ticklen=4,
showline=True,
zeroline=False,
showgrid=False,
mirror=False,
)
return yaxis
def violinplot(vals, fillcolor="#1f77b4", rugplot=True):
"""
Refer to FigureFactory.create_violin() for docstring.
"""
vals = np.asarray(vals, float)
# summary statistics
vals_min = calc_stats(vals)["min"]
vals_max = calc_stats(vals)["max"]
q1 = calc_stats(vals)["q1"]
q2 = calc_stats(vals)["q2"]
q3 = calc_stats(vals)["q3"]
d1 = calc_stats(vals)["d1"]
d2 = calc_stats(vals)["d2"]
# kernel density estimation of pdf
pdf = scipy_stats.gaussian_kde(vals)
# grid over the data interval
xx = np.linspace(vals_min, vals_max, 100)
# evaluate the pdf at the grid xx
yy = pdf(xx)
max_pdf = np.max(yy)
# distance from the violin plot to rugplot
distance = (2.0 * max_pdf) / 10 if rugplot else 0
# range for x values in the plot
plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
plot_data = [
make_half_violin(-yy, xx, fillcolor=fillcolor),
make_half_violin(yy, xx, fillcolor=fillcolor),
make_non_outlier_interval(d1, d2),
make_quartiles(q1, q3),
make_median(q2),
]
if rugplot:
plot_data.append(
make_violin_rugplot(vals, max_pdf, distance=distance, color=fillcolor)
)
return plot_data, plot_xrange
def violin_no_colorscale(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot without colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
color_index = 0
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], float)
if color_index >= len(colors):
color_index = 0
plot_data, plot_xrange = violinplot(
vals, fillcolor=colors[color_index], rugplot=rugplot
)
for item in plot_data:
fig.append_trace(item, 1, k + 1)
color_index += 1
# add violin plot labels
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def violin_colorscale(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot with colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
# make sure all group names are keys in group_stats
for group in group_name:
if group not in group_stats:
raise exceptions.PlotlyError(
"All values/groups in the index "
"column must be represented "
"as a key in group_stats."
)
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
# prepare low and high color for colorscale
lowcolor = clrs.color_parser(colors[0], clrs.unlabel_rgb)
highcolor = clrs.color_parser(colors[1], clrs.unlabel_rgb)
# find min and max values in group_stats
group_stats_values = []
for key in group_stats:
group_stats_values.append(group_stats[key])
max_value = max(group_stats_values)
min_value = min(group_stats_values)
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], float)
# find intermediate color from colorscale
intermed = (group_stats[gr] - min_value) / (max_value - min_value)
intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed)
plot_data, plot_xrange = violinplot(
vals, fillcolor="rgb{}".format(intermed_color), rugplot=rugplot
)
for item in plot_data:
fig.append_trace(item, 1, k + 1)
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# add colorbar to plot
trace_dummy = graph_objs.Scatter(
x=[0],
y=[0],
mode="markers",
marker=dict(
size=2,
cmin=min_value,
cmax=max_value,
colorscale=[[0, colors[0]], [1, colors[1]]],
showscale=True,
),
showlegend=False,
)
fig.append_trace(trace_dummy, 1, L)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def violin_dict(
data,
data_header,
group_header,
colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
):
"""
Refer to FigureFactory.create_violin() for docstring.
Returns fig for violin plot without colorscale.
"""
# collect all group names
group_name = []
for name in data[group_header]:
if name not in group_name:
group_name.append(name)
if sort:
group_name.sort()
# check if all group names appear in colors dict
for group in group_name:
if group not in colors:
raise exceptions.PlotlyError(
"If colors is a dictionary, all "
"the group names must appear as "
"keys in colors."
)
gb = data.groupby([group_header])
L = len(group_name)
fig = make_subplots(
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
)
for k, gr in enumerate(group_name):
vals = np.asarray(gb.get_group(gr)[data_header], float)
plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], rugplot=rugplot)
for item in plot_data:
fig.append_trace(item, 1, k + 1)
# add violin plot labels
fig["layout"].update(
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
)
# set the sharey axis style
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
fig["layout"].update(
title=title,
showlegend=False,
hovermode="closest",
autosize=False,
height=height,
width=width,
)
return fig
def create_violin(
data,
data_header=None,
group_header=None,
colors=None,
use_colorscale=False,
group_stats=None,
rugplot=True,
sort=False,
height=450,
width=600,
title="Violin and Rug Plot",
):
"""
**deprecated**, use instead the plotly.graph_objects trace
:class:`plotly.graph_objects.Violin`.
:param (list|array) data: accepts either a list of numerical values,
a list of dictionaries all with identical keys and at least one
column of numeric values, or a pandas dataframe with at least one
column of numbers.
:param (str) data_header: the header of the data column to be used
from an inputted pandas dataframe. Not applicable if 'data' is
a list of numeric values.
:param (str) group_header: applicable if grouping data by a variable.
'group_header' must be set to the name of the grouping variable.
:param (str|tuple|list|dict) colors: either a plotly scale name,
an rgb or hex color, a color tuple, a list of colors or a
dictionary. An rgb color is of the form 'rgb(x, y, z)' where
x, y and z belong to the interval [0, 255] and a color tuple is a
tuple of the form (a, b, c) where a, b and c belong to [0, 1].
If colors is a list, it must contain valid color types as its
members.
:param (bool) use_colorscale: only applicable if grouping by another
variable. Will implement a colorscale based on the first 2 colors
of param colors. This means colors must be a list with at least 2
colors in it (Plotly colorscales are accepted since they map to a
list of two rgb colors). Default = False
:param (dict) group_stats: a dictionary where each key is a unique
value from the group_header column in data. Each value must be a
number and will be used to color the violin plots if a colorscale
is being used.
:param (bool) rugplot: determines if a rugplot is draw on violin plot.
Default = True
:param (bool) sort: determines if violins are sorted
alphabetically (True) or by input order (False). Default = False
:param (float) height: the height of the violin plot.
:param (float) width: the width of the violin plot.
:param (str) title: the title of the violin plot.
Example 1: Single Violin Plot
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> from scipy import stats
>>> # create list of random values
>>> data_list = np.random.randn(100)
>>> # create violin fig
>>> fig = create_violin(data_list, colors='#604d9e')
>>> # plot
>>> fig.show()
Example 2: Multiple Violin Plots with Qualitative Coloring
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> import pandas as pd
>>> from scipy import stats
>>> # create dataframe
>>> np.random.seed(619517)
>>> Nr=250
>>> y = np.random.randn(Nr)
>>> gr = np.random.choice(list("ABCDE"), Nr)
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
>>> for i, letter in enumerate("ABCDE"):
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
>>> # create violin fig
>>> fig = create_violin(df, data_header='Score', group_header='Group',
... sort=True, height=600, width=1000)
>>> # plot
>>> fig.show()
Example 3: Violin Plots with Colorscale
>>> from plotly.figure_factory import create_violin
>>> import plotly.graph_objs as graph_objects
>>> import numpy as np
>>> import pandas as pd
>>> from scipy import stats
>>> # create dataframe
>>> np.random.seed(619517)
>>> Nr=250
>>> y = np.random.randn(Nr)
>>> gr = np.random.choice(list("ABCDE"), Nr)
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
>>> for i, letter in enumerate("ABCDE"):
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
>>> # define header params
>>> data_header = 'Score'
>>> group_header = 'Group'
>>> # make groupby object with pandas
>>> group_stats = {}
>>> groupby_data = df.groupby([group_header])
>>> for group in "ABCDE":
... data_from_group = groupby_data.get_group(group)[data_header]
... # take a stat of the grouped data
... stat = np.median(data_from_group)
... # add to dictionary
... group_stats[group] = stat
>>> # create violin fig
>>> fig = create_violin(df, data_header='Score', group_header='Group',
... height=600, width=1000, use_colorscale=True,
... group_stats=group_stats)
>>> # plot
>>> fig.show()
"""
# Validate colors
if isinstance(colors, dict):
valid_colors = clrs.validate_colors_dict(colors, "rgb")
else:
valid_colors = clrs.validate_colors(colors, "rgb")
# validate data and choose plot type
if group_header is None:
if isinstance(data, list):
if len(data) <= 0:
raise exceptions.PlotlyError(
"If data is a list, it must be "
"nonempty and contain either "
"numbers or dictionaries."
)
if not all(isinstance(element, Number) for element in data):
raise exceptions.PlotlyError(
"If data is a list, it must contain only numbers."
)
if pd and isinstance(data, pd.core.frame.DataFrame):
if data_header is None:
raise exceptions.PlotlyError(
"data_header must be the "
"column name with the "
"desired numeric data for "
"the violin plot."
)
data = data[data_header].values.tolist()
# call the plotting functions
plot_data, plot_xrange = violinplot(
data, fillcolor=valid_colors[0], rugplot=rugplot
)
layout = graph_objs.Layout(
title=title,
autosize=False,
font=graph_objs.layout.Font(size=11),
height=height,
showlegend=False,
width=width,
xaxis=make_XAxis("", plot_xrange),
yaxis=make_YAxis(""),
hovermode="closest",
)
layout["yaxis"].update(dict(showline=False, showticklabels=False, ticks=""))
fig = graph_objs.Figure(data=plot_data, layout=layout)
return fig
else:
if not isinstance(data, pd.core.frame.DataFrame):
raise exceptions.PlotlyError(
"Error. You must use a pandas "
"DataFrame if you are using a "
"group header."
)
if data_header is None:
raise exceptions.PlotlyError(
"data_header must be the column "
"name with the desired numeric "
"data for the violin plot."
)
if use_colorscale is False:
if isinstance(valid_colors, dict):
# validate colors dict choice below
fig = violin_dict(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig
else:
fig = violin_no_colorscale(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig
else:
if isinstance(valid_colors, dict):
raise exceptions.PlotlyError(
"The colors param cannot be "
"a dictionary if you are "
"using a colorscale."
)
if len(valid_colors) < 2:
raise exceptions.PlotlyError(
"colors must be a list with "
"at least 2 colors. A "
"Plotly scale is allowed."
)
if not isinstance(group_stats, dict):
raise exceptions.PlotlyError(
"Your group_stats param must be a dictionary."
)
fig = violin_colorscale(
data,
data_header,
group_header,
valid_colors,
use_colorscale,
group_stats,
rugplot,
sort,
height,
width,
title,
)
return fig

View File

@ -0,0 +1,249 @@
from collections.abc import Sequence
from plotly import exceptions
def is_sequence(obj):
return isinstance(obj, Sequence) and not isinstance(obj, str)
def validate_index(index_vals):
"""
Validates if a list contains all numbers or all strings
:raises: (PlotlyError) If there are any two items in the list whose
types differ
"""
from numbers import Number
if isinstance(index_vals[0], Number):
if not all(isinstance(item, Number) for item in index_vals):
raise exceptions.PlotlyError(
"Error in indexing column. "
"Make sure all entries of each "
"column are all numbers or "
"all strings."
)
elif isinstance(index_vals[0], str):
if not all(isinstance(item, str) for item in index_vals):
raise exceptions.PlotlyError(
"Error in indexing column. "
"Make sure all entries of each "
"column are all numbers or "
"all strings."
)
def validate_dataframe(array):
"""
Validates all strings or numbers in each dataframe column
:raises: (PlotlyError) If there are any two items in any list whose
types differ
"""
from numbers import Number
for vector in array:
if isinstance(vector[0], Number):
if not all(isinstance(item, Number) for item in vector):
raise exceptions.PlotlyError(
"Error in dataframe. "
"Make sure all entries of "
"each column are either "
"numbers or strings."
)
elif isinstance(vector[0], str):
if not all(isinstance(item, str) for item in vector):
raise exceptions.PlotlyError(
"Error in dataframe. "
"Make sure all entries of "
"each column are either "
"numbers or strings."
)
def validate_equal_length(*args):
"""
Validates that data lists or ndarrays are the same length.
:raises: (PlotlyError) If any data lists are not the same length.
"""
length = len(args[0])
if any(len(lst) != length for lst in args):
raise exceptions.PlotlyError(
"Oops! Your data lists or ndarrays should be the same length."
)
def validate_positive_scalars(**kwargs):
"""
Validates that all values given in key/val pairs are positive.
Accepts kwargs to improve Exception messages.
:raises: (PlotlyError) If any value is < 0 or raises.
"""
for key, val in kwargs.items():
try:
if val <= 0:
raise ValueError("{} must be > 0, got {}".format(key, val))
except TypeError:
raise exceptions.PlotlyError("{} must be a number, got {}".format(key, val))
def flatten(array):
"""
Uses list comprehension to flatten array
:param (array): An iterable to flatten
:raises (PlotlyError): If iterable is not nested.
:rtype (list): The flattened list.
"""
try:
return [item for sublist in array for item in sublist]
except TypeError:
raise exceptions.PlotlyError(
"Your data array could not be "
"flattened! Make sure your data is "
"entered as lists or ndarrays!"
)
def endpts_to_intervals(endpts):
"""
Returns a list of intervals for categorical colormaps
Accepts a list or tuple of sequentially increasing numbers and returns
a list representation of the mathematical intervals with these numbers
as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
:raises: (PlotlyError) If input is not a list or tuple
:raises: (PlotlyError) If the input contains a string
:raises: (PlotlyError) If any number does not increase after the
previous one in the sequence
"""
length = len(endpts)
# Check if endpts is a list or tuple
if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
raise exceptions.PlotlyError(
"The intervals_endpts argument must "
"be a list or tuple of a sequence "
"of increasing numbers."
)
# Check if endpts contains only numbers
for item in endpts:
if isinstance(item, str):
raise exceptions.PlotlyError(
"The intervals_endpts argument "
"must be a list or tuple of a "
"sequence of increasing "
"numbers."
)
# Check if numbers in endpts are increasing
for k in range(length - 1):
if endpts[k] >= endpts[k + 1]:
raise exceptions.PlotlyError(
"The intervals_endpts argument "
"must be a list or tuple of a "
"sequence of increasing "
"numbers."
)
else:
intervals = []
# add -inf to intervals
intervals.append([float("-inf"), endpts[0]])
for k in range(length - 1):
interval = []
interval.append(endpts[k])
interval.append(endpts[k + 1])
intervals.append(interval)
# add +inf to intervals
intervals.append([endpts[length - 1], float("inf")])
return intervals
def annotation_dict_for_label(
text,
lane,
num_of_lanes,
subplot_spacing,
row_col="col",
flipped=True,
right_side=True,
text_color="#0f0f0f",
):
"""
Returns annotation dict for label of n labels of a 1xn or nx1 subplot.
:param (str) text: the text for a label.
:param (int) lane: the label number for text. From 1 to n inclusive.
:param (int) num_of_lanes: the number 'n' of rows or columns in subplot.
:param (float) subplot_spacing: the value for the horizontal_spacing and
vertical_spacing params in your plotly.tools.make_subplots() call.
:param (str) row_col: choose whether labels are placed along rows or
columns.
:param (bool) flipped: flips text by 90 degrees. Text is printed
horizontally if set to True and row_col='row', or if False and
row_col='col'.
:param (bool) right_side: only applicable if row_col is set to 'row'.
:param (str) text_color: color of the text.
"""
temp = (1 - (num_of_lanes - 1) * subplot_spacing) / (num_of_lanes)
if not flipped:
xanchor = "center"
yanchor = "middle"
if row_col == "col":
x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
y = 1.03
textangle = 0
elif row_col == "row":
y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
x = 1.03
textangle = 90
else:
if row_col == "col":
xanchor = "center"
yanchor = "bottom"
x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
y = 1.0
textangle = 270
elif row_col == "row":
yanchor = "middle"
y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
if right_side:
x = 1.0
xanchor = "left"
else:
x = -0.01
xanchor = "right"
textangle = 0
annotation_dict = dict(
textangle=textangle,
xanchor=xanchor,
yanchor=yanchor,
x=x,
y=y,
showarrow=False,
xref="paper",
yref="paper",
text=text,
font=dict(size=13, color=text_color),
)
return annotation_dict
def list_of_options(iterable, conj="and", period=True):
"""
Returns an English listing of objects seperated by commas ','
For example, ['foo', 'bar', 'baz'] becomes 'foo, bar and baz'
if the conjunction 'and' is selected.
"""
if len(iterable) < 2:
raise exceptions.PlotlyError(
"Your list or tuple must contain at least 2 items."
)
template = (len(iterable) - 2) * "{}, " + "{} " + conj + " {}" + period * "."
return template.format(*iterable)