Skip to content

Commit 90c7887

Browse files
authored
use GE backend for ms2.3 (#785) (#795)
1 parent 2a07185 commit 90c7887

File tree

8 files changed

+16
-0
lines changed

8 files changed

+16
-0
lines changed

examples/clip/clip/clip.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def load(name: str, device: str = "Ascend", mode: int = 1, download_root: str =
126126
take as its input.
127127
"""
128128
ms.set_context(device_target=device, mode=mode)
129+
if mode == ms.GRAPH_MODE:
130+
ms.set_context(jit_config={"jit_level": "O2"})
129131
if name in _MODELS:
130132
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
131133
ckp_dict = load_checkpoint(model_path)

examples/det/ssd/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def train(args):
2222
"""main train function"""
2323

2424
ms.set_context(mode=args.mode)
25+
if args.mode == ms.GRAPH_MODE:
26+
ms.set_context(jit_config={"jit_level": "O2"})
2527

2628
if args.distribute:
2729
init()

examples/finetune/finetune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def finetune_train(args):
3030
"""main train function"""
3131

3232
ms.set_context(mode=args.mode)
33+
if args.mode == ms.GRAPH_MODE:
34+
ms.set_context(jit_config={"jit_level": "O2"})
3335
if args.distribute:
3436
init()
3537
device_num = get_group_size()

examples/open_clip/test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def parse_args(args):
5858
def main(args):
5959
args = parse_args(args)
6060
ms.set_context(device_target=args.device_target, mode=args.mode)
61+
if args.mode == ms.GRAPH_MODE:
62+
ms.set_context(jit_config={"jit_level": "O2"})
6163
model, preprocess_train, preprocess_val = create_model_and_transforms(
6264
args.model_name,
6365
args.pretrained,

examples/seg/deeplabv3/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
def train(args):
2525
"""main train function"""
2626
ms.set_context(mode=args.mode)
27+
if args.mode == ms.GRAPH_MODE:
28+
ms.set_context(jit_config={"jit_level": "O2"})
2729

2830
if args.distribute:
2931
init()

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
def main():
3030
args = parse_args()
3131
ms.set_context(mode=args.mode)
32+
if args.mode == ms.GRAPH_MODE:
33+
ms.set_context(jit_config={"jit_level": "O2"})
3234
if args.distribute:
3335
init()
3436
rank_id, device_num = get_rank(), get_group_size()

train_with_func.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def main():
5858
args = parse_args()
5959
args = check_args(args)
6060
ms.set_context(mode=args.mode)
61+
if args.mode == ms.GRAPH_MODE:
62+
ms.set_context(jit_config={"jit_level": "O2"})
6163
if args.distribute:
6264
init()
6365
rank_id, device_num = get_rank(), get_group_size()

validate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def check_batch_size(num_samples, ori_batch_size=32, refine=True):
2727

2828
def validate(args):
2929
ms.set_context(mode=args.mode)
30+
if args.mode == ms.GRAPH_MODE:
31+
ms.set_context(jit_config={"jit_level": "O2"})
3032

3133
# create dataset
3234
dataset_eval = create_dataset(

0 commit comments

Comments
 (0)