Replies: 3 comments 2 replies
-
您好,我这边没有复现出这个报错,我的环境如下:
请您参考Issue: Bug report模板给出复现环境及步骤: Describe the bug(问题描述) To Reproduce(复现步骤)
Operating environment(运行环境):
Additional context |
Beta Was this translation helpful? Give feedback.
-
哦我可能知道了,您是直接用的 文档里这个只是个示例,实际使用时是需要输入linear_feature_columns和dnn_feature_columns的,您可以参考下run_classification_criteo.py里初始化模型的用法。 |
Beta Was this translation helpful? Give feedback.
-
感谢我已按照格式更新 - 跟随着run_classification_criteo.py exmaple 同样还是报错 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
首先非常感谢这个deepctr torch这个package 可以非常快速的试各种model 但是我在读取train好的model 使用predict这个function会报错
Operating environment(运行环境):
python version 3.8.8
torch version 1.8.1
deepctr-torch version 0.2.7
请您参考Issue: Bug report模板给出复现环境及步骤:
Describe the bug(问题描述)
使用读取存储的模型用predict这个function的时候会有error
具体的error 信息:
NotImplementedError Traceback (most recent call last)
in
----> 1 reload_model.predict(train_model_input)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/models/basemodel.py in predict(self, x, batch_size)
340 x = x_test[0].to(self.device).float()
341
--> 342 y_pred = model(x).cpu().data.numpy() # .squeeze()
343 pred_ans.append(y_pred)
344
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/models/deepfm.py in forward(self, X)
76
77 if self.use_dnn:
---> 78 dnn_input = combined_dnn_input(
79 sparse_embedding_list, dense_value_list)
80 dnn_output = self.dnn(dnn_input)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/inputs.py in combined_dnn_input(sparse_embedding_list, dense_value_list)
136 return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1)
137 else:
--> 138 raise NotImplementedError
139
140
NotImplementedError:
To Reproduce(复现步骤)
跟着run_classification_criteo.py的example
在这之后加入save/load model的步骤 以下是在example code
新增加的那部分,其他都保持一致
torch.save(model, "test.h5")
reload_model = torch.load("test.h5")
reload_model.predict(train_model_input) ## 此处报错
Additional context
同时也试过
都是同样的报错
Beta Was this translation helpful? Give feedback.
All reactions