Skip to content

Commit a053564

Browse files
authored
Merge pull request #612 from Sichao25/pred
Debug dynast.py
2 parents c99d0c3 + 152a9db commit a053564

File tree

1 file changed

+49
-43
lines changed

1 file changed

+49
-43
lines changed

dynamo/preprocessing/dynast.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@ def lambda_correction(
1919
"""Use lambda (cell-wise detection rate) to estimate the labelled RNA.
2020
2121
Args:
22-
adata: an adata object generated from dynast.
23-
lambda_key: the key to the cell-wise detection rate. Defaults to "lambda".
24-
inplace: whether to inplace update the layers. If False, new layers that append '_corrected" to the existing
22+
adata: An adata object generated from dynast.
23+
lambda_key: The key to the cell-wise detection rate. Defaults to "lambda".
24+
inplace: Whether to inplace update the layers. If False, new layers that append '_corrected" to the existing
2525
will be used to store the updated data. Defaults to True.
26-
copy: whether to copy the adata object or update adata object inplace. Defaults to False.
26+
copy: Whether to copy the adata object or update adata object inplace. Defaults to False.
2727
2828
Raises:
29-
ValueError: the `lambda_key` cannot be found in `adata.obs`
29+
ValueError: The `lambda_key` cannot be found in `adata.obs`.
30+
ValueError: The adata object has to include labeling layers.
3031
ValueError: `data_type` is set to 'splicing_labeling' but the existing layers in the adata object don't meet the
3132
requirements.
3233
ValueError: `data_type` is set to 'labeling' but the existing layers in the adata object don't meet the
3334
requirements.
35+
3436
Returns:
3537
A new AnnData object that are updated with lambda corrected layers if `copy` is true. Otherwise, return None.
3638
"""
@@ -54,18 +56,22 @@ def lambda_correction(
5456
logger.info("identify the data type..", indent_level=1)
5557
all_layers = adata.layers.keys()
5658

57-
has_ul = np.any([i.contains("ul_") for i in all_layers])
58-
has_un = np.any([i.contains("un_") for i in all_layers])
59-
has_sl = np.any([i.contains("sl_") for i in all_layers])
60-
has_sn = np.any([i.contains("sn_") for i in all_layers])
59+
has_ul = np.any(["ul_" in i for i in all_layers])
60+
has_un = np.any(["un_" in i for i in all_layers])
61+
has_sl = np.any(["sl_" in i for i in all_layers])
62+
has_sn = np.any(["sn_" in i for i in all_layers])
6163

62-
has_l = np.any([i.contains("_l_") for i in all_layers])
63-
has_n = np.any([i.contains("_n_") for i in all_layers])
64+
has_l = np.any(["_l_" in i for i in all_layers])
65+
has_n = np.any(["_n_" in i for i in all_layers])
6466

65-
if sum(has_ul + has_un + has_sl + has_sn) == 4:
67+
if np.count_nonzero([has_ul, has_un, has_sl, has_sn]) == 4:
6668
datatype = "splicing_labeling"
67-
elif sum(has_l + has_n):
69+
elif np.count_nonzero([has_l, has_n]):
6870
datatype = "labeling"
71+
else:
72+
raise ValueError(
73+
"the adata object has to include labeling layers."
74+
)
6975

7076
logger.info(f"the data type identified is {datatype}", indent_level=2)
7177

@@ -74,44 +80,44 @@ def lambda_correction(
7480
layers, match_tot_layer = [], []
7581
for layer in all_layers:
7682
if "ul_" in layer:
77-
layers += layer
78-
match_tot_layer += "unspliced"
83+
layers.append(layer)
84+
match_tot_layer.append("unspliced")
7985
elif "un_" in layer:
80-
layers += layer
81-
match_tot_layer += "unspliced"
86+
layers.append(layer)
87+
match_tot_layer.append("unspliced")
8288
elif "sl_" in layer:
83-
layers += layer
84-
match_tot_layer += "spliced"
89+
layers.append(layer)
90+
match_tot_layer.append("spliced")
8591
elif "sn_" in layer:
86-
layers += layer
87-
match_tot_layer += "spliced"
92+
layers.append(layer)
93+
match_tot_layer.append("spliced")
8894
elif "spliced" in layer:
89-
layers += layer
95+
layers.append(layer)
9096
elif "unspliced" in layer:
91-
layers += layer
97+
layers.append(layer)
9298

93-
if len(layers) != 6:
94-
raise ValueError(
95-
"the adata object has to include ul, un, sl, sn, unspliced, spliced, "
96-
"six relevant layers for splicing and labeling quantified datasets."
97-
)
99+
if len(layers) != 6:
100+
raise ValueError(
101+
"the adata object has to include ul, un, sl, sn, unspliced, spliced, "
102+
"six relevant layers for splicing and labeling quantified datasets."
103+
)
98104
elif datatype == "labeling":
99105
layers, match_tot_layer = [], []
100106
for layer in all_layers:
101107
if "_l_" in layer:
102-
layers += layer
103-
match_tot_layer += ["total"]
108+
layers.append(layer)
109+
match_tot_layer.append("total")
104110
elif "_n_" in layer:
105-
layers += layer
106-
match_tot_layer += ["total"]
111+
layers.append(layer)
112+
match_tot_layer.append("total")
107113
elif "total" in layer:
108-
layers += layer
114+
layers.append(layer)
109115

110-
if len(layers) != 3:
111-
raise ValueError(
112-
"the adata object has to include labeled, unlabeled, three relevant layers for labeling quantified "
113-
"datasets."
114-
)
116+
if len(layers) != 3:
117+
raise ValueError(
118+
"the adata object has to include labeled, unlabeled, three relevant layers for labeling quantified "
119+
"datasets."
120+
)
115121

116122
logger.info("detection rate correction starts", indent_level=1)
117123
for i, layer in enumerate(main_tqdm(layers, desc="iterating all relevant layers")):
@@ -133,9 +139,9 @@ def lambda_correction(
133139

134140
else:
135141
if inplace:
136-
adata.layers[layer] = cur_total - adata.layers[layer[i - 1]]
142+
adata.layers[layer] = cur_total - adata.layers[layers[i - 1]]
137143
else:
138-
adata.layers[layer + "_corrected"] = cur_total - adata.layers[layer[i - 1]]
144+
adata.layers[layer + "_corrected"] = cur_total - adata.layers[layers[i - 1]]
139145

140146
logger.finish_progress(progress_name="lambda_correction")
141147

@@ -148,15 +154,15 @@ def sparse_mimmax(A: csr_matrix, B: csr_matrix, type="min") -> csr_matrix:
148154
"""Return the element-wise minimum/maximum of sparse matrices `A` and `B`.
149155
150156
Args:
151-
A: The first sparse matrix
152-
B: The second sparse matrix
157+
A: The first sparse matrix.
158+
B: The second sparse matrix.
153159
type: The type of calculation, either "min" or "max". Defaults to "min".
154160
155161
Returns:
156162
A sparse matrix that contain the element-wise maximal or minimal of two sparse matrices.
157163
"""
158164

159165
AgtB = (A < B).astype(int) if type == "min" else (A > B).astype(int)
160-
M = AgtB.multiply(A - B) + B
166+
M = np.multiply(AgtB, A - B) + B
161167

162168
return M

0 commit comments

Comments
 (0)