-
Notifications
You must be signed in to change notification settings - Fork 18
Description
Expected Behavior:
When initializing the GraphST class with an adata without being preprocessed, the preprocessing step should calculate the top dim_input highly variable genes to match the model's expected input dimensionality.
Actual Behavior:
The preprocess function hardcodes the n_top_genes parameter to 3000, regardless of the dim_input value provided. This results in the model being initialized with a dim_input that doesn't correspond to the number of highly variable genes selected, leading to a mismatch in dimensions.
Relevant Code Snippet:
def preprocess(adata):
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, zero_center=False, max_value=10)
When I tried to assign dim_input = 1500:
model = GraphST.GraphST(adata, datatype='Stereo', device=torch.device('cuda:0'), dim_input=1500)
I also got 3000 HVGs:
'model.dim_input
3000
model.adata.obsm['feat'].shape
(21103, 3000)`
Proposed Solution:
Modify the preprocess function to use the dim_input parameter when selecting the top highly variable genes. Here's the corrected code:
def preprocess(adata, dim_input):
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=dim_input)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, zero_center=False, max_value=10)