Skip to content

Commit 13b2d8b

Browse files
migrate code from @ppseverin
1 parent 2710cf4 commit 13b2d8b

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def _check_and_prepare_data(data, config):
9494

9595
return dates, opens, highs, lows, closes, volumes
9696

97+
def _label_validator(label_value):
98+
if isinstance(label_value,str):
99+
return True
100+
elif not isinstance(label_value,(tuple,list)):
101+
return False
102+
for label in label_value:
103+
if not isinstance(label,str):
104+
return False
105+
return True
106+
97107
def _get_valid_plot_types(plottype=None):
98108

99109
_alias_types = {

src/mplfinance/plotting.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from mplfinance import _styles
3434

35-
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
35+
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
3636
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
3737
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
3838
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
@@ -1093,32 +1093,41 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
10931093
mark = apdict['marker']
10941094
color = apdict['color']
10951095
alpha = apdict['alpha']
1096+
labels = apdict['labels']
10961097
edgecolors = apdict['edgecolors']
10971098
linewidths = apdict['linewidths']
10981099

10991100
if isinstance(mark,(list,tuple,np.ndarray)):
11001101
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
11011102
else:
11021103
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1104+
if labels is not None:
1105+
ax.legend(labels=labels)
11031106
elif aptype == 'bar':
11041107
width = 0.8 if apdict['width'] is None else apdict['width']
11051108
bottom = apdict['bottom']
11061109
color = apdict['color']
11071110
alpha = apdict['alpha']
11081111
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
1112+
if apdict['labels'] is not None:
1113+
ax.legend(labels=apdict['labels'])
11091114
elif aptype == 'line':
11101115
ls = apdict['linestyle']
11111116
color = apdict['color']
11121117
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11131118
alpha = apdict['alpha']
11141119
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1120+
if apdict['labels'] is not None:
1121+
ax.legend(labels=apdict['labels'])
11151122
elif aptype == 'step':
11161123
stepwhere = apdict['stepwhere']
11171124
ls = apdict['linestyle']
11181125
color = apdict['color']
11191126
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11201127
alpha = apdict['alpha']
11211128
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1129+
if apdict['labels'] is not None:
1130+
ax.legend(labels=apdict['labels'])
11221131
else:
11231132
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
11241133

@@ -1371,6 +1380,9 @@ def _valid_addplot_kwargs():
13711380
'fill_between': { 'Default' : None, # added by Wen
13721381
'Description' : " fill region",
13731382
'Validator' : _fill_between_validator },
1383+
"labels" : { 'Default' : None,
1384+
'Description' : 'Labels for the added plot.',
1385+
'Validator' : _label_validator },
13741386

13751387
}
13761388

0 commit comments

Comments
 (0)