Skip to content

Commit b14c30b

Browse files
add new api for labels
1 parent 3a9ce92 commit b14c30b

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,18 @@ def _check_and_prepare_data(data, config):
9595
return dates, opens, highs, lows, closes, volumes
9696

9797
def _label_validator(label_value):
98-
''' Validates the input of labels for the added plots.
98+
''' Validates the input of label for the added plots.
9999
label_value may be a str or a list of str.
100100
'''
101101
if isinstance(label_value,str):
102102
return True
103-
elif not isinstance(label_value,(tuple,list)):
104-
return False
105-
for label in label_value:
106-
if not isinstance(label,str):
107-
return False
108-
return True
103+
return False
104+
# elif not isinstance(label_value,(tuple,list)):
105+
# return False
106+
# for label in label_value:
107+
# if not isinstance(label,str):
108+
# return False
109+
# return True
109110

110111
def _get_valid_plot_types(plottype=None):
111112

src/mplfinance/plotting.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,8 @@ def plot( data, **kwargs ):
752752

753753
elif not _list_of_dict(addplot):
754754
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))
755+
756+
contains_legend_label=[] # a list of axes that contains legend labels
755757

756758
for apdict in addplot:
757759

@@ -779,6 +781,10 @@ def plot( data, **kwargs ):
779781
ydata = apdata.loc[:,column] if havedf else column
780782
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
781783
_addplot_apply_supplements(ax,apdict,xdates)
784+
if apdict["label"]: # not supported for aptype == 'ohlc' or 'candle'
785+
contains_legend_label.append(ax)
786+
for ax in set(contains_legend_label): # there will be duplicates,
787+
ax.legend() # but its ok to call ax.legend() multiple times
782788

783789
# fill_between is NOT supported for external_axes_mode
784790
# (caller can easily call ax.fill_between() themselves).
@@ -1088,46 +1094,38 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
10881094
ax = apdict['ax']
10891095

10901096
aptype = apdict['type']
1097+
label = apdict['label']
10911098
if aptype == 'scatter':
10921099
size = apdict['markersize']
10931100
mark = apdict['marker']
10941101
color = apdict['color']
10951102
alpha = apdict['alpha']
1096-
labels = apdict['labels']
10971103
edgecolors = apdict['edgecolors']
10981104
linewidths = apdict['linewidths']
10991105

11001106
if isinstance(mark,(list,tuple,np.ndarray)):
11011107
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1102-
else:
1103-
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1104-
if labels is not None:
1105-
ax.legend(labels=labels)
1108+
else:
1109+
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
11061110
elif aptype == 'bar':
11071111
width = 0.8 if apdict['width'] is None else apdict['width']
11081112
bottom = apdict['bottom']
11091113
color = apdict['color']
11101114
alpha = apdict['alpha']
1111-
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
1112-
if apdict['labels'] is not None:
1113-
ax.legend(labels=apdict['labels'])
1115+
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
11141116
elif aptype == 'line':
11151117
ls = apdict['linestyle']
11161118
color = apdict['color']
11171119
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11181120
alpha = apdict['alpha']
1119-
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1120-
if apdict['labels'] is not None:
1121-
ax.legend(labels=apdict['labels'])
1121+
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11221122
elif aptype == 'step':
11231123
stepwhere = apdict['stepwhere']
11241124
ls = apdict['linestyle']
11251125
color = apdict['color']
11261126
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11271127
alpha = apdict['alpha']
1128-
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1129-
if apdict['labels'] is not None:
1130-
ax.legend(labels=apdict['labels'])
1128+
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11311129
else:
11321130
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
11331131

@@ -1380,8 +1378,8 @@ def _valid_addplot_kwargs():
13801378
'fill_between': { 'Default' : None, # added by Wen
13811379
'Description' : " fill region",
13821380
'Validator' : _fill_between_validator },
1383-
"labels" : { 'Default' : None,
1384-
'Description' : 'Labels for the added plot.',
1381+
"label" : { 'Default' : None,
1382+
'Description' : 'Label for the added plot. One per added plot.',
13851383
'Validator' : _label_validator },
13861384

13871385
}

0 commit comments

Comments
 (0)