Skip to content

GTNNWR模型维度报错 #13

@wxha1994qy

Description

@wxha1994qy

在python3.7+torch1.13.1+cu117环境下,我使用gnnwr0.010版本完全按照“https://mydde.deep-time.org/org-portal/MyDDE/project/64f976f12a7491e657f02793”的例子运行了[gtnnwr]模型,但是无法使用GPU参加计算。于是我使用了最新的GNNWR模型(0.1.16),输入条件与以上链接相同,但是在forward方法中的x = torch.reshape(x, shape=(batch * height, x.shape【2】))里报错tuple index out of range,猜测是x少了一个维度,但是输入和条件与例子完全一致。我的python和PyTorch版本:python.8+torch2.4.1+cu121,是模型的参数输入除了什么问题吗?求大神指正。。。以下是我的输入条件:
import numpy as np
import pandas as pd
from gnnwr.datasets import init_dataset
from gnnwr.models import GTNNWR

data = pd.read_csv('E:\03code\my_project\gitHub_code\huigui\GNNWR\data\demo_data_gtnnwr.csv')
data["id"] = np.arange(len(data))
train_dataset, val_dataset, test_dataset = init_dataset(data=data,
test_ratio=0.15,
valid_ratio=0.1,
x_column=['refl_b01', 'refl_b02',
'refl_b03','refl_b04','refl_b05',
'refl_b07'],
y_column=['SiO3'],
spatial_column=['proj_x', 'proj_y'],
temp_column=['day'],
id_column=['id'],
sample_seed=48,
batch_size=128)
gtnnwr = GTNNWR(train_dataset, val_dataset, test_dataset, [[3], [2048, 512, 256,32]],optimizer='Adam')

gtnnwr.add_graph()

gtnnwr.run(50000,8000)
gtnnwr.result()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions