Skip to content

Commit 381a816

Browse files
authored
Add color option to plot_points (#357)
* Add color option to `plot_points` * Make plot_points more accessible in api * Fix formatting
1 parent 0e135b6 commit 381a816

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

src/gemdat/plots/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
jumps_vs_time,
2222
msd_per_element,
2323
plot_3d,
24+
plot_3d_points,
2425
radial_distribution,
2526
rectilinear,
2627
shape,
@@ -44,6 +45,7 @@
4445
'msd_per_element',
4546
'msd_per_element',
4647
'plot_3d',
48+
'plot_3d_points',
4749
'polar',
4850
'radial_distribution',
4951
'rectilinear',

src/gemdat/plots/plotly/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ._jumps_vs_distance import jumps_vs_distance
1616
from ._jumps_vs_time import jumps_vs_time
1717
from ._msd_per_element import msd_per_element
18-
from ._plot3d import plot_3d
18+
from ._plot3d import plot_3d, plot_3d_points
1919
from ._polar import polar
2020
from ._radial_distribution import radial_distribution
2121
from ._rectilinear import rectilinear
@@ -37,6 +37,7 @@
3737
'jumps_vs_time',
3838
'msd_per_element',
3939
'plot_3d',
40+
'plot_3d_points',
4041
'polar',
4142
'radial_distribution',
4243
'rectilinear',

src/gemdat/plots/plotly/_plot3d.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection
4-
from typing import TYPE_CHECKING, Sequence
4+
from typing import TYPE_CHECKING, Optional, Sequence
55

66
import numpy as np
77
import plotly.express as px
@@ -65,7 +65,14 @@ def plot_lattice_vectors(lattice: Lattice, *, fig: go.Figure):
6565
)
6666

6767

68-
def plot_points(points: np.ndarray, labels: Sequence, *, fig: go.Figure, point_size: int = 5):
68+
def plot_3d_points(
69+
points: np.ndarray,
70+
labels: Sequence,
71+
*,
72+
fig: go.Figure,
73+
point_size: int = 5,
74+
colors: Optional[dict[str, str]] = None,
75+
):
6976
"""Plot points using plotly.
7077
7178
Parameters
@@ -78,13 +85,18 @@ def plot_points(points: np.ndarray, labels: Sequence, *, fig: go.Figure, point_s
7885
Plotly figure to add traces to
7986
point_size : int, optional
8087
Size of the points
88+
colors: dict, optional
89+
Mapping of colors for the each label.
90+
See the following link for a list of accepted colours:
91+
https://developer.mozilla.org/en-US/docs/Web/CSS/named-color
8192
"""
8293
assert len(points) == len(labels)
8394

84-
colors = {
85-
label: px.colors.sample_colorscale('rainbow', [i / (len(labels) - 1)])
86-
for i, label in enumerate(labels)
87-
}
95+
if not colors:
96+
colors = {
97+
label: px.colors.sample_colorscale('rainbow', [i / (len(labels) - 1)])
98+
for i, label in enumerate(labels)
99+
}
88100

89101
for i, (x, y, z) in enumerate(points):
90102
label = labels[i]
@@ -120,7 +132,7 @@ def plot_structure(structure: Structure, *, lattice: Lattice | None = None, fig:
120132
else:
121133
cart_coords = structure.cart_coords
122134

123-
plot_points(cart_coords, labels=structure.labels, fig=fig)
135+
plot_3d_points(cart_coords, labels=structure.labels, fig=fig)
124136
plot_lattice_vectors(structure.lattice, fig=fig)
125137

126138

0 commit comments

Comments
 (0)