Skip to content

Commit 523341a

Browse files
committed
[AOTI] Pass device explicitly (pytorch#858)
Followup after pytorch/torchchat#815 to unblock migration to a newer version of PyTorch where AOTI seems to lost ability error out when one attempts to load CPU model on GPU, see https://github.com/pytorch/torchchat/actions/runs/9391753397/job/25913830802 for example Workaround by adding `-d ${DEVICE}` option to `aoti_runner`
1 parent 5ff3222 commit 523341a

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

.github/workflows/hqq-dtype.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
python generate.py --dtype ${DTYPE} --device ${DEVICE} --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
6363
.ci/scripts/check_gibberish ./output_aoti
6464
65-
./cmake-out/aoti_run ${MODEL_DIR}/${MODEL_NAME}.so -z ${TOKENIZER_PATH} -i "${PROMPT}" > ./output_runner_aoti
65+
./cmake-out/aoti_run ${MODEL_DIR}/${MODEL_NAME}.so -d ${DEVICE} -z ${TOKENIZER_PATH} -i "${PROMPT}" > ./output_runner_aoti
6666
cat ./output_runner_aoti
6767
# .ci/scripts/check_gibberish ./output_runner_aoti --no-extract
6868
@@ -77,7 +77,7 @@ jobs:
7777
python generate.py --dtype ${DTYPE} --device ${DEVICE} --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
7878
.ci/scripts/check_gibberish ./output_aoti
7979
80-
./cmake-out/aoti_run ${MODEL_DIR}/${MODEL_NAME}.so -z ${TOKENIZER_PATH} -i "${PROMPT}" > ./output_runner_aoti
80+
./cmake-out/aoti_run ${MODEL_DIR}/${MODEL_NAME}.so -d ${DEVICE} -z ${TOKENIZER_PATH} -i "${PROMPT}" > ./output_runner_aoti
8181
cat ./output_runner_aoti
8282
# .ci/scripts/check_gibberish ./output_runner_aoti --no-extract
8383

.github/workflows/runner-cuda-dtype.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
5959
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-dso-path /tmp/model.so
6060
61-
./cmake-out/aoti_run /tmp/model.so -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
61+
./cmake-out/aoti_run /tmp/model.so -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
6262
6363
echo "**********************************************"
6464
echo "******** INT4 HQQ group-wise quantized *******"

runner/run.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ void build_transformer(
136136

137137
#ifdef __AOTI_MODEL__
138138
#ifdef USE_CUDA
139-
try {
139+
if (aoti_device.type() == torch::kCUDA) {
140140
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
141141
aoti_device = torch::Device(torch::kCUDA);
142-
} catch (std::runtime_error& e) {
142+
} else {
143143
#else
144144
{
145145
#endif
@@ -811,6 +811,7 @@ void error_usage() {
811811
" -v <int> (optional) vocab size, default is model-specific.\n");
812812
fprintf(
813813
stderr, " -l <int> (optional) llama version (2 or 3), default 2.\n");
814+
fprintf(stderr, " -d <string> (optional) device(CUDA or CPU) model was exported for\n");
814815
exit(EXIT_FAILURE);
815816
}
816817

@@ -880,6 +881,20 @@ int main(int argc, char* argv[]) {
880881
system_prompt = argv[i + 1];
881882
} else if (argv[i][1] == 'l') {
882883
llama_ver = atoi(argv[i + 1]);
884+
#ifdef __AOTI_MODEL__
885+
} else if (argv[i][1] == 'd') {
886+
#ifdef USE_CUDA
887+
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
888+
aoti_device = torch::Device(torch::kCUDA);
889+
} else
890+
#endif
891+
if (strcasecmp(argv[i + 1], "CPU") == 0) {
892+
aoti_device = torch::Device(torch::kCPU);
893+
} else {
894+
fprintf(stderr, "Unknown device %s", argv[i + 1]);
895+
exit(1);
896+
}
897+
#endif
883898
} else {
884899
error_usage();
885900
}

0 commit comments

Comments
 (0)