Skip to content

Commit 48da00a

Browse files
committed
add grid parameter to scatter
1 parent 2aa3416 commit 48da00a

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

pca/pca.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,19 @@ class pca:
7676
* Documentation: https://erdogant.github.io/pca/
7777
7878
"""
79+
def __init__(self,
80+
n_components=0.95,
81+
n_feat=25,
82+
method='pca',
83+
alpha=0.05,
84+
multipletests='fdr_bh',
85+
n_std=3,
86+
onehot=False,
87+
normalize=False,
88+
detect_outliers=['ht2', 'spe'],
89+
random_state=None,
90+
verbose=3):
7991

80-
def __init__(self, n_components=0.95, n_feat=25, method='pca', alpha=0.05, multipletests='fdr_bh', n_std=3, onehot=False, normalize=False, detect_outliers=['ht2', 'spe'], random_state=None, verbose=3):
8192
"""Initialize pca with user-defined parameters."""
8293
if isinstance(detect_outliers, str): detect_outliers = [detect_outliers]
8394
if detect_outliers is not None: detect_outliers=list(map(str.lower, detect_outliers))
@@ -543,6 +554,7 @@ def scatter(self,
543554
visible=True,
544555
fig=None,
545556
ax=None,
557+
grid=True,
546558
y=None, # deprecated
547559
label=None, # deprecated
548560
verbose=3):
@@ -658,13 +670,15 @@ def scatter(self,
658670
marker=marker,
659671
jitter=jitter,
660672
density=density,
673+
opaque_type='per_class',
661674
density_on_top=density_on_top,
662675
gradient=gradient,
663676
cmap=cmap,
664677
legend=legend,
665678
fontcolor=fontcolor,
666679
fontsize=fontsize,
667680
fontweight=fontweight,
681+
grid=grid,
668682
dpi=dpi,
669683
figsize=figsize,
670684
visible=visible,
@@ -708,6 +722,7 @@ def biplot(self,
708722
fig=None,
709723
ax=None,
710724
dpi=100,
725+
grid=True,
711726
y=None, # deprecated
712727
label=None, # deprecated
713728
verbose=None):
@@ -839,7 +854,7 @@ def biplot(self,
839854
# Pre-processing
840855
labels, topfeat, n_feat = self._fig_preprocessing(labels, n_feat, d3)
841856
# Scatterplot
842-
fig, ax = self.scatter(labels=labels, legend=legend, PC=PC, SPE=SPE, HT2=HT2, cmap=cmap, visible=visible, figsize=figsize, alpha=alpha, title=title, gradient=gradient, fig=fig, ax=ax, c=c, s=s, jitter=jitter, marker=marker, fontcolor=fontcolor, fontweight=fontweight, fontsize=fontsize, edgecolor=edgecolor, density=density, density_on_top=density_on_top, dpi=dpi, verbose=verbose)
857+
fig, ax = self.scatter(labels=labels, legend=legend, PC=PC, SPE=SPE, HT2=HT2, cmap=cmap, visible=visible, figsize=figsize, alpha=alpha, title=title, gradient=gradient, fig=fig, ax=ax, c=c, s=s, jitter=jitter, marker=marker, fontcolor=fontcolor, fontweight=fontweight, fontsize=fontsize, edgecolor=edgecolor, density=density, density_on_top=density_on_top, dpi=dpi, grid=grid, verbose=verbose)
843858
# Add the loadings with arrow to the plot
844859
fig, ax = _plot_loadings(self, topfeat, n_feat, PC, d3, arrowdict, fig, ax, verbose)
845860
# Plot

0 commit comments

Comments
 (0)