Skip to content

Commit 2169479

Browse files
authored
add save repo and file
add save repo and file
2 parents da7d0ba + f7a96cb commit 2169479

18 files changed

+243
-51
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
SHELL=/bin/bash
2-
PROJECT_NAME=openplugin
2+
PROJECT_NAME=huggingface_tool
33
PROJECT_PATH=${PROJECT_NAME}/
4-
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} tests examples -type f -name "*.py")
4+
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} tests -type f -name "*.py")
55

66
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
77
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ Tools for loading, upload, managing huggingface models and datasets
1515

1616
Firstly, you need to login with `huggingface-cli login` (you can create or find your token at [settings](https://huggingface.co/settings/tokens)).
1717

18+
- Download and save a repo with: `htool save-repo <repo_id> <save_dir> -r <model/dataset>`. `-r` means the repo is a model or dataset repo. By default, it is a model repo.
19+
- For example: `htool save-repo OpenRL/tizero ./tizero`
20+
- For example: `htool save-repo OpenRL/DeepFakeFace ./DeepFakeFace -r dataset`
21+
- Download and save a file with: `htool save-file <repo_id>:<remote_filepath> <save_dir> -r <model/dataset>`. `-r` means the repo is a model or dataset repo. By default, it is a model repo.
22+
- For example: `htool save-file OpenRL/tizero:actor.pt ./tizero`
23+
- For example: `htool save-file OpenRL/DeepFakeFace:README.md ./DeepFakeFace -r dataset`
1824
- Download and save transformer models with: `htool save-model <model_class> <model_name> <save_dir>`
1925
- For example: `htool save-model AutoModelForCausalLM gpt2 ./gpt2`
2026
- Download and save tokenizer with: `htool save-tk <tokenizer_name> <save_dir>`

huggingface_tool/cli/cli.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
""""""
18-
import os
18+
1919

2020
import click
2121
from click.core import Context, Option
@@ -30,9 +30,9 @@ def red(text: str):
3030

3131

3232
def print_version(
33-
ctx: Context,
34-
param: Option,
35-
value: bool,
33+
ctx: Context,
34+
param: Option,
35+
value: bool,
3636
) -> None:
3737
if not value or ctx.resilient_parsing:
3838
return
@@ -42,9 +42,9 @@ def print_version(
4242

4343

4444
def print_system_info(
45-
ctx: Context,
46-
param: Option,
47-
value: bool,
45+
ctx: Context,
46+
param: Option,
47+
value: bool,
4848
) -> None:
4949
if not value or ctx.resilient_parsing:
5050
return
@@ -88,18 +88,70 @@ def cli(ctx):
8888
@click.argument("save_dir")
8989
def save_dm(model_name, save_dir):
9090
from huggingface_tool.savers.diffusion_model_saver import DiffusionModelSaver
91+
9192
saver = DiffusionModelSaver(model_name)
9293
if saver.load():
9394
saver.save(save_dir)
9495
else:
9596
saver.logger.info("Model not found")
9697

9798

99+
@cli.command()
100+
@click.argument("file_name")
101+
@click.argument("save_dir")
102+
@click.option(
103+
"--repo_type",
104+
"-r",
105+
type=click.Choice(
106+
[
107+
"model",
108+
"dataset",
109+
]
110+
),
111+
default="model",
112+
help="repo type",
113+
)
114+
def save_file(file_name, save_dir, repo_type):
115+
from huggingface_tool.savers.file_saver import FileSaver
116+
117+
saver = FileSaver(file_name, repo_type)
118+
if saver.load():
119+
saver.save(save_dir)
120+
else:
121+
saver.logger.info("File not found")
122+
123+
124+
@cli.command()
125+
@click.argument("repo_name")
126+
@click.argument("save_dir")
127+
@click.option(
128+
"--repo_type",
129+
"-r",
130+
type=click.Choice(
131+
[
132+
"model",
133+
"dataset",
134+
]
135+
),
136+
default="model",
137+
help="repo type",
138+
)
139+
def save_repo(repo_name, save_dir, repo_type):
140+
from huggingface_tool.savers.repo_saver import RepoSaver
141+
142+
saver = RepoSaver(repo_name, repo_type)
143+
if saver.load():
144+
saver.save(save_dir)
145+
else:
146+
saver.logger.info("Repo not found")
147+
148+
98149
@cli.command()
99150
@click.argument("tokenizer_name")
100151
@click.argument("save_dir")
101152
def save_tk(tokenizer_name, save_dir):
102153
from huggingface_tool.savers.tokenizer_saver import TokenizerSaver
154+
103155
saver = TokenizerSaver(tokenizer_name)
104156
if saver.load():
105157
saver.save(save_dir)
@@ -112,6 +164,7 @@ def save_tk(tokenizer_name, save_dir):
112164
@click.argument("save_dir")
113165
def save_data(dataset_name, save_dir):
114166
from huggingface_tool.savers.dataset_saver import DatasetSaver
167+
115168
saver = DatasetSaver(dataset_name)
116169
if saver.load():
117170
saver.save(save_dir)
@@ -125,28 +178,33 @@ def save_data(dataset_name, save_dir):
125178
@click.argument("save_dir")
126179
def save_model(model_class, model_name, save_dir):
127180
from huggingface_tool.savers.model_saver import ModelSaver
181+
128182
saver = ModelSaver(model_class, model_name)
129183
if saver.load():
130184
saver.save(save_dir)
131185
else:
132186
saver.logger.info("Dataset not found")
133187

188+
134189
@cli.command()
135190
@click.argument("dataset_dir")
136191
@click.argument("dataset_name")
137192
def upload_data(dataset_dir, dataset_name):
138193
from huggingface_tool.uploaders.dataset_uploader import DatasetUploader
194+
139195
uploader = DatasetUploader(dataset_dir, dataset_name)
140196
if uploader.check():
141197
uploader.push()
142198
else:
143199
uploader.logger.info("Dataset not valid")
144200

201+
145202
@cli.command()
146203
@click.argument("model_dir")
147204
@click.argument("model_name")
148205
def upload_model(model_dir, model_name):
149206
from huggingface_tool.uploaders.model_uploader import ModelUploader
207+
150208
uploader = ModelUploader(model_dir, model_name)
151209
if uploader.check():
152210
uploader.push()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
from abc import ABC
19+
20+
from huggingface_tool.savers.base_saver import BaseSaver
21+
22+
23+
class BaseAPISaver(BaseSaver, ABC):
24+
def __init__(self, name: str, repo_type: str):
25+
super().__init__(name)
26+
self.repo_type = repo_type

huggingface_tool/savers/base_model_saver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
""""""
1818
from abc import ABC
19+
1920
from huggingface_tool.savers.base_saver import BaseSaver
2021

21-
class BaseModelSaver(BaseSaver,ABC):
22+
23+
class BaseModelSaver(BaseSaver, ABC):
2224
def save(self, save_dir: str):
2325
if self.loaded_object is None:
2426
self.logger.info("No model loaded, cannot save")
2527
return
2628
self.loaded_object.save_pretrained(save_dir)
27-

huggingface_tool/savers/base_saver.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
from huggingface_tool.utils.logger import Logger
2222

23+
2324
class BaseSaver(ABC):
24-
def __init__(self, name:str):
25+
def __init__(self, name: str):
2526
self.logger = Logger()
2627
self.name = name
2728
self.loaded_object = None
2829

29-
30-
3130
def load(self) -> bool:
3231
try:
3332
self.loaded_object = self._load(self.name)
@@ -37,9 +36,9 @@ def load(self) -> bool:
3736
return True
3837

3938
@abstractmethod
40-
def _load(self,name:str):
39+
def _load(self, name: str):
4140
raise NotImplementedError
4241

4342
@abstractmethod
44-
def save(self,save_dir: str):
43+
def save(self, save_dir: str):
4544
raise NotImplementedError

huggingface_tool/savers/dataset_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
from huggingface_tool.savers.base_saver import BaseSaver
2121

22+
2223
class DatasetSaver(BaseSaver):
23-
def _load(self,name:str):
24+
def _load(self, name: str):
2425
return datasets.load_dataset(name)
2526

2627
def save(self, save_dir: str):
2728
if self.loaded_object is None:
2829
self.logger.info("No dataset loaded, cannot save")
2930
return
3031
self.loaded_object.save_to_disk(save_dir)
31-

huggingface_tool/savers/diffusion_model_saver.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
# limitations under the License.
1616

1717
""""""
18-
from huggingface_tool.savers.base_model_saver import BaseModelSaver
19-
2018
import torch
2119
from diffusers import StableDiffusionPipeline
2220

21+
from huggingface_tool.savers.base_model_saver import BaseModelSaver
22+
23+
2324
class DiffusionModelSaver(BaseModelSaver):
24-
def _load(self, name)->bool:
25+
def _load(self, name) -> bool:
2526
return StableDiffusionPipeline.from_pretrained(name, torch_dtype=torch.float16)

huggingface_tool/savers/file_saver.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
from pathlib import Path
19+
20+
from huggingface_hub import hf_hub_download
21+
22+
from huggingface_tool.savers.base_api_saver import BaseAPISaver
23+
24+
25+
class FileSaver(BaseAPISaver):
26+
def _load(self, name: str):
27+
split_string = name.split(":")
28+
assert len(split_string) == 2
29+
repo_id = split_string[0]
30+
file = split_string[1]
31+
repo_type = self.repo_type if self.repo_type == "dataset" else None
32+
return {"repo_id": repo_id, "file": file, "repo_type": repo_type}
33+
34+
def save(self, name):
35+
local_dir = Path(name)
36+
local_dir.mkdir(parents=True, exist_ok=True)
37+
hf_hub_download(
38+
repo_id=self.loaded_object["repo_id"],
39+
filename=self.loaded_object["file"],
40+
repo_type=self.loaded_object["repo_type"],
41+
local_dir=local_dir,
42+
)

huggingface_tool/savers/model_saver.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,27 @@
2222
model_class_dict = {
2323
"AutoModelForSeq2SeqLM": transformers.AutoModelForSeq2SeqLM,
2424
"AutoModelForCausalLM": transformers.AutoModelForCausalLM,
25-
"AutoModelForSequenceClassification": transformers.AutoModelForSequenceClassification,
25+
"AutoModelForSequenceClassification": (
26+
transformers.AutoModelForSequenceClassification
27+
),
2628
"AutoModelForQuestionAnswering": transformers.AutoModelForQuestionAnswering,
2729
"AutoModelForTokenClassification": transformers.AutoModelForTokenClassification,
2830
"AutoModelForMultipleChoice": transformers.AutoModelForMultipleChoice,
29-
"AutoModelForNextSentencePrediction": transformers.AutoModelForNextSentencePrediction,
31+
"AutoModelForNextSentencePrediction": (
32+
transformers.AutoModelForNextSentencePrediction
33+
),
3034
"AutoModelForPreTraining": transformers.AutoModelForPreTraining,
3135
"AutoModelForMaskedLM": transformers.AutoModelForMaskedLM,
32-
"AutoModelForTableQuestionAnswering": transformers.AutoModelForTableQuestionAnswering,
36+
"AutoModelForTableQuestionAnswering": (
37+
transformers.AutoModelForTableQuestionAnswering
38+
),
3339
}
3440

41+
3542
class ModelSaver(BaseModelSaver):
36-
def __init__(self, model_class:str, name:str):
43+
def __init__(self, model_class: str, name: str):
3744
super().__init__(name)
3845
self.model_class = model_class
3946

40-
def _load(self,name):
41-
return model_class_dict[self.model_class].from_pretrained(name)
47+
def _load(self, name):
48+
return model_class_dict[self.model_class].from_pretrained(name)

huggingface_tool/savers/repo_saver.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
from pathlib import Path
20+
21+
from huggingface_hub import snapshot_download
22+
23+
from huggingface_tool.savers.base_api_saver import BaseAPISaver
24+
25+
26+
class RepoSaver(BaseAPISaver):
27+
def _load(self, name: str):
28+
repo_id = name
29+
repo_type = self.repo_type if self.repo_type == "dataset" else None
30+
return {"repo_id": repo_id, "repo_type": repo_type}
31+
32+
def save(self, name):
33+
local_dir = Path(name)
34+
snapshot_download(
35+
repo_id=self.loaded_object["repo_id"],
36+
repo_type=self.loaded_object["repo_type"],
37+
local_dir=local_dir,
38+
)

0 commit comments

Comments
 (0)