Skip to content

Commit f608812

Browse files
allow addplot label to be a list when data is a dataframe
1 parent ca3054e commit f608812

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,14 @@ def _check_and_prepare_data(data, config):
9797

9898

9999
def _label_validator(label_value):
100-
''' Validates the input of label for the added plots.
101-
label_value may be a str or a list of str.
100+
''' Validates the input of [legend] label for added plots.
101+
label_value may be a str or a sequence of str.
102102
'''
103103
if isinstance(label_value,str):
104104
return True
105+
if isinstance(label_value,(list,tuple,np.ndarray)):
106+
if all([isinstance(v,str) for v in label_value]):
107+
return True
105108
return False
106109

107110

src/mplfinance/plotting.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -777,12 +777,14 @@ def plot( data, **kwargs ):
777777
else:
778778
havedf = False # must be a single series or array
779779
apdata = [apdata,] # make it iterable
780+
colcount = 0
780781
for column in apdata:
781782
ydata = apdata.loc[:,column] if havedf else column
782-
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
783+
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount)
783784
_addplot_apply_supplements(ax,apdict,xdates)
784-
if apdict["label"]: # not supported for aptype == 'ohlc' or 'candle'
785-
contains_legend_label.append(ax)
785+
colcount += 1
786+
if apdict["label"]: # not supported for aptype == 'ohlc' or 'candle'
787+
contains_legend_label.append(ax)
786788
for ax in set(contains_legend_label): # there might be duplicates
787789
ax.legend()
788790

@@ -1072,7 +1074,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
10721074
ax.autoscale_view()
10731075
return ax
10741076

1075-
def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
1077+
def _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount):
10761078
external_axes_mode = apdict['ax'] is not None
10771079
if not external_axes_mode:
10781080
secondary_y = False
@@ -1094,7 +1096,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
10941096
ax = apdict['ax']
10951097

10961098
aptype = apdict['type']
1097-
label = apdict['label']
1099+
if isinstance(apdict['label'],(list,tuple,np.ndarray)):
1100+
label = apdict['label'][colcount]
1101+
else: # isinstance(...,str)
1102+
label = apdict['label']
10981103
if aptype == 'scatter':
10991104
size = apdict['markersize']
11001105
mark = apdict['marker']

0 commit comments

Comments
 (0)