done
This commit is contained in:
@ -0,0 +1,155 @@
|
||||
from numbers import Number
|
||||
|
||||
import plotly.exceptions
|
||||
|
||||
import plotly.colors as clrs
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
|
||||
def make_linear_colorscale(colors):
|
||||
"""
|
||||
Makes a list of colors into a colorscale-acceptable form
|
||||
|
||||
For documentation regarding to the form of the output, see
|
||||
https://plot.ly/python/reference/#mesh3d-colorscale
|
||||
"""
|
||||
scale = 1.0 / (len(colors) - 1)
|
||||
return [[i * scale, color] for i, color in enumerate(colors)]
|
||||
|
||||
|
||||
def create_2d_density(
|
||||
x,
|
||||
y,
|
||||
colorscale="Earth",
|
||||
ncontours=20,
|
||||
hist_color=(0, 0, 0.5),
|
||||
point_color=(0, 0, 0.5),
|
||||
point_size=2,
|
||||
title="2D Density Plot",
|
||||
height=600,
|
||||
width=600,
|
||||
):
|
||||
"""
|
||||
**deprecated**, use instead
|
||||
:func:`plotly.express.density_heatmap`.
|
||||
|
||||
:param (list|array) x: x-axis data for plot generation
|
||||
:param (list|array) y: y-axis data for plot generation
|
||||
:param (str|tuple|list) colorscale: either a plotly scale name, an rgb
|
||||
or hex color, a color tuple or a list or tuple of colors. An rgb
|
||||
color is of the form 'rgb(x, y, z)' where x, y, z belong to the
|
||||
interval [0, 255] and a color tuple is a tuple of the form
|
||||
(a, b, c) where a, b and c belong to [0, 1]. If colormap is a
|
||||
list, it must contain the valid color types aforementioned as its
|
||||
members.
|
||||
:param (int) ncontours: the number of 2D contours to draw on the plot
|
||||
:param (str) hist_color: the color of the plotted histograms
|
||||
:param (str) point_color: the color of the scatter points
|
||||
:param (str) point_size: the color of the scatter points
|
||||
:param (str) title: set the title for the plot
|
||||
:param (float) height: the height of the chart
|
||||
:param (float) width: the width of the chart
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Example 1: Simple 2D Density Plot
|
||||
|
||||
>>> from plotly.figure_factory import create_2d_density
|
||||
>>> import numpy as np
|
||||
|
||||
>>> # Make data points
|
||||
>>> t = np.linspace(-1,1.2,2000)
|
||||
>>> x = (t**3)+(0.3*np.random.randn(2000))
|
||||
>>> y = (t**6)+(0.3*np.random.randn(2000))
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig = create_2d_density(x, y)
|
||||
|
||||
>>> # Plot the data
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: Using Parameters
|
||||
|
||||
>>> from plotly.figure_factory import create_2d_density
|
||||
|
||||
>>> import numpy as np
|
||||
|
||||
>>> # Make data points
|
||||
>>> t = np.linspace(-1,1.2,2000)
|
||||
>>> x = (t**3)+(0.3*np.random.randn(2000))
|
||||
>>> y = (t**6)+(0.3*np.random.randn(2000))
|
||||
|
||||
>>> # Create custom colorscale
|
||||
>>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
|
||||
... (1, 1, 0.2), (0.98,0.98,0.98)]
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig = create_2d_density(x, y, colorscale=colorscale,
|
||||
... hist_color='rgb(255, 237, 222)', point_size=3)
|
||||
|
||||
>>> # Plot the data
|
||||
>>> fig.show()
|
||||
"""
|
||||
|
||||
# validate x and y are filled with numbers only
|
||||
for array in [x, y]:
|
||||
if not all(isinstance(element, Number) for element in array):
|
||||
raise plotly.exceptions.PlotlyError(
|
||||
"All elements of your 'x' and 'y' lists must be numbers."
|
||||
)
|
||||
|
||||
# validate x and y are the same length
|
||||
if len(x) != len(y):
|
||||
raise plotly.exceptions.PlotlyError(
|
||||
"Both lists 'x' and 'y' must be the same length."
|
||||
)
|
||||
|
||||
colorscale = clrs.validate_colors(colorscale, "rgb")
|
||||
colorscale = make_linear_colorscale(colorscale)
|
||||
|
||||
# validate hist_color and point_color
|
||||
hist_color = clrs.validate_colors(hist_color, "rgb")
|
||||
point_color = clrs.validate_colors(point_color, "rgb")
|
||||
|
||||
trace1 = graph_objs.Scatter(
|
||||
x=x,
|
||||
y=y,
|
||||
mode="markers",
|
||||
name="points",
|
||||
marker=dict(color=point_color[0], size=point_size, opacity=0.4),
|
||||
)
|
||||
trace2 = graph_objs.Histogram2dContour(
|
||||
x=x,
|
||||
y=y,
|
||||
name="density",
|
||||
ncontours=ncontours,
|
||||
colorscale=colorscale,
|
||||
reversescale=True,
|
||||
showscale=False,
|
||||
)
|
||||
trace3 = graph_objs.Histogram(
|
||||
x=x, name="x density", marker=dict(color=hist_color[0]), yaxis="y2"
|
||||
)
|
||||
trace4 = graph_objs.Histogram(
|
||||
y=y, name="y density", marker=dict(color=hist_color[0]), xaxis="x2"
|
||||
)
|
||||
data = [trace1, trace2, trace3, trace4]
|
||||
|
||||
layout = graph_objs.Layout(
|
||||
showlegend=False,
|
||||
autosize=False,
|
||||
title=title,
|
||||
height=height,
|
||||
width=width,
|
||||
xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
|
||||
yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
|
||||
margin=dict(t=50),
|
||||
hovermode="closest",
|
||||
bargap=0,
|
||||
xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
|
||||
yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
|
||||
)
|
||||
|
||||
fig = graph_objs.Figure(data=data, layout=layout)
|
||||
return fig
|
@ -0,0 +1,69 @@
|
||||
# ruff: noqa: E402
|
||||
|
||||
from plotly import optional_imports
|
||||
|
||||
# Require that numpy exists for figure_factory
|
||||
np = optional_imports.get_module("numpy")
|
||||
if np is None:
|
||||
raise ImportError(
|
||||
"""\
|
||||
The figure factory module requires the numpy package"""
|
||||
)
|
||||
|
||||
|
||||
from plotly.figure_factory._2d_density import create_2d_density
|
||||
from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
|
||||
from plotly.figure_factory._bullet import create_bullet
|
||||
from plotly.figure_factory._candlestick import create_candlestick
|
||||
from plotly.figure_factory._dendrogram import create_dendrogram
|
||||
from plotly.figure_factory._distplot import create_distplot
|
||||
from plotly.figure_factory._facet_grid import create_facet_grid
|
||||
from plotly.figure_factory._gantt import create_gantt
|
||||
from plotly.figure_factory._ohlc import create_ohlc
|
||||
from plotly.figure_factory._quiver import create_quiver
|
||||
from plotly.figure_factory._scatterplot import create_scatterplotmatrix
|
||||
from plotly.figure_factory._streamline import create_streamline
|
||||
from plotly.figure_factory._table import create_table
|
||||
from plotly.figure_factory._trisurf import create_trisurf
|
||||
from plotly.figure_factory._violin import create_violin
|
||||
|
||||
if optional_imports.get_module("pandas") is not None:
|
||||
from plotly.figure_factory._county_choropleth import create_choropleth
|
||||
from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox
|
||||
else:
|
||||
|
||||
def create_choropleth(*args, **kwargs):
|
||||
raise ImportError("Please install pandas to use `create_choropleth`")
|
||||
|
||||
def create_hexbin_mapbox(*args, **kwargs):
|
||||
raise ImportError("Please install pandas to use `create_hexbin_mapbox`")
|
||||
|
||||
|
||||
if optional_imports.get_module("skimage") is not None:
|
||||
from plotly.figure_factory._ternary_contour import create_ternary_contour
|
||||
else:
|
||||
|
||||
def create_ternary_contour(*args, **kwargs):
|
||||
raise ImportError("Please install scikit-image to use `create_ternary_contour`")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_2d_density",
|
||||
"create_annotated_heatmap",
|
||||
"create_bullet",
|
||||
"create_candlestick",
|
||||
"create_choropleth",
|
||||
"create_dendrogram",
|
||||
"create_distplot",
|
||||
"create_facet_grid",
|
||||
"create_gantt",
|
||||
"create_hexbin_mapbox",
|
||||
"create_ohlc",
|
||||
"create_quiver",
|
||||
"create_scatterplotmatrix",
|
||||
"create_streamline",
|
||||
"create_table",
|
||||
"create_ternary_contour",
|
||||
"create_trisurf",
|
||||
"create_violin",
|
||||
]
|
@ -0,0 +1,307 @@
|
||||
import plotly.colors as clrs
|
||||
from plotly import exceptions, optional_imports
|
||||
from plotly.figure_factory import utils
|
||||
from plotly.graph_objs import graph_objs
|
||||
from plotly.validator_cache import ValidatorCache
|
||||
|
||||
# Optional imports, may be None for users that only use our core functionality.
|
||||
np = optional_imports.get_module("numpy")
|
||||
|
||||
|
||||
def validate_annotated_heatmap(z, x, y, annotation_text):
|
||||
"""
|
||||
Annotated-heatmap-specific validations
|
||||
|
||||
Check that if a text matrix is supplied, it has the same
|
||||
dimensions as the z matrix.
|
||||
|
||||
See FigureFactory.create_annotated_heatmap() for params
|
||||
|
||||
:raises: (PlotlyError) If z and text matrices do not have the same
|
||||
dimensions.
|
||||
"""
|
||||
if annotation_text is not None and isinstance(annotation_text, list):
|
||||
utils.validate_equal_length(z, annotation_text)
|
||||
for lst in range(len(z)):
|
||||
if len(z[lst]) != len(annotation_text[lst]):
|
||||
raise exceptions.PlotlyError(
|
||||
"z and text should have the same dimensions"
|
||||
)
|
||||
|
||||
if x:
|
||||
if len(x) != len(z[0]):
|
||||
raise exceptions.PlotlyError(
|
||||
"oops, the x list that you "
|
||||
"provided does not match the "
|
||||
"width of your z matrix "
|
||||
)
|
||||
|
||||
if y:
|
||||
if len(y) != len(z):
|
||||
raise exceptions.PlotlyError(
|
||||
"oops, the y list that you "
|
||||
"provided does not match the "
|
||||
"length of your z matrix "
|
||||
)
|
||||
|
||||
|
||||
def create_annotated_heatmap(
|
||||
z,
|
||||
x=None,
|
||||
y=None,
|
||||
annotation_text=None,
|
||||
colorscale="Plasma",
|
||||
font_colors=None,
|
||||
showscale=False,
|
||||
reversescale=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
**deprecated**, use instead
|
||||
:func:`plotly.express.imshow`.
|
||||
|
||||
Function that creates annotated heatmaps
|
||||
|
||||
This function adds annotations to each cell of the heatmap.
|
||||
|
||||
:param (list[list]|ndarray) z: z matrix to create heatmap.
|
||||
:param (list) x: x axis labels.
|
||||
:param (list) y: y axis labels.
|
||||
:param (list[list]|ndarray) annotation_text: Text strings for
|
||||
annotations. Should have the same dimensions as the z matrix. If no
|
||||
text is added, the values of the z matrix are annotated. Default =
|
||||
z matrix values.
|
||||
:param (list|str) colorscale: heatmap colorscale.
|
||||
:param (list) font_colors: List of two color strings: [min_text_color,
|
||||
max_text_color] where min_text_color is applied to annotations for
|
||||
heatmap values < (max_value - min_value)/2. If font_colors is not
|
||||
defined, the colors are defined logically as black or white
|
||||
depending on the heatmap's colorscale.
|
||||
:param (bool) showscale: Display colorscale. Default = False
|
||||
:param (bool) reversescale: Reverse colorscale. Default = False
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
|
||||
These kwargs describe other attributes about the annotated Heatmap
|
||||
trace such as the colorscale. For more information on valid kwargs
|
||||
call help(plotly.graph_objs.Heatmap)
|
||||
|
||||
Example 1: Simple annotated heatmap with default configuration
|
||||
|
||||
>>> import plotly.figure_factory as ff
|
||||
|
||||
>>> z = [[0.300000, 0.00000, 0.65, 0.300000],
|
||||
... [1, 0.100005, 0.45, 0.4300],
|
||||
... [0.300000, 0.00000, 0.65, 0.300000],
|
||||
... [1, 0.100005, 0.45, 0.00000]]
|
||||
|
||||
>>> fig = ff.create_annotated_heatmap(z)
|
||||
>>> fig.show()
|
||||
"""
|
||||
|
||||
# Avoiding mutables in the call signature
|
||||
font_colors = font_colors if font_colors is not None else []
|
||||
validate_annotated_heatmap(z, x, y, annotation_text)
|
||||
|
||||
# validate colorscale
|
||||
colorscale_validator = ValidatorCache.get_validator("heatmap", "colorscale")
|
||||
colorscale = colorscale_validator.validate_coerce(colorscale)
|
||||
|
||||
annotations = _AnnotatedHeatmap(
|
||||
z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
|
||||
).make_annotations()
|
||||
|
||||
if x or y:
|
||||
trace = dict(
|
||||
type="heatmap",
|
||||
z=z,
|
||||
x=x,
|
||||
y=y,
|
||||
colorscale=colorscale,
|
||||
showscale=showscale,
|
||||
reversescale=reversescale,
|
||||
**kwargs,
|
||||
)
|
||||
layout = dict(
|
||||
annotations=annotations,
|
||||
xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"),
|
||||
yaxis=dict(ticks="", dtick=1, ticksuffix=" "),
|
||||
)
|
||||
else:
|
||||
trace = dict(
|
||||
type="heatmap",
|
||||
z=z,
|
||||
colorscale=colorscale,
|
||||
showscale=showscale,
|
||||
reversescale=reversescale,
|
||||
**kwargs,
|
||||
)
|
||||
layout = dict(
|
||||
annotations=annotations,
|
||||
xaxis=dict(
|
||||
ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False
|
||||
),
|
||||
yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False),
|
||||
)
|
||||
|
||||
data = [trace]
|
||||
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
def to_rgb_color_list(color_str, default):
|
||||
color_str = color_str.strip()
|
||||
if color_str.startswith("rgb"):
|
||||
return [int(v) for v in color_str.strip("rgba()").split(",")]
|
||||
elif color_str.startswith("#"):
|
||||
return clrs.hex_to_rgb(color_str)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def should_use_black_text(background_color):
|
||||
return (
|
||||
background_color[0] * 0.299
|
||||
+ background_color[1] * 0.587
|
||||
+ background_color[2] * 0.114
|
||||
) > 186
|
||||
|
||||
|
||||
class _AnnotatedHeatmap(object):
|
||||
"""
|
||||
Refer to TraceFactory.create_annotated_heatmap() for docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
|
||||
):
|
||||
self.z = z
|
||||
if x:
|
||||
self.x = x
|
||||
else:
|
||||
self.x = range(len(z[0]))
|
||||
if y:
|
||||
self.y = y
|
||||
else:
|
||||
self.y = range(len(z))
|
||||
if annotation_text is not None:
|
||||
self.annotation_text = annotation_text
|
||||
else:
|
||||
self.annotation_text = self.z
|
||||
self.colorscale = colorscale
|
||||
self.reversescale = reversescale
|
||||
self.font_colors = font_colors
|
||||
|
||||
if np and isinstance(self.z, np.ndarray):
|
||||
self.zmin = np.amin(self.z)
|
||||
self.zmax = np.amax(self.z)
|
||||
else:
|
||||
self.zmin = min([v for row in self.z for v in row])
|
||||
self.zmax = max([v for row in self.z for v in row])
|
||||
|
||||
if kwargs.get("zmin", None) is not None:
|
||||
self.zmin = kwargs["zmin"]
|
||||
if kwargs.get("zmax", None) is not None:
|
||||
self.zmax = kwargs["zmax"]
|
||||
|
||||
self.zmid = (self.zmax + self.zmin) / 2
|
||||
|
||||
if kwargs.get("zmid", None) is not None:
|
||||
self.zmid = kwargs["zmid"]
|
||||
|
||||
def get_text_color(self):
|
||||
"""
|
||||
Get font color for annotations.
|
||||
|
||||
The annotated heatmap can feature two text colors: min_text_color and
|
||||
max_text_color. The min_text_color is applied to annotations for
|
||||
heatmap values < (max_value - min_value)/2. The user can define these
|
||||
two colors. Otherwise the colors are defined logically as black or
|
||||
white depending on the heatmap's colorscale.
|
||||
|
||||
:rtype (string, string) min_text_color, max_text_color: text
|
||||
color for annotations for heatmap values <
|
||||
(max_value - min_value)/2 and text color for annotations for
|
||||
heatmap values >= (max_value - min_value)/2
|
||||
"""
|
||||
# Plotly colorscales ranging from a lighter shade to a darker shade
|
||||
colorscales = [
|
||||
"Greys",
|
||||
"Greens",
|
||||
"Blues",
|
||||
"YIGnBu",
|
||||
"YIOrRd",
|
||||
"RdBu",
|
||||
"Picnic",
|
||||
"Jet",
|
||||
"Hot",
|
||||
"Blackbody",
|
||||
"Earth",
|
||||
"Electric",
|
||||
"Viridis",
|
||||
"Cividis",
|
||||
]
|
||||
# Plotly colorscales ranging from a darker shade to a lighter shade
|
||||
colorscales_reverse = ["Reds"]
|
||||
|
||||
white = "#FFFFFF"
|
||||
black = "#000000"
|
||||
if self.font_colors:
|
||||
min_text_color = self.font_colors[0]
|
||||
max_text_color = self.font_colors[-1]
|
||||
elif self.colorscale in colorscales and self.reversescale:
|
||||
min_text_color = black
|
||||
max_text_color = white
|
||||
elif self.colorscale in colorscales:
|
||||
min_text_color = white
|
||||
max_text_color = black
|
||||
elif self.colorscale in colorscales_reverse and self.reversescale:
|
||||
min_text_color = white
|
||||
max_text_color = black
|
||||
elif self.colorscale in colorscales_reverse:
|
||||
min_text_color = black
|
||||
max_text_color = white
|
||||
elif isinstance(self.colorscale, list):
|
||||
min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255])
|
||||
max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255])
|
||||
|
||||
# swap min/max colors if reverse scale
|
||||
if self.reversescale:
|
||||
min_col, max_col = max_col, min_col
|
||||
|
||||
if should_use_black_text(min_col):
|
||||
min_text_color = black
|
||||
else:
|
||||
min_text_color = white
|
||||
|
||||
if should_use_black_text(max_col):
|
||||
max_text_color = black
|
||||
else:
|
||||
max_text_color = white
|
||||
else:
|
||||
min_text_color = black
|
||||
max_text_color = black
|
||||
return min_text_color, max_text_color
|
||||
|
||||
def make_annotations(self):
|
||||
"""
|
||||
Get annotations for each cell of the heatmap with graph_objs.Annotation
|
||||
|
||||
:rtype (list[dict]) annotations: list of annotations for each cell of
|
||||
the heatmap
|
||||
"""
|
||||
min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
|
||||
annotations = []
|
||||
for n, row in enumerate(self.z):
|
||||
for m, val in enumerate(row):
|
||||
font_color = min_text_color if val < self.zmid else max_text_color
|
||||
annotations.append(
|
||||
graph_objs.layout.Annotation(
|
||||
text=str(self.annotation_text[n][m]),
|
||||
x=self.x[m],
|
||||
y=self.y[n],
|
||||
xref="x1",
|
||||
yref="y1",
|
||||
font=dict(color=font_color),
|
||||
showarrow=False,
|
||||
)
|
||||
)
|
||||
return annotations
|
366
lib/python3.11/site-packages/plotly/figure_factory/_bullet.py
Normal file
366
lib/python3.11/site-packages/plotly/figure_factory/_bullet.py
Normal file
@ -0,0 +1,366 @@
|
||||
import math
|
||||
|
||||
from plotly import exceptions, optional_imports
|
||||
import plotly.colors as clrs
|
||||
from plotly.figure_factory import utils
|
||||
|
||||
import plotly
|
||||
import plotly.graph_objs as go
|
||||
|
||||
pd = optional_imports.get_module("pandas")
|
||||
|
||||
|
||||
def _bullet(
|
||||
df,
|
||||
markers,
|
||||
measures,
|
||||
ranges,
|
||||
subtitles,
|
||||
titles,
|
||||
orientation,
|
||||
range_colors,
|
||||
measure_colors,
|
||||
horizontal_spacing,
|
||||
vertical_spacing,
|
||||
scatter_options,
|
||||
layout_options,
|
||||
):
|
||||
num_of_lanes = len(df)
|
||||
num_of_rows = num_of_lanes if orientation == "h" else 1
|
||||
num_of_cols = 1 if orientation == "h" else num_of_lanes
|
||||
if not horizontal_spacing:
|
||||
horizontal_spacing = 1.0 / num_of_lanes
|
||||
if not vertical_spacing:
|
||||
vertical_spacing = 1.0 / num_of_lanes
|
||||
fig = plotly.subplots.make_subplots(
|
||||
num_of_rows,
|
||||
num_of_cols,
|
||||
print_grid=False,
|
||||
horizontal_spacing=horizontal_spacing,
|
||||
vertical_spacing=vertical_spacing,
|
||||
)
|
||||
|
||||
# layout
|
||||
fig["layout"].update(
|
||||
dict(shapes=[]),
|
||||
title="Bullet Chart",
|
||||
height=600,
|
||||
width=1000,
|
||||
showlegend=False,
|
||||
barmode="stack",
|
||||
annotations=[],
|
||||
margin=dict(l=120 if orientation == "h" else 80),
|
||||
)
|
||||
|
||||
# update layout
|
||||
fig["layout"].update(layout_options)
|
||||
|
||||
if orientation == "h":
|
||||
width_axis = "yaxis"
|
||||
length_axis = "xaxis"
|
||||
else:
|
||||
width_axis = "xaxis"
|
||||
length_axis = "yaxis"
|
||||
|
||||
for key in fig["layout"]:
|
||||
if "xaxis" in key or "yaxis" in key:
|
||||
fig["layout"][key]["showgrid"] = False
|
||||
fig["layout"][key]["zeroline"] = False
|
||||
if length_axis in key:
|
||||
fig["layout"][key]["tickwidth"] = 1
|
||||
if width_axis in key:
|
||||
fig["layout"][key]["showticklabels"] = False
|
||||
fig["layout"][key]["range"] = [0, 1]
|
||||
|
||||
# narrow domain if 1 bar
|
||||
if num_of_lanes <= 1:
|
||||
fig["layout"][width_axis + "1"]["domain"] = [0.4, 0.6]
|
||||
|
||||
if not range_colors:
|
||||
range_colors = ["rgb(200, 200, 200)", "rgb(245, 245, 245)"]
|
||||
if not measure_colors:
|
||||
measure_colors = ["rgb(31, 119, 180)", "rgb(176, 196, 221)"]
|
||||
|
||||
for row in range(num_of_lanes):
|
||||
# ranges bars
|
||||
for idx in range(len(df.iloc[row]["ranges"])):
|
||||
inter_colors = clrs.n_colors(
|
||||
range_colors[0], range_colors[1], len(df.iloc[row]["ranges"]), "rgb"
|
||||
)
|
||||
x = (
|
||||
[sorted(df.iloc[row]["ranges"])[-1 - idx]]
|
||||
if orientation == "h"
|
||||
else [0]
|
||||
)
|
||||
y = (
|
||||
[0]
|
||||
if orientation == "h"
|
||||
else [sorted(df.iloc[row]["ranges"])[-1 - idx]]
|
||||
)
|
||||
bar = go.Bar(
|
||||
x=x,
|
||||
y=y,
|
||||
marker=dict(color=inter_colors[-1 - idx]),
|
||||
name="ranges",
|
||||
hoverinfo="x" if orientation == "h" else "y",
|
||||
orientation=orientation,
|
||||
width=2,
|
||||
base=0,
|
||||
xaxis="x{}".format(row + 1),
|
||||
yaxis="y{}".format(row + 1),
|
||||
)
|
||||
fig.add_trace(bar)
|
||||
|
||||
# measures bars
|
||||
for idx in range(len(df.iloc[row]["measures"])):
|
||||
inter_colors = clrs.n_colors(
|
||||
measure_colors[0],
|
||||
measure_colors[1],
|
||||
len(df.iloc[row]["measures"]),
|
||||
"rgb",
|
||||
)
|
||||
x = (
|
||||
[sorted(df.iloc[row]["measures"])[-1 - idx]]
|
||||
if orientation == "h"
|
||||
else [0.5]
|
||||
)
|
||||
y = (
|
||||
[0.5]
|
||||
if orientation == "h"
|
||||
else [sorted(df.iloc[row]["measures"])[-1 - idx]]
|
||||
)
|
||||
bar = go.Bar(
|
||||
x=x,
|
||||
y=y,
|
||||
marker=dict(color=inter_colors[-1 - idx]),
|
||||
name="measures",
|
||||
hoverinfo="x" if orientation == "h" else "y",
|
||||
orientation=orientation,
|
||||
width=0.4,
|
||||
base=0,
|
||||
xaxis="x{}".format(row + 1),
|
||||
yaxis="y{}".format(row + 1),
|
||||
)
|
||||
fig.add_trace(bar)
|
||||
|
||||
# markers
|
||||
x = df.iloc[row]["markers"] if orientation == "h" else [0.5]
|
||||
y = [0.5] if orientation == "h" else df.iloc[row]["markers"]
|
||||
markers = go.Scatter(
|
||||
x=x,
|
||||
y=y,
|
||||
name="markers",
|
||||
hoverinfo="x" if orientation == "h" else "y",
|
||||
xaxis="x{}".format(row + 1),
|
||||
yaxis="y{}".format(row + 1),
|
||||
**scatter_options,
|
||||
)
|
||||
|
||||
fig.add_trace(markers)
|
||||
|
||||
# titles and subtitles
|
||||
title = df.iloc[row]["titles"]
|
||||
if "subtitles" in df:
|
||||
subtitle = "<br>{}".format(df.iloc[row]["subtitles"])
|
||||
else:
|
||||
subtitle = ""
|
||||
label = "<b>{}</b>".format(title) + subtitle
|
||||
annot = utils.annotation_dict_for_label(
|
||||
label,
|
||||
(num_of_lanes - row if orientation == "h" else row + 1),
|
||||
num_of_lanes,
|
||||
vertical_spacing if orientation == "h" else horizontal_spacing,
|
||||
"row" if orientation == "h" else "col",
|
||||
True if orientation == "h" else False,
|
||||
False,
|
||||
)
|
||||
fig["layout"]["annotations"] += (annot,)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_bullet(
|
||||
data,
|
||||
markers=None,
|
||||
measures=None,
|
||||
ranges=None,
|
||||
subtitles=None,
|
||||
titles=None,
|
||||
orientation="h",
|
||||
range_colors=("rgb(200, 200, 200)", "rgb(245, 245, 245)"),
|
||||
measure_colors=("rgb(31, 119, 180)", "rgb(176, 196, 221)"),
|
||||
horizontal_spacing=None,
|
||||
vertical_spacing=None,
|
||||
scatter_options={},
|
||||
**layout_options,
|
||||
):
|
||||
"""
|
||||
**deprecated**, use instead the plotly.graph_objects trace
|
||||
:class:`plotly.graph_objects.Indicator`.
|
||||
|
||||
:param (pd.DataFrame | list | tuple) data: either a list/tuple of
|
||||
dictionaries or a pandas DataFrame.
|
||||
:param (str) markers: the column name or dictionary key for the markers in
|
||||
each subplot.
|
||||
:param (str) measures: the column name or dictionary key for the measure
|
||||
bars in each subplot. This bar usually represents the quantitative
|
||||
measure of performance, usually a list of two values [a, b] and are
|
||||
the blue bars in the foreground of each subplot by default.
|
||||
:param (str) ranges: the column name or dictionary key for the qualitative
|
||||
ranges of performance, usually a 3-item list [bad, okay, good]. They
|
||||
correspond to the grey bars in the background of each chart.
|
||||
:param (str) subtitles: the column name or dictionary key for the subtitle
|
||||
of each subplot chart. The subplots are displayed right underneath
|
||||
each title.
|
||||
:param (str) titles: the column name or dictionary key for the main label
|
||||
of each subplot chart.
|
||||
:param (bool) orientation: if 'h', the bars are placed horizontally as
|
||||
rows. If 'v' the bars are placed vertically in the chart.
|
||||
:param (list) range_colors: a tuple of two colors between which all
|
||||
the rectangles for the range are drawn. These rectangles are meant to
|
||||
be qualitative indicators against which the marker and measure bars
|
||||
are compared.
|
||||
Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)')
|
||||
:param (list) measure_colors: a tuple of two colors which is used to color
|
||||
the thin quantitative bars in the bullet chart.
|
||||
Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)')
|
||||
:param (float) horizontal_spacing: see the 'horizontal_spacing' param in
|
||||
plotly.tools.make_subplots. Ranges between 0 and 1.
|
||||
:param (float) vertical_spacing: see the 'vertical_spacing' param in
|
||||
plotly.tools.make_subplots. Ranges between 0 and 1.
|
||||
:param (dict) scatter_options: describes attributes for the scatter trace
|
||||
in each subplot such as name and marker size. Call
|
||||
help(plotly.graph_objs.Scatter) for more information on valid params.
|
||||
:param layout_options: describes attributes for the layout of the figure
|
||||
such as title, height and width. Call help(plotly.graph_objs.Layout)
|
||||
for more information on valid params.
|
||||
|
||||
Example 1: Use a Dictionary
|
||||
|
||||
>>> import plotly.figure_factory as ff
|
||||
|
||||
>>> data = [
|
||||
... {"label": "revenue", "sublabel": "us$, in thousands",
|
||||
... "range": [150, 225, 300], "performance": [220,270], "point": [250]},
|
||||
... {"label": "Profit", "sublabel": "%", "range": [20, 25, 30],
|
||||
... "performance": [21, 23], "point": [26]},
|
||||
... {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600],
|
||||
... "performance": [100,320],"point": [550]},
|
||||
... {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500],
|
||||
... "performance": [1000, 1650],"point": [2100]},
|
||||
... {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5],
|
||||
... "performance": [3.2, 4.7], "point": [4.4]}
|
||||
... ]
|
||||
|
||||
>>> fig = ff.create_bullet(
|
||||
... data, titles='label', subtitles='sublabel', markers='point',
|
||||
... measures='performance', ranges='range', orientation='h',
|
||||
... title='my simple bullet chart'
|
||||
... )
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: Use a DataFrame with Custom Colors
|
||||
|
||||
>>> import plotly.figure_factory as ff
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json')
|
||||
|
||||
>>> fig = ff.create_bullet(
|
||||
... data, titles='title', markers='markers', measures='measures',
|
||||
... orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'],
|
||||
... scatter_options={'marker': {'symbol': 'circle'}}, width=700)
|
||||
>>> fig.show()
|
||||
"""
|
||||
# validate df
|
||||
if not pd:
|
||||
raise ImportError("'pandas' must be installed for this figure factory.")
|
||||
|
||||
if utils.is_sequence(data):
|
||||
if not all(isinstance(item, dict) for item in data):
|
||||
raise exceptions.PlotlyError(
|
||||
"Every entry of the data argument list, tuple, etc must "
|
||||
"be a dictionary."
|
||||
)
|
||||
|
||||
elif not isinstance(data, pd.DataFrame):
|
||||
raise exceptions.PlotlyError(
|
||||
"You must input a pandas DataFrame, or a list of dictionaries."
|
||||
)
|
||||
|
||||
# make DataFrame from data with correct column headers
|
||||
col_names = ["titles", "subtitle", "markers", "measures", "ranges"]
|
||||
if utils.is_sequence(data):
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
[d[titles] for d in data] if titles else [""] * len(data),
|
||||
[d[subtitles] for d in data] if subtitles else [""] * len(data),
|
||||
[d[markers] for d in data] if markers else [[]] * len(data),
|
||||
[d[measures] for d in data] if measures else [[]] * len(data),
|
||||
[d[ranges] for d in data] if ranges else [[]] * len(data),
|
||||
],
|
||||
index=col_names,
|
||||
)
|
||||
elif isinstance(data, pd.DataFrame):
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
data[titles].tolist() if titles else [""] * len(data),
|
||||
data[subtitles].tolist() if subtitles else [""] * len(data),
|
||||
data[markers].tolist() if markers else [[]] * len(data),
|
||||
data[measures].tolist() if measures else [[]] * len(data),
|
||||
data[ranges].tolist() if ranges else [[]] * len(data),
|
||||
],
|
||||
index=col_names,
|
||||
)
|
||||
df = pd.DataFrame.transpose(df)
|
||||
|
||||
# make sure ranges, measures, 'markers' are not NAN or NONE
|
||||
for needed_key in ["ranges", "measures", "markers"]:
|
||||
for idx, r in enumerate(df[needed_key]):
|
||||
try:
|
||||
r_is_nan = math.isnan(r)
|
||||
if r_is_nan or r is None:
|
||||
df[needed_key][idx] = []
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# validate custom colors
|
||||
for colors_list in [range_colors, measure_colors]:
|
||||
if colors_list:
|
||||
if len(colors_list) != 2:
|
||||
raise exceptions.PlotlyError(
|
||||
"Both 'range_colors' or 'measure_colors' must be a list "
|
||||
"of two valid colors."
|
||||
)
|
||||
clrs.validate_colors(colors_list)
|
||||
colors_list = clrs.convert_colors_to_same_type(colors_list, "rgb")[0]
|
||||
|
||||
# default scatter options
|
||||
default_scatter = {
|
||||
"marker": {"size": 12, "symbol": "diamond-tall", "color": "rgb(0, 0, 0)"}
|
||||
}
|
||||
|
||||
if scatter_options == {}:
|
||||
scatter_options.update(default_scatter)
|
||||
else:
|
||||
# add default options to scatter_options if they are not present
|
||||
for k in default_scatter["marker"]:
|
||||
if k not in scatter_options["marker"]:
|
||||
scatter_options["marker"][k] = default_scatter["marker"][k]
|
||||
|
||||
fig = _bullet(
|
||||
df,
|
||||
markers,
|
||||
measures,
|
||||
ranges,
|
||||
subtitles,
|
||||
titles,
|
||||
orientation,
|
||||
range_colors,
|
||||
measure_colors,
|
||||
horizontal_spacing,
|
||||
vertical_spacing,
|
||||
scatter_options,
|
||||
layout_options,
|
||||
)
|
||||
|
||||
return fig
|
@ -0,0 +1,277 @@
|
||||
from plotly.figure_factory import utils
|
||||
from plotly.figure_factory._ohlc import (
|
||||
_DEFAULT_INCREASING_COLOR,
|
||||
_DEFAULT_DECREASING_COLOR,
|
||||
validate_ohlc,
|
||||
)
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
|
||||
def make_increasing_candle(open, high, low, close, dates, **kwargs):
|
||||
"""
|
||||
Makes boxplot trace for increasing candlesticks
|
||||
|
||||
_make_increasing_candle() and _make_decreasing_candle separate the
|
||||
increasing traces from the decreasing traces so kwargs (such as
|
||||
color) can be passed separately to increasing or decreasing traces
|
||||
when direction is set to 'increasing' or 'decreasing' in
|
||||
FigureFactory.create_candlestick()
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing values
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param kwargs: kwargs to be passed to increasing trace via
|
||||
plotly.graph_objs.Scatter.
|
||||
|
||||
:rtype (list) candle_incr_data: list of the box trace for
|
||||
increasing candlesticks.
|
||||
"""
|
||||
increase_x, increase_y = _Candlestick(
|
||||
open, high, low, close, dates, **kwargs
|
||||
).get_candle_increase()
|
||||
|
||||
if "line" in kwargs:
|
||||
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
|
||||
else:
|
||||
kwargs.setdefault("fillcolor", _DEFAULT_INCREASING_COLOR)
|
||||
if "name" in kwargs:
|
||||
kwargs.setdefault("showlegend", True)
|
||||
else:
|
||||
kwargs.setdefault("showlegend", False)
|
||||
kwargs.setdefault("name", "Increasing")
|
||||
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR))
|
||||
|
||||
candle_incr_data = dict(
|
||||
type="box",
|
||||
x=increase_x,
|
||||
y=increase_y,
|
||||
whiskerwidth=0,
|
||||
boxpoints=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [candle_incr_data]
|
||||
|
||||
|
||||
def make_decreasing_candle(open, high, low, close, dates, **kwargs):
|
||||
"""
|
||||
Makes boxplot trace for decreasing candlesticks
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing values
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param kwargs: kwargs to be passed to decreasing trace via
|
||||
plotly.graph_objs.Scatter.
|
||||
|
||||
:rtype (list) candle_decr_data: list of the box trace for
|
||||
decreasing candlesticks.
|
||||
"""
|
||||
|
||||
decrease_x, decrease_y = _Candlestick(
|
||||
open, high, low, close, dates, **kwargs
|
||||
).get_candle_decrease()
|
||||
|
||||
if "line" in kwargs:
|
||||
kwargs.setdefault("fillcolor", kwargs["line"]["color"])
|
||||
else:
|
||||
kwargs.setdefault("fillcolor", _DEFAULT_DECREASING_COLOR)
|
||||
kwargs.setdefault("showlegend", False)
|
||||
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR))
|
||||
kwargs.setdefault("name", "Decreasing")
|
||||
|
||||
candle_decr_data = dict(
|
||||
type="box",
|
||||
x=decrease_x,
|
||||
y=decrease_y,
|
||||
whiskerwidth=0,
|
||||
boxpoints=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [candle_decr_data]
|
||||
|
||||
|
||||
def create_candlestick(open, high, low, close, dates=None, direction="both", **kwargs):
|
||||
"""
|
||||
**deprecated**, use instead the plotly.graph_objects trace
|
||||
:class:`plotly.graph_objects.Candlestick`
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing values
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param (string) direction: direction can be 'increasing', 'decreasing',
|
||||
or 'both'. When the direction is 'increasing', the returned figure
|
||||
consists of all candlesticks where the close value is greater than
|
||||
the corresponding open value, and when the direction is
|
||||
'decreasing', the returned figure consists of all candlesticks
|
||||
where the close value is less than or equal to the corresponding
|
||||
open value. When the direction is 'both', both increasing and
|
||||
decreasing candlesticks are returned. Default: 'both'
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
|
||||
These kwargs describe other attributes about the ohlc Scatter trace
|
||||
such as the color or the legend name. For more information on valid
|
||||
kwargs call help(plotly.graph_objs.Scatter)
|
||||
|
||||
:rtype (dict): returns a representation of candlestick chart figure.
|
||||
|
||||
Example 1: Simple candlestick chart from a Pandas DataFrame
|
||||
|
||||
>>> from plotly.figure_factory import create_candlestick
|
||||
>>> from datetime import datetime
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
|
||||
>>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
|
||||
... dates=df.index)
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: Customize the candlestick colors
|
||||
|
||||
>>> from plotly.figure_factory import create_candlestick
|
||||
>>> from plotly.graph_objs import Line, Marker
|
||||
>>> from datetime import datetime
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
|
||||
|
||||
>>> # Make increasing candlesticks and customize their color and name
|
||||
>>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
|
||||
... dates=df.index,
|
||||
... direction='increasing', name='AAPL',
|
||||
... marker=Marker(color='rgb(150, 200, 250)'),
|
||||
... line=Line(color='rgb(150, 200, 250)'))
|
||||
|
||||
>>> # Make decreasing candlesticks and customize their color and name
|
||||
>>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
|
||||
... dates=df.index,
|
||||
... direction='decreasing',
|
||||
... marker=Marker(color='rgb(128, 128, 128)'),
|
||||
... line=Line(color='rgb(128, 128, 128)'))
|
||||
|
||||
>>> # Initialize the figure
|
||||
>>> fig = fig_increasing
|
||||
|
||||
>>> # Add decreasing data with .extend()
|
||||
>>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP
|
||||
>>> fig.show()
|
||||
|
||||
Example 3: Candlestick chart with datetime objects
|
||||
|
||||
>>> from plotly.figure_factory import create_candlestick
|
||||
|
||||
>>> from datetime import datetime
|
||||
|
||||
>>> # Add data
|
||||
>>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
|
||||
>>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
|
||||
>>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
|
||||
>>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
|
||||
>>> dates = [datetime(year=2013, month=10, day=10),
|
||||
... datetime(year=2013, month=11, day=10),
|
||||
... datetime(year=2013, month=12, day=10),
|
||||
... datetime(year=2014, month=1, day=10),
|
||||
... datetime(year=2014, month=2, day=10)]
|
||||
|
||||
>>> # Create ohlc
|
||||
>>> fig = create_candlestick(open_data, high_data,
|
||||
... low_data, close_data, dates=dates)
|
||||
>>> fig.show()
|
||||
"""
|
||||
if dates is not None:
|
||||
utils.validate_equal_length(open, high, low, close, dates)
|
||||
else:
|
||||
utils.validate_equal_length(open, high, low, close)
|
||||
validate_ohlc(open, high, low, close, direction, **kwargs)
|
||||
|
||||
if direction == "increasing":
|
||||
candle_incr_data = make_increasing_candle(
|
||||
open, high, low, close, dates, **kwargs
|
||||
)
|
||||
data = candle_incr_data
|
||||
elif direction == "decreasing":
|
||||
candle_decr_data = make_decreasing_candle(
|
||||
open, high, low, close, dates, **kwargs
|
||||
)
|
||||
data = candle_decr_data
|
||||
else:
|
||||
candle_incr_data = make_increasing_candle(
|
||||
open, high, low, close, dates, **kwargs
|
||||
)
|
||||
candle_decr_data = make_decreasing_candle(
|
||||
open, high, low, close, dates, **kwargs
|
||||
)
|
||||
data = candle_incr_data + candle_decr_data
|
||||
|
||||
layout = graph_objs.Layout()
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _Candlestick(object):
|
||||
"""
|
||||
Refer to FigureFactory.create_candlestick() for docstring.
|
||||
"""
|
||||
|
||||
def __init__(self, open, high, low, close, dates, **kwargs):
|
||||
self.open = open
|
||||
self.high = high
|
||||
self.low = low
|
||||
self.close = close
|
||||
if dates is not None:
|
||||
self.x = dates
|
||||
else:
|
||||
self.x = [x for x in range(len(self.open))]
|
||||
self.get_candle_increase()
|
||||
|
||||
def get_candle_increase(self):
|
||||
"""
|
||||
Separate increasing data from decreasing data.
|
||||
|
||||
The data is increasing when close value > open value
|
||||
and decreasing when the close value <= open value.
|
||||
"""
|
||||
increase_y = []
|
||||
increase_x = []
|
||||
for index in range(len(self.open)):
|
||||
if self.close[index] > self.open[index]:
|
||||
increase_y.append(self.low[index])
|
||||
increase_y.append(self.open[index])
|
||||
increase_y.append(self.close[index])
|
||||
increase_y.append(self.close[index])
|
||||
increase_y.append(self.close[index])
|
||||
increase_y.append(self.high[index])
|
||||
increase_x.append(self.x[index])
|
||||
|
||||
increase_x = [[x, x, x, x, x, x] for x in increase_x]
|
||||
increase_x = utils.flatten(increase_x)
|
||||
|
||||
return increase_x, increase_y
|
||||
|
||||
def get_candle_decrease(self):
|
||||
"""
|
||||
Separate increasing data from decreasing data.
|
||||
|
||||
The data is increasing when close value > open value
|
||||
and decreasing when the close value <= open value.
|
||||
"""
|
||||
decrease_y = []
|
||||
decrease_x = []
|
||||
for index in range(len(self.open)):
|
||||
if self.close[index] <= self.open[index]:
|
||||
decrease_y.append(self.low[index])
|
||||
decrease_y.append(self.open[index])
|
||||
decrease_y.append(self.close[index])
|
||||
decrease_y.append(self.close[index])
|
||||
decrease_y.append(self.close[index])
|
||||
decrease_y.append(self.high[index])
|
||||
decrease_x.append(self.x[index])
|
||||
|
||||
decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
|
||||
decrease_x = utils.flatten(decrease_x)
|
||||
|
||||
return decrease_x, decrease_y
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,395 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from plotly import exceptions, optional_imports
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
# Optional imports, may be None for users that only use our core functionality.
|
||||
np = optional_imports.get_module("numpy")
|
||||
scp = optional_imports.get_module("scipy")
|
||||
sch = optional_imports.get_module("scipy.cluster.hierarchy")
|
||||
scs = optional_imports.get_module("scipy.spatial")
|
||||
|
||||
|
||||
def create_dendrogram(
|
||||
X,
|
||||
orientation="bottom",
|
||||
labels=None,
|
||||
colorscale=None,
|
||||
distfun=None,
|
||||
linkagefun=lambda x: sch.linkage(x, "complete"),
|
||||
hovertext=None,
|
||||
color_threshold=None,
|
||||
):
|
||||
"""
|
||||
Function that returns a dendrogram Plotly figure object. This is a thin
|
||||
wrapper around scipy.cluster.hierarchy.dendrogram.
|
||||
|
||||
See also https://dash.plot.ly/dash-bio/clustergram.
|
||||
|
||||
:param (ndarray) X: Matrix of observations as array of arrays
|
||||
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
|
||||
:param (list) labels: List of axis category labels(observation labels)
|
||||
:param (list) colorscale: Optional colorscale for the dendrogram tree.
|
||||
Requires 8 colors to be specified, the 7th of
|
||||
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
|
||||
and 6th are used twice as often as the others.
|
||||
Given a shorter list, the missing values are
|
||||
replaced with defaults and with a longer list the
|
||||
extra values are ignored.
|
||||
:param (function) distfun: Function to compute the pairwise distance from
|
||||
the observations
|
||||
:param (function) linkagefun: Function to compute the linkage matrix from
|
||||
the pairwise distances
|
||||
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
|
||||
clusters
|
||||
:param (double) color_threshold: Value at which the separation of clusters will be made
|
||||
|
||||
Example 1: Simple bottom oriented dendrogram
|
||||
|
||||
>>> from plotly.figure_factory import create_dendrogram
|
||||
|
||||
>>> import numpy as np
|
||||
|
||||
>>> X = np.random.rand(10,10)
|
||||
>>> fig = create_dendrogram(X)
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: Dendrogram to put on the left of the heatmap
|
||||
|
||||
>>> from plotly.figure_factory import create_dendrogram
|
||||
|
||||
>>> import numpy as np
|
||||
|
||||
>>> X = np.random.rand(5,5)
|
||||
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
|
||||
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
|
||||
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
|
||||
>>> dendro.show()
|
||||
|
||||
Example 3: Dendrogram with Pandas
|
||||
|
||||
>>> from plotly.figure_factory import create_dendrogram
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
|
||||
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
|
||||
>>> fig = create_dendrogram(df, labels=Index)
|
||||
>>> fig.show()
|
||||
"""
|
||||
if not scp or not scs or not sch:
|
||||
raise ImportError(
|
||||
"FigureFactory.create_dendrogram requires scipy, \
|
||||
scipy.spatial and scipy.hierarchy"
|
||||
)
|
||||
|
||||
s = X.shape
|
||||
if len(s) != 2:
|
||||
exceptions.PlotlyError("X should be 2-dimensional array.")
|
||||
|
||||
if distfun is None:
|
||||
distfun = scs.distance.pdist
|
||||
|
||||
dendrogram = _Dendrogram(
|
||||
X,
|
||||
orientation,
|
||||
labels,
|
||||
colorscale,
|
||||
distfun=distfun,
|
||||
linkagefun=linkagefun,
|
||||
hovertext=hovertext,
|
||||
color_threshold=color_threshold,
|
||||
)
|
||||
|
||||
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
|
||||
|
||||
|
||||
class _Dendrogram(object):
|
||||
"""Refer to FigureFactory.create_dendrogram() for docstring."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
X,
|
||||
orientation="bottom",
|
||||
labels=None,
|
||||
colorscale=None,
|
||||
width=np.inf,
|
||||
height=np.inf,
|
||||
xaxis="xaxis",
|
||||
yaxis="yaxis",
|
||||
distfun=None,
|
||||
linkagefun=lambda x: sch.linkage(x, "complete"),
|
||||
hovertext=None,
|
||||
color_threshold=None,
|
||||
):
|
||||
self.orientation = orientation
|
||||
self.labels = labels
|
||||
self.xaxis = xaxis
|
||||
self.yaxis = yaxis
|
||||
self.data = []
|
||||
self.leaves = []
|
||||
self.sign = {self.xaxis: 1, self.yaxis: 1}
|
||||
self.layout = {self.xaxis: {}, self.yaxis: {}}
|
||||
|
||||
if self.orientation in ["left", "bottom"]:
|
||||
self.sign[self.xaxis] = 1
|
||||
else:
|
||||
self.sign[self.xaxis] = -1
|
||||
|
||||
if self.orientation in ["right", "bottom"]:
|
||||
self.sign[self.yaxis] = 1
|
||||
else:
|
||||
self.sign[self.yaxis] = -1
|
||||
|
||||
if distfun is None:
|
||||
distfun = scs.distance.pdist
|
||||
|
||||
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
|
||||
X, colorscale, distfun, linkagefun, hovertext, color_threshold
|
||||
)
|
||||
|
||||
self.labels = ordered_labels
|
||||
self.leaves = leaves
|
||||
yvals_flat = yvals.flatten()
|
||||
xvals_flat = xvals.flatten()
|
||||
|
||||
self.zero_vals = []
|
||||
|
||||
for i in range(len(yvals_flat)):
|
||||
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
|
||||
self.zero_vals.append(xvals_flat[i])
|
||||
|
||||
if len(self.zero_vals) > len(yvals) + 1:
|
||||
# If the length of zero_vals is larger than the length of yvals,
|
||||
# it means that there are wrong vals because of the identicial samples.
|
||||
# Three and more identicial samples will make the yvals of spliting
|
||||
# center into 0 and it will accidentally take it as leaves.
|
||||
l_border = int(min(self.zero_vals))
|
||||
r_border = int(max(self.zero_vals))
|
||||
correct_leaves_pos = range(
|
||||
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
|
||||
)
|
||||
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
|
||||
self.zero_vals = [v for v in correct_leaves_pos]
|
||||
|
||||
self.zero_vals.sort()
|
||||
self.layout = self.set_figure_layout(width, height)
|
||||
self.data = dd_traces
|
||||
|
||||
def get_color_dict(self, colorscale):
|
||||
"""
|
||||
Returns colorscale used for dendrogram tree clusters.
|
||||
|
||||
:param (list) colorscale: Colors to use for the plot in rgb format.
|
||||
:rtype (dict): A dict of default colors mapped to the user colorscale.
|
||||
|
||||
"""
|
||||
|
||||
# These are the color codes returned for dendrograms
|
||||
# We're replacing them with nicer colors
|
||||
# This list is the colors that can be used by dendrogram, which were
|
||||
# determined as the combination of the default above_threshold_color and
|
||||
# the default color palette (see scipy/cluster/hierarchy.py)
|
||||
d = {
|
||||
"r": "red",
|
||||
"g": "green",
|
||||
"b": "blue",
|
||||
"c": "cyan",
|
||||
"m": "magenta",
|
||||
"y": "yellow",
|
||||
"k": "black",
|
||||
# TODO: 'w' doesn't seem to be in the default color
|
||||
# palette in scipy/cluster/hierarchy.py
|
||||
"w": "white",
|
||||
}
|
||||
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
|
||||
|
||||
if colorscale is None:
|
||||
rgb_colorscale = [
|
||||
"rgb(0,116,217)", # blue
|
||||
"rgb(35,205,205)", # cyan
|
||||
"rgb(61,153,112)", # green
|
||||
"rgb(40,35,35)", # black
|
||||
"rgb(133,20,75)", # magenta
|
||||
"rgb(255,65,54)", # red
|
||||
"rgb(255,255,255)", # white
|
||||
"rgb(255,220,0)", # yellow
|
||||
]
|
||||
else:
|
||||
rgb_colorscale = colorscale
|
||||
|
||||
for i in range(len(default_colors.keys())):
|
||||
k = list(default_colors.keys())[i] # PY3 won't index keys
|
||||
if i < len(rgb_colorscale):
|
||||
default_colors[k] = rgb_colorscale[i]
|
||||
|
||||
# add support for cyclic format colors as introduced in scipy===1.5.0
|
||||
# before this, the colors were named 'r', 'b', 'y' etc., now they are
|
||||
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
|
||||
# scipy version, we try as much as possible to map the new colors to the
|
||||
# old colors
|
||||
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
|
||||
# comment above).
|
||||
new_old_color_map = [
|
||||
("C0", "b"),
|
||||
("C1", "g"),
|
||||
("C2", "r"),
|
||||
("C3", "c"),
|
||||
("C4", "m"),
|
||||
("C5", "y"),
|
||||
("C6", "k"),
|
||||
("C7", "g"),
|
||||
("C8", "r"),
|
||||
("C9", "c"),
|
||||
]
|
||||
for nc, oc in new_old_color_map:
|
||||
try:
|
||||
default_colors[nc] = default_colors[oc]
|
||||
except KeyError:
|
||||
# it could happen that the old color isn't found (if a custom
|
||||
# colorscale was specified), in this case we set it to an
|
||||
# arbitrary default.
|
||||
default_colors[nc] = "rgb(0,116,217)"
|
||||
|
||||
return default_colors
|
||||
|
||||
def set_axis_layout(self, axis_key):
|
||||
"""
|
||||
Sets and returns default axis object for dendrogram figure.
|
||||
|
||||
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
|
||||
:rtype (dict): An axis_key dictionary with set parameters.
|
||||
|
||||
"""
|
||||
axis_defaults = {
|
||||
"type": "linear",
|
||||
"ticks": "outside",
|
||||
"mirror": "allticks",
|
||||
"rangemode": "tozero",
|
||||
"showticklabels": True,
|
||||
"zeroline": False,
|
||||
"showgrid": False,
|
||||
"showline": True,
|
||||
}
|
||||
|
||||
if len(self.labels) != 0:
|
||||
axis_key_labels = self.xaxis
|
||||
if self.orientation in ["left", "right"]:
|
||||
axis_key_labels = self.yaxis
|
||||
if axis_key_labels not in self.layout:
|
||||
self.layout[axis_key_labels] = {}
|
||||
self.layout[axis_key_labels]["tickvals"] = [
|
||||
zv * self.sign[axis_key] for zv in self.zero_vals
|
||||
]
|
||||
self.layout[axis_key_labels]["ticktext"] = self.labels
|
||||
self.layout[axis_key_labels]["tickmode"] = "array"
|
||||
|
||||
self.layout[axis_key].update(axis_defaults)
|
||||
|
||||
return self.layout[axis_key]
|
||||
|
||||
def set_figure_layout(self, width, height):
|
||||
"""
|
||||
Sets and returns default layout object for dendrogram figure.
|
||||
|
||||
"""
|
||||
self.layout.update(
|
||||
{
|
||||
"showlegend": False,
|
||||
"autosize": False,
|
||||
"hovermode": "closest",
|
||||
"width": width,
|
||||
"height": height,
|
||||
}
|
||||
)
|
||||
|
||||
self.set_axis_layout(self.xaxis)
|
||||
self.set_axis_layout(self.yaxis)
|
||||
|
||||
return self.layout
|
||||
|
||||
def get_dendrogram_traces(
|
||||
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
|
||||
):
|
||||
"""
|
||||
Calculates all the elements needed for plotting a dendrogram.
|
||||
|
||||
:param (ndarray) X: Matrix of observations as array of arrays
|
||||
:param (list) colorscale: Color scale for dendrogram tree clusters
|
||||
:param (function) distfun: Function to compute the pairwise distance
|
||||
from the observations
|
||||
:param (function) linkagefun: Function to compute the linkage matrix
|
||||
from the pairwise distances
|
||||
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
|
||||
:rtype (tuple): Contains all the traces in the following order:
|
||||
(a) trace_list: List of Plotly trace objects for dendrogram tree
|
||||
(b) icoord: All X points of the dendrogram tree as array of arrays
|
||||
with length 4
|
||||
(c) dcoord: All Y points of the dendrogram tree as array of arrays
|
||||
with length 4
|
||||
(d) ordered_labels: leaf labels in the order they are going to
|
||||
appear on the plot
|
||||
(e) P['leaves']: left-to-right traversal of the leaves
|
||||
|
||||
"""
|
||||
d = distfun(X)
|
||||
Z = linkagefun(d)
|
||||
P = sch.dendrogram(
|
||||
Z,
|
||||
orientation=self.orientation,
|
||||
labels=self.labels,
|
||||
no_plot=True,
|
||||
color_threshold=color_threshold,
|
||||
)
|
||||
|
||||
icoord = np.array(P["icoord"])
|
||||
dcoord = np.array(P["dcoord"])
|
||||
ordered_labels = np.array(P["ivl"])
|
||||
color_list = np.array(P["color_list"])
|
||||
colors = self.get_color_dict(colorscale)
|
||||
|
||||
trace_list = []
|
||||
|
||||
for i in range(len(icoord)):
|
||||
# xs and ys are arrays of 4 points that make up the '∩' shapes
|
||||
# of the dendrogram tree
|
||||
if self.orientation in ["top", "bottom"]:
|
||||
xs = icoord[i]
|
||||
else:
|
||||
xs = dcoord[i]
|
||||
|
||||
if self.orientation in ["top", "bottom"]:
|
||||
ys = dcoord[i]
|
||||
else:
|
||||
ys = icoord[i]
|
||||
color_key = color_list[i]
|
||||
hovertext_label = None
|
||||
if hovertext:
|
||||
hovertext_label = hovertext[i]
|
||||
trace = dict(
|
||||
type="scatter",
|
||||
x=np.multiply(self.sign[self.xaxis], xs),
|
||||
y=np.multiply(self.sign[self.yaxis], ys),
|
||||
mode="lines",
|
||||
marker=dict(color=colors[color_key]),
|
||||
text=hovertext_label,
|
||||
hoverinfo="text",
|
||||
)
|
||||
|
||||
try:
|
||||
x_index = int(self.xaxis[-1])
|
||||
except ValueError:
|
||||
x_index = ""
|
||||
|
||||
try:
|
||||
y_index = int(self.yaxis[-1])
|
||||
except ValueError:
|
||||
y_index = ""
|
||||
|
||||
trace["xaxis"] = f"x{x_index}"
|
||||
trace["yaxis"] = f"y{y_index}"
|
||||
|
||||
trace_list.append(trace)
|
||||
|
||||
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
|
441
lib/python3.11/site-packages/plotly/figure_factory/_distplot.py
Normal file
441
lib/python3.11/site-packages/plotly/figure_factory/_distplot.py
Normal file
@ -0,0 +1,441 @@
|
||||
from plotly import exceptions, optional_imports
|
||||
from plotly.figure_factory import utils
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
# Optional imports, may be None for users that only use our core functionality.
|
||||
np = optional_imports.get_module("numpy")
|
||||
pd = optional_imports.get_module("pandas")
|
||||
scipy = optional_imports.get_module("scipy")
|
||||
scipy_stats = optional_imports.get_module("scipy.stats")
|
||||
|
||||
|
||||
DEFAULT_HISTNORM = "probability density"
|
||||
ALTERNATIVE_HISTNORM = "probability"
|
||||
|
||||
|
||||
def validate_distplot(hist_data, curve_type):
|
||||
"""
|
||||
Distplot-specific validations
|
||||
|
||||
:raises: (PlotlyError) If hist_data is not a list of lists
|
||||
:raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
|
||||
'normal').
|
||||
"""
|
||||
hist_data_types = (list,)
|
||||
if np:
|
||||
hist_data_types += (np.ndarray,)
|
||||
if pd:
|
||||
hist_data_types += (pd.core.series.Series,)
|
||||
|
||||
if not isinstance(hist_data[0], hist_data_types):
|
||||
raise exceptions.PlotlyError(
|
||||
"Oops, this function was written "
|
||||
"to handle multiple datasets, if "
|
||||
"you want to plot just one, make "
|
||||
"sure your hist_data variable is "
|
||||
"still a list of lists, i.e. x = "
|
||||
"[1, 2, 3] -> x = [[1, 2, 3]]"
|
||||
)
|
||||
|
||||
curve_opts = ("kde", "normal")
|
||||
if curve_type not in curve_opts:
|
||||
raise exceptions.PlotlyError("curve_type must be defined as 'kde' or 'normal'")
|
||||
|
||||
if not scipy:
|
||||
raise ImportError("FigureFactory.create_distplot requires scipy")
|
||||
|
||||
|
||||
def create_distplot(
|
||||
hist_data,
|
||||
group_labels,
|
||||
bin_size=1.0,
|
||||
curve_type="kde",
|
||||
colors=None,
|
||||
rug_text=None,
|
||||
histnorm=DEFAULT_HISTNORM,
|
||||
show_hist=True,
|
||||
show_curve=True,
|
||||
show_rug=True,
|
||||
):
|
||||
"""
|
||||
Function that creates a distplot similar to seaborn.distplot;
|
||||
**this function is deprecated**, use instead :mod:`plotly.express`
|
||||
functions, for example
|
||||
|
||||
>>> import plotly.express as px
|
||||
>>> tips = px.data.tips()
|
||||
>>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
|
||||
... hover_data=tips.columns)
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
The distplot can be composed of all or any combination of the following
|
||||
3 components: (1) histogram, (2) curve: (a) kernel density estimation
|
||||
or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
|
||||
(from multiple datasets) can be created in the same plot.
|
||||
|
||||
:param (list[list]) hist_data: Use list of lists to plot multiple data
|
||||
sets on the same plot.
|
||||
:param (list[str]) group_labels: Names for each data set.
|
||||
:param (list[float]|float) bin_size: Size of histogram bins.
|
||||
Default = 1.
|
||||
:param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
|
||||
:param (str) histnorm: 'probability density' or 'probability'
|
||||
Default = 'probability density'
|
||||
:param (bool) show_hist: Add histogram to distplot? Default = True
|
||||
:param (bool) show_curve: Add curve to distplot? Default = True
|
||||
:param (bool) show_rug: Add rug to distplot? Default = True
|
||||
:param (list[str]) colors: Colors for traces.
|
||||
:param (list[list]) rug_text: Hovertext values for rug_plot,
|
||||
:return (dict): Representation of a distplot figure.
|
||||
|
||||
Example 1: Simple distplot of 1 data set
|
||||
|
||||
>>> from plotly.figure_factory import create_distplot
|
||||
|
||||
>>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
|
||||
... 3.5, 4.1, 4.4, 4.5, 4.5,
|
||||
... 5.0, 5.0, 5.2, 5.5, 5.5,
|
||||
... 5.5, 5.5, 5.5, 6.1, 7.0]]
|
||||
>>> group_labels = ['distplot example']
|
||||
>>> fig = create_distplot(hist_data, group_labels)
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
Example 2: Two data sets and added rug text
|
||||
|
||||
>>> from plotly.figure_factory import create_distplot
|
||||
>>> # Add histogram data
|
||||
>>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
|
||||
... -0.9, -0.07, 1.95, 0.9, -0.2,
|
||||
... -0.5, 0.3, 0.4, -0.37, 0.6]
|
||||
>>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
|
||||
... 1.0, 0.8, 1.7, 0.5, 0.8,
|
||||
... -0.3, 1.2, 0.56, 0.3, 2.2]
|
||||
|
||||
>>> # Group data together
|
||||
>>> hist_data = [hist1_x, hist2_x]
|
||||
|
||||
>>> group_labels = ['2012', '2013']
|
||||
|
||||
>>> # Add text
|
||||
>>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
|
||||
... 'f1', 'g1', 'h1', 'i1', 'j1',
|
||||
... 'k1', 'l1', 'm1', 'n1', 'o1']
|
||||
|
||||
>>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
|
||||
... 'f2', 'g2', 'h2', 'i2', 'j2',
|
||||
... 'k2', 'l2', 'm2', 'n2', 'o2']
|
||||
|
||||
>>> # Group text together
|
||||
>>> rug_text_all = [rug_text_1, rug_text_2]
|
||||
|
||||
>>> # Create distplot
|
||||
>>> fig = create_distplot(
|
||||
... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
|
||||
|
||||
>>> # Add title
|
||||
>>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
Example 3: Plot with normal curve and hide rug plot
|
||||
|
||||
>>> from plotly.figure_factory import create_distplot
|
||||
>>> import numpy as np
|
||||
|
||||
>>> x1 = np.random.randn(190)
|
||||
>>> x2 = np.random.randn(200)+1
|
||||
>>> x3 = np.random.randn(200)-1
|
||||
>>> x4 = np.random.randn(210)+2
|
||||
|
||||
>>> hist_data = [x1, x2, x3, x4]
|
||||
>>> group_labels = ['2012', '2013', '2014', '2015']
|
||||
|
||||
>>> fig = create_distplot(
|
||||
... hist_data, group_labels, curve_type='normal',
|
||||
... show_rug=False, bin_size=.4)
|
||||
|
||||
|
||||
Example 4: Distplot with Pandas
|
||||
|
||||
>>> from plotly.figure_factory import create_distplot
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> df = pd.DataFrame({'2012': np.random.randn(200),
|
||||
... '2013': np.random.randn(200)+1})
|
||||
>>> fig = create_distplot([df[c] for c in df.columns], df.columns)
|
||||
>>> fig.show()
|
||||
"""
|
||||
if colors is None:
|
||||
colors = []
|
||||
if rug_text is None:
|
||||
rug_text = []
|
||||
|
||||
validate_distplot(hist_data, curve_type)
|
||||
utils.validate_equal_length(hist_data, group_labels)
|
||||
|
||||
if isinstance(bin_size, (float, int)):
|
||||
bin_size = [bin_size] * len(hist_data)
|
||||
|
||||
data = []
|
||||
if show_hist:
|
||||
hist = _Distplot(
|
||||
hist_data,
|
||||
histnorm,
|
||||
group_labels,
|
||||
bin_size,
|
||||
curve_type,
|
||||
colors,
|
||||
rug_text,
|
||||
show_hist,
|
||||
show_curve,
|
||||
).make_hist()
|
||||
|
||||
data.append(hist)
|
||||
|
||||
if show_curve:
|
||||
if curve_type == "normal":
|
||||
curve = _Distplot(
|
||||
hist_data,
|
||||
histnorm,
|
||||
group_labels,
|
||||
bin_size,
|
||||
curve_type,
|
||||
colors,
|
||||
rug_text,
|
||||
show_hist,
|
||||
show_curve,
|
||||
).make_normal()
|
||||
else:
|
||||
curve = _Distplot(
|
||||
hist_data,
|
||||
histnorm,
|
||||
group_labels,
|
||||
bin_size,
|
||||
curve_type,
|
||||
colors,
|
||||
rug_text,
|
||||
show_hist,
|
||||
show_curve,
|
||||
).make_kde()
|
||||
|
||||
data.append(curve)
|
||||
|
||||
if show_rug:
|
||||
rug = _Distplot(
|
||||
hist_data,
|
||||
histnorm,
|
||||
group_labels,
|
||||
bin_size,
|
||||
curve_type,
|
||||
colors,
|
||||
rug_text,
|
||||
show_hist,
|
||||
show_curve,
|
||||
).make_rug()
|
||||
|
||||
data.append(rug)
|
||||
layout = graph_objs.Layout(
|
||||
barmode="overlay",
|
||||
hovermode="closest",
|
||||
legend=dict(traceorder="reversed"),
|
||||
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
|
||||
yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
|
||||
yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
|
||||
)
|
||||
else:
|
||||
layout = graph_objs.Layout(
|
||||
barmode="overlay",
|
||||
hovermode="closest",
|
||||
legend=dict(traceorder="reversed"),
|
||||
xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
|
||||
yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
|
||||
)
|
||||
|
||||
data = sum(data, [])
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _Distplot(object):
|
||||
"""
|
||||
Refer to TraceFactory.create_distplot() for docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hist_data,
|
||||
histnorm,
|
||||
group_labels,
|
||||
bin_size,
|
||||
curve_type,
|
||||
colors,
|
||||
rug_text,
|
||||
show_hist,
|
||||
show_curve,
|
||||
):
|
||||
self.hist_data = hist_data
|
||||
self.histnorm = histnorm
|
||||
self.group_labels = group_labels
|
||||
self.bin_size = bin_size
|
||||
self.show_hist = show_hist
|
||||
self.show_curve = show_curve
|
||||
self.trace_number = len(hist_data)
|
||||
if rug_text:
|
||||
self.rug_text = rug_text
|
||||
else:
|
||||
self.rug_text = [None] * self.trace_number
|
||||
|
||||
self.start = []
|
||||
self.end = []
|
||||
if colors:
|
||||
self.colors = colors
|
||||
else:
|
||||
self.colors = [
|
||||
"rgb(31, 119, 180)",
|
||||
"rgb(255, 127, 14)",
|
||||
"rgb(44, 160, 44)",
|
||||
"rgb(214, 39, 40)",
|
||||
"rgb(148, 103, 189)",
|
||||
"rgb(140, 86, 75)",
|
||||
"rgb(227, 119, 194)",
|
||||
"rgb(127, 127, 127)",
|
||||
"rgb(188, 189, 34)",
|
||||
"rgb(23, 190, 207)",
|
||||
]
|
||||
self.curve_x = [None] * self.trace_number
|
||||
self.curve_y = [None] * self.trace_number
|
||||
|
||||
for trace in self.hist_data:
|
||||
self.start.append(min(trace) * 1.0)
|
||||
self.end.append(max(trace) * 1.0)
|
||||
|
||||
def make_hist(self):
|
||||
"""
|
||||
Makes the histogram(s) for FigureFactory.create_distplot().
|
||||
|
||||
:rtype (list) hist: list of histogram representations
|
||||
"""
|
||||
hist = [None] * self.trace_number
|
||||
|
||||
for index in range(self.trace_number):
|
||||
hist[index] = dict(
|
||||
type="histogram",
|
||||
x=self.hist_data[index],
|
||||
xaxis="x1",
|
||||
yaxis="y1",
|
||||
histnorm=self.histnorm,
|
||||
name=self.group_labels[index],
|
||||
legendgroup=self.group_labels[index],
|
||||
marker=dict(color=self.colors[index % len(self.colors)]),
|
||||
autobinx=False,
|
||||
xbins=dict(
|
||||
start=self.start[index],
|
||||
end=self.end[index],
|
||||
size=self.bin_size[index],
|
||||
),
|
||||
opacity=0.7,
|
||||
)
|
||||
return hist
|
||||
|
||||
def make_kde(self):
|
||||
"""
|
||||
Makes the kernel density estimation(s) for create_distplot().
|
||||
|
||||
This is called when curve_type = 'kde' in create_distplot().
|
||||
|
||||
:rtype (list) curve: list of kde representations
|
||||
"""
|
||||
curve = [None] * self.trace_number
|
||||
for index in range(self.trace_number):
|
||||
self.curve_x[index] = [
|
||||
self.start[index] + x * (self.end[index] - self.start[index]) / 500
|
||||
for x in range(500)
|
||||
]
|
||||
self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
|
||||
self.curve_x[index]
|
||||
)
|
||||
|
||||
if self.histnorm == ALTERNATIVE_HISTNORM:
|
||||
self.curve_y[index] *= self.bin_size[index]
|
||||
|
||||
for index in range(self.trace_number):
|
||||
curve[index] = dict(
|
||||
type="scatter",
|
||||
x=self.curve_x[index],
|
||||
y=self.curve_y[index],
|
||||
xaxis="x1",
|
||||
yaxis="y1",
|
||||
mode="lines",
|
||||
name=self.group_labels[index],
|
||||
legendgroup=self.group_labels[index],
|
||||
showlegend=False if self.show_hist else True,
|
||||
marker=dict(color=self.colors[index % len(self.colors)]),
|
||||
)
|
||||
return curve
|
||||
|
||||
def make_normal(self):
|
||||
"""
|
||||
Makes the normal curve(s) for create_distplot().
|
||||
|
||||
This is called when curve_type = 'normal' in create_distplot().
|
||||
|
||||
:rtype (list) curve: list of normal curve representations
|
||||
"""
|
||||
curve = [None] * self.trace_number
|
||||
mean = [None] * self.trace_number
|
||||
sd = [None] * self.trace_number
|
||||
|
||||
for index in range(self.trace_number):
|
||||
mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
|
||||
self.curve_x[index] = [
|
||||
self.start[index] + x * (self.end[index] - self.start[index]) / 500
|
||||
for x in range(500)
|
||||
]
|
||||
self.curve_y[index] = scipy_stats.norm.pdf(
|
||||
self.curve_x[index], loc=mean[index], scale=sd[index]
|
||||
)
|
||||
|
||||
if self.histnorm == ALTERNATIVE_HISTNORM:
|
||||
self.curve_y[index] *= self.bin_size[index]
|
||||
|
||||
for index in range(self.trace_number):
|
||||
curve[index] = dict(
|
||||
type="scatter",
|
||||
x=self.curve_x[index],
|
||||
y=self.curve_y[index],
|
||||
xaxis="x1",
|
||||
yaxis="y1",
|
||||
mode="lines",
|
||||
name=self.group_labels[index],
|
||||
legendgroup=self.group_labels[index],
|
||||
showlegend=False if self.show_hist else True,
|
||||
marker=dict(color=self.colors[index % len(self.colors)]),
|
||||
)
|
||||
return curve
|
||||
|
||||
def make_rug(self):
|
||||
"""
|
||||
Makes the rug plot(s) for create_distplot().
|
||||
|
||||
:rtype (list) rug: list of rug plot representations
|
||||
"""
|
||||
rug = [None] * self.trace_number
|
||||
for index in range(self.trace_number):
|
||||
rug[index] = dict(
|
||||
type="scatter",
|
||||
x=self.hist_data[index],
|
||||
y=([self.group_labels[index]] * len(self.hist_data[index])),
|
||||
xaxis="x1",
|
||||
yaxis="y2",
|
||||
mode="markers",
|
||||
name=self.group_labels[index],
|
||||
legendgroup=self.group_labels[index],
|
||||
showlegend=(False if self.show_hist or self.show_curve else True),
|
||||
text=self.rug_text[index],
|
||||
marker=dict(
|
||||
color=self.colors[index % len(self.colors)], symbol="line-ns-open"
|
||||
),
|
||||
)
|
||||
return rug
|
1195
lib/python3.11/site-packages/plotly/figure_factory/_facet_grid.py
Normal file
1195
lib/python3.11/site-packages/plotly/figure_factory/_facet_grid.py
Normal file
File diff suppressed because it is too large
Load Diff
1034
lib/python3.11/site-packages/plotly/figure_factory/_gantt.py
Normal file
1034
lib/python3.11/site-packages/plotly/figure_factory/_gantt.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,526 @@
|
||||
from plotly.express._core import build_dataframe
|
||||
from plotly.express._doc import make_docstring
|
||||
from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
|
||||
import narwhals.stable.v1 as nw
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _project_latlon_to_wgs84(lat, lon):
|
||||
"""
|
||||
Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
|
||||
"""
|
||||
x = lon * np.pi / 180
|
||||
y = np.arctanh(np.sin(lat * np.pi / 180))
|
||||
return x, y
|
||||
|
||||
|
||||
def _project_wgs84_to_latlon(x, y):
|
||||
"""
|
||||
Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
|
||||
"""
|
||||
lon = x * 180 / np.pi
|
||||
lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
|
||||
return lat, lon
|
||||
|
||||
|
||||
def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim):
|
||||
"""
|
||||
Get the mapbox zoom level given bounds and a figure dimension
|
||||
Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds
|
||||
"""
|
||||
|
||||
scale = (
|
||||
2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256
|
||||
)
|
||||
WORLD_DIM = {"height": 256 * scale, "width": 256 * scale}
|
||||
ZOOM_MAX = 18
|
||||
|
||||
def latRad(lat):
|
||||
sin = np.sin(lat * np.pi / 180)
|
||||
radX2 = np.log((1 + sin) / (1 - sin)) / 2
|
||||
return max(min(radX2, np.pi), -np.pi) / 2
|
||||
|
||||
def zoom(mapPx, worldPx, fraction):
|
||||
return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2)
|
||||
|
||||
latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi
|
||||
|
||||
lngDiff = lon_max - lon_min
|
||||
lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360
|
||||
|
||||
latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction)
|
||||
lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction)
|
||||
|
||||
return min(latZoom, lngZoom, ZOOM_MAX)
|
||||
|
||||
|
||||
def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count):
|
||||
"""
|
||||
Computes the aggregation at hexagonal bin level.
|
||||
Also defines the coordinates of the hexagons for plotting.
|
||||
The binning is inspired by matplotlib's implementation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
Array of x values (shape N)
|
||||
y : np.ndarray
|
||||
Array of y values (shape N)
|
||||
x_range : np.ndarray
|
||||
Min and max x (shape 2)
|
||||
y_range : np.ndarray
|
||||
Min and max y (shape 2)
|
||||
color : np.ndarray
|
||||
Metric to aggregate at hexagon level (shape N)
|
||||
nx : int
|
||||
Number of hexagons horizontally
|
||||
agg_func : function
|
||||
Numpy compatible aggregator, this function must take a one-dimensional
|
||||
np.ndarray as input and output a scalar
|
||||
min_count : int
|
||||
Minimum number of points in the hexagon for the hexagon to be displayed
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
X coordinates of each hexagon (shape M x 6)
|
||||
np.ndarray
|
||||
Y coordinates of each hexagon (shape M x 6)
|
||||
np.ndarray
|
||||
Centers of the hexagons (shape M x 2)
|
||||
np.ndarray
|
||||
Aggregated value in each hexagon (shape M)
|
||||
|
||||
"""
|
||||
xmin = x_range.min()
|
||||
xmax = x_range.max()
|
||||
ymin = y_range.min()
|
||||
ymax = y_range.max()
|
||||
|
||||
# In the x-direction, the hexagons exactly cover the region from
|
||||
# xmin to xmax. Need some padding to avoid roundoff errors.
|
||||
padding = 1.0e-9 * (xmax - xmin)
|
||||
xmin -= padding
|
||||
xmax += padding
|
||||
|
||||
Dx = xmax - xmin
|
||||
Dy = ymax - ymin
|
||||
if Dx == 0 and Dy > 0:
|
||||
dx = Dy / nx
|
||||
elif Dx == 0 and Dy == 0:
|
||||
dx, _ = _project_latlon_to_wgs84(1, 1)
|
||||
else:
|
||||
dx = Dx / nx
|
||||
dy = dx * np.sqrt(3)
|
||||
ny = np.ceil(Dy / dy).astype(int)
|
||||
|
||||
# Center the hexagons vertically since we only want regular hexagons
|
||||
ymin -= (ymin + dy * ny - ymax) / 2
|
||||
|
||||
x = (x - xmin) / dx
|
||||
y = (y - ymin) / dy
|
||||
ix1 = np.round(x).astype(int)
|
||||
iy1 = np.round(y).astype(int)
|
||||
ix2 = np.floor(x).astype(int)
|
||||
iy2 = np.floor(y).astype(int)
|
||||
|
||||
nx1 = nx + 1
|
||||
ny1 = ny + 1
|
||||
nx2 = nx
|
||||
ny2 = ny
|
||||
n = nx1 * ny1 + nx2 * ny2
|
||||
|
||||
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
|
||||
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
|
||||
bdist = d1 < d2
|
||||
|
||||
if color is None:
|
||||
lattice1 = np.zeros((nx1, ny1))
|
||||
lattice2 = np.zeros((nx2, ny2))
|
||||
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
|
||||
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
|
||||
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
|
||||
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
|
||||
if min_count is not None:
|
||||
lattice1[lattice1 < min_count] = np.nan
|
||||
lattice2[lattice2 < min_count] = np.nan
|
||||
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
|
||||
good_idxs = ~np.isnan(accum)
|
||||
else:
|
||||
if min_count is None:
|
||||
min_count = 1
|
||||
|
||||
# create accumulation arrays
|
||||
lattice1 = np.empty((nx1, ny1), dtype=object)
|
||||
for i in range(nx1):
|
||||
for j in range(ny1):
|
||||
lattice1[i, j] = []
|
||||
lattice2 = np.empty((nx2, ny2), dtype=object)
|
||||
for i in range(nx2):
|
||||
for j in range(ny2):
|
||||
lattice2[i, j] = []
|
||||
|
||||
for i in range(len(x)):
|
||||
if bdist[i]:
|
||||
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
|
||||
lattice1[ix1[i], iy1[i]].append(color[i])
|
||||
else:
|
||||
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
|
||||
lattice2[ix2[i], iy2[i]].append(color[i])
|
||||
|
||||
for i in range(nx1):
|
||||
for j in range(ny1):
|
||||
vals = lattice1[i, j]
|
||||
if len(vals) >= min_count:
|
||||
lattice1[i, j] = agg_func(vals)
|
||||
else:
|
||||
lattice1[i, j] = np.nan
|
||||
for i in range(nx2):
|
||||
for j in range(ny2):
|
||||
vals = lattice2[i, j]
|
||||
if len(vals) >= min_count:
|
||||
lattice2[i, j] = agg_func(vals)
|
||||
else:
|
||||
lattice2[i, j] = np.nan
|
||||
|
||||
accum = np.hstack(
|
||||
(lattice1.astype(float).ravel(), lattice2.astype(float).ravel())
|
||||
)
|
||||
good_idxs = ~np.isnan(accum)
|
||||
|
||||
agreggated_value = accum[good_idxs]
|
||||
|
||||
centers = np.zeros((n, 2), float)
|
||||
centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
|
||||
centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
|
||||
centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
|
||||
centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5
|
||||
centers[:, 0] *= dx
|
||||
centers[:, 1] *= dy
|
||||
centers[:, 0] += xmin
|
||||
centers[:, 1] += ymin
|
||||
centers = centers[good_idxs]
|
||||
|
||||
# Define normalised regular hexagon coordinates
|
||||
hx = [0, 0.5, 0.5, 0, -0.5, -0.5]
|
||||
hy = [
|
||||
-0.5 / np.cos(np.pi / 6),
|
||||
-0.5 * np.tan(np.pi / 6),
|
||||
0.5 * np.tan(np.pi / 6),
|
||||
0.5 / np.cos(np.pi / 6),
|
||||
0.5 * np.tan(np.pi / 6),
|
||||
-0.5 * np.tan(np.pi / 6),
|
||||
]
|
||||
|
||||
# Number of hexagons needed
|
||||
m = len(centers)
|
||||
|
||||
# Coordinates for all hexagonal patches
|
||||
hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0])
|
||||
hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1])
|
||||
|
||||
return hxs, hys, centers, agreggated_value
|
||||
|
||||
|
||||
def _compute_wgs84_hexbin(
|
||||
lat=None,
|
||||
lon=None,
|
||||
lat_range=None,
|
||||
lon_range=None,
|
||||
color=None,
|
||||
nx=None,
|
||||
agg_func=None,
|
||||
min_count=None,
|
||||
native_namespace=None,
|
||||
):
|
||||
"""
|
||||
Computes the lat-lon aggregation at hexagonal bin level.
|
||||
Latitude and longitude need to be projected to WGS84 before aggregating
|
||||
in order to display regular hexagons on the map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lat : np.ndarray
|
||||
Array of latitudes (shape N)
|
||||
lon : np.ndarray
|
||||
Array of longitudes (shape N)
|
||||
lat_range : np.ndarray
|
||||
Min and max latitudes (shape 2)
|
||||
lon_range : np.ndarray
|
||||
Min and max longitudes (shape 2)
|
||||
color : np.ndarray
|
||||
Metric to aggregate at hexagon level (shape N)
|
||||
nx : int
|
||||
Number of hexagons horizontally
|
||||
agg_func : function
|
||||
Numpy compatible aggregator, this function must take a one-dimensional
|
||||
np.ndarray as input and output a scalar
|
||||
min_count : int
|
||||
Minimum number of points in the hexagon for the hexagon to be displayed
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Lat coordinates of each hexagon (shape M x 6)
|
||||
np.ndarray
|
||||
Lon coordinates of each hexagon (shape M x 6)
|
||||
nw.Series
|
||||
Unique id for each hexagon, to be used in the geojson data (shape M)
|
||||
np.ndarray
|
||||
Aggregated value in each hexagon (shape M)
|
||||
|
||||
"""
|
||||
# Project to WGS 84
|
||||
x, y = _project_latlon_to_wgs84(lat, lon)
|
||||
|
||||
if lat_range is None:
|
||||
lat_range = np.array([lat.min(), lat.max()])
|
||||
if lon_range is None:
|
||||
lon_range = np.array([lon.min(), lon.max()])
|
||||
|
||||
x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
|
||||
|
||||
hxs, hys, centers, agreggated_value = _compute_hexbin(
|
||||
x, y, x_range, y_range, color, nx, agg_func, min_count
|
||||
)
|
||||
|
||||
# Convert back to lat-lon
|
||||
hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
|
||||
|
||||
# Create unique feature id based on hexagon center
|
||||
centers = centers.astype(str)
|
||||
hexagons_ids = (
|
||||
nw.from_dict(
|
||||
{"x1": centers[:, 0], "x2": centers[:, 1]},
|
||||
native_namespace=native_namespace,
|
||||
)
|
||||
.select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=","))
|
||||
.get_column("hexagons_ids")
|
||||
)
|
||||
|
||||
return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value
|
||||
|
||||
|
||||
def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
|
||||
"""
|
||||
Creates a geojson of hexagonal features based on the outputs of
|
||||
_compute_wgs84_hexbin
|
||||
"""
|
||||
features = []
|
||||
if ids is None:
|
||||
ids = np.arange(len(hexagons_lats))
|
||||
for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids):
|
||||
points = np.array([lon, lat]).T.tolist()
|
||||
points.append(points[0])
|
||||
features.append(
|
||||
dict(
|
||||
type="Feature",
|
||||
id=idx,
|
||||
geometry=dict(type="Polygon", coordinates=[points]),
|
||||
)
|
||||
)
|
||||
return dict(type="FeatureCollection", features=features)
|
||||
|
||||
|
||||
def create_hexbin_mapbox(
|
||||
data_frame=None,
|
||||
lat=None,
|
||||
lon=None,
|
||||
color=None,
|
||||
nx_hexagon=5,
|
||||
agg_func=None,
|
||||
animation_frame=None,
|
||||
color_discrete_sequence=None,
|
||||
color_discrete_map={},
|
||||
labels={},
|
||||
color_continuous_scale=None,
|
||||
range_color=None,
|
||||
color_continuous_midpoint=None,
|
||||
opacity=None,
|
||||
zoom=None,
|
||||
center=None,
|
||||
mapbox_style=None,
|
||||
title=None,
|
||||
template=None,
|
||||
width=None,
|
||||
height=None,
|
||||
min_count=None,
|
||||
show_original_data=False,
|
||||
original_data_marker=None,
|
||||
):
|
||||
"""
|
||||
Returns a figure aggregating scattered points into connected hexagons
|
||||
"""
|
||||
args = build_dataframe(args=locals(), constructor=None)
|
||||
native_namespace = nw.get_native_namespace(args["data_frame"])
|
||||
if agg_func is None:
|
||||
agg_func = np.mean
|
||||
|
||||
lat_range = (
|
||||
args["data_frame"]
|
||||
.select(
|
||||
nw.min(args["lat"]).name.suffix("_min"),
|
||||
nw.max(args["lat"]).name.suffix("_max"),
|
||||
)
|
||||
.to_numpy()
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
lon_range = (
|
||||
args["data_frame"]
|
||||
.select(
|
||||
nw.min(args["lon"]).name.suffix("_min"),
|
||||
nw.max(args["lon"]).name.suffix("_max"),
|
||||
)
|
||||
.to_numpy()
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
|
||||
lat=args["data_frame"].get_column(args["lat"]).to_numpy(),
|
||||
lon=args["data_frame"].get_column(args["lon"]).to_numpy(),
|
||||
lat_range=lat_range,
|
||||
lon_range=lon_range,
|
||||
color=None,
|
||||
nx=nx_hexagon,
|
||||
agg_func=agg_func,
|
||||
min_count=min_count,
|
||||
native_namespace=native_namespace,
|
||||
)
|
||||
|
||||
geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
|
||||
|
||||
if zoom is None:
|
||||
if height is None and width is None:
|
||||
mapDim = dict(height=450, width=450)
|
||||
elif height is None and width is not None:
|
||||
mapDim = dict(height=450, width=width)
|
||||
elif height is not None and width is None:
|
||||
mapDim = dict(height=height, width=height)
|
||||
else:
|
||||
mapDim = dict(height=height, width=width)
|
||||
zoom = _getBoundsZoomLevel(
|
||||
lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim
|
||||
)
|
||||
|
||||
if center is None:
|
||||
center = dict(lat=lat_range.mean(), lon=lon_range.mean())
|
||||
|
||||
if args["animation_frame"] is not None:
|
||||
groups = dict(
|
||||
args["data_frame"]
|
||||
.group_by(args["animation_frame"], drop_null_keys=True)
|
||||
.__iter__()
|
||||
)
|
||||
else:
|
||||
groups = {(0,): args["data_frame"]}
|
||||
|
||||
agg_data_frame_list = []
|
||||
for key, df in groups.items():
|
||||
_, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
|
||||
lat=df.get_column(args["lat"]).to_numpy(),
|
||||
lon=df.get_column(args["lon"]).to_numpy(),
|
||||
lat_range=lat_range,
|
||||
lon_range=lon_range,
|
||||
color=df.get_column(args["color"]).to_numpy() if args["color"] else None,
|
||||
nx=nx_hexagon,
|
||||
agg_func=agg_func,
|
||||
min_count=min_count,
|
||||
native_namespace=native_namespace,
|
||||
)
|
||||
agg_data_frame_list.append(
|
||||
nw.from_dict(
|
||||
{
|
||||
"frame": [key[0]] * len(hexagons_ids),
|
||||
"locations": hexagons_ids,
|
||||
"color": aggregated_value,
|
||||
},
|
||||
native_namespace=native_namespace,
|
||||
)
|
||||
)
|
||||
|
||||
agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns(
|
||||
color=nw.col("color").cast(nw.Int64)
|
||||
)
|
||||
|
||||
if range_color is None:
|
||||
range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]
|
||||
|
||||
fig = choropleth_mapbox(
|
||||
data_frame=agg_data_frame.to_native(),
|
||||
geojson=geojson,
|
||||
locations="locations",
|
||||
color="color",
|
||||
hover_data={"color": True, "locations": False, "frame": False},
|
||||
animation_frame=("frame" if args["animation_frame"] is not None else None),
|
||||
color_discrete_sequence=color_discrete_sequence,
|
||||
color_discrete_map=color_discrete_map,
|
||||
labels=labels,
|
||||
color_continuous_scale=color_continuous_scale,
|
||||
range_color=range_color,
|
||||
color_continuous_midpoint=color_continuous_midpoint,
|
||||
opacity=opacity,
|
||||
zoom=zoom,
|
||||
center=center,
|
||||
mapbox_style=mapbox_style,
|
||||
title=title,
|
||||
template=template,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
if show_original_data:
|
||||
original_fig = scatter_mapbox(
|
||||
data_frame=(
|
||||
args["data_frame"].sort(
|
||||
by=args["animation_frame"], descending=False, nulls_last=True
|
||||
)
|
||||
if args["animation_frame"] is not None
|
||||
else args["data_frame"]
|
||||
).to_native(),
|
||||
lat=args["lat"],
|
||||
lon=args["lon"],
|
||||
animation_frame=args["animation_frame"],
|
||||
)
|
||||
original_fig.data[0].hoverinfo = "skip"
|
||||
original_fig.data[0].hovertemplate = None
|
||||
original_fig.data[0].marker = original_data_marker
|
||||
|
||||
fig.add_trace(original_fig.data[0])
|
||||
|
||||
if args["animation_frame"] is not None:
|
||||
for i in range(len(original_fig.frames)):
|
||||
original_fig.frames[i].data[0].hoverinfo = "skip"
|
||||
original_fig.frames[i].data[0].hovertemplate = None
|
||||
original_fig.frames[i].data[0].marker = original_data_marker
|
||||
|
||||
fig.frames[i].data = [
|
||||
fig.frames[i].data[0],
|
||||
original_fig.frames[i].data[0],
|
||||
]
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
create_hexbin_mapbox.__doc__ = make_docstring(
|
||||
create_hexbin_mapbox,
|
||||
override_dict=dict(
|
||||
nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
|
||||
agg_func=[
|
||||
"function",
|
||||
"Numpy array aggregator, it must take as input a 1D array",
|
||||
"and output a scalar value.",
|
||||
],
|
||||
min_count=[
|
||||
"int",
|
||||
"Minimum number of points in a hexagon for it to be displayed.",
|
||||
"If None and color is not set, display all hexagons.",
|
||||
"If None and color is set, only display hexagons that contain points.",
|
||||
],
|
||||
show_original_data=[
|
||||
"bool",
|
||||
"Whether to show the original data on top of the hexbin aggregation.",
|
||||
],
|
||||
original_data_marker=["dict", "Scattermapbox marker options."],
|
||||
),
|
||||
)
|
295
lib/python3.11/site-packages/plotly/figure_factory/_ohlc.py
Normal file
295
lib/python3.11/site-packages/plotly/figure_factory/_ohlc.py
Normal file
@ -0,0 +1,295 @@
|
||||
from plotly import exceptions
|
||||
from plotly.graph_objs import graph_objs
|
||||
from plotly.figure_factory import utils
|
||||
|
||||
|
||||
# Default colours for finance charts
|
||||
_DEFAULT_INCREASING_COLOR = "#3D9970" # http://clrs.cc
|
||||
_DEFAULT_DECREASING_COLOR = "#FF4136"
|
||||
|
||||
|
||||
def validate_ohlc(open, high, low, close, direction, **kwargs):
|
||||
"""
|
||||
ohlc and candlestick specific validations
|
||||
|
||||
Specifically, this checks that the high value is the greatest value and
|
||||
the low value is the lowest value in each unit.
|
||||
|
||||
See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
|
||||
for params
|
||||
|
||||
:raises: (PlotlyError) If the high value is not the greatest value in
|
||||
each unit.
|
||||
:raises: (PlotlyError) If the low value is not the lowest value in each
|
||||
unit.
|
||||
:raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
|
||||
"""
|
||||
for lst in [open, low, close]:
|
||||
for index in range(len(high)):
|
||||
if high[index] < lst[index]:
|
||||
raise exceptions.PlotlyError(
|
||||
"Oops! Looks like some of "
|
||||
"your high values are less "
|
||||
"the corresponding open, "
|
||||
"low, or close values. "
|
||||
"Double check that your data "
|
||||
"is entered in O-H-L-C order"
|
||||
)
|
||||
|
||||
for lst in [open, high, close]:
|
||||
for index in range(len(low)):
|
||||
if low[index] > lst[index]:
|
||||
raise exceptions.PlotlyError(
|
||||
"Oops! Looks like some of "
|
||||
"your low values are greater "
|
||||
"than the corresponding high"
|
||||
", open, or close values. "
|
||||
"Double check that your data "
|
||||
"is entered in O-H-L-C order"
|
||||
)
|
||||
|
||||
direction_opts = ("increasing", "decreasing", "both")
|
||||
if direction not in direction_opts:
|
||||
raise exceptions.PlotlyError(
|
||||
"direction must be defined as 'increasing', 'decreasing', or 'both'"
|
||||
)
|
||||
|
||||
|
||||
def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
|
||||
"""
|
||||
Makes increasing ohlc sticks
|
||||
|
||||
_make_increasing_ohlc() and _make_decreasing_ohlc separate the
|
||||
increasing trace from the decreasing trace so kwargs (such as
|
||||
color) can be passed separately to increasing or decreasing traces
|
||||
when direction is set to 'increasing' or 'decreasing' in
|
||||
FigureFactory.create_candlestick()
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing values
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param kwargs: kwargs to be passed to increasing trace via
|
||||
plotly.graph_objs.Scatter.
|
||||
|
||||
:rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
|
||||
sticks.
|
||||
"""
|
||||
(flat_increase_x, flat_increase_y, text_increase) = _OHLC(
|
||||
open, high, low, close, dates
|
||||
).get_increase()
|
||||
|
||||
if "name" in kwargs:
|
||||
showlegend = True
|
||||
else:
|
||||
kwargs.setdefault("name", "Increasing")
|
||||
showlegend = False
|
||||
|
||||
kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR, width=1))
|
||||
kwargs.setdefault("text", text_increase)
|
||||
|
||||
ohlc_incr = dict(
|
||||
type="scatter",
|
||||
x=flat_increase_x,
|
||||
y=flat_increase_y,
|
||||
mode="lines",
|
||||
showlegend=showlegend,
|
||||
**kwargs,
|
||||
)
|
||||
return ohlc_incr
|
||||
|
||||
|
||||
def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
|
||||
"""
|
||||
Makes decreasing ohlc sticks
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing values
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param kwargs: kwargs to be passed to increasing trace via
|
||||
plotly.graph_objs.Scatter.
|
||||
|
||||
:rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
|
||||
sticks.
|
||||
"""
|
||||
(flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC(
|
||||
open, high, low, close, dates
|
||||
).get_decrease()
|
||||
|
||||
kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR, width=1))
|
||||
kwargs.setdefault("text", text_decrease)
|
||||
kwargs.setdefault("showlegend", False)
|
||||
kwargs.setdefault("name", "Decreasing")
|
||||
|
||||
ohlc_decr = dict(
|
||||
type="scatter", x=flat_decrease_x, y=flat_decrease_y, mode="lines", **kwargs
|
||||
)
|
||||
return ohlc_decr
|
||||
|
||||
|
||||
def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs):
|
||||
"""
|
||||
**deprecated**, use instead the plotly.graph_objects trace
|
||||
:class:`plotly.graph_objects.Ohlc`
|
||||
|
||||
:param (list) open: opening values
|
||||
:param (list) high: high values
|
||||
:param (list) low: low values
|
||||
:param (list) close: closing
|
||||
:param (list) dates: list of datetime objects. Default: None
|
||||
:param (string) direction: direction can be 'increasing', 'decreasing',
|
||||
or 'both'. When the direction is 'increasing', the returned figure
|
||||
consists of all units where the close value is greater than the
|
||||
corresponding open value, and when the direction is 'decreasing',
|
||||
the returned figure consists of all units where the close value is
|
||||
less than or equal to the corresponding open value. When the
|
||||
direction is 'both', both increasing and decreasing units are
|
||||
returned. Default: 'both'
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Scatter.
|
||||
These kwargs describe other attributes about the ohlc Scatter trace
|
||||
such as the color or the legend name. For more information on valid
|
||||
kwargs call help(plotly.graph_objs.Scatter)
|
||||
|
||||
:rtype (dict): returns a representation of an ohlc chart figure.
|
||||
|
||||
Example 1: Simple OHLC chart from a Pandas DataFrame
|
||||
|
||||
>>> from plotly.figure_factory import create_ohlc
|
||||
>>> from datetime import datetime
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
|
||||
>>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index)
|
||||
>>> fig.show()
|
||||
"""
|
||||
if dates is not None:
|
||||
utils.validate_equal_length(open, high, low, close, dates)
|
||||
else:
|
||||
utils.validate_equal_length(open, high, low, close)
|
||||
validate_ohlc(open, high, low, close, direction, **kwargs)
|
||||
|
||||
if direction == "increasing":
|
||||
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
|
||||
data = [ohlc_incr]
|
||||
elif direction == "decreasing":
|
||||
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
|
||||
data = [ohlc_decr]
|
||||
else:
|
||||
ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
|
||||
ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
|
||||
data = [ohlc_incr, ohlc_decr]
|
||||
|
||||
layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")
|
||||
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _OHLC(object):
|
||||
"""
|
||||
Refer to FigureFactory.create_ohlc_increase() for docstring.
|
||||
"""
|
||||
|
||||
def __init__(self, open, high, low, close, dates, **kwargs):
|
||||
self.open = open
|
||||
self.high = high
|
||||
self.low = low
|
||||
self.close = close
|
||||
self.empty = [None] * len(open)
|
||||
self.dates = dates
|
||||
|
||||
self.all_x = []
|
||||
self.all_y = []
|
||||
self.increase_x = []
|
||||
self.increase_y = []
|
||||
self.decrease_x = []
|
||||
self.decrease_y = []
|
||||
self.get_all_xy()
|
||||
self.separate_increase_decrease()
|
||||
|
||||
def get_all_xy(self):
|
||||
"""
|
||||
Zip data to create OHLC shape
|
||||
|
||||
OHLC shape: low to high vertical bar with
|
||||
horizontal branches for open and close values.
|
||||
If dates were added, the smallest date difference is calculated and
|
||||
multiplied by .2 to get the length of the open and close branches.
|
||||
If no date data was provided, the x-axis is a list of integers and the
|
||||
length of the open and close branches is .2.
|
||||
"""
|
||||
self.all_y = list(
|
||||
zip(
|
||||
self.open,
|
||||
self.open,
|
||||
self.high,
|
||||
self.low,
|
||||
self.close,
|
||||
self.close,
|
||||
self.empty,
|
||||
)
|
||||
)
|
||||
if self.dates is not None:
|
||||
date_dif = []
|
||||
for i in range(len(self.dates) - 1):
|
||||
date_dif.append(self.dates[i + 1] - self.dates[i])
|
||||
date_dif_min = (min(date_dif)) / 5
|
||||
self.all_x = [
|
||||
[x - date_dif_min, x, x, x, x, x + date_dif_min, None]
|
||||
for x in self.dates
|
||||
]
|
||||
else:
|
||||
self.all_x = [
|
||||
[x - 0.2, x, x, x, x, x + 0.2, None] for x in range(len(self.open))
|
||||
]
|
||||
|
||||
def separate_increase_decrease(self):
|
||||
"""
|
||||
Separate data into two groups: increase and decrease
|
||||
|
||||
(1) Increase, where close > open and
|
||||
(2) Decrease, where close <= open
|
||||
"""
|
||||
for index in range(len(self.open)):
|
||||
if self.close[index] is None:
|
||||
pass
|
||||
elif self.close[index] > self.open[index]:
|
||||
self.increase_x.append(self.all_x[index])
|
||||
self.increase_y.append(self.all_y[index])
|
||||
else:
|
||||
self.decrease_x.append(self.all_x[index])
|
||||
self.decrease_y.append(self.all_y[index])
|
||||
|
||||
def get_increase(self):
|
||||
"""
|
||||
Flatten increase data and get increase text
|
||||
|
||||
:rtype (list, list, list): flat_increase_x: x-values for the increasing
|
||||
trace, flat_increase_y: y=values for the increasing trace and
|
||||
text_increase: hovertext for the increasing trace
|
||||
"""
|
||||
flat_increase_x = utils.flatten(self.increase_x)
|
||||
flat_increase_y = utils.flatten(self.increase_y)
|
||||
text_increase = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
|
||||
len(self.increase_x)
|
||||
)
|
||||
|
||||
return flat_increase_x, flat_increase_y, text_increase
|
||||
|
||||
def get_decrease(self):
|
||||
"""
|
||||
Flatten decrease data and get decrease text
|
||||
|
||||
:rtype (list, list, list): flat_decrease_x: x-values for the decreasing
|
||||
trace, flat_decrease_y: y=values for the decreasing trace and
|
||||
text_decrease: hovertext for the decreasing trace
|
||||
"""
|
||||
flat_decrease_x = utils.flatten(self.decrease_x)
|
||||
flat_decrease_y = utils.flatten(self.decrease_y)
|
||||
text_decrease = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
|
||||
len(self.decrease_x)
|
||||
)
|
||||
|
||||
return flat_decrease_x, flat_decrease_y, text_decrease
|
265
lib/python3.11/site-packages/plotly/figure_factory/_quiver.py
Normal file
265
lib/python3.11/site-packages/plotly/figure_factory/_quiver.py
Normal file
@ -0,0 +1,265 @@
|
||||
import math
|
||||
|
||||
from plotly import exceptions
|
||||
from plotly.graph_objs import graph_objs
|
||||
from plotly.figure_factory import utils
|
||||
|
||||
|
||||
def create_quiver(
|
||||
x, y, u, v, scale=0.1, arrow_scale=0.3, angle=math.pi / 9, scaleratio=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Returns data for a quiver plot.
|
||||
|
||||
:param (list|ndarray) x: x coordinates of the arrow locations
|
||||
:param (list|ndarray) y: y coordinates of the arrow locations
|
||||
:param (list|ndarray) u: x components of the arrow vectors
|
||||
:param (list|ndarray) v: y components of the arrow vectors
|
||||
:param (float in [0,1]) scale: scales size of the arrows(ideally to
|
||||
avoid overlap). Default = .1
|
||||
:param (float in [0,1]) arrow_scale: value multiplied to length of barb
|
||||
to get length of arrowhead. Default = .3
|
||||
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
|
||||
:param (positive float) scaleratio: the ratio between the scale of the y-axis
|
||||
and the scale of the x-axis (scale_y / scale_x). Default = None, the
|
||||
scale ratio is not fixed.
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
|
||||
for more information on valid kwargs call
|
||||
help(plotly.graph_objs.Scatter)
|
||||
|
||||
:rtype (dict): returns a representation of quiver figure.
|
||||
|
||||
Example 1: Trivial Quiver
|
||||
|
||||
>>> from plotly.figure_factory import create_quiver
|
||||
>>> import math
|
||||
|
||||
>>> # 1 Arrow from (0,0) to (1,1)
|
||||
>>> fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
Example 2: Quiver plot using meshgrid
|
||||
|
||||
>>> from plotly.figure_factory import create_quiver
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import math
|
||||
|
||||
>>> # Add data
|
||||
>>> x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
|
||||
>>> u = np.cos(x)*y
|
||||
>>> v = np.sin(x)*y
|
||||
|
||||
>>> #Create quiver
|
||||
>>> fig = create_quiver(x, y, u, v)
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
Example 3: Styling the quiver plot
|
||||
|
||||
>>> from plotly.figure_factory import create_quiver
|
||||
>>> import numpy as np
|
||||
>>> import math
|
||||
|
||||
>>> # Add data
|
||||
>>> x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
|
||||
... np.arange(-math.pi, math.pi, .5))
|
||||
>>> u = np.cos(x)*y
|
||||
>>> v = np.sin(x)*y
|
||||
|
||||
>>> # Create quiver
|
||||
>>> fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
|
||||
... name='Wind Velocity', line=dict(width=1))
|
||||
|
||||
>>> # Add title to layout
|
||||
>>> fig.update_layout(title='Quiver Plot') # doctest: +SKIP
|
||||
>>> fig.show()
|
||||
|
||||
|
||||
Example 4: Forcing a fix scale ratio to maintain the arrow length
|
||||
|
||||
>>> from plotly.figure_factory import create_quiver
|
||||
>>> import numpy as np
|
||||
|
||||
>>> # Add data
|
||||
>>> x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
|
||||
>>> u = x
|
||||
>>> v = y
|
||||
>>> angle = np.arctan(v / u)
|
||||
>>> norm = 0.25
|
||||
>>> u = norm * np.cos(angle)
|
||||
>>> v = norm * np.sin(angle)
|
||||
|
||||
>>> # Create quiver with a fix scale ratio
|
||||
>>> fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)
|
||||
>>> fig.show()
|
||||
"""
|
||||
utils.validate_equal_length(x, y, u, v)
|
||||
utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
|
||||
|
||||
if scaleratio is None:
|
||||
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
|
||||
else:
|
||||
quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)
|
||||
|
||||
barb_x, barb_y = quiver_obj.get_barbs()
|
||||
arrow_x, arrow_y = quiver_obj.get_quiver_arrows()
|
||||
|
||||
quiver_plot = graph_objs.Scatter(
|
||||
x=barb_x + arrow_x, y=barb_y + arrow_y, mode="lines", **kwargs
|
||||
)
|
||||
|
||||
data = [quiver_plot]
|
||||
|
||||
if scaleratio is None:
|
||||
layout = graph_objs.Layout(hovermode="closest")
|
||||
else:
|
||||
layout = graph_objs.Layout(
|
||||
hovermode="closest", yaxis=dict(scaleratio=scaleratio, scaleanchor="x")
|
||||
)
|
||||
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _Quiver(object):
|
||||
"""
|
||||
Refer to FigureFactory.create_quiver() for docstring
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, u, v, scale, arrow_scale, angle, scaleratio=1, **kwargs):
|
||||
try:
|
||||
x = utils.flatten(x)
|
||||
except exceptions.PlotlyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
y = utils.flatten(y)
|
||||
except exceptions.PlotlyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
u = utils.flatten(u)
|
||||
except exceptions.PlotlyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
v = utils.flatten(v)
|
||||
except exceptions.PlotlyError:
|
||||
pass
|
||||
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.u = u
|
||||
self.v = v
|
||||
self.scale = scale
|
||||
self.scaleratio = scaleratio
|
||||
self.arrow_scale = arrow_scale
|
||||
self.angle = angle
|
||||
self.end_x = []
|
||||
self.end_y = []
|
||||
self.scale_uv()
|
||||
barb_x, barb_y = self.get_barbs()
|
||||
arrow_x, arrow_y = self.get_quiver_arrows()
|
||||
|
||||
def scale_uv(self):
|
||||
"""
|
||||
Scales u and v to avoid overlap of the arrows.
|
||||
|
||||
u and v are added to x and y to get the
|
||||
endpoints of the arrows so a smaller scale value will
|
||||
result in less overlap of arrows.
|
||||
"""
|
||||
self.u = [i * self.scale * self.scaleratio for i in self.u]
|
||||
self.v = [i * self.scale for i in self.v]
|
||||
|
||||
def get_barbs(self):
|
||||
"""
|
||||
Creates x and y startpoint and endpoint pairs
|
||||
|
||||
After finding the endpoint of each barb this zips startpoint and
|
||||
endpoint pairs to create 2 lists: x_values for barbs and y values
|
||||
for barbs
|
||||
|
||||
:rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
|
||||
x_value pairs separated by a None to create the barb of the arrow,
|
||||
and list of startpoint and endpoint y_value pairs separated by a
|
||||
None to create the barb of the arrow.
|
||||
"""
|
||||
self.end_x = [i + j for i, j in zip(self.x, self.u)]
|
||||
self.end_y = [i + j for i, j in zip(self.y, self.v)]
|
||||
empty = [None] * len(self.x)
|
||||
barb_x = utils.flatten(zip(self.x, self.end_x, empty))
|
||||
barb_y = utils.flatten(zip(self.y, self.end_y, empty))
|
||||
return barb_x, barb_y
|
||||
|
||||
def get_quiver_arrows(self):
|
||||
"""
|
||||
Creates lists of x and y values to plot the arrows
|
||||
|
||||
Gets length of each barb then calculates the length of each side of
|
||||
the arrow. Gets angle of barb and applies angle to each side of the
|
||||
arrowhead. Next uses arrow_scale to scale the length of arrowhead and
|
||||
creates x and y values for arrowhead point1 and point2. Finally x and y
|
||||
values for point1, endpoint and point2s for each arrowhead are
|
||||
separated by a None and zipped to create lists of x and y values for
|
||||
the arrows.
|
||||
|
||||
:rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
|
||||
x_values separated by a None to create the arrowhead and list of
|
||||
point1, endpoint, point2 y_values separated by a None to create
|
||||
the barb of the arrow.
|
||||
"""
|
||||
dif_x = [i - j for i, j in zip(self.end_x, self.x)]
|
||||
dif_y = [i - j for i, j in zip(self.end_y, self.y)]
|
||||
|
||||
# Get barb lengths(default arrow length = 30% barb length)
|
||||
barb_len = [None] * len(self.x)
|
||||
for index in range(len(barb_len)):
|
||||
barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index])
|
||||
|
||||
# Make arrow lengths
|
||||
arrow_len = [None] * len(self.x)
|
||||
arrow_len = [i * self.arrow_scale for i in barb_len]
|
||||
|
||||
# Get barb angles
|
||||
barb_ang = [None] * len(self.x)
|
||||
for index in range(len(barb_ang)):
|
||||
barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio)
|
||||
|
||||
# Set angles to create arrow
|
||||
ang1 = [i + self.angle for i in barb_ang]
|
||||
ang2 = [i - self.angle for i in barb_ang]
|
||||
|
||||
cos_ang1 = [None] * len(ang1)
|
||||
for index in range(len(ang1)):
|
||||
cos_ang1[index] = math.cos(ang1[index])
|
||||
seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
|
||||
|
||||
sin_ang1 = [None] * len(ang1)
|
||||
for index in range(len(ang1)):
|
||||
sin_ang1[index] = math.sin(ang1[index])
|
||||
seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
|
||||
|
||||
cos_ang2 = [None] * len(ang2)
|
||||
for index in range(len(ang2)):
|
||||
cos_ang2[index] = math.cos(ang2[index])
|
||||
seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
|
||||
|
||||
sin_ang2 = [None] * len(ang2)
|
||||
for index in range(len(ang2)):
|
||||
sin_ang2[index] = math.sin(ang2[index])
|
||||
seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
|
||||
|
||||
# Set coordinates to create arrow
|
||||
for index in range(len(self.end_x)):
|
||||
point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)]
|
||||
point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
|
||||
point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)]
|
||||
point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
|
||||
|
||||
# Combine lists to create arrow
|
||||
empty = [None] * len(self.end_x)
|
||||
arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
|
||||
arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
|
||||
return arrow_x, arrow_y
|
1135
lib/python3.11/site-packages/plotly/figure_factory/_scatterplot.py
Normal file
1135
lib/python3.11/site-packages/plotly/figure_factory/_scatterplot.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,406 @@
|
||||
import math
|
||||
|
||||
from plotly import exceptions, optional_imports
|
||||
from plotly.figure_factory import utils
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
np = optional_imports.get_module("numpy")
|
||||
|
||||
|
||||
def validate_streamline(x, y):
|
||||
"""
|
||||
Streamline-specific validations
|
||||
|
||||
Specifically, this checks that x and y are both evenly spaced,
|
||||
and that the package numpy is available.
|
||||
|
||||
See FigureFactory.create_streamline() for params
|
||||
|
||||
:raises: (ImportError) If numpy is not available.
|
||||
:raises: (PlotlyError) If x is not evenly spaced.
|
||||
:raises: (PlotlyError) If y is not evenly spaced.
|
||||
"""
|
||||
if np is False:
|
||||
raise ImportError("FigureFactory.create_streamline requires numpy")
|
||||
for index in range(len(x) - 1):
|
||||
if ((x[index + 1] - x[index]) - (x[1] - x[0])) > 0.0001:
|
||||
raise exceptions.PlotlyError(
|
||||
"x must be a 1 dimensional, evenly spaced array"
|
||||
)
|
||||
for index in range(len(y) - 1):
|
||||
if ((y[index + 1] - y[index]) - (y[1] - y[0])) > 0.0001:
|
||||
raise exceptions.PlotlyError(
|
||||
"y must be a 1 dimensional, evenly spaced array"
|
||||
)
|
||||
|
||||
|
||||
def create_streamline(
|
||||
x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=0.09, **kwargs
|
||||
):
|
||||
"""
|
||||
Returns data for a streamline plot.
|
||||
|
||||
:param (list|ndarray) x: 1 dimensional, evenly spaced list or array
|
||||
:param (list|ndarray) y: 1 dimensional, evenly spaced list or array
|
||||
:param (ndarray) u: 2 dimensional array
|
||||
:param (ndarray) v: 2 dimensional array
|
||||
:param (float|int) density: controls the density of streamlines in
|
||||
plot. This is multiplied by 30 to scale similiarly to other
|
||||
available streamline functions such as matplotlib.
|
||||
Default = 1
|
||||
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
|
||||
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
|
||||
Default = .09
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Scatter
|
||||
for more information on valid kwargs call
|
||||
help(plotly.graph_objs.Scatter)
|
||||
|
||||
:rtype (dict): returns a representation of streamline figure.
|
||||
|
||||
Example 1: Plot simple streamline and increase arrow size
|
||||
|
||||
>>> from plotly.figure_factory import create_streamline
|
||||
>>> import plotly.graph_objects as go
|
||||
>>> import numpy as np
|
||||
>>> import math
|
||||
|
||||
>>> # Add data
|
||||
>>> x = np.linspace(-3, 3, 100)
|
||||
>>> y = np.linspace(-3, 3, 100)
|
||||
>>> Y, X = np.meshgrid(x, y)
|
||||
>>> u = -1 - X**2 + Y
|
||||
>>> v = 1 + X - Y**2
|
||||
>>> u = u.T # Transpose
|
||||
>>> v = v.T # Transpose
|
||||
|
||||
>>> # Create streamline
|
||||
>>> fig = create_streamline(x, y, u, v, arrow_scale=.1)
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
|
||||
|
||||
>>> from plotly.figure_factory import create_streamline
|
||||
>>> import numpy as np
|
||||
>>> import math
|
||||
|
||||
>>> # Add data
|
||||
>>> N = 50
|
||||
>>> x_start, x_end = -2.0, 2.0
|
||||
>>> y_start, y_end = -1.0, 1.0
|
||||
>>> x = np.linspace(x_start, x_end, N)
|
||||
>>> y = np.linspace(y_start, y_end, N)
|
||||
>>> X, Y = np.meshgrid(x, y)
|
||||
>>> ss = 5.0
|
||||
>>> x_s, y_s = -1.0, 0.0
|
||||
|
||||
>>> # Compute the velocity field on the mesh grid
|
||||
>>> u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
|
||||
>>> v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
|
||||
|
||||
>>> # Create streamline
|
||||
>>> fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
|
||||
|
||||
>>> # Add source point
|
||||
>>> point = go.Scatter(x=[x_s], y=[y_s], mode='markers',
|
||||
... marker_size=14, name='source point')
|
||||
|
||||
>>> fig.add_trace(point) # doctest: +SKIP
|
||||
>>> fig.show()
|
||||
"""
|
||||
utils.validate_equal_length(x, y)
|
||||
utils.validate_equal_length(u, v)
|
||||
validate_streamline(x, y)
|
||||
utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
|
||||
|
||||
streamline_x, streamline_y = _Streamline(
|
||||
x, y, u, v, density, angle, arrow_scale
|
||||
).sum_streamlines()
|
||||
arrow_x, arrow_y = _Streamline(
|
||||
x, y, u, v, density, angle, arrow_scale
|
||||
).get_streamline_arrows()
|
||||
|
||||
streamline = graph_objs.Scatter(
|
||||
x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode="lines", **kwargs
|
||||
)
|
||||
|
||||
data = [streamline]
|
||||
layout = graph_objs.Layout(hovermode="closest")
|
||||
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _Streamline(object):
|
||||
"""
|
||||
Refer to FigureFactory.create_streamline() for docstring
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs):
|
||||
self.x = np.array(x)
|
||||
self.y = np.array(y)
|
||||
self.u = np.array(u)
|
||||
self.v = np.array(v)
|
||||
self.angle = angle
|
||||
self.arrow_scale = arrow_scale
|
||||
self.density = int(30 * density) # Scale similarly to other functions
|
||||
self.delta_x = self.x[1] - self.x[0]
|
||||
self.delta_y = self.y[1] - self.y[0]
|
||||
self.val_x = self.x
|
||||
self.val_y = self.y
|
||||
|
||||
# Set up spacing
|
||||
self.blank = np.zeros((self.density, self.density))
|
||||
self.spacing_x = len(self.x) / float(self.density - 1)
|
||||
self.spacing_y = len(self.y) / float(self.density - 1)
|
||||
self.trajectories = []
|
||||
|
||||
# Rescale speed onto axes-coordinates
|
||||
self.u = self.u / (self.x[-1] - self.x[0])
|
||||
self.v = self.v / (self.y[-1] - self.y[0])
|
||||
self.speed = np.sqrt(self.u**2 + self.v**2)
|
||||
|
||||
# Rescale u and v for integrations.
|
||||
self.u *= len(self.x)
|
||||
self.v *= len(self.y)
|
||||
self.st_x = []
|
||||
self.st_y = []
|
||||
self.get_streamlines()
|
||||
streamline_x, streamline_y = self.sum_streamlines()
|
||||
arrows_x, arrows_y = self.get_streamline_arrows()
|
||||
|
||||
def blank_pos(self, xi, yi):
|
||||
"""
|
||||
Set up positions for trajectories to be used with rk4 function.
|
||||
"""
|
||||
return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5))
|
||||
|
||||
def value_at(self, a, xi, yi):
|
||||
"""
|
||||
Set up for RK4 function, based on Bokeh's streamline code
|
||||
"""
|
||||
if isinstance(xi, np.ndarray):
|
||||
self.x = xi.astype(int)
|
||||
self.y = yi.astype(int)
|
||||
else:
|
||||
self.val_x = int(xi)
|
||||
self.val_y = int(yi)
|
||||
a00 = a[self.val_y, self.val_x]
|
||||
a01 = a[self.val_y, self.val_x + 1]
|
||||
a10 = a[self.val_y + 1, self.val_x]
|
||||
a11 = a[self.val_y + 1, self.val_x + 1]
|
||||
xt = xi - self.val_x
|
||||
yt = yi - self.val_y
|
||||
a0 = a00 * (1 - xt) + a01 * xt
|
||||
a1 = a10 * (1 - xt) + a11 * xt
|
||||
return a0 * (1 - yt) + a1 * yt
|
||||
|
||||
def rk4_integrate(self, x0, y0):
|
||||
"""
|
||||
RK4 forward and back trajectories from the initial conditions.
|
||||
|
||||
Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
|
||||
x and y trajectories then checks length of traj (s in units of axes)
|
||||
"""
|
||||
|
||||
def f(xi, yi):
|
||||
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
|
||||
ui = self.value_at(self.u, xi, yi)
|
||||
vi = self.value_at(self.v, xi, yi)
|
||||
return ui * dt_ds, vi * dt_ds
|
||||
|
||||
def g(xi, yi):
|
||||
dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
|
||||
ui = self.value_at(self.u, xi, yi)
|
||||
vi = self.value_at(self.v, xi, yi)
|
||||
return -ui * dt_ds, -vi * dt_ds
|
||||
|
||||
def check(xi, yi):
|
||||
return (0 <= xi < len(self.x) - 1) and (0 <= yi < len(self.y) - 1)
|
||||
|
||||
xb_changes = []
|
||||
yb_changes = []
|
||||
|
||||
def rk4(x0, y0, f):
|
||||
ds = 0.01
|
||||
stotal = 0
|
||||
xi = x0
|
||||
yi = y0
|
||||
xb, yb = self.blank_pos(xi, yi)
|
||||
xf_traj = []
|
||||
yf_traj = []
|
||||
while check(xi, yi):
|
||||
xf_traj.append(xi)
|
||||
yf_traj.append(yi)
|
||||
try:
|
||||
k1x, k1y = f(xi, yi)
|
||||
k2x, k2y = f(xi + 0.5 * ds * k1x, yi + 0.5 * ds * k1y)
|
||||
k3x, k3y = f(xi + 0.5 * ds * k2x, yi + 0.5 * ds * k2y)
|
||||
k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
|
||||
except IndexError:
|
||||
break
|
||||
xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.0
|
||||
yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.0
|
||||
if not check(xi, yi):
|
||||
break
|
||||
stotal += ds
|
||||
new_xb, new_yb = self.blank_pos(xi, yi)
|
||||
if new_xb != xb or new_yb != yb:
|
||||
if self.blank[new_yb, new_xb] == 0:
|
||||
self.blank[new_yb, new_xb] = 1
|
||||
xb_changes.append(new_xb)
|
||||
yb_changes.append(new_yb)
|
||||
xb = new_xb
|
||||
yb = new_yb
|
||||
else:
|
||||
break
|
||||
if stotal > 2:
|
||||
break
|
||||
return stotal, xf_traj, yf_traj
|
||||
|
||||
sf, xf_traj, yf_traj = rk4(x0, y0, f)
|
||||
sb, xb_traj, yb_traj = rk4(x0, y0, g)
|
||||
stotal = sf + sb
|
||||
x_traj = xb_traj[::-1] + xf_traj[1:]
|
||||
y_traj = yb_traj[::-1] + yf_traj[1:]
|
||||
|
||||
if len(x_traj) < 1:
|
||||
return None
|
||||
if stotal > 0.2:
|
||||
initxb, inityb = self.blank_pos(x0, y0)
|
||||
self.blank[inityb, initxb] = 1
|
||||
return x_traj, y_traj
|
||||
else:
|
||||
for xb, yb in zip(xb_changes, yb_changes):
|
||||
self.blank[yb, xb] = 0
|
||||
return None
|
||||
|
||||
def traj(self, xb, yb):
|
||||
"""
|
||||
Integrate trajectories
|
||||
|
||||
:param (int) xb: results of passing xi through self.blank_pos
|
||||
:param (int) xy: results of passing yi through self.blank_pos
|
||||
|
||||
Calculate each trajectory based on rk4 integrate method.
|
||||
"""
|
||||
|
||||
if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
|
||||
return
|
||||
if self.blank[yb, xb] == 0:
|
||||
t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
|
||||
if t is not None:
|
||||
self.trajectories.append(t)
|
||||
|
||||
def get_streamlines(self):
|
||||
"""
|
||||
Get streamlines by building trajectory set.
|
||||
"""
|
||||
for indent in range(self.density // 2):
|
||||
for xi in range(self.density - 2 * indent):
|
||||
self.traj(xi + indent, indent)
|
||||
self.traj(xi + indent, self.density - 1 - indent)
|
||||
self.traj(indent, xi + indent)
|
||||
self.traj(self.density - 1 - indent, xi + indent)
|
||||
|
||||
self.st_x = [
|
||||
np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories
|
||||
]
|
||||
self.st_y = [
|
||||
np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories
|
||||
]
|
||||
|
||||
for index in range(len(self.st_x)):
|
||||
self.st_x[index] = self.st_x[index].tolist()
|
||||
self.st_x[index].append(np.nan)
|
||||
|
||||
for index in range(len(self.st_y)):
|
||||
self.st_y[index] = self.st_y[index].tolist()
|
||||
self.st_y[index].append(np.nan)
|
||||
|
||||
def get_streamline_arrows(self):
|
||||
"""
|
||||
Makes an arrow for each streamline.
|
||||
|
||||
Gets angle of streamline at 1/3 mark and creates arrow coordinates
|
||||
based off of user defined angle and arrow_scale.
|
||||
|
||||
:param (array) st_x: x-values for all streamlines
|
||||
:param (array) st_y: y-values for all streamlines
|
||||
:param (angle in radians) angle: angle of arrowhead. Default = pi/9
|
||||
:param (float in [0,1]) arrow_scale: value to scale length of arrowhead
|
||||
Default = .09
|
||||
:rtype (list, list) arrows_x: x-values to create arrowhead and
|
||||
arrows_y: y-values to create arrowhead
|
||||
"""
|
||||
arrow_end_x = np.empty((len(self.st_x)))
|
||||
arrow_end_y = np.empty((len(self.st_y)))
|
||||
arrow_start_x = np.empty((len(self.st_x)))
|
||||
arrow_start_y = np.empty((len(self.st_y)))
|
||||
for index in range(len(self.st_x)):
|
||||
arrow_end_x[index] = self.st_x[index][int(len(self.st_x[index]) / 3)]
|
||||
arrow_start_x[index] = self.st_x[index][
|
||||
(int(len(self.st_x[index]) / 3)) - 1
|
||||
]
|
||||
arrow_end_y[index] = self.st_y[index][int(len(self.st_y[index]) / 3)]
|
||||
arrow_start_y[index] = self.st_y[index][
|
||||
(int(len(self.st_y[index]) / 3)) - 1
|
||||
]
|
||||
|
||||
dif_x = arrow_end_x - arrow_start_x
|
||||
dif_y = arrow_end_y - arrow_start_y
|
||||
|
||||
orig_err = np.geterr()
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
streamline_ang = np.arctan(dif_y / dif_x)
|
||||
np.seterr(**orig_err)
|
||||
|
||||
ang1 = streamline_ang + (self.angle)
|
||||
ang2 = streamline_ang - (self.angle)
|
||||
|
||||
seg1_x = np.cos(ang1) * self.arrow_scale
|
||||
seg1_y = np.sin(ang1) * self.arrow_scale
|
||||
seg2_x = np.cos(ang2) * self.arrow_scale
|
||||
seg2_y = np.sin(ang2) * self.arrow_scale
|
||||
|
||||
point1_x = np.empty((len(dif_x)))
|
||||
point1_y = np.empty((len(dif_y)))
|
||||
point2_x = np.empty((len(dif_x)))
|
||||
point2_y = np.empty((len(dif_y)))
|
||||
|
||||
for index in range(len(dif_x)):
|
||||
if dif_x[index] >= 0:
|
||||
point1_x[index] = arrow_end_x[index] - seg1_x[index]
|
||||
point1_y[index] = arrow_end_y[index] - seg1_y[index]
|
||||
point2_x[index] = arrow_end_x[index] - seg2_x[index]
|
||||
point2_y[index] = arrow_end_y[index] - seg2_y[index]
|
||||
else:
|
||||
point1_x[index] = arrow_end_x[index] + seg1_x[index]
|
||||
point1_y[index] = arrow_end_y[index] + seg1_y[index]
|
||||
point2_x[index] = arrow_end_x[index] + seg2_x[index]
|
||||
point2_y[index] = arrow_end_y[index] + seg2_y[index]
|
||||
|
||||
space = np.empty((len(point1_x)))
|
||||
space[:] = np.nan
|
||||
|
||||
# Combine arrays into array
|
||||
arrows_x = np.array([point1_x, arrow_end_x, point2_x, space])
|
||||
arrows_x = arrows_x.flatten("F")
|
||||
arrows_x = arrows_x.tolist()
|
||||
|
||||
# Combine arrays into array
|
||||
arrows_y = np.array([point1_y, arrow_end_y, point2_y, space])
|
||||
arrows_y = arrows_y.flatten("F")
|
||||
arrows_y = arrows_y.tolist()
|
||||
|
||||
return arrows_x, arrows_y
|
||||
|
||||
def sum_streamlines(self):
|
||||
"""
|
||||
Makes all streamlines readable as a single trace.
|
||||
|
||||
:rtype (list, list): streamline_x: all x values for each streamline
|
||||
combined into single list and streamline_y: all y values for each
|
||||
streamline combined into single list
|
||||
"""
|
||||
streamline_x = sum(self.st_x, [])
|
||||
streamline_y = sum(self.st_y, [])
|
||||
return streamline_x, streamline_y
|
280
lib/python3.11/site-packages/plotly/figure_factory/_table.py
Normal file
280
lib/python3.11/site-packages/plotly/figure_factory/_table.py
Normal file
@ -0,0 +1,280 @@
|
||||
from plotly import exceptions, optional_imports
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
pd = optional_imports.get_module("pandas")
|
||||
|
||||
|
||||
def validate_table(table_text, font_colors):
|
||||
"""
|
||||
Table-specific validations
|
||||
|
||||
Check that font_colors is supplied correctly (1, 3, or len(text)
|
||||
colors).
|
||||
|
||||
:raises: (PlotlyError) If font_colors is supplied incorretly.
|
||||
|
||||
See FigureFactory.create_table() for params
|
||||
"""
|
||||
font_colors_len_options = [1, 3, len(table_text)]
|
||||
if len(font_colors) not in font_colors_len_options:
|
||||
raise exceptions.PlotlyError(
|
||||
"Oops, font_colors should be a list of length 1, 3 or len(text)"
|
||||
)
|
||||
|
||||
|
||||
def create_table(
|
||||
table_text,
|
||||
colorscale=None,
|
||||
font_colors=None,
|
||||
index=False,
|
||||
index_title="",
|
||||
annotation_offset=0.45,
|
||||
height_constant=30,
|
||||
hoverinfo="none",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Function that creates data tables.
|
||||
|
||||
See also the plotly.graph_objects trace
|
||||
:class:`plotly.graph_objects.Table`
|
||||
|
||||
:param (pandas.Dataframe | list[list]) text: data for table.
|
||||
:param (str|list[list]) colorscale: Colorscale for table where the
|
||||
color at value 0 is the header color, .5 is the first table color
|
||||
and 1 is the second table color. (Set .5 and 1 to avoid the striped
|
||||
table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
|
||||
[1, '#ffffff']]
|
||||
:param (list) font_colors: Color for fonts in table. Can be a single
|
||||
color, three colors, or a color for each row in the table.
|
||||
Default=['#000000'] (black text for the entire table)
|
||||
:param (int) height_constant: Constant multiplied by # of rows to
|
||||
create table height. Default=30.
|
||||
:param (bool) index: Create (header-colored) index column index from
|
||||
Pandas dataframe or list[0] for each list in text. Default=False.
|
||||
:param (string) index_title: Title for index column. Default=''.
|
||||
:param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
|
||||
These kwargs describe other attributes about the annotated Heatmap
|
||||
trace such as the colorscale. For more information on valid kwargs
|
||||
call help(plotly.graph_objs.Heatmap)
|
||||
|
||||
Example 1: Simple Plotly Table
|
||||
|
||||
>>> from plotly.figure_factory import create_table
|
||||
|
||||
>>> text = [['Country', 'Year', 'Population'],
|
||||
... ['US', 2000, 282200000],
|
||||
... ['Canada', 2000, 27790000],
|
||||
... ['US', 2010, 309000000],
|
||||
... ['Canada', 2010, 34000000]]
|
||||
|
||||
>>> table = create_table(text)
|
||||
>>> table.show()
|
||||
|
||||
Example 2: Table with Custom Coloring
|
||||
|
||||
>>> from plotly.figure_factory import create_table
|
||||
>>> text = [['Country', 'Year', 'Population'],
|
||||
... ['US', 2000, 282200000],
|
||||
... ['Canada', 2000, 27790000],
|
||||
... ['US', 2010, 309000000],
|
||||
... ['Canada', 2010, 34000000]]
|
||||
>>> table = create_table(text,
|
||||
... colorscale=[[0, '#000000'],
|
||||
... [.5, '#80beff'],
|
||||
... [1, '#cce5ff']],
|
||||
... font_colors=['#ffffff', '#000000',
|
||||
... '#000000'])
|
||||
>>> table.show()
|
||||
|
||||
Example 3: Simple Plotly Table with Pandas
|
||||
|
||||
>>> from plotly.figure_factory import create_table
|
||||
>>> import pandas as pd
|
||||
>>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
|
||||
>>> df_p = df[0:25]
|
||||
>>> table_simple = create_table(df_p)
|
||||
>>> table_simple.show()
|
||||
|
||||
"""
|
||||
|
||||
# Avoiding mutables in the call signature
|
||||
colorscale = (
|
||||
colorscale
|
||||
if colorscale is not None
|
||||
else [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]]
|
||||
)
|
||||
font_colors = (
|
||||
font_colors if font_colors is not None else ["#ffffff", "#000000", "#000000"]
|
||||
)
|
||||
|
||||
validate_table(table_text, font_colors)
|
||||
table_matrix = _Table(
|
||||
table_text,
|
||||
colorscale,
|
||||
font_colors,
|
||||
index,
|
||||
index_title,
|
||||
annotation_offset,
|
||||
**kwargs,
|
||||
).get_table_matrix()
|
||||
annotations = _Table(
|
||||
table_text,
|
||||
colorscale,
|
||||
font_colors,
|
||||
index,
|
||||
index_title,
|
||||
annotation_offset,
|
||||
**kwargs,
|
||||
).make_table_annotations()
|
||||
|
||||
trace = dict(
|
||||
type="heatmap",
|
||||
z=table_matrix,
|
||||
opacity=0.75,
|
||||
colorscale=colorscale,
|
||||
showscale=False,
|
||||
hoverinfo=hoverinfo,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
data = [trace]
|
||||
layout = dict(
|
||||
annotations=annotations,
|
||||
height=len(table_matrix) * height_constant + 50,
|
||||
margin=dict(t=0, b=0, r=0, l=0),
|
||||
yaxis=dict(
|
||||
autorange="reversed",
|
||||
zeroline=False,
|
||||
gridwidth=2,
|
||||
ticks="",
|
||||
dtick=1,
|
||||
tick0=0.5,
|
||||
showticklabels=False,
|
||||
),
|
||||
xaxis=dict(
|
||||
zeroline=False,
|
||||
gridwidth=2,
|
||||
ticks="",
|
||||
dtick=1,
|
||||
tick0=-0.5,
|
||||
showticklabels=False,
|
||||
),
|
||||
)
|
||||
return graph_objs.Figure(data=data, layout=layout)
|
||||
|
||||
|
||||
class _Table(object):
|
||||
"""
|
||||
Refer to TraceFactory.create_table() for docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_text,
|
||||
colorscale,
|
||||
font_colors,
|
||||
index,
|
||||
index_title,
|
||||
annotation_offset,
|
||||
**kwargs,
|
||||
):
|
||||
if pd and isinstance(table_text, pd.DataFrame):
|
||||
headers = table_text.columns.tolist()
|
||||
table_text_index = table_text.index.tolist()
|
||||
table_text = table_text.values.tolist()
|
||||
table_text.insert(0, headers)
|
||||
if index:
|
||||
table_text_index.insert(0, index_title)
|
||||
for i in range(len(table_text)):
|
||||
table_text[i].insert(0, table_text_index[i])
|
||||
self.table_text = table_text
|
||||
self.colorscale = colorscale
|
||||
self.font_colors = font_colors
|
||||
self.index = index
|
||||
self.annotation_offset = annotation_offset
|
||||
self.x = range(len(table_text[0]))
|
||||
self.y = range(len(table_text))
|
||||
|
||||
def get_table_matrix(self):
|
||||
"""
|
||||
Create z matrix to make heatmap with striped table coloring
|
||||
|
||||
:rtype (list[list]) table_matrix: z matrix to make heatmap with striped
|
||||
table coloring.
|
||||
"""
|
||||
header = [0] * len(self.table_text[0])
|
||||
odd_row = [0.5] * len(self.table_text[0])
|
||||
even_row = [1] * len(self.table_text[0])
|
||||
table_matrix = [None] * len(self.table_text)
|
||||
table_matrix[0] = header
|
||||
for i in range(1, len(self.table_text), 2):
|
||||
table_matrix[i] = odd_row
|
||||
for i in range(2, len(self.table_text), 2):
|
||||
table_matrix[i] = even_row
|
||||
if self.index:
|
||||
for array in table_matrix:
|
||||
array[0] = 0
|
||||
return table_matrix
|
||||
|
||||
def get_table_font_color(self):
|
||||
"""
|
||||
Fill font-color array.
|
||||
|
||||
Table text color can vary by row so this extends a single color or
|
||||
creates an array to set a header color and two alternating colors to
|
||||
create the striped table pattern.
|
||||
|
||||
:rtype (list[list]) all_font_colors: list of font colors for each row
|
||||
in table.
|
||||
"""
|
||||
if len(self.font_colors) == 1:
|
||||
all_font_colors = self.font_colors * len(self.table_text)
|
||||
elif len(self.font_colors) == 3:
|
||||
all_font_colors = list(range(len(self.table_text)))
|
||||
all_font_colors[0] = self.font_colors[0]
|
||||
for i in range(1, len(self.table_text), 2):
|
||||
all_font_colors[i] = self.font_colors[1]
|
||||
for i in range(2, len(self.table_text), 2):
|
||||
all_font_colors[i] = self.font_colors[2]
|
||||
elif len(self.font_colors) == len(self.table_text):
|
||||
all_font_colors = self.font_colors
|
||||
else:
|
||||
all_font_colors = ["#000000"] * len(self.table_text)
|
||||
return all_font_colors
|
||||
|
||||
def make_table_annotations(self):
|
||||
"""
|
||||
Generate annotations to fill in table text
|
||||
|
||||
:rtype (list) annotations: list of annotations for each cell of the
|
||||
table.
|
||||
"""
|
||||
all_font_colors = _Table.get_table_font_color(self)
|
||||
annotations = []
|
||||
for n, row in enumerate(self.table_text):
|
||||
for m, val in enumerate(row):
|
||||
# Bold text in header and index
|
||||
format_text = (
|
||||
"<b>" + str(val) + "</b>"
|
||||
if n == 0 or self.index and m < 1
|
||||
else str(val)
|
||||
)
|
||||
# Match font color of index to font color of header
|
||||
font_color = (
|
||||
self.font_colors[0] if self.index and m == 0 else all_font_colors[n]
|
||||
)
|
||||
annotations.append(
|
||||
graph_objs.layout.Annotation(
|
||||
text=format_text,
|
||||
x=self.x[m] - self.annotation_offset,
|
||||
y=self.y[n],
|
||||
xref="x1",
|
||||
yref="y1",
|
||||
align="left",
|
||||
xanchor="left",
|
||||
font=dict(color=font_color),
|
||||
showarrow=False,
|
||||
)
|
||||
)
|
||||
return annotations
|
@ -0,0 +1,692 @@
|
||||
import plotly.colors as clrs
|
||||
from plotly.graph_objs import graph_objs as go
|
||||
from plotly import exceptions
|
||||
from plotly import optional_imports
|
||||
|
||||
from skimage import measure
|
||||
|
||||
np = optional_imports.get_module("numpy")
|
||||
scipy_interp = optional_imports.get_module("scipy.interpolate")
|
||||
|
||||
# -------------------------- Layout ------------------------------
|
||||
|
||||
|
||||
def _ternary_layout(
|
||||
title="Ternary contour plot", width=550, height=525, pole_labels=["a", "b", "c"]
|
||||
):
|
||||
"""
|
||||
Layout of ternary contour plot, to be passed to ``go.FigureWidget``
|
||||
object.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
title : str or None
|
||||
Title of ternary plot
|
||||
width : int
|
||||
Figure width.
|
||||
height : int
|
||||
Figure height.
|
||||
pole_labels : str, default ['a', 'b', 'c']
|
||||
Names of the three poles of the triangle.
|
||||
"""
|
||||
return dict(
|
||||
title=title,
|
||||
width=width,
|
||||
height=height,
|
||||
ternary=dict(
|
||||
sum=1,
|
||||
aaxis=dict(
|
||||
title=dict(text=pole_labels[0]), min=0.01, linewidth=2, ticks="outside"
|
||||
),
|
||||
baxis=dict(
|
||||
title=dict(text=pole_labels[1]), min=0.01, linewidth=2, ticks="outside"
|
||||
),
|
||||
caxis=dict(
|
||||
title=dict(text=pole_labels[2]), min=0.01, linewidth=2, ticks="outside"
|
||||
),
|
||||
),
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
|
||||
# ------------- Transformations of coordinates -------------------
|
||||
|
||||
|
||||
def _replace_zero_coords(ternary_data, delta=0.0005):
|
||||
"""
|
||||
Replaces zero ternary coordinates with delta and normalize the new
|
||||
triplets (a, b, c).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
ternary_data : ndarray of shape (N, 3)
|
||||
|
||||
delta : float
|
||||
Small float to regularize logarithm.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Implements a method
|
||||
by J. A. Martin-Fernandez, C. Barcelo-Vidal, V. Pawlowsky-Glahn,
|
||||
Dealing with zeros and missing values in compositional data sets
|
||||
using nonparametric imputation, Mathematical Geology 35 (2003),
|
||||
pp 253-278.
|
||||
"""
|
||||
zero_mask = ternary_data == 0
|
||||
is_any_coord_zero = np.any(zero_mask, axis=0)
|
||||
|
||||
unity_complement = 1 - delta * is_any_coord_zero
|
||||
if np.any(unity_complement) < 0:
|
||||
raise ValueError(
|
||||
"The provided value of delta led to negative"
|
||||
"ternary coords.Set a smaller delta"
|
||||
)
|
||||
ternary_data = np.where(zero_mask, delta, unity_complement * ternary_data)
|
||||
return ternary_data
|
||||
|
||||
|
||||
def _ilr_transform(barycentric):
|
||||
"""
|
||||
Perform Isometric Log-Ratio on barycentric (compositional) data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
barycentric: ndarray of shape (3, N)
|
||||
Barycentric coordinates.
|
||||
|
||||
References
|
||||
----------
|
||||
"An algebraic method to compute isometric logratio transformation and
|
||||
back transformation of compositional data", Jarauta-Bragulat, E.,
|
||||
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
|
||||
Intl Assoc for Math Geology, 2003, pp 31-30.
|
||||
"""
|
||||
barycentric = np.asarray(barycentric)
|
||||
x_0 = np.log(barycentric[0] / barycentric[1]) / np.sqrt(2)
|
||||
x_1 = (
|
||||
1.0 / np.sqrt(6) * np.log(barycentric[0] * barycentric[1] / barycentric[2] ** 2)
|
||||
)
|
||||
ilr_tdata = np.stack((x_0, x_1))
|
||||
return ilr_tdata
|
||||
|
||||
|
||||
def _ilr_inverse(x):
|
||||
"""
|
||||
Perform inverse Isometric Log-Ratio (ILR) transform to retrieve
|
||||
barycentric (compositional) data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array of shape (2, N)
|
||||
Coordinates in ILR space.
|
||||
|
||||
References
|
||||
----------
|
||||
"An algebraic method to compute isometric logratio transformation and
|
||||
back transformation of compositional data", Jarauta-Bragulat, E.,
|
||||
Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
|
||||
Intl Assoc for Math Geology, 2003, pp 31-30.
|
||||
"""
|
||||
x = np.array(x)
|
||||
matrix = np.array([[0.5, 1, 1.0], [-0.5, 1, 1.0], [0.0, 0.0, 1.0]])
|
||||
s = np.sqrt(2) / 2
|
||||
t = np.sqrt(3 / 2)
|
||||
Sk = np.einsum("ik, kj -> ij", np.array([[s, t], [-s, t]]), x)
|
||||
Z = -np.log(1 + np.exp(Sk).sum(axis=0))
|
||||
log_barycentric = np.einsum(
|
||||
"ik, kj -> ij", matrix, np.stack((2 * s * x[0], t * x[1], Z))
|
||||
)
|
||||
iilr_tdata = np.exp(log_barycentric)
|
||||
return iilr_tdata
|
||||
|
||||
|
||||
def _transform_barycentric_cartesian():
|
||||
"""
|
||||
Returns the transformation matrix from barycentric to Cartesian
|
||||
coordinates and conversely.
|
||||
"""
|
||||
# reference triangle
|
||||
tri_verts = np.array([[0.5, np.sqrt(3) / 2], [0, 0], [1, 0]])
|
||||
M = np.array([tri_verts[:, 0], tri_verts[:, 1], np.ones(3)])
|
||||
return M, np.linalg.inv(M)
|
||||
|
||||
|
||||
def _prepare_barycentric_coord(b_coords):
|
||||
"""
|
||||
Check ternary coordinates and return the right barycentric coordinates.
|
||||
"""
|
||||
if not isinstance(b_coords, (list, np.ndarray)):
|
||||
raise ValueError(
|
||||
"Data should be either an array of shape (n,m),"
|
||||
"or a list of n m-lists, m=2 or 3"
|
||||
)
|
||||
b_coords = np.asarray(b_coords)
|
||||
if b_coords.shape[0] not in (2, 3):
|
||||
raise ValueError(
|
||||
"A point should have 2 (a, b) or 3 (a, b, c)barycentric coordinates"
|
||||
)
|
||||
if (
|
||||
(len(b_coords) == 3)
|
||||
and not np.allclose(b_coords.sum(axis=0), 1, rtol=0.01)
|
||||
and not np.allclose(b_coords.sum(axis=0), 100, rtol=0.01)
|
||||
):
|
||||
msg = "The sum of coordinates should be 1 or 100 for all data points"
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(b_coords) == 2:
|
||||
A, B = b_coords
|
||||
C = 1 - (A + B)
|
||||
else:
|
||||
A, B, C = b_coords / b_coords.sum(axis=0)
|
||||
if np.any(np.stack((A, B, C)) < 0):
|
||||
raise ValueError("Barycentric coordinates should be positive.")
|
||||
return np.stack((A, B, C))
|
||||
|
||||
|
||||
def _compute_grid(coordinates, values, interp_mode="ilr"):
|
||||
"""
|
||||
Transform data points with Cartesian or ILR mapping, then Compute
|
||||
interpolation on a regular grid.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
coordinates : array-like
|
||||
Barycentric coordinates of data points.
|
||||
values : 1-d array-like
|
||||
Data points, field to be represented as contours.
|
||||
interp_mode : 'ilr' (default) or 'cartesian'
|
||||
Defines how data are interpolated to compute contours.
|
||||
"""
|
||||
if interp_mode == "cartesian":
|
||||
M, invM = _transform_barycentric_cartesian()
|
||||
coord_points = np.einsum("ik, kj -> ij", M, coordinates)
|
||||
elif interp_mode == "ilr":
|
||||
coordinates = _replace_zero_coords(coordinates)
|
||||
coord_points = _ilr_transform(coordinates)
|
||||
else:
|
||||
raise ValueError("interp_mode should be cartesian or ilr")
|
||||
xx, yy = coord_points[:2]
|
||||
x_min, x_max = xx.min(), xx.max()
|
||||
y_min, y_max = yy.min(), yy.max()
|
||||
n_interp = max(200, int(np.sqrt(len(values))))
|
||||
gr_x = np.linspace(x_min, x_max, n_interp)
|
||||
gr_y = np.linspace(y_min, y_max, n_interp)
|
||||
grid_x, grid_y = np.meshgrid(gr_x, gr_y)
|
||||
# We use cubic interpolation, except outside of the convex hull
|
||||
# of data points where we use nearest neighbor values.
|
||||
grid_z = scipy_interp.griddata(
|
||||
coord_points[:2].T, values, (grid_x, grid_y), method="cubic"
|
||||
)
|
||||
return grid_z, gr_x, gr_y
|
||||
|
||||
|
||||
# ----------------------- Contour traces ----------------------
|
||||
|
||||
|
||||
def _polygon_area(x, y):
|
||||
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
||||
|
||||
|
||||
def _colors(ncontours, colormap=None):
|
||||
"""
|
||||
Return a list of ``ncontours`` colors from the ``colormap`` colorscale.
|
||||
"""
|
||||
if colormap in clrs.PLOTLY_SCALES.keys():
|
||||
cmap = clrs.PLOTLY_SCALES[colormap]
|
||||
else:
|
||||
raise exceptions.PlotlyError(
|
||||
"Colorscale must be a valid Plotly Colorscale."
|
||||
"The available colorscale names are {}".format(clrs.PLOTLY_SCALES.keys())
|
||||
)
|
||||
values = np.linspace(0, 1, ncontours)
|
||||
vals_cmap = np.array([pair[0] for pair in cmap])
|
||||
cols = np.array([pair[1] for pair in cmap])
|
||||
inds = np.searchsorted(vals_cmap, values)
|
||||
if "#" in cols[0]: # for Viridis
|
||||
cols = [clrs.label_rgb(clrs.hex_to_rgb(col)) for col in cols]
|
||||
|
||||
colors = [cols[0]]
|
||||
for ind, val in zip(inds[1:], values[1:]):
|
||||
val1, val2 = vals_cmap[ind - 1], vals_cmap[ind]
|
||||
interm = (val - val1) / (val2 - val1)
|
||||
col = clrs.find_intermediate_color(
|
||||
cols[ind - 1], cols[ind], interm, colortype="rgb"
|
||||
)
|
||||
colors.append(col)
|
||||
return colors
|
||||
|
||||
|
||||
def _is_invalid_contour(x, y):
|
||||
"""
|
||||
Utility function for _contour_trace
|
||||
|
||||
Contours with an area of the order as 1 pixel are considered spurious.
|
||||
"""
|
||||
too_small = np.all(np.abs(x - x[0]) < 2) and np.all(np.abs(y - y[0]) < 2)
|
||||
return too_small
|
||||
|
||||
|
||||
def _extract_contours(im, values, colors):
|
||||
"""
|
||||
Utility function for _contour_trace.
|
||||
|
||||
In ``im`` only one part of the domain has valid values (corresponding
|
||||
to a subdomain where barycentric coordinates are well defined). When
|
||||
computing contours, we need to assign values outside of this domain.
|
||||
We can choose a value either smaller than all the values inside the
|
||||
valid domain, or larger. This value must be chose with caution so that
|
||||
no spurious contours are added. For example, if the boundary of the valid
|
||||
domain has large values and the outer value is set to a small one, all
|
||||
intermediate contours will be added at the boundary.
|
||||
|
||||
Therefore, we compute the two sets of contours (with an outer value
|
||||
smaller of larger than all values in the valid domain), and choose
|
||||
the value resulting in a smaller total number of contours. There might
|
||||
be a faster way to do this, but it works...
|
||||
"""
|
||||
mask_nan = np.isnan(im)
|
||||
im_min, im_max = (
|
||||
im[np.logical_not(mask_nan)].min(),
|
||||
im[np.logical_not(mask_nan)].max(),
|
||||
)
|
||||
zz_min = np.copy(im)
|
||||
zz_min[mask_nan] = 2 * im_min
|
||||
zz_max = np.copy(im)
|
||||
zz_max[mask_nan] = 2 * im_max
|
||||
all_contours1, all_values1, all_areas1, all_colors1 = [], [], [], []
|
||||
all_contours2, all_values2, all_areas2, all_colors2 = [], [], [], []
|
||||
for i, val in enumerate(values):
|
||||
contour_level1 = measure.find_contours(zz_min, val)
|
||||
contour_level2 = measure.find_contours(zz_max, val)
|
||||
all_contours1.extend(contour_level1)
|
||||
all_contours2.extend(contour_level2)
|
||||
all_values1.extend([val] * len(contour_level1))
|
||||
all_values2.extend([val] * len(contour_level2))
|
||||
all_areas1.extend(
|
||||
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level1]
|
||||
)
|
||||
all_areas2.extend(
|
||||
[_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level2]
|
||||
)
|
||||
all_colors1.extend([colors[i]] * len(contour_level1))
|
||||
all_colors2.extend([colors[i]] * len(contour_level2))
|
||||
if len(all_contours1) <= len(all_contours2):
|
||||
return all_contours1, all_values1, all_areas1, all_colors1
|
||||
else:
|
||||
return all_contours2, all_values2, all_areas2, all_colors2
|
||||
|
||||
|
||||
def _add_outer_contour(
|
||||
all_contours,
|
||||
all_values,
|
||||
all_areas,
|
||||
all_colors,
|
||||
values,
|
||||
val_outer,
|
||||
v_min,
|
||||
v_max,
|
||||
colors,
|
||||
color_min,
|
||||
color_max,
|
||||
):
|
||||
"""
|
||||
Utility function for _contour_trace
|
||||
|
||||
Adds the background color to fill gaps outside of computed contours.
|
||||
|
||||
To compute the background color, the color of the contour with largest
|
||||
area (``val_outer``) is used. As background color, we choose the next
|
||||
color value in the direction of the extrema of the colormap.
|
||||
|
||||
Then we add information for the outer contour for the different lists
|
||||
provided as arguments.
|
||||
|
||||
A discrete colormap with all used colors is also returned (to be used
|
||||
by colorscale trace).
|
||||
"""
|
||||
# The exact value of outer contour is not used when defining the trace
|
||||
outer_contour = 20 * np.array([[0, 0, 1], [0, 1, 0.5]]).T
|
||||
all_contours = [outer_contour] + all_contours
|
||||
delta_values = np.diff(values)[0]
|
||||
values = np.concatenate(
|
||||
([values[0] - delta_values], values, [values[-1] + delta_values])
|
||||
)
|
||||
colors = np.concatenate(([color_min], colors, [color_max]))
|
||||
index = np.nonzero(values == val_outer)[0][0]
|
||||
if index < len(values) / 2:
|
||||
index -= 1
|
||||
else:
|
||||
index += 1
|
||||
all_colors = [colors[index]] + all_colors
|
||||
all_values = [values[index]] + all_values
|
||||
all_areas = [0] + all_areas
|
||||
used_colors = [color for color in colors if color in all_colors]
|
||||
# Define discrete colorscale
|
||||
color_number = len(used_colors)
|
||||
scale = np.linspace(0, 1, color_number + 1)
|
||||
discrete_cm = []
|
||||
for i, color in enumerate(used_colors):
|
||||
discrete_cm.append([scale[i], used_colors[i]])
|
||||
discrete_cm.append([scale[i + 1], used_colors[i]])
|
||||
discrete_cm.append([scale[color_number], used_colors[color_number - 1]])
|
||||
|
||||
return all_contours, all_values, all_areas, all_colors, discrete_cm
|
||||
|
||||
|
||||
def _contour_trace(
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
ncontours=None,
|
||||
colorscale="Electric",
|
||||
linecolor="rgb(150,150,150)",
|
||||
interp_mode="llr",
|
||||
coloring=None,
|
||||
v_min=0,
|
||||
v_max=1,
|
||||
):
|
||||
"""
|
||||
Contour trace in Cartesian coordinates.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
x, y : array-like
|
||||
Cartesian coordinates
|
||||
z : array-like
|
||||
Field to be represented as contours.
|
||||
ncontours : int or None
|
||||
Number of contours to display (determined automatically if None).
|
||||
colorscale : None or str (Plotly colormap)
|
||||
colorscale of the contours.
|
||||
linecolor : rgb color
|
||||
Color used for lines. If ``colorscale`` is not None, line colors are
|
||||
determined from ``colorscale`` instead.
|
||||
interp_mode : 'ilr' (default) or 'cartesian'
|
||||
Defines how data are interpolated to compute contours. If 'irl',
|
||||
ILR (Isometric Log-Ratio) of compositional data is performed. If
|
||||
'cartesian', contours are determined in Cartesian space.
|
||||
coloring : None or 'lines'
|
||||
How to display contour. Filled contours if None, lines if ``lines``.
|
||||
vmin, vmax : float
|
||||
Bounds of interval of values used for the colorspace
|
||||
|
||||
Notes
|
||||
=====
|
||||
"""
|
||||
# Prepare colors
|
||||
# We do not take extrema, for example for one single contour
|
||||
# the color will be the middle point of the colormap
|
||||
colors = _colors(ncontours + 2, colorscale)
|
||||
# Values used for contours, extrema are not used
|
||||
# For example for a binary array [0, 1], the value of
|
||||
# the contour for ncontours=1 is 0.5.
|
||||
values = np.linspace(v_min, v_max, ncontours + 2)
|
||||
color_min, color_max = colors[0], colors[-1]
|
||||
colors = colors[1:-1]
|
||||
values = values[1:-1]
|
||||
|
||||
# Color of line contours
|
||||
if linecolor is None:
|
||||
linecolor = "rgb(150, 150, 150)"
|
||||
else:
|
||||
colors = [linecolor] * ncontours
|
||||
|
||||
# Retrieve all contours
|
||||
all_contours, all_values, all_areas, all_colors = _extract_contours(
|
||||
z, values, colors
|
||||
)
|
||||
|
||||
# Now sort contours by decreasing area
|
||||
order = np.argsort(all_areas)[::-1]
|
||||
|
||||
# Add outer contour
|
||||
all_contours, all_values, all_areas, all_colors, discrete_cm = _add_outer_contour(
|
||||
all_contours,
|
||||
all_values,
|
||||
all_areas,
|
||||
all_colors,
|
||||
values,
|
||||
all_values[order[0]],
|
||||
v_min,
|
||||
v_max,
|
||||
colors,
|
||||
color_min,
|
||||
color_max,
|
||||
)
|
||||
order = np.concatenate(([0], order + 1))
|
||||
|
||||
# Compute traces, in the order of decreasing area
|
||||
traces = []
|
||||
M, invM = _transform_barycentric_cartesian()
|
||||
dx = (x.max() - x.min()) / x.size
|
||||
dy = (y.max() - y.min()) / y.size
|
||||
for index in order:
|
||||
y_contour, x_contour = all_contours[index].T
|
||||
val = all_values[index]
|
||||
if interp_mode == "cartesian":
|
||||
bar_coords = np.dot(
|
||||
invM,
|
||||
np.stack((dx * x_contour, dy * y_contour, np.ones(x_contour.shape))),
|
||||
)
|
||||
elif interp_mode == "ilr":
|
||||
bar_coords = _ilr_inverse(
|
||||
np.stack((dx * x_contour + x.min(), dy * y_contour + y.min()))
|
||||
)
|
||||
if index == 0: # outer triangle
|
||||
a = np.array([1, 0, 0])
|
||||
b = np.array([0, 1, 0])
|
||||
c = np.array([0, 0, 1])
|
||||
else:
|
||||
a, b, c = bar_coords
|
||||
if _is_invalid_contour(x_contour, y_contour):
|
||||
continue
|
||||
|
||||
_col = all_colors[index] if coloring == "lines" else linecolor
|
||||
trace = dict(
|
||||
type="scatterternary",
|
||||
a=a,
|
||||
b=b,
|
||||
c=c,
|
||||
mode="lines",
|
||||
line=dict(color=_col, shape="spline", width=1),
|
||||
fill="toself",
|
||||
fillcolor=all_colors[index],
|
||||
showlegend=True,
|
||||
hoverinfo="skip",
|
||||
name="%.3f" % val,
|
||||
)
|
||||
if coloring == "lines":
|
||||
trace["fill"] = None
|
||||
traces.append(trace)
|
||||
|
||||
return traces, discrete_cm
|
||||
|
||||
|
||||
# -------------------- Figure Factory for ternary contour -------------
|
||||
|
||||
|
||||
def create_ternary_contour(
|
||||
coordinates,
|
||||
values,
|
||||
pole_labels=["a", "b", "c"],
|
||||
width=500,
|
||||
height=500,
|
||||
ncontours=None,
|
||||
showscale=False,
|
||||
coloring=None,
|
||||
colorscale="Bluered",
|
||||
linecolor=None,
|
||||
title=None,
|
||||
interp_mode="ilr",
|
||||
showmarkers=False,
|
||||
):
|
||||
"""
|
||||
Ternary contour plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
coordinates : list or ndarray
|
||||
Barycentric coordinates of shape (2, N) or (3, N) where N is the
|
||||
number of data points. The sum of the 3 coordinates is expected
|
||||
to be 1 for all data points.
|
||||
values : array-like
|
||||
Data points of field to be represented as contours.
|
||||
pole_labels : str, default ['a', 'b', 'c']
|
||||
Names of the three poles of the triangle.
|
||||
width : int
|
||||
Figure width.
|
||||
height : int
|
||||
Figure height.
|
||||
ncontours : int or None
|
||||
Number of contours to display (determined automatically if None).
|
||||
showscale : bool, default False
|
||||
If True, a colorbar showing the color scale is displayed.
|
||||
coloring : None or 'lines'
|
||||
How to display contour. Filled contours if None, lines if ``lines``.
|
||||
colorscale : None or str (Plotly colormap)
|
||||
colorscale of the contours.
|
||||
linecolor : None or rgb color
|
||||
Color used for lines. ``colorscale`` has to be set to None, otherwise
|
||||
line colors are determined from ``colorscale``.
|
||||
title : str or None
|
||||
Title of ternary plot
|
||||
interp_mode : 'ilr' (default) or 'cartesian'
|
||||
Defines how data are interpolated to compute contours. If 'irl',
|
||||
ILR (Isometric Log-Ratio) of compositional data is performed. If
|
||||
'cartesian', contours are determined in Cartesian space.
|
||||
showmarkers : bool, default False
|
||||
If True, markers corresponding to input compositional points are
|
||||
superimposed on contours, using the same colorscale.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
Example 1: ternary contour plot with filled contours
|
||||
|
||||
>>> import plotly.figure_factory as ff
|
||||
>>> import numpy as np
|
||||
>>> # Define coordinates
|
||||
>>> a, b = np.mgrid[0:1:20j, 0:1:20j]
|
||||
>>> mask = a + b <= 1
|
||||
>>> a = a[mask].ravel()
|
||||
>>> b = b[mask].ravel()
|
||||
>>> c = 1 - a - b
|
||||
>>> # Values to be displayed as contours
|
||||
>>> z = a * b * c
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z)
|
||||
>>> fig.show()
|
||||
|
||||
It is also possible to give only two barycentric coordinates for each
|
||||
point, since the sum of the three coordinates is one:
|
||||
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b)), z)
|
||||
|
||||
|
||||
Example 2: ternary contour plot with line contours
|
||||
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines')
|
||||
|
||||
Example 3: customize number of contours
|
||||
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, ncontours=8)
|
||||
|
||||
Example 4: superimpose contour plot and original data as markers
|
||||
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines',
|
||||
... showmarkers=True)
|
||||
|
||||
Example 5: customize title and pole labels
|
||||
|
||||
>>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z,
|
||||
... title='Ternary plot',
|
||||
... pole_labels=['clay', 'quartz', 'fledspar'])
|
||||
"""
|
||||
if scipy_interp is None:
|
||||
raise ImportError(
|
||||
"""\
|
||||
The create_ternary_contour figure factory requires the scipy package"""
|
||||
)
|
||||
sk_measure = optional_imports.get_module("skimage")
|
||||
if sk_measure is None:
|
||||
raise ImportError(
|
||||
"""\
|
||||
The create_ternary_contour figure factory requires the scikit-image
|
||||
package"""
|
||||
)
|
||||
if colorscale is None:
|
||||
showscale = False
|
||||
if ncontours is None:
|
||||
ncontours = 5
|
||||
coordinates = _prepare_barycentric_coord(coordinates)
|
||||
v_min, v_max = values.min(), values.max()
|
||||
grid_z, gr_x, gr_y = _compute_grid(coordinates, values, interp_mode=interp_mode)
|
||||
|
||||
layout = _ternary_layout(
|
||||
pole_labels=pole_labels, width=width, height=height, title=title
|
||||
)
|
||||
|
||||
contour_trace, discrete_cm = _contour_trace(
|
||||
gr_x,
|
||||
gr_y,
|
||||
grid_z,
|
||||
ncontours=ncontours,
|
||||
colorscale=colorscale,
|
||||
linecolor=linecolor,
|
||||
interp_mode=interp_mode,
|
||||
coloring=coloring,
|
||||
v_min=v_min,
|
||||
v_max=v_max,
|
||||
)
|
||||
|
||||
fig = go.Figure(data=contour_trace, layout=layout)
|
||||
|
||||
opacity = 1 if showmarkers else 0
|
||||
a, b, c = coordinates
|
||||
hovertemplate = (
|
||||
pole_labels[0]
|
||||
+ ": %{a:.3f}<br>"
|
||||
+ pole_labels[1]
|
||||
+ ": %{b:.3f}<br>"
|
||||
+ pole_labels[2]
|
||||
+ ": %{c:.3f}<br>"
|
||||
"z: %{marker.color:.3f}<extra></extra>"
|
||||
)
|
||||
|
||||
fig.add_scatterternary(
|
||||
a=a,
|
||||
b=b,
|
||||
c=c,
|
||||
mode="markers",
|
||||
marker={
|
||||
"color": values,
|
||||
"colorscale": colorscale,
|
||||
"line": {"color": "rgb(120, 120, 120)", "width": int(coloring != "lines")},
|
||||
},
|
||||
opacity=opacity,
|
||||
hovertemplate=hovertemplate,
|
||||
)
|
||||
if showscale:
|
||||
if not showmarkers:
|
||||
colorscale = discrete_cm
|
||||
colorbar = dict(
|
||||
{
|
||||
"type": "scatterternary",
|
||||
"a": [None],
|
||||
"b": [None],
|
||||
"c": [None],
|
||||
"marker": {
|
||||
"cmin": values.min(),
|
||||
"cmax": values.max(),
|
||||
"colorscale": colorscale,
|
||||
"showscale": True,
|
||||
},
|
||||
"mode": "markers",
|
||||
}
|
||||
)
|
||||
fig.add_trace(colorbar)
|
||||
|
||||
return fig
|
509
lib/python3.11/site-packages/plotly/figure_factory/_trisurf.py
Normal file
509
lib/python3.11/site-packages/plotly/figure_factory/_trisurf.py
Normal file
@ -0,0 +1,509 @@
|
||||
from plotly import exceptions, optional_imports
|
||||
import plotly.colors as clrs
|
||||
from plotly.graph_objs import graph_objs
|
||||
|
||||
np = optional_imports.get_module("numpy")
|
||||
|
||||
|
||||
def map_face2color(face, colormap, scale, vmin, vmax):
|
||||
"""
|
||||
Normalize facecolor values by vmin/vmax and return rgb-color strings
|
||||
|
||||
This function takes a tuple color along with a colormap and a minimum
|
||||
(vmin) and maximum (vmax) range of possible mean distances for the
|
||||
given parametrized surface. It returns an rgb color based on the mean
|
||||
distance between vmin and vmax
|
||||
|
||||
"""
|
||||
if vmin >= vmax:
|
||||
raise exceptions.PlotlyError(
|
||||
"Incorrect relation between vmin "
|
||||
"and vmax. The vmin value cannot be "
|
||||
"bigger than or equal to the value "
|
||||
"of vmax."
|
||||
)
|
||||
if len(colormap) == 1:
|
||||
# color each triangle face with the same color in colormap
|
||||
face_color = colormap[0]
|
||||
face_color = clrs.convert_to_RGB_255(face_color)
|
||||
face_color = clrs.label_rgb(face_color)
|
||||
return face_color
|
||||
if face == vmax:
|
||||
# pick last color in colormap
|
||||
face_color = colormap[-1]
|
||||
face_color = clrs.convert_to_RGB_255(face_color)
|
||||
face_color = clrs.label_rgb(face_color)
|
||||
return face_color
|
||||
else:
|
||||
if scale is None:
|
||||
# find the normalized distance t of a triangle face between
|
||||
# vmin and vmax where the distance is between 0 and 1
|
||||
t = (face - vmin) / float((vmax - vmin))
|
||||
low_color_index = int(t / (1.0 / (len(colormap) - 1)))
|
||||
|
||||
face_color = clrs.find_intermediate_color(
|
||||
colormap[low_color_index],
|
||||
colormap[low_color_index + 1],
|
||||
t * (len(colormap) - 1) - low_color_index,
|
||||
)
|
||||
|
||||
face_color = clrs.convert_to_RGB_255(face_color)
|
||||
face_color = clrs.label_rgb(face_color)
|
||||
else:
|
||||
# find the face color for a non-linearly interpolated scale
|
||||
t = (face - vmin) / float((vmax - vmin))
|
||||
|
||||
low_color_index = 0
|
||||
for k in range(len(scale) - 1):
|
||||
if scale[k] <= t < scale[k + 1]:
|
||||
break
|
||||
low_color_index += 1
|
||||
|
||||
low_scale_val = scale[low_color_index]
|
||||
high_scale_val = scale[low_color_index + 1]
|
||||
|
||||
face_color = clrs.find_intermediate_color(
|
||||
colormap[low_color_index],
|
||||
colormap[low_color_index + 1],
|
||||
(t - low_scale_val) / (high_scale_val - low_scale_val),
|
||||
)
|
||||
|
||||
face_color = clrs.convert_to_RGB_255(face_color)
|
||||
face_color = clrs.label_rgb(face_color)
|
||||
return face_color
|
||||
|
||||
|
||||
def trisurf(
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
simplices,
|
||||
show_colorbar,
|
||||
edges_color,
|
||||
scale,
|
||||
colormap=None,
|
||||
color_func=None,
|
||||
plot_edges=False,
|
||||
x_edge=None,
|
||||
y_edge=None,
|
||||
z_edge=None,
|
||||
facecolor=None,
|
||||
):
|
||||
"""
|
||||
Refer to FigureFactory.create_trisurf() for docstring
|
||||
"""
|
||||
# numpy import check
|
||||
if not np:
|
||||
raise ImportError("FigureFactory._trisurf() requires numpy imported.")
|
||||
points3D = np.vstack((x, y, z)).T
|
||||
simplices = np.atleast_2d(simplices)
|
||||
|
||||
# vertices of the surface triangles
|
||||
tri_vertices = points3D[simplices]
|
||||
|
||||
# Define colors for the triangle faces
|
||||
if color_func is None:
|
||||
# mean values of z-coordinates of triangle vertices
|
||||
mean_dists = tri_vertices[:, :, 2].mean(-1)
|
||||
elif isinstance(color_func, (list, np.ndarray)):
|
||||
# Pre-computed list / array of values to map onto color
|
||||
if len(color_func) != len(simplices):
|
||||
raise ValueError(
|
||||
"If color_func is a list/array, it must "
|
||||
"be the same length as simplices."
|
||||
)
|
||||
|
||||
# convert all colors in color_func to rgb
|
||||
for index in range(len(color_func)):
|
||||
if isinstance(color_func[index], str):
|
||||
if "#" in color_func[index]:
|
||||
foo = clrs.hex_to_rgb(color_func[index])
|
||||
color_func[index] = clrs.label_rgb(foo)
|
||||
|
||||
if isinstance(color_func[index], tuple):
|
||||
foo = clrs.convert_to_RGB_255(color_func[index])
|
||||
color_func[index] = clrs.label_rgb(foo)
|
||||
|
||||
mean_dists = np.asarray(color_func)
|
||||
else:
|
||||
# apply user inputted function to calculate
|
||||
# custom coloring for triangle vertices
|
||||
mean_dists = []
|
||||
for triangle in tri_vertices:
|
||||
dists = []
|
||||
for vertex in triangle:
|
||||
dist = color_func(vertex[0], vertex[1], vertex[2])
|
||||
dists.append(dist)
|
||||
mean_dists.append(np.mean(dists))
|
||||
mean_dists = np.asarray(mean_dists)
|
||||
|
||||
# Check if facecolors are already strings and can be skipped
|
||||
if isinstance(mean_dists[0], str):
|
||||
facecolor = mean_dists
|
||||
else:
|
||||
min_mean_dists = np.min(mean_dists)
|
||||
max_mean_dists = np.max(mean_dists)
|
||||
|
||||
if facecolor is None:
|
||||
facecolor = []
|
||||
for index in range(len(mean_dists)):
|
||||
color = map_face2color(
|
||||
mean_dists[index], colormap, scale, min_mean_dists, max_mean_dists
|
||||
)
|
||||
facecolor.append(color)
|
||||
|
||||
# Make sure facecolor is a list so output is consistent across Pythons
|
||||
facecolor = np.asarray(facecolor)
|
||||
ii, jj, kk = simplices.T
|
||||
|
||||
triangles = graph_objs.Mesh3d(
|
||||
x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name=""
|
||||
)
|
||||
|
||||
mean_dists_are_numbers = not isinstance(mean_dists[0], str)
|
||||
|
||||
if mean_dists_are_numbers and show_colorbar is True:
|
||||
# make a colorscale from the colors
|
||||
colorscale = clrs.make_colorscale(colormap, scale)
|
||||
colorscale = clrs.convert_colorscale_to_rgb(colorscale)
|
||||
|
||||
colorbar = graph_objs.Scatter3d(
|
||||
x=x[:1],
|
||||
y=y[:1],
|
||||
z=z[:1],
|
||||
mode="markers",
|
||||
marker=dict(
|
||||
size=0.1,
|
||||
color=[min_mean_dists, max_mean_dists],
|
||||
colorscale=colorscale,
|
||||
showscale=True,
|
||||
),
|
||||
hoverinfo="none",
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
# the triangle sides are not plotted
|
||||
if plot_edges is False:
|
||||
if mean_dists_are_numbers and show_colorbar is True:
|
||||
return [triangles, colorbar]
|
||||
else:
|
||||
return [triangles]
|
||||
|
||||
# define the lists x_edge, y_edge and z_edge, of x, y, resp z
|
||||
# coordinates of edge end points for each triangle
|
||||
# None separates data corresponding to two consecutive triangles
|
||||
is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
|
||||
if any(is_none):
|
||||
if not all(is_none):
|
||||
raise ValueError(
|
||||
"If any (x_edge, y_edge, z_edge) is None, all must be None"
|
||||
)
|
||||
else:
|
||||
x_edge = []
|
||||
y_edge = []
|
||||
z_edge = []
|
||||
|
||||
# Pull indices we care about, then add a None column to separate tris
|
||||
ixs_triangles = [0, 1, 2, 0]
|
||||
pull_edges = tri_vertices[:, ixs_triangles, :]
|
||||
x_edge_pull = np.hstack(
|
||||
[pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])]
|
||||
)
|
||||
y_edge_pull = np.hstack(
|
||||
[pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])]
|
||||
)
|
||||
z_edge_pull = np.hstack(
|
||||
[pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])]
|
||||
)
|
||||
|
||||
# Now unravel the edges into a 1-d vector for plotting
|
||||
x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
|
||||
y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
|
||||
z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
|
||||
|
||||
if not (len(x_edge) == len(y_edge) == len(z_edge)):
|
||||
raise exceptions.PlotlyError(
|
||||
"The lengths of x_edge, y_edge and z_edge are not the same."
|
||||
)
|
||||
|
||||
# define the lines for plotting
|
||||
lines = graph_objs.Scatter3d(
|
||||
x=x_edge,
|
||||
y=y_edge,
|
||||
z=z_edge,
|
||||
mode="lines",
|
||||
line=graph_objs.scatter3d.Line(color=edges_color, width=1.5),
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
if mean_dists_are_numbers and show_colorbar is True:
|
||||
return [triangles, lines, colorbar]
|
||||
else:
|
||||
return [triangles, lines]
|
||||
|
||||
|
||||
def create_trisurf(
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
simplices,
|
||||
colormap=None,
|
||||
show_colorbar=True,
|
||||
scale=None,
|
||||
color_func=None,
|
||||
title="Trisurf Plot",
|
||||
plot_edges=True,
|
||||
showbackground=True,
|
||||
backgroundcolor="rgb(230, 230, 230)",
|
||||
gridcolor="rgb(255, 255, 255)",
|
||||
zerolinecolor="rgb(255, 255, 255)",
|
||||
edges_color="rgb(50, 50, 50)",
|
||||
height=800,
|
||||
width=800,
|
||||
aspectratio=None,
|
||||
):
|
||||
"""
|
||||
Returns figure for a triangulated surface plot
|
||||
|
||||
:param (array) x: data values of x in a 1D array
|
||||
:param (array) y: data values of y in a 1D array
|
||||
:param (array) z: data values of z in a 1D array
|
||||
:param (array) simplices: an array of shape (ntri, 3) where ntri is
|
||||
the number of triangles in the triangularization. Each row of the
|
||||
array contains the indicies of the verticies of each triangle
|
||||
:param (str|tuple|list) colormap: either a plotly scale name, an rgb
|
||||
or hex color, a color tuple or a list of colors. An rgb color is
|
||||
of the form 'rgb(x, y, z)' where x, y, z belong to the interval
|
||||
[0, 255] and a color tuple is a tuple of the form (a, b, c) where
|
||||
a, b and c belong to [0, 1]. If colormap is a list, it must
|
||||
contain the valid color types aforementioned as its members
|
||||
:param (bool) show_colorbar: determines if colorbar is visible
|
||||
:param (list|array) scale: sets the scale values to be used if a non-
|
||||
linearly interpolated colormap is desired. If left as None, a
|
||||
linear interpolation between the colors will be excecuted
|
||||
:param (function|list) color_func: The parameter that determines the
|
||||
coloring of the surface. Takes either a function with 3 arguments
|
||||
x, y, z or a list/array of color values the same length as
|
||||
simplices. If None, coloring will only depend on the z axis
|
||||
:param (str) title: title of the plot
|
||||
:param (bool) plot_edges: determines if the triangles on the trisurf
|
||||
are visible
|
||||
:param (bool) showbackground: makes background in plot visible
|
||||
:param (str) backgroundcolor: color of background. Takes a string of
|
||||
the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
|
||||
:param (str) gridcolor: color of the gridlines besides the axes. Takes
|
||||
a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
|
||||
inclusive
|
||||
:param (str) zerolinecolor: color of the axes. Takes a string of the
|
||||
form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
|
||||
:param (str) edges_color: color of the edges, if plot_edges is True
|
||||
:param (int|float) height: the height of the plot (in pixels)
|
||||
:param (int|float) width: the width of the plot (in pixels)
|
||||
:param (dict) aspectratio: a dictionary of the aspect ratio values for
|
||||
the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
|
||||
|
||||
Example 1: Sphere
|
||||
|
||||
>>> # Necessary Imports for Trisurf
|
||||
>>> import numpy as np
|
||||
>>> from scipy.spatial import Delaunay
|
||||
|
||||
>>> from plotly.figure_factory import create_trisurf
|
||||
>>> from plotly.graph_objs import graph_objs
|
||||
|
||||
>>> # Make data for plot
|
||||
>>> u = np.linspace(0, 2*np.pi, 20)
|
||||
>>> v = np.linspace(0, np.pi, 20)
|
||||
>>> u,v = np.meshgrid(u,v)
|
||||
>>> u = u.flatten()
|
||||
>>> v = v.flatten()
|
||||
|
||||
>>> x = np.sin(v)*np.cos(u)
|
||||
>>> y = np.sin(v)*np.sin(u)
|
||||
>>> z = np.cos(v)
|
||||
|
||||
>>> points2D = np.vstack([u,v]).T
|
||||
>>> tri = Delaunay(points2D)
|
||||
>>> simplices = tri.simplices
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
|
||||
... simplices=simplices)
|
||||
|
||||
Example 2: Torus
|
||||
|
||||
>>> # Necessary Imports for Trisurf
|
||||
>>> import numpy as np
|
||||
>>> from scipy.spatial import Delaunay
|
||||
|
||||
>>> from plotly.figure_factory import create_trisurf
|
||||
>>> from plotly.graph_objs import graph_objs
|
||||
|
||||
>>> # Make data for plot
|
||||
>>> u = np.linspace(0, 2*np.pi, 20)
|
||||
>>> v = np.linspace(0, 2*np.pi, 20)
|
||||
>>> u,v = np.meshgrid(u,v)
|
||||
>>> u = u.flatten()
|
||||
>>> v = v.flatten()
|
||||
|
||||
>>> x = (3 + (np.cos(v)))*np.cos(u)
|
||||
>>> y = (3 + (np.cos(v)))*np.sin(u)
|
||||
>>> z = np.sin(v)
|
||||
|
||||
>>> points2D = np.vstack([u,v]).T
|
||||
>>> tri = Delaunay(points2D)
|
||||
>>> simplices = tri.simplices
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
|
||||
... simplices=simplices)
|
||||
|
||||
Example 3: Mobius Band
|
||||
|
||||
>>> # Necessary Imports for Trisurf
|
||||
>>> import numpy as np
|
||||
>>> from scipy.spatial import Delaunay
|
||||
|
||||
>>> from plotly.figure_factory import create_trisurf
|
||||
>>> from plotly.graph_objs import graph_objs
|
||||
|
||||
>>> # Make data for plot
|
||||
>>> u = np.linspace(0, 2*np.pi, 24)
|
||||
>>> v = np.linspace(-1, 1, 8)
|
||||
>>> u,v = np.meshgrid(u,v)
|
||||
>>> u = u.flatten()
|
||||
>>> v = v.flatten()
|
||||
|
||||
>>> tp = 1 + 0.5*v*np.cos(u/2.)
|
||||
>>> x = tp*np.cos(u)
|
||||
>>> y = tp*np.sin(u)
|
||||
>>> z = 0.5*v*np.sin(u/2.)
|
||||
|
||||
>>> points2D = np.vstack([u,v]).T
|
||||
>>> tri = Delaunay(points2D)
|
||||
>>> simplices = tri.simplices
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
|
||||
... simplices=simplices)
|
||||
|
||||
Example 4: Using a Custom Colormap Function with Light Cone
|
||||
|
||||
>>> # Necessary Imports for Trisurf
|
||||
>>> import numpy as np
|
||||
>>> from scipy.spatial import Delaunay
|
||||
|
||||
>>> from plotly.figure_factory import create_trisurf
|
||||
>>> from plotly.graph_objs import graph_objs
|
||||
|
||||
>>> # Make data for plot
|
||||
>>> u=np.linspace(-np.pi, np.pi, 30)
|
||||
>>> v=np.linspace(-np.pi, np.pi, 30)
|
||||
>>> u,v=np.meshgrid(u,v)
|
||||
>>> u=u.flatten()
|
||||
>>> v=v.flatten()
|
||||
|
||||
>>> x = u
|
||||
>>> y = u*np.cos(v)
|
||||
>>> z = u*np.sin(v)
|
||||
|
||||
>>> points2D = np.vstack([u,v]).T
|
||||
>>> tri = Delaunay(points2D)
|
||||
>>> simplices = tri.simplices
|
||||
|
||||
>>> # Define distance function
|
||||
>>> def dist_origin(x, y, z):
|
||||
... return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
|
||||
|
||||
>>> # Create a figure
|
||||
>>> fig1 = create_trisurf(x=x, y=y, z=z,
|
||||
... colormap=['#FFFFFF', '#E4FFFE',
|
||||
... '#A4F6F9', '#FF99FE',
|
||||
... '#BA52ED'],
|
||||
... scale=[0, 0.6, 0.71, 0.89, 1],
|
||||
... simplices=simplices,
|
||||
... color_func=dist_origin)
|
||||
|
||||
Example 5: Enter color_func as a list of colors
|
||||
|
||||
>>> # Necessary Imports for Trisurf
|
||||
>>> import numpy as np
|
||||
>>> from scipy.spatial import Delaunay
|
||||
>>> import random
|
||||
|
||||
>>> from plotly.figure_factory import create_trisurf
|
||||
>>> from plotly.graph_objs import graph_objs
|
||||
|
||||
>>> # Make data for plot
|
||||
>>> u=np.linspace(-np.pi, np.pi, 30)
|
||||
>>> v=np.linspace(-np.pi, np.pi, 30)
|
||||
>>> u,v=np.meshgrid(u,v)
|
||||
>>> u=u.flatten()
|
||||
>>> v=v.flatten()
|
||||
|
||||
>>> x = u
|
||||
>>> y = u*np.cos(v)
|
||||
>>> z = u*np.sin(v)
|
||||
|
||||
>>> points2D = np.vstack([u,v]).T
|
||||
>>> tri = Delaunay(points2D)
|
||||
>>> simplices = tri.simplices
|
||||
|
||||
|
||||
>>> colors = []
|
||||
>>> color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
|
||||
|
||||
>>> for index in range(len(simplices)):
|
||||
... colors.append(random.choice(color_choices))
|
||||
|
||||
>>> fig = create_trisurf(
|
||||
... x, y, z, simplices,
|
||||
... color_func=colors,
|
||||
... show_colorbar=True,
|
||||
... edges_color='rgb(2, 85, 180)',
|
||||
... title=' Modern Art'
|
||||
... )
|
||||
"""
|
||||
if aspectratio is None:
|
||||
aspectratio = {"x": 1, "y": 1, "z": 1}
|
||||
|
||||
# Validate colormap
|
||||
clrs.validate_colors(colormap)
|
||||
colormap, scale = clrs.convert_colors_to_same_type(
|
||||
colormap, colortype="tuple", return_default_colors=True, scale=scale
|
||||
)
|
||||
|
||||
data1 = trisurf(
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
simplices,
|
||||
show_colorbar=show_colorbar,
|
||||
color_func=color_func,
|
||||
colormap=colormap,
|
||||
scale=scale,
|
||||
edges_color=edges_color,
|
||||
plot_edges=plot_edges,
|
||||
)
|
||||
|
||||
axis = dict(
|
||||
showbackground=showbackground,
|
||||
backgroundcolor=backgroundcolor,
|
||||
gridcolor=gridcolor,
|
||||
zerolinecolor=zerolinecolor,
|
||||
)
|
||||
layout = graph_objs.Layout(
|
||||
title=title,
|
||||
width=width,
|
||||
height=height,
|
||||
scene=graph_objs.layout.Scene(
|
||||
xaxis=graph_objs.layout.scene.XAxis(**axis),
|
||||
yaxis=graph_objs.layout.scene.YAxis(**axis),
|
||||
zaxis=graph_objs.layout.scene.ZAxis(**axis),
|
||||
aspectratio=dict(
|
||||
x=aspectratio["x"], y=aspectratio["y"], z=aspectratio["z"]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return graph_objs.Figure(data=data1, layout=layout)
|
704
lib/python3.11/site-packages/plotly/figure_factory/_violin.py
Normal file
704
lib/python3.11/site-packages/plotly/figure_factory/_violin.py
Normal file
@ -0,0 +1,704 @@
|
||||
from numbers import Number
|
||||
|
||||
from plotly import exceptions, optional_imports
|
||||
import plotly.colors as clrs
|
||||
from plotly.graph_objs import graph_objs
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
pd = optional_imports.get_module("pandas")
|
||||
np = optional_imports.get_module("numpy")
|
||||
scipy_stats = optional_imports.get_module("scipy.stats")
|
||||
|
||||
|
||||
def calc_stats(data):
|
||||
"""
|
||||
Calculate statistics for use in violin plot.
|
||||
"""
|
||||
x = np.asarray(data, float)
|
||||
vals_min = np.min(x)
|
||||
vals_max = np.max(x)
|
||||
q2 = np.percentile(x, 50, interpolation="linear")
|
||||
q1 = np.percentile(x, 25, interpolation="lower")
|
||||
q3 = np.percentile(x, 75, interpolation="higher")
|
||||
iqr = q3 - q1
|
||||
whisker_dist = 1.5 * iqr
|
||||
|
||||
# in order to prevent drawing whiskers outside the interval
|
||||
# of data one defines the whisker positions as:
|
||||
d1 = np.min(x[x >= (q1 - whisker_dist)])
|
||||
d2 = np.max(x[x <= (q3 + whisker_dist)])
|
||||
return {
|
||||
"min": vals_min,
|
||||
"max": vals_max,
|
||||
"q1": q1,
|
||||
"q2": q2,
|
||||
"q3": q3,
|
||||
"d1": d1,
|
||||
"d2": d2,
|
||||
}
|
||||
|
||||
|
||||
def make_half_violin(x, y, fillcolor="#1f77b4", linecolor="rgb(0, 0, 0)"):
|
||||
"""
|
||||
Produces a sideways probability distribution fig violin plot.
|
||||
"""
|
||||
text = [
|
||||
"(pdf(y), y)=(" + "{:0.2f}".format(x[i]) + ", " + "{:0.2f}".format(y[i]) + ")"
|
||||
for i in range(len(x))
|
||||
]
|
||||
|
||||
return graph_objs.Scatter(
|
||||
x=x,
|
||||
y=y,
|
||||
mode="lines",
|
||||
name="",
|
||||
text=text,
|
||||
fill="tonextx",
|
||||
fillcolor=fillcolor,
|
||||
line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape="spline"),
|
||||
hoverinfo="text",
|
||||
opacity=0.5,
|
||||
)
|
||||
|
||||
|
||||
def make_violin_rugplot(vals, pdf_max, distance, color="#1f77b4"):
|
||||
"""
|
||||
Returns a rugplot fig for a violin plot.
|
||||
"""
|
||||
return graph_objs.Scatter(
|
||||
y=vals,
|
||||
x=[-pdf_max - distance] * len(vals),
|
||||
marker=graph_objs.scatter.Marker(color=color, symbol="line-ew-open"),
|
||||
mode="markers",
|
||||
name="",
|
||||
showlegend=False,
|
||||
hoverinfo="y",
|
||||
)
|
||||
|
||||
|
||||
def make_non_outlier_interval(d1, d2):
|
||||
"""
|
||||
Returns the scatterplot fig of most of a violin plot.
|
||||
"""
|
||||
return graph_objs.Scatter(
|
||||
x=[0, 0],
|
||||
y=[d1, d2],
|
||||
name="",
|
||||
mode="lines",
|
||||
line=graph_objs.scatter.Line(width=1.5, color="rgb(0,0,0)"),
|
||||
)
|
||||
|
||||
|
||||
def make_quartiles(q1, q3):
|
||||
"""
|
||||
Makes the upper and lower quartiles for a violin plot.
|
||||
"""
|
||||
return graph_objs.Scatter(
|
||||
x=[0, 0],
|
||||
y=[q1, q3],
|
||||
text=[
|
||||
"lower-quartile: " + "{:0.2f}".format(q1),
|
||||
"upper-quartile: " + "{:0.2f}".format(q3),
|
||||
],
|
||||
mode="lines",
|
||||
line=graph_objs.scatter.Line(width=4, color="rgb(0,0,0)"),
|
||||
hoverinfo="text",
|
||||
)
|
||||
|
||||
|
||||
def make_median(q2):
|
||||
"""
|
||||
Formats the 'median' hovertext for a violin plot.
|
||||
"""
|
||||
return graph_objs.Scatter(
|
||||
x=[0],
|
||||
y=[q2],
|
||||
text=["median: " + "{:0.2f}".format(q2)],
|
||||
mode="markers",
|
||||
marker=dict(symbol="square", color="rgb(255,255,255)"),
|
||||
hoverinfo="text",
|
||||
)
|
||||
|
||||
|
||||
def make_XAxis(xaxis_title, xaxis_range):
|
||||
"""
|
||||
Makes the x-axis for a violin plot.
|
||||
"""
|
||||
xaxis = graph_objs.layout.XAxis(
|
||||
title=xaxis_title,
|
||||
range=xaxis_range,
|
||||
showgrid=False,
|
||||
zeroline=False,
|
||||
showline=False,
|
||||
mirror=False,
|
||||
ticks="",
|
||||
showticklabels=False,
|
||||
)
|
||||
return xaxis
|
||||
|
||||
|
||||
def make_YAxis(yaxis_title):
|
||||
"""
|
||||
Makes the y-axis for a violin plot.
|
||||
"""
|
||||
yaxis = graph_objs.layout.YAxis(
|
||||
title=yaxis_title,
|
||||
showticklabels=True,
|
||||
autorange=True,
|
||||
ticklen=4,
|
||||
showline=True,
|
||||
zeroline=False,
|
||||
showgrid=False,
|
||||
mirror=False,
|
||||
)
|
||||
return yaxis
|
||||
|
||||
|
||||
def violinplot(vals, fillcolor="#1f77b4", rugplot=True):
|
||||
"""
|
||||
Refer to FigureFactory.create_violin() for docstring.
|
||||
"""
|
||||
vals = np.asarray(vals, float)
|
||||
# summary statistics
|
||||
vals_min = calc_stats(vals)["min"]
|
||||
vals_max = calc_stats(vals)["max"]
|
||||
q1 = calc_stats(vals)["q1"]
|
||||
q2 = calc_stats(vals)["q2"]
|
||||
q3 = calc_stats(vals)["q3"]
|
||||
d1 = calc_stats(vals)["d1"]
|
||||
d2 = calc_stats(vals)["d2"]
|
||||
|
||||
# kernel density estimation of pdf
|
||||
pdf = scipy_stats.gaussian_kde(vals)
|
||||
# grid over the data interval
|
||||
xx = np.linspace(vals_min, vals_max, 100)
|
||||
# evaluate the pdf at the grid xx
|
||||
yy = pdf(xx)
|
||||
max_pdf = np.max(yy)
|
||||
# distance from the violin plot to rugplot
|
||||
distance = (2.0 * max_pdf) / 10 if rugplot else 0
|
||||
# range for x values in the plot
|
||||
plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
|
||||
plot_data = [
|
||||
make_half_violin(-yy, xx, fillcolor=fillcolor),
|
||||
make_half_violin(yy, xx, fillcolor=fillcolor),
|
||||
make_non_outlier_interval(d1, d2),
|
||||
make_quartiles(q1, q3),
|
||||
make_median(q2),
|
||||
]
|
||||
if rugplot:
|
||||
plot_data.append(
|
||||
make_violin_rugplot(vals, max_pdf, distance=distance, color=fillcolor)
|
||||
)
|
||||
return plot_data, plot_xrange
|
||||
|
||||
|
||||
def violin_no_colorscale(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
):
|
||||
"""
|
||||
Refer to FigureFactory.create_violin() for docstring.
|
||||
|
||||
Returns fig for violin plot without colorscale.
|
||||
|
||||
"""
|
||||
|
||||
# collect all group names
|
||||
group_name = []
|
||||
for name in data[group_header]:
|
||||
if name not in group_name:
|
||||
group_name.append(name)
|
||||
if sort:
|
||||
group_name.sort()
|
||||
|
||||
gb = data.groupby([group_header])
|
||||
L = len(group_name)
|
||||
|
||||
fig = make_subplots(
|
||||
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
|
||||
)
|
||||
color_index = 0
|
||||
for k, gr in enumerate(group_name):
|
||||
vals = np.asarray(gb.get_group(gr)[data_header], float)
|
||||
if color_index >= len(colors):
|
||||
color_index = 0
|
||||
plot_data, plot_xrange = violinplot(
|
||||
vals, fillcolor=colors[color_index], rugplot=rugplot
|
||||
)
|
||||
for item in plot_data:
|
||||
fig.append_trace(item, 1, k + 1)
|
||||
color_index += 1
|
||||
|
||||
# add violin plot labels
|
||||
fig["layout"].update(
|
||||
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
|
||||
)
|
||||
|
||||
# set the sharey axis style
|
||||
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
|
||||
fig["layout"].update(
|
||||
title=title,
|
||||
showlegend=False,
|
||||
hovermode="closest",
|
||||
autosize=False,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def violin_colorscale(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
):
|
||||
"""
|
||||
Refer to FigureFactory.create_violin() for docstring.
|
||||
|
||||
Returns fig for violin plot with colorscale.
|
||||
|
||||
"""
|
||||
|
||||
# collect all group names
|
||||
group_name = []
|
||||
for name in data[group_header]:
|
||||
if name not in group_name:
|
||||
group_name.append(name)
|
||||
if sort:
|
||||
group_name.sort()
|
||||
|
||||
# make sure all group names are keys in group_stats
|
||||
for group in group_name:
|
||||
if group not in group_stats:
|
||||
raise exceptions.PlotlyError(
|
||||
"All values/groups in the index "
|
||||
"column must be represented "
|
||||
"as a key in group_stats."
|
||||
)
|
||||
|
||||
gb = data.groupby([group_header])
|
||||
L = len(group_name)
|
||||
|
||||
fig = make_subplots(
|
||||
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
|
||||
)
|
||||
|
||||
# prepare low and high color for colorscale
|
||||
lowcolor = clrs.color_parser(colors[0], clrs.unlabel_rgb)
|
||||
highcolor = clrs.color_parser(colors[1], clrs.unlabel_rgb)
|
||||
|
||||
# find min and max values in group_stats
|
||||
group_stats_values = []
|
||||
for key in group_stats:
|
||||
group_stats_values.append(group_stats[key])
|
||||
|
||||
max_value = max(group_stats_values)
|
||||
min_value = min(group_stats_values)
|
||||
|
||||
for k, gr in enumerate(group_name):
|
||||
vals = np.asarray(gb.get_group(gr)[data_header], float)
|
||||
|
||||
# find intermediate color from colorscale
|
||||
intermed = (group_stats[gr] - min_value) / (max_value - min_value)
|
||||
intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed)
|
||||
|
||||
plot_data, plot_xrange = violinplot(
|
||||
vals, fillcolor="rgb{}".format(intermed_color), rugplot=rugplot
|
||||
)
|
||||
for item in plot_data:
|
||||
fig.append_trace(item, 1, k + 1)
|
||||
fig["layout"].update(
|
||||
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
|
||||
)
|
||||
# add colorbar to plot
|
||||
trace_dummy = graph_objs.Scatter(
|
||||
x=[0],
|
||||
y=[0],
|
||||
mode="markers",
|
||||
marker=dict(
|
||||
size=2,
|
||||
cmin=min_value,
|
||||
cmax=max_value,
|
||||
colorscale=[[0, colors[0]], [1, colors[1]]],
|
||||
showscale=True,
|
||||
),
|
||||
showlegend=False,
|
||||
)
|
||||
fig.append_trace(trace_dummy, 1, L)
|
||||
|
||||
# set the sharey axis style
|
||||
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
|
||||
fig["layout"].update(
|
||||
title=title,
|
||||
showlegend=False,
|
||||
hovermode="closest",
|
||||
autosize=False,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def violin_dict(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
):
|
||||
"""
|
||||
Refer to FigureFactory.create_violin() for docstring.
|
||||
|
||||
Returns fig for violin plot without colorscale.
|
||||
|
||||
"""
|
||||
|
||||
# collect all group names
|
||||
group_name = []
|
||||
for name in data[group_header]:
|
||||
if name not in group_name:
|
||||
group_name.append(name)
|
||||
|
||||
if sort:
|
||||
group_name.sort()
|
||||
|
||||
# check if all group names appear in colors dict
|
||||
for group in group_name:
|
||||
if group not in colors:
|
||||
raise exceptions.PlotlyError(
|
||||
"If colors is a dictionary, all "
|
||||
"the group names must appear as "
|
||||
"keys in colors."
|
||||
)
|
||||
|
||||
gb = data.groupby([group_header])
|
||||
L = len(group_name)
|
||||
|
||||
fig = make_subplots(
|
||||
rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
|
||||
)
|
||||
|
||||
for k, gr in enumerate(group_name):
|
||||
vals = np.asarray(gb.get_group(gr)[data_header], float)
|
||||
plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], rugplot=rugplot)
|
||||
for item in plot_data:
|
||||
fig.append_trace(item, 1, k + 1)
|
||||
|
||||
# add violin plot labels
|
||||
fig["layout"].update(
|
||||
{"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
|
||||
)
|
||||
|
||||
# set the sharey axis style
|
||||
fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
|
||||
fig["layout"].update(
|
||||
title=title,
|
||||
showlegend=False,
|
||||
hovermode="closest",
|
||||
autosize=False,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_violin(
|
||||
data,
|
||||
data_header=None,
|
||||
group_header=None,
|
||||
colors=None,
|
||||
use_colorscale=False,
|
||||
group_stats=None,
|
||||
rugplot=True,
|
||||
sort=False,
|
||||
height=450,
|
||||
width=600,
|
||||
title="Violin and Rug Plot",
|
||||
):
|
||||
"""
|
||||
**deprecated**, use instead the plotly.graph_objects trace
|
||||
:class:`plotly.graph_objects.Violin`.
|
||||
|
||||
:param (list|array) data: accepts either a list of numerical values,
|
||||
a list of dictionaries all with identical keys and at least one
|
||||
column of numeric values, or a pandas dataframe with at least one
|
||||
column of numbers.
|
||||
:param (str) data_header: the header of the data column to be used
|
||||
from an inputted pandas dataframe. Not applicable if 'data' is
|
||||
a list of numeric values.
|
||||
:param (str) group_header: applicable if grouping data by a variable.
|
||||
'group_header' must be set to the name of the grouping variable.
|
||||
:param (str|tuple|list|dict) colors: either a plotly scale name,
|
||||
an rgb or hex color, a color tuple, a list of colors or a
|
||||
dictionary. An rgb color is of the form 'rgb(x, y, z)' where
|
||||
x, y and z belong to the interval [0, 255] and a color tuple is a
|
||||
tuple of the form (a, b, c) where a, b and c belong to [0, 1].
|
||||
If colors is a list, it must contain valid color types as its
|
||||
members.
|
||||
:param (bool) use_colorscale: only applicable if grouping by another
|
||||
variable. Will implement a colorscale based on the first 2 colors
|
||||
of param colors. This means colors must be a list with at least 2
|
||||
colors in it (Plotly colorscales are accepted since they map to a
|
||||
list of two rgb colors). Default = False
|
||||
:param (dict) group_stats: a dictionary where each key is a unique
|
||||
value from the group_header column in data. Each value must be a
|
||||
number and will be used to color the violin plots if a colorscale
|
||||
is being used.
|
||||
:param (bool) rugplot: determines if a rugplot is draw on violin plot.
|
||||
Default = True
|
||||
:param (bool) sort: determines if violins are sorted
|
||||
alphabetically (True) or by input order (False). Default = False
|
||||
:param (float) height: the height of the violin plot.
|
||||
:param (float) width: the width of the violin plot.
|
||||
:param (str) title: the title of the violin plot.
|
||||
|
||||
Example 1: Single Violin Plot
|
||||
|
||||
>>> from plotly.figure_factory import create_violin
|
||||
>>> import plotly.graph_objs as graph_objects
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from scipy import stats
|
||||
|
||||
>>> # create list of random values
|
||||
>>> data_list = np.random.randn(100)
|
||||
|
||||
>>> # create violin fig
|
||||
>>> fig = create_violin(data_list, colors='#604d9e')
|
||||
|
||||
>>> # plot
|
||||
>>> fig.show()
|
||||
|
||||
Example 2: Multiple Violin Plots with Qualitative Coloring
|
||||
|
||||
>>> from plotly.figure_factory import create_violin
|
||||
>>> import plotly.graph_objs as graph_objects
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
>>> from scipy import stats
|
||||
|
||||
>>> # create dataframe
|
||||
>>> np.random.seed(619517)
|
||||
>>> Nr=250
|
||||
>>> y = np.random.randn(Nr)
|
||||
>>> gr = np.random.choice(list("ABCDE"), Nr)
|
||||
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
|
||||
|
||||
>>> for i, letter in enumerate("ABCDE"):
|
||||
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
|
||||
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
|
||||
|
||||
>>> # create violin fig
|
||||
>>> fig = create_violin(df, data_header='Score', group_header='Group',
|
||||
... sort=True, height=600, width=1000)
|
||||
|
||||
>>> # plot
|
||||
>>> fig.show()
|
||||
|
||||
Example 3: Violin Plots with Colorscale
|
||||
|
||||
>>> from plotly.figure_factory import create_violin
|
||||
>>> import plotly.graph_objs as graph_objects
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
>>> from scipy import stats
|
||||
|
||||
>>> # create dataframe
|
||||
>>> np.random.seed(619517)
|
||||
>>> Nr=250
|
||||
>>> y = np.random.randn(Nr)
|
||||
>>> gr = np.random.choice(list("ABCDE"), Nr)
|
||||
>>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
|
||||
|
||||
>>> for i, letter in enumerate("ABCDE"):
|
||||
... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
|
||||
>>> df = pd.DataFrame(dict(Score=y, Group=gr))
|
||||
|
||||
>>> # define header params
|
||||
>>> data_header = 'Score'
|
||||
>>> group_header = 'Group'
|
||||
|
||||
>>> # make groupby object with pandas
|
||||
>>> group_stats = {}
|
||||
>>> groupby_data = df.groupby([group_header])
|
||||
|
||||
>>> for group in "ABCDE":
|
||||
... data_from_group = groupby_data.get_group(group)[data_header]
|
||||
... # take a stat of the grouped data
|
||||
... stat = np.median(data_from_group)
|
||||
... # add to dictionary
|
||||
... group_stats[group] = stat
|
||||
|
||||
>>> # create violin fig
|
||||
>>> fig = create_violin(df, data_header='Score', group_header='Group',
|
||||
... height=600, width=1000, use_colorscale=True,
|
||||
... group_stats=group_stats)
|
||||
|
||||
>>> # plot
|
||||
>>> fig.show()
|
||||
"""
|
||||
|
||||
# Validate colors
|
||||
if isinstance(colors, dict):
|
||||
valid_colors = clrs.validate_colors_dict(colors, "rgb")
|
||||
else:
|
||||
valid_colors = clrs.validate_colors(colors, "rgb")
|
||||
|
||||
# validate data and choose plot type
|
||||
if group_header is None:
|
||||
if isinstance(data, list):
|
||||
if len(data) <= 0:
|
||||
raise exceptions.PlotlyError(
|
||||
"If data is a list, it must be "
|
||||
"nonempty and contain either "
|
||||
"numbers or dictionaries."
|
||||
)
|
||||
|
||||
if not all(isinstance(element, Number) for element in data):
|
||||
raise exceptions.PlotlyError(
|
||||
"If data is a list, it must contain only numbers."
|
||||
)
|
||||
|
||||
if pd and isinstance(data, pd.core.frame.DataFrame):
|
||||
if data_header is None:
|
||||
raise exceptions.PlotlyError(
|
||||
"data_header must be the "
|
||||
"column name with the "
|
||||
"desired numeric data for "
|
||||
"the violin plot."
|
||||
)
|
||||
|
||||
data = data[data_header].values.tolist()
|
||||
|
||||
# call the plotting functions
|
||||
plot_data, plot_xrange = violinplot(
|
||||
data, fillcolor=valid_colors[0], rugplot=rugplot
|
||||
)
|
||||
|
||||
layout = graph_objs.Layout(
|
||||
title=title,
|
||||
autosize=False,
|
||||
font=graph_objs.layout.Font(size=11),
|
||||
height=height,
|
||||
showlegend=False,
|
||||
width=width,
|
||||
xaxis=make_XAxis("", plot_xrange),
|
||||
yaxis=make_YAxis(""),
|
||||
hovermode="closest",
|
||||
)
|
||||
layout["yaxis"].update(dict(showline=False, showticklabels=False, ticks=""))
|
||||
|
||||
fig = graph_objs.Figure(data=plot_data, layout=layout)
|
||||
|
||||
return fig
|
||||
|
||||
else:
|
||||
if not isinstance(data, pd.core.frame.DataFrame):
|
||||
raise exceptions.PlotlyError(
|
||||
"Error. You must use a pandas "
|
||||
"DataFrame if you are using a "
|
||||
"group header."
|
||||
)
|
||||
|
||||
if data_header is None:
|
||||
raise exceptions.PlotlyError(
|
||||
"data_header must be the column "
|
||||
"name with the desired numeric "
|
||||
"data for the violin plot."
|
||||
)
|
||||
|
||||
if use_colorscale is False:
|
||||
if isinstance(valid_colors, dict):
|
||||
# validate colors dict choice below
|
||||
fig = violin_dict(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
valid_colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
)
|
||||
return fig
|
||||
else:
|
||||
fig = violin_no_colorscale(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
valid_colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
)
|
||||
return fig
|
||||
else:
|
||||
if isinstance(valid_colors, dict):
|
||||
raise exceptions.PlotlyError(
|
||||
"The colors param cannot be "
|
||||
"a dictionary if you are "
|
||||
"using a colorscale."
|
||||
)
|
||||
|
||||
if len(valid_colors) < 2:
|
||||
raise exceptions.PlotlyError(
|
||||
"colors must be a list with "
|
||||
"at least 2 colors. A "
|
||||
"Plotly scale is allowed."
|
||||
)
|
||||
|
||||
if not isinstance(group_stats, dict):
|
||||
raise exceptions.PlotlyError(
|
||||
"Your group_stats param must be a dictionary."
|
||||
)
|
||||
|
||||
fig = violin_colorscale(
|
||||
data,
|
||||
data_header,
|
||||
group_header,
|
||||
valid_colors,
|
||||
use_colorscale,
|
||||
group_stats,
|
||||
rugplot,
|
||||
sort,
|
||||
height,
|
||||
width,
|
||||
title,
|
||||
)
|
||||
return fig
|
249
lib/python3.11/site-packages/plotly/figure_factory/utils.py
Normal file
249
lib/python3.11/site-packages/plotly/figure_factory/utils.py
Normal file
@ -0,0 +1,249 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from plotly import exceptions
|
||||
|
||||
|
||||
def is_sequence(obj):
|
||||
return isinstance(obj, Sequence) and not isinstance(obj, str)
|
||||
|
||||
|
||||
def validate_index(index_vals):
|
||||
"""
|
||||
Validates if a list contains all numbers or all strings
|
||||
|
||||
:raises: (PlotlyError) If there are any two items in the list whose
|
||||
types differ
|
||||
"""
|
||||
from numbers import Number
|
||||
|
||||
if isinstance(index_vals[0], Number):
|
||||
if not all(isinstance(item, Number) for item in index_vals):
|
||||
raise exceptions.PlotlyError(
|
||||
"Error in indexing column. "
|
||||
"Make sure all entries of each "
|
||||
"column are all numbers or "
|
||||
"all strings."
|
||||
)
|
||||
|
||||
elif isinstance(index_vals[0], str):
|
||||
if not all(isinstance(item, str) for item in index_vals):
|
||||
raise exceptions.PlotlyError(
|
||||
"Error in indexing column. "
|
||||
"Make sure all entries of each "
|
||||
"column are all numbers or "
|
||||
"all strings."
|
||||
)
|
||||
|
||||
|
||||
def validate_dataframe(array):
|
||||
"""
|
||||
Validates all strings or numbers in each dataframe column
|
||||
|
||||
:raises: (PlotlyError) If there are any two items in any list whose
|
||||
types differ
|
||||
"""
|
||||
from numbers import Number
|
||||
|
||||
for vector in array:
|
||||
if isinstance(vector[0], Number):
|
||||
if not all(isinstance(item, Number) for item in vector):
|
||||
raise exceptions.PlotlyError(
|
||||
"Error in dataframe. "
|
||||
"Make sure all entries of "
|
||||
"each column are either "
|
||||
"numbers or strings."
|
||||
)
|
||||
elif isinstance(vector[0], str):
|
||||
if not all(isinstance(item, str) for item in vector):
|
||||
raise exceptions.PlotlyError(
|
||||
"Error in dataframe. "
|
||||
"Make sure all entries of "
|
||||
"each column are either "
|
||||
"numbers or strings."
|
||||
)
|
||||
|
||||
|
||||
def validate_equal_length(*args):
|
||||
"""
|
||||
Validates that data lists or ndarrays are the same length.
|
||||
|
||||
:raises: (PlotlyError) If any data lists are not the same length.
|
||||
"""
|
||||
length = len(args[0])
|
||||
if any(len(lst) != length for lst in args):
|
||||
raise exceptions.PlotlyError(
|
||||
"Oops! Your data lists or ndarrays should be the same length."
|
||||
)
|
||||
|
||||
|
||||
def validate_positive_scalars(**kwargs):
|
||||
"""
|
||||
Validates that all values given in key/val pairs are positive.
|
||||
|
||||
Accepts kwargs to improve Exception messages.
|
||||
|
||||
:raises: (PlotlyError) If any value is < 0 or raises.
|
||||
"""
|
||||
for key, val in kwargs.items():
|
||||
try:
|
||||
if val <= 0:
|
||||
raise ValueError("{} must be > 0, got {}".format(key, val))
|
||||
except TypeError:
|
||||
raise exceptions.PlotlyError("{} must be a number, got {}".format(key, val))
|
||||
|
||||
|
||||
def flatten(array):
|
||||
"""
|
||||
Uses list comprehension to flatten array
|
||||
|
||||
:param (array): An iterable to flatten
|
||||
:raises (PlotlyError): If iterable is not nested.
|
||||
:rtype (list): The flattened list.
|
||||
"""
|
||||
try:
|
||||
return [item for sublist in array for item in sublist]
|
||||
except TypeError:
|
||||
raise exceptions.PlotlyError(
|
||||
"Your data array could not be "
|
||||
"flattened! Make sure your data is "
|
||||
"entered as lists or ndarrays!"
|
||||
)
|
||||
|
||||
|
||||
def endpts_to_intervals(endpts):
|
||||
"""
|
||||
Returns a list of intervals for categorical colormaps
|
||||
|
||||
Accepts a list or tuple of sequentially increasing numbers and returns
|
||||
a list representation of the mathematical intervals with these numbers
|
||||
as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
|
||||
|
||||
:raises: (PlotlyError) If input is not a list or tuple
|
||||
:raises: (PlotlyError) If the input contains a string
|
||||
:raises: (PlotlyError) If any number does not increase after the
|
||||
previous one in the sequence
|
||||
"""
|
||||
length = len(endpts)
|
||||
# Check if endpts is a list or tuple
|
||||
if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
|
||||
raise exceptions.PlotlyError(
|
||||
"The intervals_endpts argument must "
|
||||
"be a list or tuple of a sequence "
|
||||
"of increasing numbers."
|
||||
)
|
||||
# Check if endpts contains only numbers
|
||||
for item in endpts:
|
||||
if isinstance(item, str):
|
||||
raise exceptions.PlotlyError(
|
||||
"The intervals_endpts argument "
|
||||
"must be a list or tuple of a "
|
||||
"sequence of increasing "
|
||||
"numbers."
|
||||
)
|
||||
# Check if numbers in endpts are increasing
|
||||
for k in range(length - 1):
|
||||
if endpts[k] >= endpts[k + 1]:
|
||||
raise exceptions.PlotlyError(
|
||||
"The intervals_endpts argument "
|
||||
"must be a list or tuple of a "
|
||||
"sequence of increasing "
|
||||
"numbers."
|
||||
)
|
||||
else:
|
||||
intervals = []
|
||||
# add -inf to intervals
|
||||
intervals.append([float("-inf"), endpts[0]])
|
||||
for k in range(length - 1):
|
||||
interval = []
|
||||
interval.append(endpts[k])
|
||||
interval.append(endpts[k + 1])
|
||||
intervals.append(interval)
|
||||
# add +inf to intervals
|
||||
intervals.append([endpts[length - 1], float("inf")])
|
||||
return intervals
|
||||
|
||||
|
||||
def annotation_dict_for_label(
|
||||
text,
|
||||
lane,
|
||||
num_of_lanes,
|
||||
subplot_spacing,
|
||||
row_col="col",
|
||||
flipped=True,
|
||||
right_side=True,
|
||||
text_color="#0f0f0f",
|
||||
):
|
||||
"""
|
||||
Returns annotation dict for label of n labels of a 1xn or nx1 subplot.
|
||||
|
||||
:param (str) text: the text for a label.
|
||||
:param (int) lane: the label number for text. From 1 to n inclusive.
|
||||
:param (int) num_of_lanes: the number 'n' of rows or columns in subplot.
|
||||
:param (float) subplot_spacing: the value for the horizontal_spacing and
|
||||
vertical_spacing params in your plotly.tools.make_subplots() call.
|
||||
:param (str) row_col: choose whether labels are placed along rows or
|
||||
columns.
|
||||
:param (bool) flipped: flips text by 90 degrees. Text is printed
|
||||
horizontally if set to True and row_col='row', or if False and
|
||||
row_col='col'.
|
||||
:param (bool) right_side: only applicable if row_col is set to 'row'.
|
||||
:param (str) text_color: color of the text.
|
||||
"""
|
||||
temp = (1 - (num_of_lanes - 1) * subplot_spacing) / (num_of_lanes)
|
||||
if not flipped:
|
||||
xanchor = "center"
|
||||
yanchor = "middle"
|
||||
if row_col == "col":
|
||||
x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
|
||||
y = 1.03
|
||||
textangle = 0
|
||||
elif row_col == "row":
|
||||
y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
|
||||
x = 1.03
|
||||
textangle = 90
|
||||
else:
|
||||
if row_col == "col":
|
||||
xanchor = "center"
|
||||
yanchor = "bottom"
|
||||
x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
|
||||
y = 1.0
|
||||
textangle = 270
|
||||
elif row_col == "row":
|
||||
yanchor = "middle"
|
||||
y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
|
||||
if right_side:
|
||||
x = 1.0
|
||||
xanchor = "left"
|
||||
else:
|
||||
x = -0.01
|
||||
xanchor = "right"
|
||||
textangle = 0
|
||||
|
||||
annotation_dict = dict(
|
||||
textangle=textangle,
|
||||
xanchor=xanchor,
|
||||
yanchor=yanchor,
|
||||
x=x,
|
||||
y=y,
|
||||
showarrow=False,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
text=text,
|
||||
font=dict(size=13, color=text_color),
|
||||
)
|
||||
return annotation_dict
|
||||
|
||||
|
||||
def list_of_options(iterable, conj="and", period=True):
|
||||
"""
|
||||
Returns an English listing of objects seperated by commas ','
|
||||
|
||||
For example, ['foo', 'bar', 'baz'] becomes 'foo, bar and baz'
|
||||
if the conjunction 'and' is selected.
|
||||
"""
|
||||
if len(iterable) < 2:
|
||||
raise exceptions.PlotlyError(
|
||||
"Your list or tuple must contain at least 2 items."
|
||||
)
|
||||
template = (len(iterable) - 2) * "{}, " + "{} " + conj + " {}" + period * "."
|
||||
return template.format(*iterable)
|
Reference in New Issue
Block a user