Skip to content

Commit 8e3c3c0

Browse files
authored
fix bug in pt/translation example (#1128)
1 parent 2234981 commit 8e3c3c0

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_benchmark.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ function init_params {
5252

5353
# run_benchmark
5454
function run_benchmark {
55-
extra_cmd=''
56-
55+
extra_cmd='None'
5756
if [[ ${mode} == "accuracy" ]]; then
5857
mode_cmd=" --accuracy_only"
5958
elif [[ ${mode} == "benchmark" ]]; then
@@ -65,10 +64,9 @@ function run_benchmark {
6564

6665
if [ "${topology}" = "t5_WMT_en_ro" ];then
6766
model_name_or_path='t5-small'
68-
extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en"
67+
extra_cmd='translate English to Romanian: '
6968
elif [ "${topology}" = "marianmt_WMT_en_ro" ]; then
7069
model_name_or_path='Helsinki-NLP/opus-mt-en-ro'
71-
extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en"
7270
fi
7371

7472
if [[ ${int8} == "true" ]]; then
@@ -82,9 +80,12 @@ function run_benchmark {
8280
--predict_with_generate \
8381
--per_device_eval_batch_size ${batch_size} \
8482
--output_dir ${tuned_checkpoint} \
85-
--source_prefix "translate English to Romanian: " \
83+
--source_lang en \
84+
--target_lang ro \
85+
--dataset_name wmt16 \
86+
--dataset_config_name ro-en\
8687
${mode_cmd} \
87-
${extra_cmd}
88+
--source_prefix "$extra_cmd"
8889
}
8990

9091
main "$@"

examples/pytorch/nlp/huggingface_models/translation/quantization/ptq_dynamic/eager/run_tuning.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,17 @@ function init_params {
3737

3838
# run_tuning
3939
function run_tuning {
40-
extra_cmd=''
40+
extra_cmd='None'
4141
batch_size=16
4242
model_type='bert'
4343

4444
if [ "${topology}" = "t5_WMT_en_ro" ];then
4545
model_name_or_path='t5-small'
4646
model_type='t5'
47-
extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en"
47+
extra_cmd='translate English to Romanian: '
4848
elif [ "${topology}" = "marianmt_WMT_en_ro" ]; then
4949
model_name_or_path='Helsinki-NLP/opus-mt-en-ro'
5050
model_type='marianmt'
51-
extra_cmd="--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en"
5251
fi
5352

5453
sed -i "/: bert/s|name:.*|name: $model_type|g" conf.yaml
@@ -61,10 +60,13 @@ function run_tuning {
6160
--predict_with_generate \
6261
--per_device_eval_batch_size ${batch_size} \
6362
--output_dir ${tuned_checkpoint} \
64-
--source_prefix "translate English to Romanian: " \
63+
--source_lang en \
64+
--target_lang ro \
65+
--dataset_name wmt16 \
66+
--dataset_config_name ro-en\
6567
--tune \
6668
--overwrite_output_dir \
67-
$extra_cmd
69+
--source_prefix "$extra_cmd"
6870
}
6971

7072
main "$@"

0 commit comments

Comments
 (0)