1
1
from __future__ import annotations
2
2
3
3
from collections .abc import Collection
4
- from typing import TYPE_CHECKING , Sequence
4
+ from typing import TYPE_CHECKING , Optional , Sequence
5
5
6
6
import numpy as np
7
7
import plotly .express as px
@@ -65,7 +65,14 @@ def plot_lattice_vectors(lattice: Lattice, *, fig: go.Figure):
65
65
)
66
66
67
67
68
- def plot_points (points : np .ndarray , labels : Sequence , * , fig : go .Figure , point_size : int = 5 ):
68
+ def plot_3d_points (
69
+ points : np .ndarray ,
70
+ labels : Sequence ,
71
+ * ,
72
+ fig : go .Figure ,
73
+ point_size : int = 5 ,
74
+ colors : Optional [dict [str , str ]] = None ,
75
+ ):
69
76
"""Plot points using plotly.
70
77
71
78
Parameters
@@ -78,13 +85,18 @@ def plot_points(points: np.ndarray, labels: Sequence, *, fig: go.Figure, point_s
78
85
Plotly figure to add traces to
79
86
point_size : int, optional
80
87
Size of the points
88
+ colors: dict, optional
89
+ Mapping of colors for the each label.
90
+ See the following link for a list of accepted colours:
91
+ https://developer.mozilla.org/en-US/docs/Web/CSS/named-color
81
92
"""
82
93
assert len (points ) == len (labels )
83
94
84
- colors = {
85
- label : px .colors .sample_colorscale ('rainbow' , [i / (len (labels ) - 1 )])
86
- for i , label in enumerate (labels )
87
- }
95
+ if not colors :
96
+ colors = {
97
+ label : px .colors .sample_colorscale ('rainbow' , [i / (len (labels ) - 1 )])
98
+ for i , label in enumerate (labels )
99
+ }
88
100
89
101
for i , (x , y , z ) in enumerate (points ):
90
102
label = labels [i ]
@@ -120,7 +132,7 @@ def plot_structure(structure: Structure, *, lattice: Lattice | None = None, fig:
120
132
else :
121
133
cart_coords = structure .cart_coords
122
134
123
- plot_points (cart_coords , labels = structure .labels , fig = fig )
135
+ plot_3d_points (cart_coords , labels = structure .labels , fig = fig )
124
136
plot_lattice_vectors (structure .lattice , fig = fig )
125
137
126
138
0 commit comments