-
Notifications
You must be signed in to change notification settings - Fork 257
[TRL] Provide SFT method and an example #1434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
mindnlp/trl/examples/sft_test.py
Outdated
@@ -0,0 +1,25 @@ | |||
# imports | |||
from mindnlp.dataset import load_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个放到 llm/trl目录里
|
||
# import mindspore.numpy as np | ||
# import mindspore as ms | ||
from mindspore import ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers的部分先不管,代码里不要体现
#管理和控制分布式训练环境下的训练进程 | ||
|
||
import mindspore as ms | ||
from mindspore import nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一改为 from mindnlp.core import nn, ops
后续的接口都用mindnlp.core里的
mindnlp/trl/trainer/sft_trainer.py
Outdated
@@ -0,0 +1,721 @@ | |||
'''Copyright 2023 The HuggingFace Team. All rights reserved. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件放到 mindnlp/engine/trainer里
mindnlp/trl/trainer/sft_trainer.py
Outdated
import mindspore as ms | ||
from mindspore import nn | ||
from mindspore.dataset import Dataset, transforms | ||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,都用mindnlp.core的
mindnlp/trl/trainer/utils.py
Outdated
import mindspore.numpy as np | ||
# import pandas as pd | ||
import mindspore as ms | ||
import mindspore.ops as ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
pylint没过 |
llm/trl/kto_test.py
Outdated
'''Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
别叫test, run_kto
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pylint: disable=C,R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要全忽略
mindnlp/trl/models/modeling_base.py
Outdated
from copy import deepcopy | ||
from typing import Optional | ||
|
||
# import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要有torch的代码
mindnlp/trl/models/modeling_base.py
Outdated
LocalEntryNotFoundError, | ||
RepositoryNotFoundError, | ||
) | ||
# from safetensors.torch import load_file as safe_load_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from mindnlp.core.serialization import safe_load_file
sft_test.py
Outdated
@@ -0,0 +1,27 @@ | |||
''' | |||
This file is an example for sft method. | |||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除
5a74091
to
f52a716
Compare
###PR内容###
提供了SFT方法和所需的配置文件,并在exaples文件夹下提供sft调用的示例。