-
Notifications
You must be signed in to change notification settings - Fork 257
PatchTSMixer模型迁移 #1611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PatchTSMixer模型迁移 #1611
Conversation
@@ -0,0 +1,20 @@ | |||
# Copyright 2023 The HuggingFace Team. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto确认过了吗,是不是ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在都加上了
from ...test_configuration_common import ConfigTester | ||
from ...test_modeling_common import ModelTesterMixin,floats_tensor,ids_tensor | ||
TOLERANCE = 1e-4 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过的截图附一下
Parameters: | ||
data (`mindspore.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): | ||
input for Batch norm calculation | ||
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.XXTensor改掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""PyTorch PatchTSMixer model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytorch -> MindSpore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改
config = PatchTSMixerConfig(**self.__class__.params) | ||
enc = PatchTSMixerEncoder(config) | ||
output = enc(self.__class__.enc_data) | ||
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RUN_SLOW测试结果也要附上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clamp_min接口改为clamp(min=)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase最新代码哈 |
6个用例挂了,自己修复掉,用mindspore2.4 daily包测试 |
No description provided.