diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 1f4a226933..0eaacea835 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -2,9 +2,6 @@ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" - schedule: - # Check for updates to GitHub Actions every week - interval: "weekly" open-pull-requests-limit: 2 reviewers: - "Yikun" diff --git a/.github/workflows/nightly_benchmarks.yaml b/.github/workflows/nightly_benchmarks.yaml index 2b9c062957..5c94c6a235 100644 --- a/.github/workflows/nightly_benchmarks.yaml +++ b/.github/workflows/nightly_benchmarks.yaml @@ -18,11 +18,7 @@ name: 'Benchmarks / Performance' # This workflow runs nightly benchmarks for vllm-ascend. -on: - schedule: - # Run at 02:00 everyday - - cron: '00 18 * * *' - +on: workflow_dispatch: # Allow manual triggering of the workflow diff --git a/.github/workflows/vllm_ascend_doctest.yaml b/.github/workflows/vllm_ascend_doctest.yaml index 67f98fbaf7..866588807b 100644 --- a/.github/workflows/vllm_ascend_doctest.yaml +++ b/.github/workflows/vllm_ascend_doctest.yaml @@ -29,9 +29,6 @@ on: - 'tests/e2e/doctests/**' - 'tests/e2e/common.sh' - 'tests/e2e/run_doctests.sh' - schedule: - # Runs every 4 hours - - cron: '0 */4 * * *' # Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly # declared as "shell: bash -el {0}" on steps that need to be properly activated. diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 540680dd2f..97d316e755 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -18,8 +18,6 @@ name: 'test' on: - schedule: - - cron: '0 23 * * *' pull_request: branches: - 'main' @@ -44,12 +42,6 @@ defaults: run: shell: bash -el {0} -# only cancel in-progress runs of the same workflow -# and ignore the lint / 1 card / 4 cards test type -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - jobs: lint: runs-on: ubuntu-latest @@ -114,171 +106,32 @@ jobs: echo "::add-matcher::.github/workflows/matchers/mypy.json" tools/mypy.sh 1 ${{ matrix.python-version }} - ut: - needs: [lint] - name: unit test - if: ${{ needs.lint.result == 'success' }} - runs-on: ubuntu-latest - container: - image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 - env: - VLLM_LOGGING_LEVEL: ERROR - VLLM_USE_MODELSCOPE: True - strategy: - matrix: - vllm_version: [main, v0.9.1] - steps: - - name: Install packages - run: | - apt-get update -y - apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev - - - name: Checkout vllm-project/vllm repo - uses: actions/checkout@v4 - with: - repository: vllm-project/vllm - ref: ${{ matrix.vllm_version }} - path: ./vllm-empty - - - name: Install vllm-project/vllm from source - working-directory: ./vllm-empty - run: | - VLLM_TARGET_DEVICE=empty python3 -m pip install . --extra-index https://download.pytorch.org/whl/cpu/ - python3 -m pip uninstall -y triton - - - name: Checkout vllm-project/vllm-ascend repo - uses: actions/checkout@v4 - - - name: Install vllm-project/vllm-ascend - run: | - export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib - python3 -m pip install -r requirements-dev.txt --extra-index https://download.pytorch.org/whl/cpu/ - python3 -m pip install -v . --extra-index https://download.pytorch.org/whl/cpu/ - - - name: Run unit test for V1 Engine - env: - VLLM_USE_V1: 1 - VLLM_WORKER_MULTIPROC_METHOD: spawn - TORCH_DEVICE_BACKEND_AUTOLOAD: 0 - run: | - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib - pytest -sv tests/ut - e2e: needs: [lint] if: ${{ needs.lint.result == 'success' }} strategy: max-parallel: 2 matrix: - os: [linux-arm64-npu-1] - vllm_version: [main, v0.9.1] - name: singlecard e2e test - runs-on: ${{ matrix.os }} - container: - # TODO(yikun): Remove m.daocloud.io prefix when infra proxy ready - image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 - env: - VLLM_LOGGING_LEVEL: ERROR - steps: - - name: Check npu and CANN info - run: | - npu-smi info - cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info - - - name: Config mirrors - run: | - sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list - pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - apt-get update -y - apt install git -y - git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf https://github.com/ - - - name: Checkout vllm-project/vllm-ascend repo - uses: actions/checkout@v4 - - - name: Install system dependencies - run: | - apt-get -y install `cat packages.txt` - apt-get -y install gcc g++ cmake libnuma-dev - - - name: Checkout vllm-project/vllm repo - uses: actions/checkout@v4 - with: - repository: vllm-project/vllm - ref: ${{ matrix.vllm_version }} - path: ./vllm-empty - - - name: Install vllm-project/vllm from source - working-directory: ./vllm-empty - run: | - VLLM_TARGET_DEVICE=empty pip install -e . - - - name: Install vllm-project/vllm-ascend - env: - PIP_EXTRA_INDEX_URL: https://mirrors.huaweicloud.com/ascend/repos/pypi - run: | - pip install -r requirements-dev.txt - pip install -v -e . - - - name: Run e2e test for V1 Engine - env: - VLLM_USE_V1: 1 - VLLM_WORKER_MULTIPROC_METHOD: spawn - VLLM_USE_MODELSCOPE: True - run: | - pytest -sv tests/e2e/singlecard/test_offline_inference.py - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # TODO(sss): guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py - pytest -sv tests/e2e/singlecard/test_camem.py - pytest -sv tests/e2e/singlecard/ \ - --ignore=tests/e2e/singlecard/test_offline_inference.py \ - --ignore=tests/e2e/singlecard/test_ilama_lora.py \ - --ignore=tests/e2e/singlecard/test_guided_decoding.py \ - --ignore=tests/e2e/singlecard/test_camem.py - - - name: Run e2e test on V0 engine - if: ${{ github.event_name == 'schedule' }} - env: - VLLM_USE_V1: 0 - VLLM_USE_MODELSCOPE: True - run: | - pytest -sv tests/e2e/singlecard/test_offline_inference.py - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py - pytest -sv tests/e2e/singlecard/test_camem.py - pytest -sv tests/e2e/singlecard/test_prompt_embedding.py - pytest -sv tests/e2e/singlecard/ \ - --ignore=tests/e2e/singlecard/test_offline_inference.py \ - --ignore=tests/e2e/singlecard/test_ilama_lora.py \ - --ignore=tests/e2e/singlecard/test_guided_decoding.py \ - --ignore=tests/e2e/singlecard/test_camem.py \ - --ignore=tests/e2e/singlecard/test_prompt_embedding.py \ - --ignore=tests/e2e/singlecard/core/test_ascend_scheduler.py \ - --ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py - - e2e-4-cards: - needs: [e2e] - if: ${{ needs.e2e.result == 'success' }} - strategy: - max-parallel: 1 - matrix: - os: [linux-arm64-npu-4] - vllm_version: [main, v0.9.1] - name: multicard e2e test + os: [linux-arm64-npu-1, linux-arm64-npu-4] + vllm_version: [v0.9.1] + concurrency: + group: > + ${{ + matrix.os == 'linux-arm64-npu-4' + && github.event.pull_request.number + && format('pr-{0}-limit-npu-4', github.event.pull_request.number) + || format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_version, github.event.pull_request.number) + }} + cancel-in-progress: false + name: vLLM Ascend test runs-on: ${{ matrix.os }} container: # TODO(yikun): Remove m.daocloud.io prefix when infra proxy ready image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 env: + HF_ENDPOINT: https://hf-mirror.com + HF_TOKEN: ${{ secrets.HF_TOKEN }} VLLM_LOGGING_LEVEL: ERROR - VLLM_USE_MODELSCOPE: True steps: - name: Check npu and CANN info run: | @@ -324,32 +177,64 @@ jobs: env: VLLM_USE_V1: 1 VLLM_WORKER_MULTIPROC_METHOD: spawn - VLLM_USE_MODELSCOPE: True run: | - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py - # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. - # To avoid oom, we need to run the test in a single process. - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then + VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py + # guided decoding doesn't work, fix it later + # pytest -sv tests/singlecard/test_guided_decoding.py.py + # test_ascend_config.py should be ran separately because it will regenerate the global config many times. + pytest -sv tests/singlecard/test_ascend_config.py + pytest -sv tests/singlecard/test_camem.py + pytest -sv tests/singlecard/core/test_ascend_scheduler.py + pytest -sv tests/singlecard/core/test_ascend_scheduler_e2e.py + pytest -sv tests/singlecard/ \ + --ignore=tests/singlecard/test_offline_inference.py \ + --ignore=tests/singlecard/test_guided_decoding.py \ + --ignore=tests/singlecard/test_ascend_config.py \ + --ignore=tests/singlecard/test_camem.py \ + --ignore=tests/singlecard/core/test_ascend_scheduler.py \ + --ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py + else + pytest -sv tests/multicard/test_ilama_lora_tp2.py + # To avoid oom, we need to run the test in a single process. + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_w4a8_deepseek.py::test_deepseek_W4A8 + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py + fi - name: Run vllm-project/vllm-ascend test on V0 engine if: ${{ github.event_name == 'schedule' }} env: VLLM_USE_V1: 0 - VLLM_USE_MODELSCOPE: True run: | - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py - # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. - # To avoid oom, we need to run the test in a single process. - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then + VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py + # guided decoding doesn't work, fix it later + # pytest -sv tests/singlecard/test_guided_decoding.py.py + pytest -sv tests/singlecard/test_camem.py + # test_ascend_config.py should be ran separately because it will regenerate the global config many times. + pytest -sv tests/singlecard/test_ascend_config.py + pytest -sv tests/singlecard/test_prompt_embedding.py + pytest -sv tests/singlecard/ \ + --ignore=tests/singlecard/test_offline_inference.py \ + --ignore=tests/singlecard/test_guided_decoding.py \ + --ignore=tests/singlecard/test_camem.py \ + --ignore=tests/singlecard/test_ascend_config.py \ + --ignore=tests/singlecard/test_prompt_embedding.py \ + --ignore=tests/singlecard/core/test_ascend_scheduler.py \ + --ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py + else + pytest -sv tests/multicard/test_ilama_lora_tp2.py + # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error. + # To avoid oom, we need to run the test in a single process. + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py + fi diff --git a/.github/workflows/vllm_ascend_test_long_term.yaml b/.github/workflows/vllm_ascend_test_long_term.yaml index e249849e19..1b07fbc3be 100644 --- a/.github/workflows/vllm_ascend_test_long_term.yaml +++ b/.github/workflows/vllm_ascend_test_long_term.yaml @@ -17,9 +17,6 @@ name: 'e2e test / long-term-test' on: - schedule: - # Runs at 23:00 UTC (7:00 AM Beijing) every day - - cron: '0 23 * * *' pull_request: types: [ labeled ] @@ -43,7 +40,7 @@ jobs: max-parallel: 2 matrix: os: [linux-arm64-npu-1, linux-arm64-npu-4] - vllm_version: [main, v0.9.1] + vllm_version: [v0.9.1] name: vLLM Ascend long term test runs-on: ${{ matrix.os }} container: @@ -97,13 +94,17 @@ jobs: - name: Run vllm-project/vllm-ascend long term test run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then - # spec decode test - VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py + # v0 spec decode test + # VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process + # pytest -sv tests/long_term/spec_decode_v0 --ignore=tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py + # v1 spec decode test + # TODO: revert me when test_v1_mtp_correctness.py is fixed + VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py # TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed - # VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py - VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process - pytest -sv tests/e2e/long_term/spec_decode --ignore=tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py - pytest -sv tests/e2e/long_term/test_accuracy.py + # VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_spec_decode.py + # accuracy test single card + pytest -sv tests/long_term/test_accuracy.py else - VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py + # accuracy test multi card + VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py fi diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml index 932b3e59b3..015c88c170 100644 --- a/.github/workflows/vllm_ascend_test_pd.yaml +++ b/.github/workflows/vllm_ascend_test_pd.yaml @@ -17,9 +17,6 @@ name: 'e2e test / pd-disaggregation' on: - schedule: - # Runs at 23:00 UTC (7:00 AM Beijing) every day - - cron: '0 23 * * *' pull_request: types: [ labeled ] @@ -41,7 +38,7 @@ jobs: if: ${{ contains(github.event.pull_request.labels.*.name, 'pd-test') && contains(github.event.pull_request.labels.*.name, 'ready-for-test') || github.event_name == 'schedule' }} strategy: matrix: - vllm_verison: [main, v0.9.1] + vllm_verison: [v0.9.1] name: vLLM Ascend prefilling decoding disaggregation test runs-on: linux-arm64-npu-static-8 @@ -106,3 +103,7 @@ jobs: - name: Run vllm-project/vllm-ascend PD Disaggregation test run: | pytest -sv tests/e2e/pd_disaggreate/test_pd_e2e.py + + - name: Run vllm-project/vllm-ascend PD Disaggregation edge test + run: | + bash tests/e2e/pd_disaggreate/run_edge_case_test.sh \ No newline at end of file diff --git a/README.md b/README.md index 7d0966c8d4..bdc6a1bb00 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l - Software: * Python >= 3.9, < 3.12 * CANN >= 8.1.RC1 - * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250528 + * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250619 * vLLM (the same version as vllm-ascend) ## Getting Started diff --git a/README.zh.md b/README.zh.md index 2d2062a8b4..afe2c76143 100644 --- a/README.zh.md +++ b/README.zh.md @@ -39,7 +39,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP - 软件: * Python >= 3.9, < 3.12 * CANN >= 8.1.RC1 - * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250528 + * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250619 * vLLM (与vllm-ascend版本一致) ## 开始使用 diff --git a/docs/source/faqs.md b/docs/source/faqs.md index 1de3befb2d..be6d689eff 100644 --- a/docs/source/faqs.md +++ b/docs/source/faqs.md @@ -114,7 +114,7 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam - **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html). -### 15. Failed to enable NPU graph mode when running DeepSeek? +### 16. Failed to enable NPU graph mode when running DeepSeek? You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future. And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}. @@ -123,3 +123,6 @@ And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tenso [rank0]: RuntimeError: EZ9999: Inner Error! [rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218] ``` + +### 17. Failed to reinstall vllm-ascend from source after uninstalling vllm-ascend? +You may encounter the problem of C compilation failure when reinstalling vllm-ascend from source using pip. If the installation fails, it is recommended to use `python setup.py install` to install, or use `python setup.py clean` to clear the cache. diff --git a/docs/source/installation.md b/docs/source/installation.md index c290f7e5f7..b7eb611a03 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -12,7 +12,7 @@ This document describes how to install vllm-ascend manually. | Software | Supported version | Note | |---------------|----------------------------------|-------------------------------------------| | CANN | >= 8.1.RC1 | Required for vllm-ascend and torch-npu | - | torch-npu | >= 2.5.1.post1.dev20250528 | Required for vllm-ascend | + | torch-npu | >= 2.5.1.post1.dev20250619 | Required for vllm-ascend | | torch | >= 2.5.1 | Required for torch-npu and vllm | You have 2 way to install: @@ -246,8 +246,7 @@ for output in outputs: Then run: ```bash -# Try `export VLLM_USE_MODELSCOPE=true` and `pip install modelscope` -# to speed up download if huggingface is not reachable. +# export VLLM_USE_MODELSCOPE=true to speed up download if huggingface is not reachable. python example.py ``` diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index d4756ef5e1..2a0194209b 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -28,7 +28,6 @@ The following table lists the additional configuration options available in vLLM |-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| | `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | -| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. | | `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | diff --git a/examples/disaggregate_prefill_v1/README.md b/examples/disaggregate_prefill_v1/README.md new file mode 100644 index 0000000000..544d5ba020 --- /dev/null +++ b/examples/disaggregate_prefill_v1/README.md @@ -0,0 +1,234 @@ +# Disaggregated Prefill-Decode Deployment Guide + +## Overview +This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2). + +## Prerequisites +- Ascend NPU environment with vLLM 0.9.1 installed +- Network interfaces configured for distributed communication (eg: eth0) +- Model weights located at `/data01/deepseek_r1_w8a8_zhw` + +## Rank table generation +The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode: + +Run the following command on every node to generate the rank table: +```shell +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \ + --npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16 +``` +Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json` + +## Start disaggregated vLLM-ascend service +Execution Sequence +- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 +- Start Prefill on Node 1 (P1) +- Start Prefill on Node 2 (P2) +- Start Decode on Node 1 (D1) +- Start Decode on Node 2 (D2) +- Start proxy server on Node1 + +* Run prefill server P1 on first node +```shell +export HCCL_IF_IP=172.19.32.175 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run prefill server P2 on second node +```shell +export HCCL_IF_IP=172.19.241.49 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", \ + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run decode server d1 on third node +```shell +export HCCL_IF_IP=172.19.123.51 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run decode server d2 on last node +```shell +export HCCL_IF_IP=172.19.190.36 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run proxy server on the first node +```shell +cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1 +python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002 +``` + +* Verification +Check service health using the proxy server endpoint: +```shell +curl http://localhost:1025/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "deepseek", + "prompt": "Who are you?", + "max_tokens": 100, + "temperature": 0 + }' +``` + +* Performance +Test performance with vllm benchmark +```shell +cd /vllm-workspace/vllm/benchmarks +python3 benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --random-input-len 4096 \ + --random-output-len 1536 \ + --num-prompts 256 \ + --ignore-eos \ + --model deepseek \ + --tokenizer /data01/deepseek_r1_w8a8_zhw \ + --host localhost \ + --port 8000 \ + --endpoint /v1/completions \ + --max-concurrency 4 \ + --request-rate 4 +``` \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.py b/examples/disaggregate_prefill_v1/gen_ranktable.py new file mode 100644 index 0000000000..d170f3ba06 --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.py @@ -0,0 +1,120 @@ +import argparse +import json +import os + +import torch.distributed as dist + +from vllm_ascend.soc_info import NPUSocInfo + +parser = argparse.ArgumentParser( + description="Arguments of rank table generator", ) +parser.add_argument("--local-host", type=str, required=True, help="local ip") +parser.add_argument("--prefill-device-cnt", + type=int, + required=True, + help="number of prefill devices") +parser.add_argument("--decode-device-cnt", + type=int, + required=True, + help="number of decode devices") +args = parser.parse_args() +local_host = args.local_host +prefill_device_cnt = args.prefill_device_cnt +decode_device_cnt = args.decode_device_cnt + +print("enter py") + +hccn_tool_path = os.environ.get("HCCN_TOOL_PATH", + "/usr/local/Ascend/driver/tools/hccn_tool") +master_addr = os.environ.get("MASTER_ADDR") +master_port = os.environ.get("MASTER_PORT") +rank = os.environ.get("RANK") +local_rank = os.environ.get("LOCAL_RANK") +# This variable is set by torchrun, +# and is different from WORLD_SIZE in gen_rank_table.sh. +world_size = os.environ.get("WORLD_SIZE") +soc_info = NPUSocInfo() + + +def get_cmd_stdout(cmd): + import subprocess + return subprocess.run(cmd, capture_output=True, + shell=True).stdout.decode("utf-8").strip() + + +print(f"local_host: {local_host}") +print("gen ranktable.json") + +num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split( + ":")[1].strip() +num_cards = int(num_cards) +chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split( + "\n")[0].split(":")[1].strip() +chips_per_card = int(chips_per_card) + +# generate local device list for local rank 0, and gather it to all ranks +local_device_list: list[dict[str, str]] = list() +if local_rank == "0": + super_pod_id = "0" + for card_id in range(num_cards): + for chip_id in range(chips_per_card): + device_id = card_id * chips_per_card + chip_id + if soc_info.is_a3: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" + ).split(":")[1].strip() + super_device_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID" + ).split(":")[1].strip() + super_pod_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"" + ).split(":")[1].strip() + else: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr" + ).split(":")[1].strip() + + device_info = { + "server_id": local_host, + "device_id": str(device_id), + "device_ip": str(device_ip), + } + if soc_info.is_a3: + device_info.update({ + "super_pod_id": str(super_pod_id), + "super_device_id": str(super_device_id) + }) + local_device_list.append(device_info) + +dist.init_process_group(backend=dist.Backend.GLOO) +global_device_list = [None] * dist.get_world_size() +dist.all_gather_object(global_device_list, local_device_list) +global_device_list = [ + device_info for device_list in global_device_list + for device_info in device_list # type: ignore[attr-defined] +] +cnt = 1 +for device_info in global_device_list: # type: ignore[assignment] + device_info["cluster_id"] = str(cnt) + cnt += 1 +assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \ +"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster" +ranktable = { + "version": + "1.2", + "server_count": + str(world_size), + "prefill_device_list": + global_device_list[:prefill_device_cnt], + "decode_device_list": + global_device_list[prefill_device_cnt:prefill_device_cnt + + decode_device_cnt], + "status": + "completed" +} + +if local_rank == '0': + with open("ranktable.json", "w") as f: + json.dump(ranktable, f, indent=4) + + print("gen ranktable.json done") diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.sh b/examples/disaggregate_prefill_v1/gen_ranktable.sh new file mode 100644 index 0000000000..33d4a32e8d --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} + +NPUS_PER_NODE=8 +while [[ $# -gt 0 ]]; do + case "$1" in + --ips) + shift + while [[ $# -gt 0 && ! "$1" == --* ]]; do + IPs+=("$1") + shift + done + ;; + --npus-per-node) + shift + NPUS_PER_NODE="$1" + shift + ;; + --network-card-name) + shift + NETWORK_CARD_NAME="$1" + shift + ;; + --prefill-device-cnt) + shift + PREFILL_DEVICE_CNT="$1" + shift + ;; + --decode-device-cnt) + shift + DECODE_DEVICE_CNT="$1" + shift + ;; + esac +done +LOCAL_HOSTS=($(hostname -I)) +LOCAL_HOST="127.0.0.1" +MASTER_ADDR=${IPs[0]} +MASTER_PORT=6657 +NNODES=${#IPs[@]} +NODE_RANK="8" +for i in "${!IPs[@]}"; do + ip="${IPs[$i]}" + for local_host in "${LOCAL_HOSTS[@]}"; do + if [[ "$local_host" == "$ip" ]]; then + LOCAL_HOST=$local_host + NODE_RANK=$i + break 2 + fi + done +done + +if [[ $NODE_RANK == "" ]];then + echo "[Error] para \"NODE_RANK\" must be defined" + exit 1 +fi + +WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES)) +RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK` + +echo "========>param:" +echo "LOCAL_HOST": $LOCAL_HOST +echo "WORLD_SIZE: " $WORLD_SIZE +echo "RANKSTART": $RANKSTART +echo "NNODES": $NNODES +echo "NODE_RANK": $NODE_RANK +echo "===============" + +if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then + GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \ + --nproc_per_node 1 \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT +fi \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/run_server.sh b/examples/disaggregate_prefill_v1/run_server.sh new file mode 100644 index 0000000000..37cf6d3aee --- /dev/null +++ b/examples/disaggregate_prefill_v1/run_server.sh @@ -0,0 +1,32 @@ +export HCCL_IF_IP=141.61.39.117 +export GLOO_SOCKET_IFNAME="enp48s3u1u1" +export TP_SOCKET_IFNAME="enp48s3u1u1" +export HCCL_SOCKET_IFNAME="enp48s3u1u1" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table + +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 + +export VLLM_USE_V1=1 + +vllm serve model_path \ + --host 0.0.0.0 \ + --port 20002 \ + --tensor-parallel-size 1\ + --seed 1024 \ + --served-model-name dsv3 \ + --max-model-len 2000 \ + ---max-num-batched-tokens 2000 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": 0, + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3" + }' \ + --additional-config \ + '{"enable_graph_mode": "True"}'\ diff --git a/examples/disaggregate_prefill_v1/toy_proxy_server.py b/examples/disaggregate_prefill_v1/toy_proxy_server.py new file mode 100644 index 0000000000..4478073f74 --- /dev/null +++ b/examples/disaggregate_prefill_v1/toy_proxy_server.py @@ -0,0 +1,261 @@ +# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py index b06c52d8c5..37a14d5f7b 100644 --- a/examples/dp_offline/data_parallel.py +++ b/examples/dp_offline/data_parallel.py @@ -1,85 +1,226 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py # SPDX-License-Identifier: Apache-2.0 -# usage: -# python examples/offline_inference_data_parallel.py -# we need to have a launcher to create multiple data parallel -# ranks. And each rank will create a vLLM instance to process its own prompts. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Usage: +Single node: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 + +Multi-node: + Node 0 (assume the node has ip of 10.99.48.128): + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=0 \ + --master-addr=10.99.48.128 \ + --master-port=13345 + Node 1: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=1 \ + --master-addr=10.99.48.128 \ + --master-port=13345 +""" -import gc import os +from time import sleep + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument( + "--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path", + ) + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=2, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + parser.add_argument("--enforce-eager", + action="store_true", + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action="store_true", + help="Trust remote code.") + return parser.parse_args() -def main(): - dp_rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - dp_size = int(os.environ['WORLD_SIZE']) - master_addr = os.environ['MASTER_ADDR'] - master_port = os.environ['MASTER_PORT'] - tp_size = 1 - etp_size = 1 - os.environ["VLLM_DP_RANK"] = str(dp_rank) +def main( + model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + GPUs_per_dp_rank, + enforce_eager, + trust_remote_code, +): + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) - os.environ["VLLM_DP_MASTER_IP"] = master_addr - os.environ["VLLM_DP_MASTER_PORT"] = master_port - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( - str(i) - for i in range(local_rank * tp_size, (local_rank + 1) * tp_size)) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) - import torch - from vllm import LLM, SamplingParams - from vllm.distributed.parallel_state import ( - destroy_distributed_environment, destroy_model_parallel) + # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the + # engine processes. + # Sample prompts. prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", - ] * 4 + ] * 100 - promts_per_rank = len(prompts) // dp_size - start = dp_rank * promts_per_rank - end = start + promts_per_rank - prompts = prompts[start:end] + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + floor = len(prompts) // dp_size + remainder = len(prompts) % dp_size + + # Distribute prompts into even groups. + def start(rank): + return rank * floor + min(rank, remainder) + + prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)] if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt prompts = ["Placeholder"] - print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") - num_seqs = len(prompts) + print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") + + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=32, + ) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=4, - min_tokens=4) # Create an LLM. - llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - tensor_parallel_size=tp_size, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=num_seqs, - additional_config={ - 'expert_tensor_parallel_size': etp_size, - 'torchair_graph_config': { - 'enabled': False, - }, - }) + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + trust_remote_code=trust_remote_code, + distributed_executor_backend="mp", + max_model_len=2048, + max_num_batched_tokens=2048, + max_num_seqs=16, + enable_prefix_caching=False, + enable_expert_parallel=True, + gpu_memory_utilization=0.9, + additional_config={ + "ascend_scheduler_config": { + "enabled": True + }, + "torchair_graph_config": { + "enabled": False, + "enable_multistream_shared_expert": False + }, + }, + ) outputs = llm.generate(prompts, sampling_params) - for output in outputs: + # Print the outputs. + for i, output in enumerate(outputs): + if i >= 5: + # print only 5 outputs + break prompt = output.prompt generated_text = output.outputs[0].text - print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " + print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}") - del llm - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() + # Give engines time to pause their processing loops before exiting. + sleep(1) if __name__ == "__main__": - main() + args = parse_args() + + dp_size = args.dp_size + tp_size = args.tp_size + node_size = args.node_size + node_rank = args.node_rank + + if node_size == 1: + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + else: + dp_master_ip = args.master_addr + dp_master_port = args.master_port + + assert dp_size % node_size == 0, "dp_size should be divisible by node_size" + dp_per_node = dp_size // node_size + + from multiprocessing import Process + + procs = [] + for local_dp_rank, global_dp_rank in enumerate( + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): + proc = Process( + target=main, + args=( + args.model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + args.enforce_eager, + args.trust_remote_code, + ), + ) + proc.start() + procs.append(proc) + exit_code = 0 + for proc in procs: + proc.join(timeout=3000) + if proc.exitcode is None: + print( + f"Killing process {proc.pid} that didn't stop within 5 minutes." + ) + proc.kill() + exit_code = 1 + elif proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/examples/dp_offline/run_dp.sh b/examples/dp_offline/run_dp.sh index 405df604a4..508d966651 100644 --- a/examples/dp_offline/run_dp.sh +++ b/examples/dp_offline/run_dp.sh @@ -1,19 +1,28 @@ +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* + +ifname="ifname" +local_ip="local ip" +master_addr="master ip" +model_path="path to model ckpt" + export HCCL_IF_IP=${local_ip} export GLOO_SOCKET_IFNAME=${ifname} export TP_SOCKET_IFNAME=${ifname} export HCCL_SOCKET_IFNAME=${ifname} -# dp_size = node_size * dp_per_node -node_size=1 -node_rank=0 -dp_per_node=4 -master_addr=127.0.0.1 -master_port=12345 - -rm -rf ./.torchair_cache/ -rm -rf ./dynamo_* -rm -rf /root/ascend/log/debug/plog/* +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 +# export VLLM_VERSION=0.9.0 -torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ - --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ - data_parallel.py +python data_parallel.py \ + --model=${model_path} \ + --dp-size=4 \ + --tp-size=4 \ + --enforce-eager \ + --trust-remote-code \ + --node-size=1 \ + --node-rank=0 \ + --master-addr=${master_addr} \ + --master-port=13345 diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py index d8153e38ca..dd8ee9aeb1 100644 --- a/examples/offline_dualbatch_overlap_npu.py +++ b/examples/offline_dualbatch_overlap_npu.py @@ -20,6 +20,7 @@ def main(): tensor_parallel_size=2, max_model_len=4096, trust_remote_code=True, + enable_expert_parallel=True, additional_config={ "torchair_graph_config": { "enabled": False @@ -27,7 +28,6 @@ def main(): "ascend_scheduler_config": { "enabled": True }, - "expert_tensor_parallel_size": 1 }) # Generate texts from the prompts. The output is a list of RequestOutput diff --git a/examples/run_dp_server.sh b/examples/run_dp_server.sh index e2bf4c8158..eb3cfbf510 100644 --- a/examples/run_dp_server.sh +++ b/examples/run_dp_server.sh @@ -1,3 +1,7 @@ +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* + export HCCL_IF_IP=2.0.0.0 export GLOO_SOCKET_IFNAME="enp189s0f0" export TP_SOCKET_IFNAME="enp189s0f0" @@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0" export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 -export VLLM_USE_V1=0 - -export ASCEND_RT_VISIBLE_DEVICES=0,1 -export VLLM_DP_SIZE=2 -export VLLM_DP_RANK=0 -export VLLM_DP_MASTER_IP="2.0.0.0" -export VLLM_DP_MASTER_PORT=40001 -export VLLM_DP_PROXY_IP="2.0.0.0" -export VLLM_DP_PROXY_PORT=30002 -export VLLM_DP_MONITOR_PORT=30003 -export VLLM_HTTP_PORT=20001 +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 vllm serve /data/weights/Qwen2.5-0.5B-Instruct \ --host 0.0.0.0 \ - --port 20001 \ - --tensor-parallel-size 1 \ - --seed 1024 \ + --port 20002 \ --served-model-name Qwen \ - --max-model-len 2000 \ - --max-num-batched-tokens 2000 \ + --data-parallel-size 4 \ + --data-parallel-size-local 4 \ + --data-parallel-address 2.0.0.0 \ + --data-parallel-rpc-port 13389 \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --max-num-seqs 16 \ + --max-model-len 4096 \ + --max-num-batched-tokens 4096 \ + --gpu-memory-utilization 0.9 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ \ No newline at end of file + --enforce-eager \ + --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}' diff --git a/examples/run_dp_with_cached_graph_etp16.sh b/examples/run_dp_with_cached_graph_etp16.sh new file mode 100644 index 0000000000..5f1d3b782b --- /dev/null +++ b/examples/run_dp_with_cached_graph_etp16.sh @@ -0,0 +1,25 @@ +export HCCL_IF_IP=2.0.0.0 +export GLOO_SOCKET_IFNAME="enp189s0f0" +export TP_SOCKET_IFNAME="enp189s0f0" +export HCCL_SOCKET_IFNAME="enp189s0f0" + +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 +# export VLLM_VERSION=0.9.0 + +nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \ + --host 0.0.0.0 \ + --port 20002 \ + --quantization ascend \ + -dp=2 \ + -tp=8 \ + --no-enable-prefix-caching \ + --max-num-seqs 24 \ + --max-model-len 4096 \ + --max-num-batched-tokens 4096 \ + --gpu-memory-utilization 0.96 \ + --trust-remote-code \ + --distributed-executor-backend=mp \ + --additional-config '{"torchair_graph_config":{"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true}}' \ + & > run.log & +disown diff --git a/pyproject.toml b/pyproject.toml index 514b755c32..fc7c7c2c71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,5 +19,7 @@ requires = [ "msgpack", "quart", "numba", + # Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 + "transformers<4.53.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index eadb96f1e9..6d84ec658c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,7 @@ numba # Install torch_npu --pre --extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi -torch-npu==2.5.1.post1.dev20250528 +torch-npu==2.5.1.post1.dev20250619 + +# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 +transformers<4.53.0 diff --git a/tests/e2e/pd_disaggreate/run_edge_case_test.sh b/tests/e2e/pd_disaggreate/run_edge_case_test.sh new file mode 100644 index 0000000000..2f5b1bc8e9 --- /dev/null +++ b/tests/e2e/pd_disaggreate/run_edge_case_test.sh @@ -0,0 +1,141 @@ +#!/bin/bash +export LCCL_DETERMINISTIC=1 +export HCCL_DETERMINISTIC=true +export CLOSE_MATMUL_K_SHIFT=1 +export VLLM_USE_V1=1 + +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen2.5-0.5B-Instruct" +) + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Gen ranktable +RANKTABLE_PATH=${GIT_ROOT}/examples/disaggregate_prefill_v1/ranktable.json +if [ -f "$RANKTABLE_PATH" ]; then + rm "$RANKTABLE_PATH" +fi +cd ${GIT_ROOT}/examples/disaggregate_prefill_v1 +LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'` +bash gen_ranktable.sh --ips $LOCAL_HOST --network-card-name enp189s0f0 --prefill-device-cnt 1 --decode-device-cnt 1 +cd - +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="$RANKTABLE_PATH" + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/health > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == *"deepseek"* ]]; then + extra_args="--trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Start prefill instance + PREFILL_PORT=8001 + + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \ + --port $PREFILL_PORT \ + --seed 1024 \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Start decode instance + DECODE_PORT=8002 + + # Build the command with or without model-specific args + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \ + --port $DECODE_PORT \ + --seed 1024 \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Wait for all instances to start + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $DECODE_PORT + + # Build the command for the proxy server with all the hosts and ports + PROXY_PORT=8192 + PROXY_CMD="python ${GIT_ROOT}/examples/disaggregate_prefill_v1/toy_proxy_server.py --port $PROXY_PORT" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORT}" + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/e2e/pd_disaggreate/test_edge_cases.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" \ No newline at end of file diff --git a/tests/e2e/pd_disaggreate/test_edge_cases.py b/tests/e2e/pd_disaggreate/test_edge_cases.py new file mode 100644 index 0000000000..fe53ddc6db --- /dev/null +++ b/tests/e2e/pd_disaggreate/test_edge_cases.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# This code is from: https://github.com/vllm-project/vllm/blob/main/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +import os + +import openai + +PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_PORT = os.getenv("PROXY_PORT", None) + +if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: + raise ValueError( + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + +LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 +PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 +SHORT_PROMPT = "Red Hat is " + + +def test_edge_cases(): + # Set the OpenAI API key and base URL + decode_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{DECODE_PORT}/v1", + ) + prefill_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PREFILL_PORT}/v1", + ) + proxy_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PROXY_PORT}/v1", + ) + + # Get the list of models + models = decode_client.models.list() + MODEL = models.data[0].id + + # (1) Check that we can handle a very short prompt, + # less than the length of the block size. + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"SMALL PROMPT: {proxy_response=}") + print(f"SMALL PROMPT: {prefill_response=}") + assert proxy_response == prefill_response + + # (2) Check that we can handle a full prefix cache + # hit on the D worker but not on the P worker. + # (2a): prime the D worker. + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + decode_response = completion.choices[0].text + # (2b): send via the P/D setup + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + print(f"FULL CACHE HIT: {proxy_response=}") + assert proxy_response == decode_response + + # (3) Check that we can handle a partial prefix cache + # hit on the D worker. + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"PARTIAL CACHE HIT: {proxy_response=}") + assert proxy_response == prefill_response \ No newline at end of file diff --git a/tests/e2e/long_term/spec_decode/__init__.py b/tests/long_term/spec_decode_v0/__init__.py similarity index 100% rename from tests/e2e/long_term/spec_decode/__init__.py rename to tests/long_term/spec_decode_v0/__init__.py diff --git a/tests/e2e/long_term/spec_decode/conftest.py b/tests/long_term/spec_decode_v0/conftest.py similarity index 100% rename from tests/e2e/long_term/spec_decode/conftest.py rename to tests/long_term/spec_decode_v0/conftest.py diff --git a/tests/e2e/long_term/spec_decode/e2e/__init__.py b/tests/long_term/spec_decode_v0/e2e/__init__.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/__init__.py rename to tests/long_term/spec_decode_v0/e2e/__init__.py diff --git a/tests/e2e/long_term/spec_decode/e2e/conftest.py b/tests/long_term/spec_decode_v0/e2e/conftest.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/conftest.py rename to tests/long_term/spec_decode_v0/e2e/conftest.py diff --git a/tests/long_term/spec_decode_v0/e2e/test_eagle_correctness.py b/tests/long_term/spec_decode_v0/e2e/test_eagle_correctness.py new file mode 100644 index 0000000000..88ee6bb3d2 --- /dev/null +++ b/tests/long_term/spec_decode_v0/e2e/test_eagle_correctness.py @@ -0,0 +1,344 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, EAGLE would not break the +correctness for the target model outputs. +""" + +import pytest + +from tests.long_term.spec_decode_v0.e2e.conftest import \ + run_equality_correctness_test + +# main model +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 4 + +# precision +# TODO The vLLM here uses float32, but some op on the vllm-ascend +# do not support float32, such as ROPE, When it is fixed, it is +# recommended to change this to float32. +PRECISION = "float16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, +}]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) + + +@pytest.mark.skipif(True, reason="Open it when graph mode ready.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_cuda_graph( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.skipif(True, reason="Open it when preempt ready.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py b/tests/long_term/spec_decode_v0/e2e/test_medusa_correctness.py similarity index 97% rename from tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py rename to tests/long_term/spec_decode_v0/e2e/test_medusa_correctness.py index e0c2efd7af..48b22f72c4 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/long_term/spec_decode_v0/e2e/test_medusa_correctness.py @@ -41,9 +41,15 @@ import pytest +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py from tests.e2e.long_term.spec_decode.e2e.conftest import \ run_equality_correctness_test from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +======== +from tests.long_term.spec_decode_v0.e2e.conftest import \ + run_equality_correctness_test +from tests.long_term.spec_decode_v0.utils import maybe_enable_chunked_prefill +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/e2e/test_medusa_correctness.py # main model # lmsys/vicuna-7b-v1.3 was to be used but it's causing diff --git a/tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py b/tests/long_term/spec_decode_v0/e2e/test_mlp_correctness.py similarity index 98% rename from tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py rename to tests/long_term/spec_decode_v0/e2e/test_mlp_correctness.py index 56db617755..4473058a5e 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/long_term/spec_decode_v0/e2e/test_mlp_correctness.py @@ -41,9 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import \ pad_vocab_size # noqa: F401 +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py from tests.e2e.long_term.spec_decode.e2e.conftest import \ run_equality_correctness_test from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +======== +from tests.long_term.spec_decode_v0.e2e.conftest import \ + run_equality_correctness_test +from tests.long_term.spec_decode_v0.utils import maybe_enable_chunked_prefill +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/e2e/test_mlp_correctness.py # main model MAIN_MODEL = "JackFram/llama-160m" diff --git a/tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py b/tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py rename to tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py diff --git a/tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py b/tests/long_term/spec_decode_v0/e2e/test_ngram_correctness.py similarity index 97% rename from tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py rename to tests/long_term/spec_decode_v0/e2e/test_ngram_correctness.py index b99187fe37..2317a35141 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/long_term/spec_decode_v0/e2e/test_ngram_correctness.py @@ -44,9 +44,15 @@ import pytest +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py from tests.e2e.long_term.spec_decode.e2e.conftest import \ run_equality_correctness_test from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +======== +from tests.long_term.spec_decode_v0.e2e.conftest import \ + run_equality_correctness_test +from tests.long_term.spec_decode_v0.utils import maybe_enable_chunked_prefill +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/e2e/test_ngram_correctness.py @pytest.mark.parametrize( diff --git a/tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py b/tests/long_term/spec_decode_v0/test_dynamic_spec_decode.py similarity index 93% rename from tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py rename to tests/long_term/spec_decode_v0/test_dynamic_spec_decode.py index 8e9480ea26..bf18a5ae1c 100644 --- a/tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py +++ b/tests/long_term/spec_decode_v0/test_dynamic_spec_decode.py @@ -27,8 +27,13 @@ from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.top1_proposer import Top1Proposer +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py from tests.e2e.long_term.spec_decode.test_utils import mock_spec_decode_sampler from tests.e2e.long_term.spec_decode.utils import create_batch, mock_worker +======== +from tests.long_term.spec_decode_v0.test_utils import mock_spec_decode_sampler +from tests.long_term.spec_decode_v0.utils import create_batch, mock_worker +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/test_dynamic_spec_decode.py @pytest.mark.parametrize('queue_size', [4]) diff --git a/tests/e2e/long_term/spec_decode/test_multi_step_worker.py b/tests/long_term/spec_decode_v0/test_multi_step_worker.py similarity index 99% rename from tests/e2e/long_term/spec_decode/test_multi_step_worker.py rename to tests/long_term/spec_decode_v0/test_multi_step_worker.py index b3017a987e..96617a34bc 100644 --- a/tests/e2e/long_term/spec_decode/test_multi_step_worker.py +++ b/tests/long_term/spec_decode_v0/test_multi_step_worker.py @@ -29,7 +29,11 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/test_multi_step_worker.py from tests.e2e.long_term.spec_decode.utils import ( +======== +from tests.long_term.spec_decode_v0.utils import ( +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/test_multi_step_worker.py assert_logprobs_dict_allclose, create_batch, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) diff --git a/tests/e2e/long_term/spec_decode/test_ngram_worker.py b/tests/long_term/spec_decode_v0/test_ngram_worker.py similarity index 97% rename from tests/e2e/long_term/spec_decode/test_ngram_worker.py rename to tests/long_term/spec_decode_v0/test_ngram_worker.py index 078a4d2bed..e43bc90e22 100644 --- a/tests/e2e/long_term/spec_decode/test_ngram_worker.py +++ b/tests/long_term/spec_decode_v0/test_ngram_worker.py @@ -22,7 +22,11 @@ from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.top1_proposer import Top1Proposer +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/test_ngram_worker.py from tests.e2e.long_term.spec_decode.utils import ( +======== +from tests.long_term.spec_decode_v0.utils import ( +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/test_ngram_worker.py create_seq_group_metadata_from_prompts, create_worker) diff --git a/tests/e2e/long_term/spec_decode/test_spec_decode_worker.py b/tests/long_term/spec_decode_v0/test_spec_decode_worker.py similarity index 98% rename from tests/e2e/long_term/spec_decode/test_spec_decode_worker.py rename to tests/long_term/spec_decode_v0/test_spec_decode_worker.py index 94a1bcf1e7..a0094ca8b6 100644 --- a/tests/e2e/long_term/spec_decode/test_spec_decode_worker.py +++ b/tests/long_term/spec_decode_v0/test_spec_decode_worker.py @@ -35,10 +35,17 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) +<<<<<<<< HEAD:tests/e2e/long_term/spec_decode/test_spec_decode_worker.py from tests.e2e.long_term.spec_decode.test_utils import mock_spec_decode_sampler from tests.e2e.long_term.spec_decode.utils import (create_batch, create_sampler_output_list, create_worker, mock_worker) +======== +from tests.long_term.spec_decode_v0.test_utils import mock_spec_decode_sampler +from tests.long_term.spec_decode_v0.utils import (create_batch, + create_sampler_output_list, + create_worker, mock_worker) +>>>>>>>> upstream/v0.9.1-dev:tests/long_term/spec_decode_v0/test_spec_decode_worker.py from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner from vllm_ascend.worker.worker import NPUWorker diff --git a/tests/e2e/long_term/spec_decode/test_utils.py b/tests/long_term/spec_decode_v0/test_utils.py similarity index 100% rename from tests/e2e/long_term/spec_decode/test_utils.py rename to tests/long_term/spec_decode_v0/test_utils.py diff --git a/tests/e2e/long_term/spec_decode/utils.py b/tests/long_term/spec_decode_v0/utils.py similarity index 100% rename from tests/e2e/long_term/spec_decode/utils.py rename to tests/long_term/spec_decode_v0/utils.py diff --git a/tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py similarity index 54% rename from tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py rename to tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py index 2219a6f552..3b5e1986f2 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py +++ b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py @@ -63,7 +63,10 @@ def test_mtp_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True) + ref_llm = LLM(model=model_name, + max_model_len=256, + gpu_memory_utilization=0.8, + enforce_eager=True) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -74,6 +77,7 @@ def test_mtp_correctness( "num_speculative_tokens": 1, }, max_model_len=256, + gpu_memory_utilization=0.8, enforce_eager=True) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 @@ -90,3 +94,62 @@ def test_mtp_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) del spec_llm + + +def test_mtp_torchair_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using mtp speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name, + max_model_len=256, + enforce_eager=False, + additional_config={ + "torchair_graph_config": { + "enabled": True + }, + "ascend_scheduler_config": { + "enabled": True + }, + }) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM(model=model_name, + trust_remote_code=True, + enforce_eager=False, + speculative_config={ + "method": "deepseek_mtp", + "num_speculative_tokens": 1, + }, + additional_config={ + "torchair_graph_config": { + "enabled": True + }, + "ascend_scheduler_config": { + "enabled": True + }, + }) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + del spec_llm diff --git a/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py b/tests/long_term/spec_decode_v1/test_v1_spec_decode.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py rename to tests/long_term/spec_decode_v1/test_v1_spec_decode.py diff --git a/tests/e2e/long_term/test_accuracy.py b/tests/long_term/test_accuracy.py similarity index 100% rename from tests/e2e/long_term/test_accuracy.py rename to tests/long_term/test_accuracy.py diff --git a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py b/tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py similarity index 97% rename from tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py rename to tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py index 27986cb149..3a9068ff6b 100644 --- a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py +++ b/tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py @@ -38,7 +38,7 @@ def run_test(model_name, queue, more_args=None): - model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4" + model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4,enforce_eager=True" if more_args is not None: model_args = f"{model_args},{more_args}" results = lm_eval.simple_evaluate( diff --git a/tests/multicard/test_data_parallel.py b/tests/multicard/test_data_parallel.py index 6c0a20de97..cabeac2846 100644 --- a/tests/multicard/test_data_parallel.py +++ b/tests/multicard/test_data_parallel.py @@ -16,7 +16,6 @@ # """ Compare the outputs of vLLM with and without aclgraph. - Run `pytest tests/multicard/test_data_parallel.py`. """ @@ -30,6 +29,7 @@ MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] +@pytest.mark.skipif(True, reason="OPEN ME when dp is supported on A2") @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="Data parallel only support on v1") @pytest.mark.parametrize("model", MODELS) diff --git a/tests/e2e/multicard/test_dynamic_npugraph_batchsize.py b/tests/multicard/test_dynamic_npugraph_batchsize.py similarity index 100% rename from tests/e2e/multicard/test_dynamic_npugraph_batchsize.py rename to tests/multicard/test_dynamic_npugraph_batchsize.py diff --git a/tests/e2e/multicard/test_ilama_lora_tp2.py b/tests/multicard/test_ilama_lora_tp2.py similarity index 83% rename from tests/e2e/multicard/test_ilama_lora_tp2.py rename to tests/multicard/test_ilama_lora_tp2.py index e743141b7a..e61ce250c8 100644 --- a/tests/e2e/multicard/test_ilama_lora_tp2.py +++ b/tests/multicard/test_ilama_lora_tp2.py @@ -1,8 +1,8 @@ import pytest from tests.conftest import VllmRunner -from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, - MODEL_PATH, do_sample) +from tests.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, MODEL_PATH, + do_sample) @pytest.mark.parametrize("distributed_executor_backend", ["mp"]) diff --git a/tests/multicard/test_model_qwen3_w4a8.py b/tests/multicard/test_model_qwen3_w4a8.py new file mode 100644 index 0000000000..e059863638 --- /dev/null +++ b/tests/multicard/test_model_qwen3_w4a8.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +"""Compare the outputs of vLLM when using W4A8 quantization on qwen3 models. + +Run `pytest tests/multicard/test_model_qwen3_w4a8.py`. +""" +import os + +import pytest +from modelscope import snapshot_download # type: ignore +from vllm import LLM, SamplingParams + +MODELS = ["vllm-ascend/Qwen3-8B-W4A8"] +PROMPTS = [ + "Hello, my name is", + "The future of AI is", +] + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="w4a8_dynamic is not supported on v0") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [16]) +def test_qwen3_model_with_w4a8_linear_method(model: str, + max_tokens: int) -> None: + messages = [[{"role": "user", "content": prompt}] for prompt in PROMPTS] + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=0.0, + ) + llm = LLM( + model=snapshot_download(model), + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=True, + quantization="ascend", + ) + vllm_outputs = llm.chat( + messages, + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + ) + golden_outputs = [ + "Hello! My name is Qwen, and I'm a large language model developed", + "The future of AI is a topic of great interest and debate, with many possibilities", + ] + assert len(vllm_outputs) == len(golden_outputs) + for vllm_output, golden_output in zip(vllm_outputs, golden_outputs): + assert vllm_output.outputs[0].text == golden_output + print(f"Generated text: {vllm_output.outputs[0].text!r}") diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py similarity index 72% rename from tests/e2e/multicard/test_offline_inference_distributed.py rename to tests/multicard/test_offline_inference_distributed.py index f5ec2c872b..df69ff8e3d 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -23,8 +23,10 @@ import os from unittest.mock import patch +import pytest from modelscope import snapshot_download # type: ignore from vllm import SamplingParams +from vllm.model_executor.models.registry import ModelRegistry from tests.conftest import VllmRunner @@ -83,6 +85,10 @@ def test_models_distributed_topk() -> None: vllm_model.generate(example_prompts, sampling_params) +@pytest.mark.skip( + reason= + "deepseek dbo dose not consider the support on half precision float, will enable this ut after we actually support it" +) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) def test_models_distributed_DeepSeek_dbo(): example_prompts = ["The president of the United States is"] * 41 @@ -94,6 +100,33 @@ def test_models_distributed_DeepSeek_dbo(): tensor_parallel_size=4, distributed_executor_backend="mp", ) as vllm_model: + model_arch = 'DeepseekV2ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" + vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeekV3_dbo(): + example_prompts = ["The president of the United States is"] * 41 + dtype = "half" + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + model_arch = 'DeepseekV3ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/multicard/test_pyhccl_distributed.py b/tests/multicard/test_pyhccl_distributed.py similarity index 100% rename from tests/e2e/multicard/test_pyhccl_distributed.py rename to tests/multicard/test_pyhccl_distributed.py diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/multicard/test_torchair_graph_mode.py similarity index 85% rename from tests/e2e/multicard/test_torchair_graph_mode.py rename to tests/multicard/test_torchair_graph_mode.py index d06ec7de22..96fa92ef4b 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/multicard/test_torchair_graph_mode.py @@ -30,10 +30,13 @@ @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="torchair graph is not supported on v0") -def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"]) +def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch, + VLLM_ASCEND_ENABLE_DBO): with monkeypatch.context() as m: m.setenv("VLLM_USE_MODELSCOPE", "True") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO) example_prompts = [ "Hello, my name is", @@ -68,10 +71,10 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch): # inaccurate. This will only change if accuracy improves with the # official weights of DeepSeek-V3. golden_results = [ - 'Hello, my name is feasibility伸 spazio debtor添', - 'The president of the United States is begg"""\n杭州风和 bestimm', - 'The capital of France is frequentlyশามalinkAllowed', - 'The future of AI is deleting俯احت怎么样了حراف', + 'Hello, my name is下载早点向前很有่อง', + 'The president of the United States isSender)## physiological Albany', + 'The capital of France is Rocky转角 hospitalizedinterval sparked', + 'The future of AI is её asegο BIOS一扫', ] assert len(golden_results) == len(vllm_output) diff --git a/tests/multicard/test_w4a8_deepseek.py b/tests/multicard/test_w4a8_deepseek.py new file mode 100644 index 0000000000..98a6f0c17e --- /dev/null +++ b/tests/multicard/test_w4a8_deepseek.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from tests.conftest import VllmRunner + + +@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="w4a8_dynamic is not supported on v0") +def test_deepseek_W4A8(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + dtype = "bfloat16" + max_tokens = 5 + with VllmRunner( + "vllm-ascend/DeepSeek-R1-w4a8-pruning", + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=True, + quantization="ascend", + enable_expert_parallel=True, + additional_config={ + "torchair_graph_config": { + "enabled": False, + }, + "ascend_scheduler_config": { + "enabled": True, + } + }, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(prompts, max_tokens) + + golden_results = [ + 'Hello, my name is逸研究发现IPPudsimentary', + 'The president of the United States is逸 Ban Corporealistically', + 'The capital of France is逸 Ban Corporealistically', + 'The future of AI is逸 Ban Corporealistically', + ] + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") diff --git a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py b/tests/ops/test_vocabparallelembedding.py similarity index 100% rename from tests/e2e/singlecard/ops/test_vocabparallelembedding.py rename to tests/ops/test_vocabparallelembedding.py diff --git a/tests/e2e/singlecard/__init__.py b/tests/singlecard/__init__.py similarity index 100% rename from tests/e2e/singlecard/__init__.py rename to tests/singlecard/__init__.py diff --git a/tests/e2e/singlecard/compile/__init__.py b/tests/singlecard/compile/__init__.py similarity index 100% rename from tests/e2e/singlecard/compile/__init__.py rename to tests/singlecard/compile/__init__.py diff --git a/tests/e2e/singlecard/compile/test_simple.py b/tests/singlecard/compile/test_simple.py similarity index 100% rename from tests/e2e/singlecard/compile/test_simple.py rename to tests/singlecard/compile/test_simple.py diff --git a/tests/e2e/singlecard/core/__init__.py b/tests/singlecard/core/__init__.py similarity index 100% rename from tests/e2e/singlecard/core/__init__.py rename to tests/singlecard/core/__init__.py diff --git a/tests/e2e/singlecard/core/test_ascend_scheduler.py b/tests/singlecard/core/test_ascend_scheduler.py similarity index 95% rename from tests/e2e/singlecard/core/test_ascend_scheduler.py rename to tests/singlecard/core/test_ascend_scheduler.py index 7d9c1b1ef5..c382ebdf40 100644 --- a/tests/e2e/singlecard/core/test_ascend_scheduler.py +++ b/tests/singlecard/core/test_ascend_scheduler.py @@ -98,7 +98,11 @@ def create_scheduler( ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], + **({ + "tensors": {} + } if vllm_version_is("0.9.0") else { + "kv_cache_tensors": [] + }), kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, @@ -141,8 +145,8 @@ def create_requests(num_requests: int, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, **({ - "pooling_params": None - } if not vllm_version_is("0.9.1") else {}), + "arrival_time": 0.0 + } if vllm_version_is("0.9.0") else {}), ) requests.append(request) return requests @@ -258,9 +262,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + ) scheduler.update_from_output(output, model_runner_output) # Schedule the next step. All three requests are running. @@ -284,10 +286,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) - + ) scheduler.update_from_output(output1, model_runner_output) output2 = scheduler.schedule() assert len(scheduler.running) == 3 @@ -338,10 +337,7 @@ def test_stop_via_update_from_output(): 11]], # First request hits EOS, second continues spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -389,10 +385,7 @@ def test_stop_via_update_from_output(): [13, 14]], # First request hits stop token spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -439,10 +432,7 @@ def test_stop_via_update_from_output(): [13]], # First request exceeds max_tokens spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -484,10 +474,7 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -537,10 +524,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) - + ) scheduler.update_from_output(scheduler_output0, model_runner_output) # Schedule the next step. @@ -557,10 +541,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) - + ) scheduler.update_from_output(scheduler_output1, model_runner_output) @@ -584,6 +565,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ + if vllm_version_is("0.9.0"): + return num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) @@ -610,10 +593,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) - + ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -652,10 +632,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) - + ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -750,9 +727,7 @@ def make_output(scheduler: AscendScheduler): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + ) def assert_scheduler_empty(scheduler: AscendScheduler): @@ -769,10 +744,11 @@ def assert_scheduler_empty(scheduler: AscendScheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 + if not vllm_version_is("0.9.0"): + assert len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].num_cached_block) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) @@ -813,4 +789,4 @@ def test_memory_leak(): scheduler.update_from_output(scheduler_output, model_runner_output) # Confirm no memory leak. - assert_scheduler_empty(scheduler) + assert_scheduler_empty(scheduler) \ No newline at end of file diff --git a/tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py b/tests/singlecard/core/test_ascend_scheduler_e2e.py similarity index 100% rename from tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py rename to tests/singlecard/core/test_ascend_scheduler_e2e.py diff --git a/tests/e2e/singlecard/ops/__init__.py b/tests/singlecard/ops/__init__.py similarity index 100% rename from tests/e2e/singlecard/ops/__init__.py rename to tests/singlecard/ops/__init__.py diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/singlecard/ops/test_fused_moe.py similarity index 100% rename from tests/e2e/singlecard/ops/test_fused_moe.py rename to tests/singlecard/ops/test_fused_moe.py diff --git a/tests/e2e/singlecard/ops/test_multi_step.py b/tests/singlecard/ops/test_multi_step.py similarity index 100% rename from tests/e2e/singlecard/ops/test_multi_step.py rename to tests/singlecard/ops/test_multi_step.py diff --git a/tests/e2e/singlecard/ops/test_rotary_embedding.py b/tests/singlecard/ops/test_rotary_embedding.py similarity index 100% rename from tests/e2e/singlecard/ops/test_rotary_embedding.py rename to tests/singlecard/ops/test_rotary_embedding.py diff --git a/tests/e2e/singlecard/sample/__init__.py b/tests/singlecard/sample/__init__.py similarity index 100% rename from tests/e2e/singlecard/sample/__init__.py rename to tests/singlecard/sample/__init__.py diff --git a/tests/e2e/singlecard/sample/test_rejection_sampler.py b/tests/singlecard/sample/test_rejection_sampler.py similarity index 100% rename from tests/e2e/singlecard/sample/test_rejection_sampler.py rename to tests/singlecard/sample/test_rejection_sampler.py diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/singlecard/test_aclgraph.py similarity index 100% rename from tests/e2e/singlecard/test_aclgraph.py rename to tests/singlecard/test_aclgraph.py diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py new file mode 100644 index 0000000000..7344371d87 --- /dev/null +++ b/tests/singlecard/test_ascend_config.py @@ -0,0 +1,188 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + +from tests.conftest import VllmRunner +from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config, + init_ascend_config) + + +def _clean_up_ascend_config(func): + + def wrapper(*args, **kwargs): + clear_ascend_config() + func(*args, **kwargs) + clear_ascend_config() + + return wrapper + + +@_clean_up_ascend_config +def test_run_without_ascend_config(): + with VllmRunner("facebook/opt-125m"): + ascend_config = get_ascend_config() + + assert not ascend_config.torchair_graph_config.enabled + assert not ascend_config.torchair_graph_config.use_cached_graph + assert ascend_config.torchair_graph_config.graph_batch_sizes == [] + assert not ascend_config.torchair_graph_config.graph_batch_sizes_init + assert not ascend_config.ascend_scheduler_config.enabled + + +@_clean_up_ascend_config +def test_run_with_ascend_config(): + if os.getenv("VLLM_USE_V1") == "0": + pytest.skip("graph only works on v1") + + input_additional_config_1 = { + "torchair_graph_config": { + # torchair graph only works with deepseek. The e2e test should be added + # in multicard test with deepseek models. + "enabled": False, + "use_cached_graph": True, + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": False, + "enable_multistream_moe": True, + "enable_multistream_mla": True, + }, + "ascend_scheduler_config": { + "enabled": True, + "enable_chunked_prefill": True, + }, + } + + # check passed with eager mode + with VllmRunner("facebook/opt-125m", + enforce_eager=True, + additional_config=input_additional_config_1): + ascend_config = get_ascend_config() + + assert not ascend_config.torchair_graph_config.enabled + assert ascend_config.torchair_graph_config.use_cached_graph + assert ascend_config.torchair_graph_config.graph_batch_sizes == [ + 1, 2, 4, 8 + ] + assert not ascend_config.torchair_graph_config.graph_batch_sizes_init + assert ascend_config.torchair_graph_config.enable_multistream_mla + assert ascend_config.torchair_graph_config.enable_multistream_moe + assert ascend_config.ascend_scheduler_config.enabled + assert ascend_config.ascend_scheduler_config.enable_chunked_prefill + + +@_clean_up_ascend_config +def test_ascend_config_init_error(): + # ascend_config should be initialized first + with pytest.raises(RuntimeError): + _ = get_ascend_config() + + +@_clean_up_ascend_config +def test_ascend_config_load_error(): + if os.getenv("VLLM_USE_V1") == "0": + pytest.skip("graph only works on v1") + # graph_batch_sizes should be list. + with pytest.raises(TypeError): + input_additional_config_fake_1 = { + "torchair_graph_config": { + "graph_batch_sizes": "fake_size", + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_1): + pass + + # graph_batch_sizes_init should not be True when graph_batch_sizes is not empty. + with pytest.raises(ValueError): + input_additional_config_fake_2 = { + "torchair_graph_config": { + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": True, + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_2): + pass + + # torchair graph only works with deepseek. + with pytest.raises(NotImplementedError): + input_additional_config_fake_2 = { + "torchair_graph_config": { + "enabled": True, + }, + } + with VllmRunner("facebook/opt-125m", + enforce_eager=False, + additional_config=input_additional_config_fake_2): + pass + + # torchair graph should not be enabled with eager mode + with pytest.raises(RuntimeError): + input_additional_config_fake_3 = { + "torchair_graph_config": { + "enabled": True, + }, + } + with VllmRunner("facebook/opt-125m", + enforce_eager=True, + additional_config=input_additional_config_fake_3): + pass + + +@_clean_up_ascend_config +def test_check_ascend_config_v0(): + if os.getenv("VLLM_USE_V1") == "1": + pytest.skip("graph only works on v1, this is the test for v0") + with pytest.raises(NotImplementedError): + input_additional_config_fake_1 = { + "torchair_graph_config": { + "enabled": True, + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_1): + pass + + +@_clean_up_ascend_config +def test_ascend_config_refresh(): + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + # set additional_config with none + init_ascend_config(vllm_config) + + input_additional_config = { + "torchair_graph_config": { + "enabled": False, + "use_cached_graph": True, + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": False, + }, + "refresh": True, + } + + # refresh ascend config + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config): + ascend_config = get_ascend_config() + + assert not ascend_config.torchair_graph_config.enabled + assert ascend_config.torchair_graph_config.use_cached_graph + assert ascend_config.torchair_graph_config.graph_batch_sizes == [ + 1, 2, 4, 8 + ] + assert not ascend_config.torchair_graph_config.graph_batch_sizes_init diff --git a/tests/e2e/singlecard/test_camem.py b/tests/singlecard/test_camem.py similarity index 100% rename from tests/e2e/singlecard/test_camem.py rename to tests/singlecard/test_camem.py diff --git a/tests/e2e/singlecard/test_chunked.py b/tests/singlecard/test_chunked.py similarity index 100% rename from tests/e2e/singlecard/test_chunked.py rename to tests/singlecard/test_chunked.py diff --git a/tests/e2e/singlecard/test_guided_decoding.py b/tests/singlecard/test_guided_decoding.py similarity index 100% rename from tests/e2e/singlecard/test_guided_decoding.py rename to tests/singlecard/test_guided_decoding.py diff --git a/tests/e2e/singlecard/test_ilama_lora.py b/tests/singlecard/test_ilama_lora.py similarity index 100% rename from tests/e2e/singlecard/test_ilama_lora.py rename to tests/singlecard/test_ilama_lora.py diff --git a/tests/e2e/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py similarity index 98% rename from tests/e2e/singlecard/test_offline_inference.py rename to tests/singlecard/test_offline_inference.py index de69612279..cd65a24969 100644 --- a/tests/e2e/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -43,6 +43,10 @@ ] os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +QUANTIZATION_MODELS = [ + "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8", +] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half", "float16"]) diff --git a/tests/e2e/singlecard/test_profile_execute_duration.py b/tests/singlecard/test_profile_execute_duration.py similarity index 100% rename from tests/e2e/singlecard/test_profile_execute_duration.py rename to tests/singlecard/test_profile_execute_duration.py diff --git a/tests/e2e/singlecard/test_prompt_embedding.py b/tests/singlecard/test_prompt_embedding.py similarity index 100% rename from tests/e2e/singlecard/test_prompt_embedding.py rename to tests/singlecard/test_prompt_embedding.py diff --git a/tests/e2e/singlecard/test_pyhccl.py b/tests/singlecard/test_pyhccl.py similarity index 100% rename from tests/e2e/singlecard/test_pyhccl.py rename to tests/singlecard/test_pyhccl.py diff --git a/tests/e2e/singlecard/test_sampler.py b/tests/singlecard/test_sampler.py similarity index 100% rename from tests/e2e/singlecard/test_sampler.py rename to tests/singlecard/test_sampler.py diff --git a/tests/e2e/singlecard/test_scheduler.py b/tests/singlecard/test_scheduler.py similarity index 95% rename from tests/e2e/singlecard/test_scheduler.py rename to tests/singlecard/test_scheduler.py index b3adf945bf..8021f0306c 100644 --- a/tests/e2e/singlecard/test_scheduler.py +++ b/tests/singlecard/test_scheduler.py @@ -31,7 +31,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm_ascend.core.scheduler import AscendScheduler -from vllm_ascend.utils import vllm_version_is EOS_TOKEN_ID = 50256 @@ -131,9 +130,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - **({ - "pooling_params": None - } if not vllm_version_is("0.9.1") else {}), ) requests.append(request) return requests @@ -241,10 +237,7 @@ def test_stop_via_update_from_output(): 11]], # First request hits EOS, second continues spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -290,10 +283,7 @@ def test_stop_via_update_from_output(): [13, 14]], # First request hits stop token spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -338,10 +328,7 @@ def test_stop_via_update_from_output(): [13]], # First request exceeds max_tokens spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) @@ -382,10 +369,7 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}, - **({ - "pooler_output": [] - } if not vllm_version_is("0.9.1") else {})) + prompt_logprobs_dict={}) scheduler.update_from_output(scheduler_output, model_output) diff --git a/tests/ut/fake_weight/config.json b/tests/ut/fake_weight/config.json deleted file mode 100644 index b3fb716a30..0000000000 --- a/tests/ut/fake_weight/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "_name_or_path": "facebook/opt-125m", - "activation_dropout": 0.0, - "activation_function": "relu", - "architectures": [ - "OPTForCausalLM" - ], - "attention_dropout": 0.0, - "bos_token_id": 2, - "do_layer_norm_before": true, - "dropout": 0.1, - "eos_token_id": 2, - "ffn_dim": 3072, - "hidden_size": 768, - "init_std": 0.02, - "layerdrop": 0.0, - "max_position_embeddings": 2048, - "model_type": "opt", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 1, - "prefix": "", - "torch_dtype": "float16", - "transformers_version": "4.21.0.dev0", - "use_cache": true, - "vocab_size": 50272, - "word_embed_proj_dim": 768 -} diff --git a/tests/ut/kv_connector/test_llmdatadist_connector.py b/tests/ut/kv_connector/test_llmdatadist_connector.py new file mode 100644 index 0000000000..94650f43e9 --- /dev/null +++ b/tests/ut/kv_connector/test_llmdatadist_connector.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +from tests.ut.kv_connector.utils import (create_request, create_scheduler, + create_vllm_config) +from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \ + LLMDataDistCMgrConnectorMetadata + + +def test_basic_inferface(): + """Unit test for basic LLMDataDistCMgrConnector interface functionality.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]): + assert block_id == block.block_id diff --git a/tests/ut/kv_connector/test_remote_decode_lifecycle.py b/tests/ut/kv_connector/test_remote_decode_lifecycle.py new file mode 100644 index 0000000000..d321490f65 --- /dev/null +++ b/tests/ut/kv_connector/test_remote_decode_lifecycle.py @@ -0,0 +1,123 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/conftest.py +# + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from tests.ut.kv_connector.utils import (assert_scheduler_empty, + create_model_runner_output, + create_request, create_scheduler, + create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + output = engine_core_outputs[0].outputs[0] + assert output.finish_reason == FinishReason.LENGTH + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and blocks should be freed + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_prefix_cache_lifecycle(): + """Test that remote decode params still works with a prefix cache hit.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 3 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote_a = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote_a) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote_a], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + ##################### + # Actual Test: confirm we send all blocks. + + # Send the KV Transfer. + NUM_EXTERNAL_FULL_BLOCKS -= 1 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote_b = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote_b) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote_b]) + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco[0].outputs[0].kv_transfer_params + # Ensure we send all block ids, even if there is a cache hit. + assert (len( + kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + + 1)) + + scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py new file mode 100644 index 0000000000..c8629d5993 --- /dev/null +++ b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py @@ -0,0 +1,242 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/conftest.py +# +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from tests.ut.kv_connector.utils import (assert_scheduler_empty, + create_model_runner_output, + create_request, create_scheduler, + create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a remote prefill.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks < + START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] + for block in blocks: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert not engine_core_outputs or not engine_core_outputs[0].outputs + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] + for block in blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_remote.request_id] + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_full_block_prompt(): + """Test that we handle a prompt that is the full block size.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Initialize a recv. + scheduler_output = scheduler.schedule() + # All blocks should be allocated. + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # STEP (2): Recv. + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # # STEP (3): Run as usual. + scheduler_output = scheduler.schedule() + + # We need to recompute the final token of the prompt to generate + # the first new token, so we should not have a new block. + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_TOKENS - 1) + assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + + model_runner_output = create_model_runner_output([request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py new file mode 100644 index 0000000000..a8f65ff429 --- /dev/null +++ b/tests/ut/kv_connector/utils.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import os +from typing import Any, Optional + +import torch +from vllm import SamplingParams +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 +os.environ["VLLM_USE_V1"] = "1" + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_kv_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 1024, + block_size: int = 128, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="LLMDataDistCMgrConnector", + kv_role="kv_both", + kv_connector_module_path= + "vllm_ascend.distributed.llmdatadist_connector_v1_a3") + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float16, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 128, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, +) -> Request: + """Make dummy request for testing.""" + + kv_transfer_params: Optional[dict[str, Any]] = None + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = dict(do_remote_prefill=False, + do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = dict(do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list( + range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + remote_tp_size=1) + + max_tokens = 1 if do_remote_decode else max_tokens + sampling_params = SamplingParams(max_tokens=max_tokens) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/tests/ut/ops/test_expert_load_balancer.py b/tests/ut/ops/test_expert_load_balancer.py deleted file mode 100644 index 3b7a69ddd4..0000000000 --- a/tests/ut/ops/test_expert_load_balancer.py +++ /dev/null @@ -1,147 +0,0 @@ -# fused moe ops test will hit the infer_schema error, we need add the patch -# here to make the test pass. -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - -import json -import unittest -from typing import List, TypedDict -from unittest import mock - -import torch - -from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer - - -class Device(TypedDict): - device_id: int - device_expert: List[int] - - -class Layer(TypedDict): - layer_id: int - device_count: int - device_list: List[Device] - - -class MockData(TypedDict): - moe_layer_count: int - layer_list: List[Layer] - - -MOCK_DATA: MockData = { - "moe_layer_count": - 1, - "layer_list": [{ - "layer_id": - 0, - "device_count": - 2, - "device_list": [{ - "device_id": 0, - "device_expert": [7, 2, 0, 3, 5] - }, { - "device_id": 1, - "device_expert": [6, 1, 4, 7, 2] - }] - }] -} - - -class TestExpertLoadBalancer(unittest.TestCase): - - def setUp(self): - json_file = "expert_map.json" - with open(json_file, 'w') as f: - json.dump(MOCK_DATA, f) - - self.expert_load_balancer = ExpertLoadBalancer(json_file, - global_expert_num=8) - - def test_init(self): - - self.assertIsInstance(self.expert_load_balancer.expert_map_tensor, - torch.Tensor) - self.assertEqual(self.expert_load_balancer.layers_num, - MOCK_DATA["moe_layer_count"]) - self.assertEqual(self.expert_load_balancer.ranks_num, - MOCK_DATA["layer_list"][0]["device_count"]) - - def test_generate_index_dicts(self): - tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]]) - result = self.expert_load_balancer.generate_index_dicts(tensor_2d) - expected_result = [{ - 7: 0, - 2: 1, - 0: 2, - 3: 3, - 5: 4 - }, { - 6: 5, - 1: 6, - 4: 7, - 7: 8, - 2: 9 - }] - self.assertEqual(result, expected_result) - - def test_generate_expert_placement_map(self): - expert_placement_map = self.expert_load_balancer.generate_expert_placement_map( - ) - self.assertEqual(expert_placement_map.shape, - (self.expert_load_balancer.layers_num, - self.expert_load_balancer.ranks_num, 8)) - self.assertTrue(torch.all(expert_placement_map >= -1)) - - def test_generate_log2phy_expert_map(self): - layer_id = 0 - log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map( - layer_id) - self.assertEqual(log2phy_map.shape, - (self.expert_load_balancer.ranks_num, 8)) - self.assertTrue(torch.all(log2phy_map >= -1)) - - @mock.patch("torch_npu.npu._lazy_init") - @mock.patch("torch.npu.current_device", return_value="cpu") - def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init): - layer_id = 0 - rank_id = 0 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - self.assertEqual(rank_local_expert_num, 5) - expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - rank_id = 1 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - def test_get_rank_log2phy_map(self): - layer_id = 0 - rank_id = 0 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - rank_id = 1 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - def test_get_global_redundant_expert_num(self): - redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num( - ) - expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \ - MOCK_DATA["layer_list"][0]["device_count"] - 8 - self.assertEqual(redundant_expert_num, expected_redundant_expert_num) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py deleted file mode 100644 index 5ec4dd72cc..0000000000 --- a/tests/ut/test_ascend_config.py +++ /dev/null @@ -1,244 +0,0 @@ -import os -import unittest -from unittest import mock - -from transformers import PretrainedConfig -from vllm.config import ModelConfig, VllmConfig - -from vllm_ascend.ascend_config import (check_ascend_config, - clear_ascend_config, get_ascend_config, - init_ascend_config) - - -class TestAscendConfig(unittest.TestCase): - - @staticmethod - def _clean_up_ascend_config(func): - - def wrapper(*args, **kwargs): - clear_ascend_config() - func(*args, **kwargs) - clear_ascend_config() - - return wrapper - - @_clean_up_ascend_config - def test_init_ascend_config_without_additional_config(self): - test_vllm_config = VllmConfig() - # No additional config given, check the default value here. - ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(ascend_config.expert_tensor_parallel_size, 0) - self.assertIsNone(ascend_config.expert_map_path) - - torchair_graph_config = ascend_config.torchair_graph_config - self.assertFalse(torchair_graph_config.enabled) - self.assertFalse(torchair_graph_config.use_cached_graph) - self.assertEqual(torchair_graph_config.graph_batch_sizes, []) - self.assertFalse(torchair_graph_config.graph_batch_sizes_init) - self.assertFalse(torchair_graph_config.enable_multistream_mla) - self.assertFalse(torchair_graph_config.enable_multistream_moe) - self.assertTrue(torchair_graph_config.enable_view_optimize) - self.assertFalse(torchair_graph_config.enable_kv_nz) - - ascend_scheduler_config = ascend_config.ascend_scheduler_config - self.assertFalse(ascend_scheduler_config.enabled) - - @_clean_up_ascend_config - def test_init_ascend_config_with_additional_config(self): - test_vllm_config = VllmConfig() - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": True, - "graph_batch_sizes": [1, 2, 4], - "graph_batch_sizes_init": False, - "enable_multistream_mla": True, - "enable_multistream_moe": True, - "enable_view_optimize": True, - "enable_kv_nz": True - }, - "ascend_scheduler_config": { - "enabled": True - }, - "expert_tensor_parallel_size": 1, - "expert_map_path": "test_expert_map_path", - "refresh": True - } - ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(ascend_config.expert_tensor_parallel_size, 1) - self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") - - torchair_graph_config = ascend_config.torchair_graph_config - self.assertTrue(torchair_graph_config.enabled) - self.assertTrue(torchair_graph_config.use_cached_graph) - self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4]) - self.assertFalse(torchair_graph_config.graph_batch_sizes_init) - self.assertTrue(torchair_graph_config.enable_multistream_mla) - self.assertTrue(torchair_graph_config.enable_multistream_moe) - self.assertTrue(torchair_graph_config.enable_view_optimize) - self.assertTrue(torchair_graph_config.enable_kv_nz) - - ascend_scheduler_config = ascend_config.ascend_scheduler_config - self.assertTrue(ascend_scheduler_config.enabled) - - @_clean_up_ascend_config - def test_init_ascend_config_with_refresh(self): - test_vllm_config = VllmConfig() - ascend_config = init_ascend_config(test_vllm_config) - self.assertFalse(ascend_config.torchair_graph_config.enabled) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - ascend_config = init_ascend_config(test_vllm_config) - self.assertFalse(ascend_config.torchair_graph_config.enabled) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True, - } - ascend_config = init_ascend_config(test_vllm_config) - self.assertTrue(ascend_config.torchair_graph_config.enabled) - - @_clean_up_ascend_config - def test_init_ascend_config_with_wrong_input(self): - test_vllm_config = VllmConfig() - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - "graph_batch_sizes": "fake_size", - }, - "refresh": True, - } - with self.assertRaises(TypeError): - init_ascend_config(test_vllm_config) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "graph_batch_sizes": [1, 2, 4, 8], - "graph_batch_sizes_init": True, - }, - "refresh": True, - } - with self.assertRaises(ValueError): - init_ascend_config(test_vllm_config) - - @_clean_up_ascend_config - def test_get_ascend_config(self): - test_vllm_config = VllmConfig() - ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(get_ascend_config(), ascend_config) - - @_clean_up_ascend_config - def test_get_ascend_config_without_init(self): - with self.assertRaises(RuntimeError): - get_ascend_config() - - @_clean_up_ascend_config - def test_clear_ascend_config(self): - test_vllm_config = VllmConfig() - ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(get_ascend_config(), ascend_config) - clear_ascend_config() - with self.assertRaises(RuntimeError): - get_ascend_config() - - @_clean_up_ascend_config - def test_check_ascend_config_pass(self): - test_vllm_config = VllmConfig() - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - # For V1 engine - with mock.patch.dict(os.environ, {"VLLM_USE_V1": "1"}): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - @_clean_up_ascend_config - def test_check_ascend_config_wrong_case(self): - test_vllm_config = VllmConfig() - # For V0 engine - with mock.patch.dict(os.environ, {"VLLM_USE_V1": "0"}): - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, True) - # For V1 engine - with mock.patch.dict(os.environ, {"VLLM_USE_V1": "1"}): - # torchair + eager mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - enforce_eager = True - check_ascend_config(test_vllm_config, enforce_eager) - # torchair + non deepseek model - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - model_path = os.path.join(os.path.dirname(__file__), - "fake_weight") - fake_model_config = ModelConfig(model=model_path) - fake_model_config.hf_config = PretrainedConfig() - fake_model_config.hf_config.model_type = "llama" - test_vllm_config.model_config = fake_model_config - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - # aclgraph + deepseek model - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True - } - model_path = os.path.join(os.path.dirname(__file__), - "fake_weight") - fake_model_config = ModelConfig(model=model_path) - fake_model_config.hf_config = PretrainedConfig() - fake_model_config.hf_config.model_type = "deepseek" - test_vllm_config.model_config = fake_model_config - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py deleted file mode 100644 index fdffa2a0fd..0000000000 --- a/tests/ut/worker/test_worker_v1.py +++ /dev/null @@ -1 +0,0 @@ -# placeholder diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d8b87c6952..eb5b09c4fb 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -36,11 +36,12 @@ def __init__(self, vllm_config): self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) - self.expert_tensor_parallel_size = int( - additional_config.get("expert_tensor_parallel_size", 0)) self.expert_map_path = additional_config.get("expert_map_path", None) + self.dynamic_eplb = additional_config.get("dynamic_eplb", False) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) + self.enable_weight_nz_layout = additional_config.get( + "enable_weight_nz_layout", False) class TorchairGraphConfig: @@ -138,6 +139,12 @@ def check_ascend_config(vllm_config, enforce_eager): else: # torchair_graph case if ascend_config.torchair_graph_config.enabled: + # torchair_graph is not supported for V1 without mla currently. + if envs.VLLM_MLA_DISABLE: + logger.warning( + "Torchair graph mode is still experimental and not supported for V1 without mla currently, " + "it has been disabled automatically.") + ascend_config.torchair_graph_config.enabled = False # torchair_graph is supported for deepseek model only currently. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py new file mode 100644 index 0000000000..75fd71c859 --- /dev/null +++ b/vllm_ascend/ascend_forward_context.py @@ -0,0 +1,79 @@ +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional + +import torch +from vllm.config import VllmConfig +from vllm.distributed import get_dp_group +from vllm.forward_context import get_forward_context, set_forward_context + + +class FusedMoEState(Enum): + AllGather = 0 + All2All = 1 + MC2 = 2 + + +# TODO(zzzzwwjj): add soc_version to choose branch +def get_fused_moe_state(ep_size: int, with_prefill: bool): + if ep_size == 1: + return FusedMoEState.AllGather + # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. + elif ep_size < 16 or with_prefill: + return FusedMoEState.All2All + else: + return FusedMoEState.MC2 + + +@contextmanager +def set_ascend_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False): + """A context manager that stores the current forward context, + can be attention metadata, etc. + We add some additional param into forward_context. + """ + with set_forward_context(attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + forward_context = get_forward_context() + forward_context.with_prefill = with_prefill + + ep_size = torch.distributed.get_world_size( + ) if vllm_config.parallel_config.enable_expert_parallel else 1 + + fused_moe_state = get_fused_moe_state(ep_size, with_prefill) + + forward_context.fused_moe_state = fused_moe_state + + forward_context.in_profile_run = in_profile_run + + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing + forward_context.capturing = False + + dp_world_size = get_dp_group().world_size + if dp_world_size > 1 and forward_context.dp_metadata is not None: + forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( + ) + elif num_tokens is not None: + forward_context.max_tokens_across_dp = num_tokens + elif attn_metadata is not None: + if hasattr(attn_metadata, 'num_actual_tokens'): + forward_context.max_tokens_across_dp = attn_metadata.num_actual_tokens + else: + forward_context.max_tokens_across_dp = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + else: + forward_context.max_tokens_across_dp = None + + try: + yield + finally: + pass diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e6a2376786..3417bb87fb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,12 +24,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill +from vllm_ascend.utils import get_graph_params class AscendAttentionBackend(AttentionBackend): @@ -114,6 +118,7 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor + seq_lens_list: list # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -133,7 +138,7 @@ class AscendMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - with_prefill_across_dp: bool = False + enable_dbo_across_dp: bool = False class AscendAttentionMetadataBuilder: @@ -149,23 +154,26 @@ def build(self, num_reqs, num_actual_tokens, max_query_len, - common_prefix_len, - with_prefill_across_dp: bool = False): + common_attn_metadata: CommonAttentionMetadata, + enable_dbo_across_dp: bool = False, + *args, + **kwargs): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # TODO: Refactor these two param to common metadata in runners, + # preparing for the hybrid KV groups feature + query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens + seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list + + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] attn_mask = self.runner.attn_mask attn_state = self.runner.attn_state - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -173,11 +181,40 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, + seq_lens_list=seq_lens_list, max_query_len=max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - with_prefill_across_dp=with_prefill_across_dp) + enable_dbo_across_dp=enable_dbo_across_dp) + return attn_metadata + + def build_dummy_metadata(self, num_actual_tokens, num_reqs, + num_scheduled_tokens, attn_state): + if attn_state == AscendAttentionState.DecodeOnly: + # NOTE: We only need to pay attention to seq_lens_list and block_table here + common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] * + num_reqs) + + block_table = self.runner.input_batch.block_table[0].block_table + block_table[:num_reqs, 0] = torch.arange(1, + num_reqs + 1, + device=block_table.device, + dtype=block_table.dtype) + + attn_metadata = self.build( + num_reqs=num_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=num_scheduled_tokens.max(), + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state return attn_metadata @@ -217,6 +254,10 @@ def __init__( self.key_cache = None self.value_cache = None + vllm_config = get_current_vllm_config() + self.full_graph = vllm_config.compilation_config.full_cuda_graph + self.block_size = vllm_config.cache_config.block_size + def forward( self, layer: AttentionLayer, @@ -228,21 +269,7 @@ def forward( output: Optional[torch.Tensor] = None, trace_flag: bool = True, ) -> torch.Tensor: - """Forward pass with Ascend attention. - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads, head_size] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size * seq_len, num_heads, head_size] - """ + """Forward pass with Ascend attention.""" num_tokens = query.shape[0] if output is None: output = torch.empty(num_tokens, @@ -275,7 +302,7 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() - if kv_cache.numel() > 0: + if len(kv_cache) > 0: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping @@ -307,11 +334,13 @@ def forward( assert attn_metadata is not None assert attn_metadata.attn_mask is not None compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] torch_npu._npu_flash_attention_qlens( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - block_table=attn_metadata.block_tables, + block_table=block_table, mask=compress_mask, seq_len=attn_metadata.query_lens, context_lens=attn_metadata.seq_lens, @@ -320,16 +349,92 @@ def forward( scale_value=self.scale, out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + if self.full_graph: + graph_params = get_graph_params() + q = query.view(num_tokens, -1, self.hidden_size) + k = self.key_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + v = self.value_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + actual_seq_lens = attn_metadata.seq_lens_list + attn_args = { + "query": q, + "key": k, + "value": v, + "actual_seq_lengths_kv": actual_seq_lens, + "block_table": attn_metadata.block_tables, + "num_heads": self.num_heads, + "scale": self.scale, + "input_layout": "BSH", + "num_key_value_heads": self.num_kv_heads, + "block_size": self.block_size, + } + + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + attn_output = torch.empty(num_tokens, + 1, + self.hidden_size, + dtype=output.dtype, + device=output.device) + softmax_lse = torch.empty(num_tokens, + dtype=output.dtype, + device=output.device) + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + **attn_args) + graph_params.workspaces[num_tokens] = workspace + + forward_context = get_forward_context() + if not forward_context.capturing: + # Execute attention kernel directly in non-capturing mode + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + else: + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append( + (q, k, v, actual_seq_lens, + attn_metadata.block_tables, self.num_heads, + self.scale, self.num_kv_heads, attn_output, + softmax_lse)) + + torch.npu.graph_task_group_begin(stream) + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + + # Reshape output to match the expected format + output.copy_( + attn_output.view(num_tokens, self.num_heads, + self.head_size)) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) # Normal V1 situation. else: # use chunked prefill for head size 192 scenario, like deepseek diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 189aa38e89..816d93c028 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,34 +13,23 @@ UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_stream_switch, + npu_wait_tensor) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -@dataclass -class CommonAttentionMetadata: - """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. - """ - - query_start_loc: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - class AscendMLABackend(AttentionBackend): accept_output_buffer: bool = True @@ -103,6 +92,7 @@ class AscendMLADecodeMetadata: seq_lens: torch.Tensor max_seq_lens: int seq_lens_list: list[int] + actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None @@ -136,8 +126,8 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - max_num_tokens_across_dp: int = 0 - with_prefill_across_dp: bool = False + enable_dbo_across_dp: bool = False + is_mtp_model: bool = False query_lens: Optional[list[int]] = None # The dimension of the attention heads @@ -290,7 +280,7 @@ def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs + assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" if isinstance(self.runner.graph_block_tables, np.ndarray): graph_block_tables = torch.zeros((max_batch_size, max_blocks), @@ -312,8 +302,12 @@ def _get_graph_runner_block_tables( return graph_block_tables[:num_seqs, :max_blocks] - def build_dummy(self, num_reqs: int, - num_actual_tokens: int) -> AscendMLAMetadata: + def build_torchair_graph_dummy( + self, + num_reqs: int, + num_actual_tokens: int, + is_mtp_model: bool = False, + ) -> AscendMLAMetadata: device = self.runner.device _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), @@ -321,11 +315,13 @@ def build_dummy(self, num_reqs: int, device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) - input_positions = torch.zeros(num_reqs, + num_tokens = num_reqs * self.runner.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = seq_lens.tolist() + input_positions = torch.zeros(num_tokens, dtype=torch.int32, device=device).long() - slot_mapping = torch.full((num_reqs, ), + slot_mapping = torch.full((num_tokens, ), PAD_SLOT_ID, dtype=torch.int32, device=device) @@ -333,28 +329,38 @@ def build_dummy(self, num_reqs: int, -1, dtype=torch.int32, device=device) + if self.runner.speculative_config is not None and\ + self.runner.speculative_config.method == 'deepseek_mtp' and not is_mtp_model: + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), + seq_lens_list=seq_lens_list, max_seq_lens=1, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], + ) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=1, - num_decode_tokens=1, + num_decode_tokens=num_decode_tokens, num_prefills=0, attn_mask=self.runner.attn_mask, - attn_state=AscendAttentionState.DecodeOnly, + attn_state=attn_state, prefill=None, decode=decode_metadata, query_start_loc=query_start_loc, seq_lens=seq_lens, block_tables=block_table, + is_mtp_model=is_mtp_model, ) def build( @@ -364,9 +370,10 @@ def build( max_query_len: int, common_attn_metadata: CommonAttentionMetadata, common_prefix_len: Optional[int] = None, - graph_pad_size: int = -1, - max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False, + num_token_pad_size: int = -1, + num_reqs_pad_size: int = 0, + enable_dbo_across_dp: bool = False, + is_mtp_model: bool = False, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -450,8 +457,9 @@ def build( ) decode_metadata = None - use_torchair_graph = graph_pad_size != -1 + use_torchair_graph = num_token_pad_size != -1 if self._num_decodes > 0: + actual_seq_q_lens = None max_seq_lens = seq_lens[:self._num_decodes].max().item() seq_lens = seq_lens[:self._num_decode_tokens] input_positions = input_positions[:self._num_decode_tokens] @@ -460,41 +468,48 @@ def build( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: - num_seqs = len(seq_lens) - if graph_pad_size != 0: - pad_value = 1 - padded_seq_lens = seq_lens.tolist() + [pad_value - ] * graph_pad_size + if num_token_pad_size != 0: + pad_value = 0 + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size else: padded_seq_lens = seq_lens.tolist() seq_lens = torch.from_numpy( np.array(padded_seq_lens).astype(np.int32)) - padding = torch.full((graph_pad_size, ), + seq_lens_list = padded_seq_lens + padding = torch.full((num_token_pad_size, ), PAD_SLOT_ID, dtype=slot_mapping.dtype, device=slot_mapping.device) slot_mapping = torch.cat([slot_mapping, padding]) block_table_padding = torch.zeros( - (graph_pad_size, ) + block_table.shape[1:], + (num_reqs_pad_size, ) + block_table.shape[1:], dtype=block_table.dtype, device=block_table.device) block_table = torch.cat([block_table, block_table_padding], dim=0) block_table = self._get_graph_runner_block_tables( - num_seqs + graph_pad_size, block_table) - padding_0 = torch.zeros(graph_pad_size, + num_reqs + num_reqs_pad_size, block_table) + padding_0 = torch.zeros(num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device) input_positions = torch.cat([input_positions, padding_0]) + actual_seq_q_lens = query_start_loc[1:].tolist( + ) + self.runner.actual_seq_q_lens[num_reqs:num_reqs + + num_reqs_pad_size] + else: + seq_lens_list = seq_lens.tolist() decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), + seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + actual_seq_q_lens=actual_seq_q_lens, + ) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -511,8 +526,8 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp, + enable_dbo_across_dp=enable_dbo_across_dp, + is_mtp_model=is_mtp_model, ) @@ -570,15 +585,6 @@ def __init__( self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 - # TODO: support numHeads / numKvHeads < 16 in MLA kernel - if self.torchair_graph_enabled: - assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \ - ("The allowed number of queries per kv when enabling both MLA and Graph mode" - " only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite," - " as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1," - " please make sure after the tensor parallel split, num_heads / num_kv_heads in " - "{32, 64, 128}.") - def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -651,20 +657,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if get_ascend_config().enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, + ACL_FORMAT_FRACTAL_NZ) + self.W_UK_T.data = torch_npu.npu_format_cast( + self.W_UK_T.data, ACL_FORMAT_FRACTAL_NZ) def _compute_prefill_context( self, query: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, ): + assert len(kv_c_and_k_pe_cache) > 1 prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse @@ -674,21 +683,22 @@ def _compute_prefill_context( q_nope = query[..., :self.qk_nope_head_dim] seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) - latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim - cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] - cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, latent_kv_dim, dtype=query.dtype, device=query.device) k_pe = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, rope_dim, dtype=query.dtype, device=query.device) @@ -738,10 +748,11 @@ def _forward_prefill( query: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 num_tokens = query.size(0) attn_output = torch.empty(num_tokens, @@ -758,7 +769,8 @@ def _forward_prefill( if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output_torch = torch.empty(num_tokens, self.num_heads * self.v_head_dim, @@ -783,7 +795,8 @@ def _forward_prefill( causal=True) elif attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ]: attn_lse = torch.empty(self.num_heads, num_tokens, @@ -833,15 +846,12 @@ def _forward_prefill( num_kv_heads=self.num_heads, out=attn_output) attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch @@ -934,44 +944,17 @@ def _forward_decode( q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1) - num_tokens = q.size(0) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) + num_tokens = q_nope.size(0) if self.running_in_graph: - # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 - q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, self.num_heads, - -1) - q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, self.num_heads, -1) - if not self.enable_kv_nz: - q_nope = q_nope.transpose(1, 2).contiguous() - q_pe = q_pe.transpose(1, 2).contiguous() - sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore - else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - sparse_mode = 0 - spec_attn_mask = None # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] + actual_seq_lengths = None if self.enable_kv_nz: k_nope = k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank // 16, block_size, 16) @@ -985,6 +968,26 @@ def _forward_decode( self.qk_rope_head_dim) input_layout = "BNSD" + # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + # [bs * q_seq_len, num_heads_per_rank, dim] + input_layout = "TND" + q_nope = q_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_q_lens + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + attn_output, _ = torch_npu.npu_fused_infer_attention_score( q_nope, k_nope, @@ -1002,18 +1005,37 @@ def _forward_decode( block_table=decode_meta.block_table, block_size=block_size, actual_seq_lengths_kv=decode_meta.seq_lens_list, - ) + actual_seq_lengths=actual_seq_lengths) else: - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=kv_c_and_k_pe_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs.VLLM_ASCEND_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) + else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode. + block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self._v_up_proj_and_o_proj(attn_output) @@ -1029,7 +1051,7 @@ def forward( hidden_states_or_q_c: torch.Tensor, # query in unified attn hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -1037,16 +1059,17 @@ def forward( if attn_metadata is None: # Profiling run. return output + # mtp model is not support for graph mode yet + self.torchair_graph_enabled = self.torchair_graph_enabled and not attn_metadata.is_mtp_model self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - if not self.torchair_graph_enabled: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -1065,19 +1088,20 @@ def forward( if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - if not self.torchair_graph_enabled: - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings + seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor cos = self.rotary_emb.cos_cached[:seq_len].to( dtype=decode_hs_or_q_c.dtype) sin = self.rotary_emb.sin_cached[:seq_len].to( @@ -1111,9 +1135,7 @@ def forward( else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) + decode_q_pe.contiguous(), decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ @@ -1122,7 +1144,7 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings + seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor cos = self.rotary_emb.cos_cached[:seq_len].to( dtype=prefill_q_pe.dtype) sin = self.rotary_emb.sin_cached[:seq_len].to( @@ -1134,22 +1156,24 @@ def forward( prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_c_normed = prefill_k_nope prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) + prefill_q_pe.contiguous(), prefill_k_pe) + + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if len(kv_cache) > 0 and kv_cache[0].numel( + if kv_cache[0].numel( ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues @@ -1159,16 +1183,15 @@ def forward( key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slots) - elif kv_cache.numel() > 0: - key = torch.cat([ - kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), - k_pe - ], - dim=2) - torch_npu._npu_reshape_and_cache_siso( - key=key, - key_cache=kv_cache, - slot_indices=attn_metadata.slot_mapping.flatten()) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py new file mode 100644 index 0000000000..c2b7bc156a --- /dev/null +++ b/vllm_ascend/attention/utils.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + + query_start_loc: torch.Tensor = None + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + query_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including only the newly + scheduled tokens""" + seq_lens_list: Optional[list] = None + """(num_input_tokens,), note that this is specifically for FIA kernel""" diff --git a/vllm_ascend/compilation/piecewise_backend.py b/vllm_ascend/compilation/piecewise_backend.py index c6a800b3d8..aafe639373 100644 --- a/vllm_ascend/compilation/piecewise_backend.py +++ b/vllm_ascend/compilation/piecewise_backend.py @@ -28,9 +28,13 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.utils import weak_ref_tensors +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.utils import get_graph_params, set_graph_params + @dataclasses.dataclass class ConcreteSizeEntry: @@ -95,6 +99,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.compilation_config.full_cuda_graph: + self.update_stream = torch.npu.Stream() + set_graph_params(self.aclgraph_capture_sizes) + # the entries for different shapes that we need to either # compile or capture aclgraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} @@ -116,7 +124,40 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def update_attn_params(self, graph_params, forward_context, runtime_shape): + for layer_idx in range(len(graph_params.handles[runtime_shape])): + query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[ + runtime_shape][layer_idx] + block_table = forward_context.attn_metadata.block_tables + actual_seq_lens = forward_context.attn_metadata.seq_lens_list + + with torch.npu.stream(self.update_stream): + torch.npu.graph_task_update_begin( + self.update_stream, + graph_params.handles[runtime_shape][layer_idx]) + torch.ops.npu.npu_fused_infer_attention_score.out( + query, + key, + value, + workspace=graph_params.workspaces[runtime_shape], + actual_seq_lengths_kv=actual_seq_lens, + block_table=block_table, + num_heads=num_heads, + scale=scale, + input_layout="BSH", + num_key_value_heads=num_kv_heads, + block_size=128, + out=[output, softmax_lse], + ) + torch.npu.graph_task_update_end(self.update_stream) + + graph_params.events[runtime_shape][layer_idx].record( + self.update_stream) + def __call__(self, *args) -> Any: + forward_context = get_forward_context() + graph_params = get_graph_params() + if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() @@ -127,6 +168,11 @@ def __call__(self, *args) -> Any: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) + if (getattr(forward_context.attn_metadata, "attn_state", + None) != AscendAttentionState.DecodeOnly + and self.compilation_config.full_cuda_graph): + return self.compiled_graph_for_general_shape(*args) + entry = self.concrete_size_entries[runtime_shape] if entry.runnable is None: @@ -189,6 +235,7 @@ def __call__(self, *args) -> Any: patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True with torch.npu.graph(aclgraph, pool=self.graph_pool): # `output` is managed by pytorch's aclgraph pool output = entry.runnable(*args) @@ -222,4 +269,9 @@ def __call__(self, *args) -> Any: ) entry.aclgraph.replay() + + if self.compilation_config.full_cuda_graph: + self.update_attn_params(graph_params, forward_context, + runtime_shape) + return entry.output diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 2fa31c264c..3f1477c9f9 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -23,7 +23,6 @@ from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.utils import cdiv -from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs @@ -87,14 +86,11 @@ def skip_cur_request(): self.waiting.popleft() skipped_waiting_requests.appendleft(request) - num_prealloc_computed_tokens = 0 # P/D: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING - num_prealloc_computed_tokens = ( - request.num_computed_tokens) else: skip_cur_request() continue @@ -112,8 +108,8 @@ def skip_cur_request(): load_kv_async = False # Get already-cached tokens. - if num_prealloc_computed_tokens == 0: - new_computed_blocks, num_native_computed_tokens = \ + if request.num_computed_tokens == 0: + new_computed_blocks, num_new_local_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) @@ -121,18 +117,17 @@ def skip_cur_request(): if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) + request, num_new_local_computed_tokens)) # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + + num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = KVCacheBlocks.create_empty() - num_native_computed_tokens = 0 - - # Total computed tokens (allocated in prior step). - num_computed_tokens = num_prealloc_computed_tokens + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens # P/D: loading remote KV, do not allocate for new work. if load_kv_async: @@ -142,9 +137,6 @@ def skip_cur_request(): # Number of tokens to be scheduled. else: prompt_limit = self._get_prompt_limit(request) - # Get already-cached tokens. - computed_blocks, num_computed_tokens = ( - self.kv_cache_manager.get_computed_blocks(request)) # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. @@ -172,7 +164,7 @@ def skip_cur_request(): skip_cur_request() continue assert num_new_tokens > 0 - blocks = computed_blocks.blocks[0] + blocks = new_computed_blocks.blocks[0] watermark = getattr(self.scheduler_config, "watermark", 0.01) if not self._check_watermark_for_prefill(request, num_new_tokens, @@ -184,8 +176,8 @@ def skip_cur_request(): new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, - num_native_computed_tokens, - new_computed_blocks=computed_blocks, + num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async) if new_blocks is None: @@ -195,8 +187,7 @@ def skip_cur_request(): # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if num_external_computed_tokens: - assert self.connector is not None + if self.connector is not None: self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, @@ -210,6 +201,7 @@ def skip_cur_request(): skipped_waiting_requests.appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, @@ -509,3 +501,40 @@ def update_from_output( return super().update_from_output(scheduler_output, model_runner_output) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less then one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + + # This will cache the blocks if caching is enabled. + # Note: vllm fix this in main branch, but still have issue on v0.9.1, so we just adopt the + # change on 0.9.1 and without cherry-pick this back to main branch on vllm-ascend + if self.kv_cache_manager.enable_caching: + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 88c2f2199b..d7be705c2b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -25,3 +25,8 @@ KVConnectorFactory.register_connector( "AscendSimpleConnector", "vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "LLMDataDistCMgrConnector", + "vllm_ascend.distributed.llmdatadist_c_mgr_connector", + "LLMDataDistCMgrConnector") diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py new file mode 100644 index 0000000000..34543cc05c --- /dev/null +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -0,0 +1,789 @@ +import contextlib +import json +import math +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import llm_datadist # type: ignore +import msgspec +import torch +import zmq +from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, + LLMException, LLMRole) +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.forward_context import ForwardContext +from vllm.utils import get_ip, logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_ascend import envs +from vllm_ascend.soc_info import NPUSocInfo + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + + +class LLMDataDistCMgrAgentMetadata(msgspec.Struct): + super_pod_id: str + server_id: str + device_id: str + device_ip: str + super_device_id: str + cluster_id: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: str + engine_id: str + remote_tp_size: str + + +class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req(self, request_id: str, local_block_ids: list[int], + kv_transfer_params: dict[str, Any]): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_tp_size=kv_transfer_params["remote_tp_size"], + ) + + +class LLMDataDistCMgrConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[ + LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler( + vllm_config, self.engine_id) + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches( + self, + kv_caches: dict[str, # type: ignore[override] + Tuple[torch.Tensor]]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + LLMDataDistCMgrConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata, **kwargs) -> None: + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + +class LLMDataDistCMgrConnectorScheduler(): + + def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.local_ip = get_ip() + # Can not retrieve the parallel config since it is not initialized. + self.local_dp_rank = None + self.tp_size = None + dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + + self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT + + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}" + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_externel_tokens: int): + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}" + ) + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port", "remote_tp_size")): + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning("" \ + f"Invalid KVTransferParams {params}, This request will be discard") + else: + assert num_externel_tokens == 0 + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = LLMDataDistCMgrConnectorMetadata() + + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params) + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + + params = request.kv_transfer_params + logger.debug( + "LLMDataDistCMgrConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # note: NIXL transfer the full block only, but I don't see any reason to do that, so here + # we just transfer any data that computed from prefill node + # note: there might be some issue on this, check it if there is any unexpected result + computed_block_ids = block_ids + # If prompt < block_size, no xfer so free blocks immediately. + + return False, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.local_ip, + remote_port=self.port, + remote_tp_size=str( + self.vllm_config.parallel_config.tensor_parallel_size), + ) + + +class LLMDataDistCMgrConnectorWorker(): + """ + Implementation of Worker side methods + """ + + def __init__(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None + logger.info("Initialize the LLMDataDistCMgrConnectorWorker") + # we assume the local node only contains dp and tp, and tp will not communicate inter-node. + # for any scenario beyond this scope, the functionality of this connector is not guaranteed. + self.local_rank_on_node = get_world_group().rank % ( + vllm_config.parallel_config.data_parallel_size_local * + vllm_config.parallel_config.tensor_parallel_size) + self.local_rank = get_world_group().local_rank + self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tp_group().rank_in_group + self.rank = get_world_group().rank + self.local_ip = get_ip() + self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config + self.local_agent_metadata: Optional[ + LLMDataDistCMgrAgentMetadata] = None + self.vllm_config = vllm_config + + self.llm_datadist_role = None + self.llm_datadist_remote_role = None + if self.kv_transfer_config.kv_role == "kv_producer": + self.llm_datadist_role = LLMRole.PROMPT + self.llm_datadist_remote_role = LLMRole.DECODER + elif self.kv_transfer_config.kv_role == "kv_consumer": + self.llm_datadist_role = LLMRole.DECODER + self.llm_datadist_remote_role = LLMRole.PROMPT + else: + raise RuntimeError( + f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}" + ) + + # linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"} + self.linked_cluster: dict[Any, Any] = {} + self.prefill_device_list: list[tuple[int, int]] = [] + self.decode_device_list: list[tuple[int, int]] = [] + global_rank_table = self.read_offline_rank_table() + self.local_agent_metadata = self.read_agent_metadata( + global_rank_table, self.local_ip, self.local_rank_on_node, + self.llm_datadist_role) + self.llm_datadist = LLMDataDist(self.llm_datadist_role, + self.local_agent_metadata.cluster_id) + self.init_llm_datadist() + self.finished_reqs: set[str] = set() + self.soc_info = NPUSocInfo() + + def listen_for_agent_metadata_req(self, event: threading.Event): + assert self.local_agent_metadata is not None + port = envs.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank + url = f"tcp://0.0.0.0:{port}" + msg_encoder = msgspec.msgpack.Encoder() + msg_decoder = msgspec.msgpack.Decoder() + msg_to_send = msg_encoder.encode(self.local_agent_metadata) + logger.debug(f"Start to listen to address: {url}") + logger.debug( + f"The local agent metadata have {len(msg_to_send)} bytes here") + logger.info( + f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers" + ) + with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined] + event.set() + while True: + identity, _, msg = sock.recv_multipart() + decode_msg = msg_decoder.decode(msg) + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + + def init_llm_datadist(self): + assert self.local_agent_metadata is not None + llm_config = LLMConfig() + llm_config.device_id = self.local_rank + llm_config.sync_kv_timeout = 20000 + llm_config.enable_switch_role = True + llm_config.enable_cache_manager = True + llm_config.enable_remote_cache_accessible = True + llm_config_options = llm_config.generate_options() + self.llm_datadist.init(llm_config_options) + self.cache_manager = self.llm_datadist.cache_manager + logger.info( + f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}" + ) + + def read_offline_rank_table(self): + assert ( + envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + ), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH" + rank_table_path = envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + with open(rank_table_path, "r", encoding="utf-8") as f: + global_rank_table = json.load(f) + decode_device_list = global_rank_table["decode_device_list"] + for decode_device in decode_device_list: + server_id = decode_device["server_id"] + device_id = decode_device["device_id"] + self.decode_device_list.append((server_id, device_id)) + prefill_device_list = global_rank_table["prefill_device_list"] + for prefill_device in prefill_device_list: + server_id = prefill_device["server_id"] + device_id = prefill_device["device_id"] + self.prefill_device_list.append((server_id, device_id)) + + # global_rank_table = json.dumps(global_rank_table) + return global_rank_table + + def read_agent_metadata(self, global_rank_table, server_id, device_rank, + agent_role): + devices_type_list = [] + agent_metadata = None + if self.llm_datadist_role == LLMRole.PROMPT: + devices_type_list.append("prefill_device_list") + elif self.llm_datadist_role == LLMRole.DECODER: + devices_type_list.append("decode_device_list") + else: + devices_type_list.append("prefill_device_list") + devices_type_list.append("decode_device_list") + for device_type in devices_type_list: + device_list = global_rank_table[device_type] + device_list = [ + d for d in device_list if d.get("server_id") == server_id + ] + if len(device_list) <= device_rank: + continue + device_info = device_list[device_rank] + super_pod_id_ = device_info.get("super_pod_id", None) + server_id_ = device_info["server_id"] + device_id_ = device_info["device_id"] + device_ip_ = device_info["device_ip"] + super_device_id_ = device_info.get("super_device_id", None) + cluster_id_ = int(device_info["cluster_id"]) + agent_metadata = LLMDataDistCMgrAgentMetadata( + super_pod_id=super_pod_id_, + server_id=server_id_, + device_id=device_id_, + device_ip=device_ip_, + super_device_id=super_device_id_, + cluster_id=cluster_id_, + ) + assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table" + return agent_metadata + + def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + assert len(first_kv_cache_tuple) > 1 + assert self.local_agent_metadata is not None + kv_cache_dtype = first_kv_cache.dtype + self.use_mla: bool = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # MHA case. [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + self.block_len = math.prod(block_shape) + self.cache_addr: list[int] = [] + alignment = 2 * 1024 * 1024 + if self.use_mla: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + k_normed = None + k_pe = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe = cache_or_caches[0], cache_or_caches[1] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe) + self.cache_key = (cache_key_k_normed, cache_key_k_pe) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + self.cache = (cache_k_normed, cache_k_pe) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + else: + for cache_or_caches in kv_caches.values(): + for cache in cache_or_caches: + base_addr = cache.data_ptr() + assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + self.cache_addr.append(base_addr) + # register paged kv cache into the llm_cache manager + self.cache_desc = CacheDesc( + len(self.cache_addr), [*cache.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + self.cache_key = BlocksCacheKey( + cluster_id=int(self.local_agent_metadata.cluster_id)) + logger.info( + f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}" + ) + try: + self.cache = self.cache_manager.register_blocks_cache( + self.cache_desc, self.cache_addr, self.cache_key) + logger.info( + "LLMDataDistCMgrConnectorWorker: End of register Paged Cache." + ) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + self.ready_event = threading.Event() + self.metadata_agent_listener_t = threading.Thread( + target=self.listen_for_agent_metadata_req, + args=(self.ready_event, ), + daemon=True, + name="metadata_agent_listener") + self.metadata_agent_listener_t.start() + self.ready_event.wait() + + def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + for req_id, meta in metadata.requests.items(): + logger.debug(f"Start to transmit {req_id}") + self._read_blocks(meta.local_block_ids, + meta.remote_block_ids, meta.remote_host, + int(meta.remote_port), meta.engine_id, req_id, + meta.remote_tp_size) + self.finished_reqs.add(req_id) + + def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: + assert self.local_agent_metadata is not None + remote_cluster_id = metadata.cluster_id + if remote_cluster_id in self.linked_cluster: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection" + ) + return remote_cluster_id + remote_super_pod_id = metadata.super_pod_id + remote_server_id = metadata.server_id + is_same_server = remote_server_id == self.local_agent_metadata.server_id + is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id + if self.llm_datadist_role == LLMRole.PROMPT: + prefill_metadata = self.local_agent_metadata + decode_metadata = metadata + else: + prefill_metadata = metadata + decode_metadata = self.local_agent_metadata + comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}" + cluster_rank_info = { + prefill_metadata.cluster_id: 0, + decode_metadata.cluster_id: 1 + } + rank_table = {} + rank_table["version"] = "1.2" + rank_table["server_count"] = "1" if is_same_server else "2" + rank_table["status"] = "completed" + + # generate server_list for rank table + rank_table["server_list"] = [] # type: ignore[assignment] + decode_server_device_info = None + prefill_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", prefill_metadata.device_id + ), ("device_ip", prefill_metadata.device_ip + ), ("super_device_id", + prefill_metadata.super_device_id), ("rank_id", "0")] + if v is not None + }], + "server_id": + prefill_metadata.server_id + } + if is_same_server: + prefill_server_device_info["device"].append( # type: ignore[attr-defined] + { + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }) + else: + decode_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }], + "server_id": + decode_metadata.server_id + } + rank_table["server_list"].append( # type: ignore[attr-defined] + prefill_server_device_info) + if decode_server_device_info is not None: + rank_table["server_list"].append( # type: ignore[attr-defined] + decode_server_device_info) + + if self.soc_info.is_a3: + # generate super_pod_list for rank table + super_pod_list = [] + prefill_super_pod_info = { + "super_pod_id": prefill_metadata.super_pod_id, + "server_list": [{ + "server_id": prefill_metadata.server_id + }], + } + if is_same_pod and not is_same_server: + prefill_super_pod_info[ + "server_list"].append( # type: ignore[attr-defined] + {"server_id": decode_metadata.server_id}) + super_pod_list.append(prefill_super_pod_info) + if not is_same_pod: + decode_super_pod_id = { + "super_pod_id": decode_metadata.super_pod_id, + "server_list": [{ + "server_id": decode_metadata.server_id + }], + } + super_pod_list.append(decode_super_pod_id) + rank_table[ + "super_pod_list"] = super_pod_list # type: ignore[assignment] + logger.info( + f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}" + ) + logger.info(f"rank table \n{rank_table}") + logger.info(f"comm name: {comm_name}") + logger.info(f"cluster rank info: {cluster_rank_info}") + comm_id = self.llm_datadist.link(comm_name, cluster_rank_info, + json.dumps(rank_table)) + while True: + ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id) + if ret == llm_datadist.RegisterMemStatus.OK: + logger.info( + f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}" + ) + break + elif ret == llm_datadist.RegisterMemStatus.FAILED: + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}" + ) + time.sleep(1) + logger.info("Checking query_register_mem_status again") + self.linked_cluster.update({remote_cluster_id: comm_id}) + logger.info(f"cached linked cluster: {self.linked_cluster}") + logger.info( + f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !" + ) + return remote_cluster_id + + def remove_remote_agent(self, cluster_id: int): + if cluster_id not in self.linked_cluster: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list" + ) + comm_id = self.linked_cluster[cluster_id] + try: + self.llm_datadist.unlink(comm_id) + self.linked_cluster.pop(cluster_id) + except LLMException: + logger.error( + f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment" + ) + logger.info( + f"Successfully remove remote client with cluster id {cluster_id} !" + ) + + def connect_to_remote_agent(self, host: str, port: int) -> int: + url = f"tcp://{host}:{port}" + logger.debug(f"Querying metadata from url: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode(self.local_agent_metadata) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + logger.info("Try request remote metadata from socket......") + sock.send(msg_send) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder() + metadata = decoder.decode(metadata_bytes) + metadata = LLMDataDistCMgrAgentMetadata(**metadata) + logger.info(f"recving metadata: {metadata}") + cluster_id = self.add_remote_agent(metadata) + return cluster_id + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_ip: str, + remote_port: int, + remote_engine_id: str, + request_id: str, + remote_tp_size: str, + ): + # if remote_ip not in self.linked_cluster: + tp_offset = self.tp_rank % int(remote_tp_size) + remote_cluster_id = self.connect_to_remote_agent( + remote_ip, remote_port + tp_offset) + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + return + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + logger.info(f"remote cluster id is: {remote_cluster_id}") + if self.use_mla: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + else: + remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key, + self.cache, # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requuests.""" + import copy + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() + if self.llm_datadist_role == LLMRole.PROMPT: + return req_ids_to_ret, None + else: + return None, req_ids_to_ret + + +# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None # type: ignore[name-defined] + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: # type: ignore[attr-defined] + socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined] + socket.bind(addr) + elif socket_type == zmq.REQ: # type: ignore[attr-defined] + socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined] + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py deleted file mode 100644 index 2778a6ef27..0000000000 --- a/vllm_ascend/distributed/parallel_state.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional - -import torch -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, - init_model_parallel_group) - -# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for -# customize parallel solution -_EP: Optional[GroupCoordinator] = None -_ETP: Optional[GroupCoordinator] = None - - -def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert model parallel group is not initialized") - return _EP - - -def get_etp_group() -> GroupCoordinator: - assert _ETP is not None, ( - "expert tensor parallel group is not initialized") - return _ETP - - -def model_parallel_initialized(): - return (_ETP is not None and _EP is not None) - - -def init_ascend_model_parallel( - expert_parallel_size: int = 1, - expert_tensor_parallel_size: int = 1, - world_size: Optional[int] = None, - backend: Optional[str] = None, -): - if model_parallel_initialized(): - return - assert torch.distributed.is_initialized() - world_size = world_size or torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - num_expert_parallel_groups = expert_tensor_parallel_size - num_expert_tensor_parallel_groups = expert_parallel_size - - global _EP - group_ranks = [] - for i in range(num_expert_parallel_groups): - ranks = list(range(i, world_size, num_expert_parallel_groups)) - group_ranks.append(ranks) - - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") - - group_ranks = [] - global _ETP - for i in range(num_expert_tensor_parallel_groups): - ranks = list( - range(i * expert_tensor_parallel_size, - (i + 1) * expert_tensor_parallel_size)) - group_ranks.append(ranks) - - _ETP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="etp") - - -def destory_ascend_model_parallel(): - global _EP - if _EP: - _EP.destroy() - _EP = None - - global _ETP - if _ETP: - _ETP.destroy() - _ETP = None diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 02ecd6625b..27d0131720 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -116,6 +116,27 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + # `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is + # used for llmdatadist to build the communication topology for kv cache transfer, it is + # a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated + # pd. The rank table can be generated by adopting the script `gen_ranktable.sh` + # in vllm_ascend's example folder. + "DISAGGREGATED_PREFILL_RANK_TABLE_PATH": + lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None), + # `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_IP` is used as the + # rpc communication listening ip, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_ASCEND_LLMDD_RPC_IP": + lambda: os.getenv("VLLM_ASCEND_LLMDD_RPC_IP", "0.0.0.0"), + # `LLMDataDistCMgrConnector` required variable. `VLLM_LLMDD_RPC_PORT` is used as the + # rpc communication listening port, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_LLMDD_RPC_PORT": + lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)), + # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible + # and the mla_pa will be the default path of deepseek decode path. + "VLLM_ASCEND_MLA_PA": + lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)) } # end-env-vars-definition diff --git a/vllm_ascend/eplb/adaptor/abstract_adaptor.py b/vllm_ascend/eplb/adaptor/abstract_adaptor.py new file mode 100644 index 0000000000..8513b69ea0 --- /dev/null +++ b/vllm_ascend/eplb/adaptor/abstract_adaptor.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from abc import ABC, abstractmethod + +class EplbAdaptor(): + + def __init__(self, **args): + pass + + @abstractmethod + def get_rank_expert_workload(self, num_moe_layers): + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self): + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self): + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self): + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py new file mode 100644 index 0000000000..585fcad7eb --- /dev/null +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -0,0 +1,209 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os +import json +import torch +import random +import torch.distributed as dist +import numpy as np + +from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor +from vllm.logger import logger + + + +class VllmEplbAdaptor(EplbAdaptor): + + def __init__(self, model, **args): + super().__init__(**args) + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + self.num_dense_layers = self.model.config.first_k_dense_replace + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.global_expert_num = self.model.config.n_routed_experts + + + # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here + self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset", + "w2_weight_scale", "w2_weight_offset"] + + self.expert_map_per_layer = dict() # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved + num_buffer_tensor = torch.where(self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list = [[] for _ in range(num_buffer_tensor)] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + for name in self.expert_weight_names: + complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[0:num_buffer_tensor] + buffer_tensors = torch.empty_like(expert_tensor) + for buffer_id in range(num_buffer_tensor): + self.buffer_tensor_list[buffer_id].append(buffer_tensors[buffer_id]) + + def init_expert_param_per_layer(self): + num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) +\ + ".mlp.experts." + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append( + [self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name].data[local_expert_id] + for name in self.expert_weight_names] + ) + + # def collect_topk_ids(self, dummy_run=False): + # if dummy_run: + # return + # self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers)) + + def get_rank_expert_workload(self) -> torch.Tensor: + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + gathered = torch.empty((world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + + try: + expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[layer_idx+3] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + logger.error(f"failed to read expert_map_path: {expert_map_path}") + + def do_update_expert_map(self, layer_id, updated_expert_map): + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id] + ): + expert_tensor.copy_(buffer_tensor) + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map[self.rank_id]) + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32 + ) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(self.num_moe_layers, -1) + + return expert_map_all \ No newline at end of file diff --git a/vllm_ascend/eplb/core/loader/abstract_loader.py b/vllm_ascend/eplb/core/loader/abstract_loader.py new file mode 100644 index 0000000000..b1bef11c5d --- /dev/null +++ b/vllm_ascend/eplb/core/loader/abstract_loader.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from abc import ABC, abstractmethod + +class ExpertWeightLoader: + + @abstractmethod + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/eplb/core/loader/device_transfer_loader.py b/vllm_ascend/eplb/core/loader/device_transfer_loader.py new file mode 100644 index 0000000000..579f653323 --- /dev/null +++ b/vllm_ascend/eplb/core/loader/device_transfer_loader.py @@ -0,0 +1,155 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch +import torch.distributed as dist +from enum import Enum + +from vllm.logger import logger +from vllm_ascend.eplb.core.loader.abstract_loader import ExpertWeightLoader + +class ExpertWeightUpdateState(Enum): + WAITING = 0 # waiting for updated expert_map by EplbWorker + READY = 1 # ready for d2d expert weights updating + TRANSFERING = 2 # d2d finished and waiting for updating expert_map into model + +class D2DExpertWeightLoader(ExpertWeightLoader): + + def __init__(self, eplb_adaptor): + self.comm_op_list = None + self.eplb_adaptor = eplb_adaptor + + self.updated_expert_map = None + self.updated_log2phy_map = None + self.layer_id = -1 # layer id to be updated + self.state = ExpertWeightUpdateState.WAITING + self.recv_expert_list = [] + self.mock_flag = True + + def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, + updated_expert_map, layer_id): + # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task + if self.state != ExpertWeightUpdateState.WAITING: + logger.error("current d2d weight update tasks are on-going, cannot accept new weight update task") + return + + # If neither send nor receive task is needed for this layer on this rank, return + if not (expert_send_info or expert_recv_info): + return + + self.updated_expert_map = updated_expert_map + + self.layer_id = layer_id + self.comm_op_list = [] + for send_info in expert_send_info: + dst_rank, global_expert_id_to_send = send_info + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: + self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + + buffer_tensor_id = 0 + for recv_info in expert_recv_info: + recv_rank, global_expert_id_to_recv = recv_info + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: + self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() + self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) + buffer_tensor_id += 1 + + self.state = ExpertWeightUpdateState.READY + + def set_log2phy_map(self, log2phy_map): + self.updated_log2phy_map = log2phy_map + + def asyn_expert_weight_transfer(self, reqs): + # Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched + if self.state != ExpertWeightUpdateState.READY: + return + + # set asynchronous stream for d2d expert weight transfer + if self.comm_op_list: + ret_list = dist.batch_isend_irecv(self.comm_op_list) + reqs.extend(ret_list) + + self.state = ExpertWeightUpdateState.TRANSFERING + + def update_expert_map_and_weight(self, reqs, redundant_enable): + # Only after send/recv tasks have been luanched, expert_map and weight can be updated + if self.state != ExpertWeightUpdateState.TRANSFERING: + return + + # Waiting for send/recv tasks finish + for req in reqs: + req.wait() + + if self.comm_op_list is not None: + self.comm_op_list = None + + # update expert_map + self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map) + + #update log2phy_map + if redundant_enable: + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map) + + # update expert weight + buffer_tensor_id = 0 + for recv_expert_info in self.recv_expert_list: + local_expert_to_replace, buffer_tensor_id = recv_expert_info + self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id) + + logger.info(f"[EPLB] finished update expert weight for layer: {self.layer_id}") + + self.recv_expert_list = [] + self.updated_expert_map = None + self.layer_id = -1 + self.state = ExpertWeightUpdateState.WAITING + + def generate_mock_update_info(self, rank_id): + if rank_id == 0: + expert_send_info = [(1, 0)] + expert_recv_info = [(1, 64)] + updated_expert_map_list = [-1] + [i for i in range(1, 64)] + [0] + [j for j in [-1] * 191] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 1: + expert_send_info = [(0, 64)] + expert_recv_info = [(0, 0)] + updated_expert_map_list = [0] + [k for k in [-1] * 63] + [i for i in range(1, 64)] + [j for j in [-1] * 129] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 2: + expert_send_info = [(3, 128)] + expert_recv_info = [(3, 192)] + updated_expert_map_list = [k for k in [-1] * 129] + [i for i in range(1, 64)] + [0] + [j for j in [-1] * 63] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 3: + expert_send_info = [(2, 192)] + expert_recv_info = [(2, 128)] + updated_expert_map_list = [k for k in [-1] * 128] + [0] + [k for k in [-1] * 64] + [i for i in range(1, 64)] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + self.mock_flag = False + return (expert_send_info, expert_recv_info, updated_expert_map, layer_id) + + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError + diff --git a/vllm_ascend/eplb/core/policy/__init__.py b/vllm_ascend/eplb/core/policy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/eplb/core/policy/dynamic_ep.py b/vllm_ascend/eplb/core/policy/dynamic_ep.py new file mode 100644 index 0000000000..c081191aab --- /dev/null +++ b/vllm_ascend/eplb/core/policy/dynamic_ep.py @@ -0,0 +1,337 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np + +from .eplb_policy import EplbPolicy, DynamicConfig + + +class DynamicTable: + # workload_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的热度 + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的热度 + # 对于收集不到的专家,填为 -1 + workload_table = None + + # placement_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的物理专家id + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的物理id + # 对于收集不到的专家,填为 -1 + placement_table = None + + +class DynamicEplb(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # 热点专家拆分为冗余专家 + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + if item_id in boxes[i]: + continue + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # 热点专家拆分为冗余专家 + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num,), dtype='object') + all_weights[: route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # 无冗余专家方案 + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + + def rebalance_experts(self, current_expert_table, expert_workload): + + info = DynamicTable() + info.workload_table = np.array(expert_workload) + info.placement_table = np.array(current_expert_table) + layer_num, num_npus, experts_per_npu= info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # 计算负载均衡,部署冗余专家 + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + # 校验专家数量、卡数量、冗余专家数量不能超过卡数量 + if num_original_expert != expert_num: + raise ValueError(f"原始专家数量 {num_original_expert} 必须等于 expert_num {expert_num}") + + if num_npus <= 0: + raise ValueError("NPUs 数量必须大于 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"NPUs 数量 {num_npus} 必须大于或等于冗余专家数量 {num_redundancy_expert}") + + # 每个卡部署的专家数量 一个冗余专家 + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + # 遍历获得每一层的放置策略,考虑计算均衡 + max_heat_per_layer_after = np.zeros([layer_num]) + for layer in range(layer_num): + # 获取当前层专家ID和对应负载,负载需要进行正则化处理, 每个卡加一个冗余专家 + weights = np.zeros((expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads[layer]): + weights[expert_id] = (expert_id, workload_weight) + + # 获取每一层全局计算均衡的放置策略 + result, layer_deployment = self.original_compute_balanced_pack_redundancy( + weights, num_npus, num_redundancy_expert + ) + + global_deployment[layer] = layer_deployment + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_weight'])['total_weight'] + + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) + # 获取层优先级 + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array(new_global_deployment).tolist() + diff --git a/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py b/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py new file mode 100644 index 0000000000..775cf5f71d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py @@ -0,0 +1,842 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 + # 一台机器上,一层最多搬运多少专家 + + ep_worldsize = 64 # 整个集群上所有的专家分布在多少个die上 + num_die_per_host = 8 # 每台机器上有几个die + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + 传入weight并返回相关限制条件下的专家复制和放置 + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass + +class DynamicTable: + # workload_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的热度 + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的热度 + # 对于收集不到的专家,填为 -1 + workload_table = None + + # placement_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的物理专家id + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的物理id + # 对于收集不到的专家,填为 -1 + placement_table = None + + +class DynamicEplbV2(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # 热点专家拆分为冗余专家 + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # 热点专家拆分为冗余专家 + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num,), dtype='object') + all_weights[: route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # 无冗余专家方案 + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def calculate_initial_imbalance(global_deployment, new_layer_workloads): + + device_num = global_deployment.shape[1] + layer_imbalance = [] + expert_num = np.zeros_like(new_layer_workloads) + # 基于部署做更新负载 + for layer_id, layer in enumerate(global_deployment): + for device in layer: + for expert_id in device: + expert_num[layer_id][expert_id] += 1 + + for layer_id, layer in enumerate(global_deployment): + cur_layer_max_workload = 0 + total_workload = 0 + for box in layer: + box_workload = 0 + for expert_id in box: + update_workload = new_layer_workloads[layer_id][expert_id] / expert_num[layer_id][expert_id] + box_workload += update_workload + total_workload += update_workload + if cur_layer_max_workload < box_workload: + cur_layer_max_workload = box_workload + + cur_layer_imbalance = cur_layer_max_workload / (total_workload / device_num) + layer_imbalance.append(cur_layer_imbalance) + + return layer_imbalance + + @staticmethod + def compute_redundant_assignments(base_experts, num_redundant_experts, num_experts): + """ + 计算每个基础专家需要分配的冗余专家,并动态调整专家权重 + 返回冗余分配表和更新后的基础专家权重列表 + """ + redundant_assignments = [[] for _ in range(num_experts)] + current_weights = base_experts.copy() + + for i in range(num_redundant_experts): + # 按权重降序排序(使用稳定排序保持相同权重的顺序) + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + # 选择当前权重最高的专家 + target_expert = sorted_weights[0] + expert_id, original_weight = target_expert + + # 计算添加冗余后的新平均权重 + current_redundancy = len(redundant_assignments[expert_id]) + new_avg_weight = original_weight * (current_redundancy + 1) / (current_redundancy + 2) + + # 更新分配表和权重列表 + redundant_assignments[expert_id].append(num_experts + i) + current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) + + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + return redundant_assignments, sorted_weights + + @staticmethod + def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts): + """ + 生产冗余专家的完整列表,并按权重降序排序 + """ + redundant_expert_list = np.empty(num_redundant_experts, dtype=object) + + # 填充冗余专家(使用对应基础专家的当前权重) + index = 0 + num_experts = len(redundant_assignments) + for expert_id in range(num_experts): + for _ in redundant_assignments[expert_id]: + redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id)) + index += 1 + + # 按权重降序排序 + sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind='stable')[::-1] + return [redundant_expert_list[i] for i in sorted_indices] + + @staticmethod + def non_redundant_expert_information(origin_deployment, updated_weights, num_radundant_experts): + + device_num = len(origin_deployment) + + device_assignments = [[] for _ in range(device_num)] + device_weights = [[] for _ in range(device_num)] + device_loads = [0] * device_num + device_counts = [0] * device_num + if num_radundant_experts: + start_id = 1 + else: + start_id = 0 + + # 统计卡上非冗余专家信息 + for box_id, box in enumerate(origin_deployment): + for i in range(start_id, len(box)): + device_assignments[box_id].append(box[i]) + cur_weight = next(weight for expert_id, weight in updated_weights if expert_id == box[i]) + device_weights[box_id].append(cur_weight) + device_loads[box_id] += cur_weight + device_counts[box_id] += 1 + + return device_assignments, device_weights, device_loads, device_counts + + @staticmethod + def recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads): + # 统计专家出现次数 + num_all_experts = [0] * len(layer_workloads) + num_devices = len(device_assignments) + for device_id in range(num_devices): + num_expert_per_npu = len(device_assignments[device_id]) + for idx in range(num_expert_per_npu): + num_all_experts[idx] += device_assignments[device_id][idx] + + for device_id in range(num_devices): + num_expert_per_npu = len(device_weights[device_id]) + total_weight = 0.0 + for idx in range(num_expert_per_npu): + expert_id = device_assignments[device_id][idx] + if num_all_experts[expert_id] == 0: + print("Error: Division by zero") + device_weights[device_id][idx] = layer_workloads[expert_id] / num_all_experts[expert_id] + total_weight += device_weights[device_id][idx] + device_loads[device_id] = total_weight + + return device_weights, device_loads + + @staticmethod + def distribute_redun_experts(self, layer_workloads, device_assignments, device_weights, device_loads, device_counts, redundant_expert_list, + items_per_device, expert_form_device, num_experts): + + num_devices = len(device_assignments) + com_between_devices = [{} for _ in range(num_devices)] + + for expert_id, weight in redundant_expert_list: + # 寻找最优设备(满足容量限制且负载最小) + candidate = -1 + for dev_id in range(num_devices): + # 保证设备内节点不同 + if expert_id in device_assignments[dev_id]: + continue + # 检查容量限制 + if device_counts[dev_id] < items_per_device: + # 选择负载最小的候选设备 + if candidate == -1 or device_loads[dev_id] < device_loads[candidate]: + candidate = dev_id + if candidate != -1: + # 分配专家到选定的设备 + device_assignments[candidate].insert(0, expert_id) + device_weights[candidate].insert(0, weight) + device_loads[candidate] += weight + device_counts[candidate] += 1 + + communication_box_index = expert_form_device[expert_id] + com_between_devices[candidate][communication_box_index] = expert_id + # 极端情况下存在冗余专家没装箱 导致箱子有空位 随机填入专家 待优化 + flag = False + for dev_id in range(num_devices): + # 检查容量限制 + if device_counts[dev_id] < items_per_device: + # 遍历合适的专家 + for expert_id in range(num_experts): + if expert_id not in device_assignments[dev_id]: + flag = True + # 随机初始化一个权重 + weight = 0.0 + # 和该专家相关的卡权重发生变化 待修改 + device_assignments[dev_id].insert(0, expert_id) + device_weights[dev_id].insert(0, weight) + device_loads[dev_id] += weight + device_counts[dev_id] += 1 + + communication_box_index = expert_form_device[expert_id] + com_between_devices[dev_id][communication_box_index] = expert_id + break + + if flag: + device_weights, device_loads = self.recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads) + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def redundancy_again(self, layer_workloads, origin_weights, num_redundant_experts, origin_deployment, expert_form_device, num_node, + is_node_redundant): + + # 每张卡上专家数量 + expert_num_per_device = origin_deployment.shape[1] + + num_experts = len(origin_weights) + if is_node_redundant: + num_experts = num_experts * num_node + + # 根据新负载重新计算冗余专家 + redundant_assignments, updated_weights = self.compute_redundant_assignments(origin_weights, + num_redundant_experts, + num_experts) + + # 收集冗余专家信息并排序 + redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts) + + # 收集重新计算冗余后卡上非冗余专家信息 + device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( + origin_deployment, updated_weights, num_redundant_experts) + + # 新计算的冗余专家进行分配 + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( + self, + layer_workloads, + device_assignments, + device_weights, + device_loads, + device_counts, + redundant_expert_list, + expert_num_per_device, + expert_form_device, + num_experts) + + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts): + """ + 生成最终分配报告并计算最大负载 + """ + report = [] + max_load = 0.0 + + for dev_id in range(len(device_assignments)): + current_load = device_loads[dev_id] + max_load = max(max_load, current_load) + + report.append({ + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id] + }) + + return report, max_load + + @staticmethod + def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, + com_between_devices): + + cur_device_deployment = cur_layer_result[cur_device_id]['assigned_experts'] + next_device_deployment = cur_layer_result[next_device_id]['assigned_experts'] + + cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] + next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + + # 两张卡上对应的两个专家进行交换 + cur_expert_id = cur_device_deployment[cur_exchange_index] + next_expert_id = next_device_deployment[next_exchange_index] + cur_device_deployment[cur_exchange_index] = next_expert_id + next_device_deployment[next_exchange_index] = cur_expert_id + + cur_expert_weight = cur_device_weight[cur_exchange_index] + next_expert_weight = next_device_weight[next_exchange_index] + cur_device_weight[cur_exchange_index] = next_expert_weight + next_device_weight[next_exchange_index] = cur_expert_weight + + cur_layer_result[cur_device_id]['total_load'] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id]['total_load'] += cur_expert_weight - next_expert_weight + + # 记录这两卡进行了通信 + com_between_devices[cur_device_id][next_device_id] = next_expert_id + com_between_devices[next_device_id][cur_device_id] = cur_expert_id + + @staticmethod + # 分层调整冗余专家 + def redundant_expert_deployment(self, layer_workloads, original_deployment, expert_form_device, node_num, + is_node_redundant): + device_num, per_device_expert_num = original_deployment.shape + route_expert_num = layer_workloads.shape[0] + redundancy_expert_num = per_device_expert_num * device_num - route_expert_num + per_node_device_num = device_num // node_num + per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1) + per_node_redun_expert_num = redundancy_expert_num // node_num + + weights = np.zeros((route_expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + weights[expert_id] = (expert_id, int(workload_weight)) + + if is_node_redundant: + + device_assignments = [] + device_weights = [] + device_loads = [] + device_counts = [] + com_between_devices = [] + + for node_id in range(node_num): + cur_node_weights = weights[ + node_id * per_node_route_expert_num: (node_id + 1) * per_node_route_expert_num] + cur_original_deployment = original_deployment[ + node_id * per_node_device_num: (node_id + 1) * per_node_device_num] + + cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( + self, + layer_workloads, + cur_node_weights, + per_node_redun_expert_num, + cur_original_deployment, + expert_form_device, + node_num, + is_node_redundant) + device_assignments += cur_device_assignments + device_weights += cur_device_weights + device_loads += cur_device_loads + device_counts += cur_device_counts + com_between_devices += cur_com_between_devices + + else: + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( + self, + layer_workloads, + weights, + redundancy_expert_num, + original_deployment, + expert_form_device, + node_num, + is_node_redundant) + # 生成报告 + report, max_load = self.generate_allocation_report(device_assignments, device_weights, device_loads, + device_counts) + + return report, max_load, com_between_devices + + @staticmethod + def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_exchanged_expert_id, + next_exchanged_expert_id, ave_workload, increment, num_redundancy_expert, cur_org_placement, next_org_placement): + + cur_device_weight = cur_device_result['expert_weights'] + next_device_weight = exchange_device_result['expert_weights'] + + cur_device_expert_id = cur_device_result['assigned_experts'] + next_device_expert_id = exchange_device_result['assigned_experts'] + + cur_device_total_weight = int(cur_device_result['total_load']) + next_device_total_weight = int(exchange_device_result['total_load']) + max_weight = max(cur_device_total_weight, next_device_total_weight) + + cur_exchange_index = -1 + next_exchange_index = -1 + + redun = False + if num_redundancy_expert != 0: + redun = True + + for index, weight in enumerate(cur_device_weight): + for next_index, next_weight in enumerate(next_device_weight): + # 跳过冗余专家 + if (index == 0 or next_index == 0) and redun : + continue + # 交换专家限制卡内专家不同 + change_flag = True + if ((cur_device_expert_id[index] in next_device_expert_id or next_device_expert_id[next_index] in cur_device_expert_id) or + (cur_org_placement[0] == next_device_expert_id[next_index] or next_org_placement[0] == cur_device_expert_id[index])): + change_flag = False + # 选择的专家不能是参与过交换的 + if (cur_device_expert_id[index] not in cur_exchanged_expert_id) and ( + next_device_expert_id[next_index] not in next_exchanged_expert_id) and change_flag: + cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight + next_total_weight_after_exchange = next_device_total_weight - next_weight + weight + exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange) + if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= ( + ave_workload * increment): + max_weight = exchange_max_weight + cur_exchange_index = index + next_exchange_index = next_index + + return cur_exchange_index, next_exchange_index + + @staticmethod + def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_result, com_between_devices, num_redundancy_expert, + org_placement_table, node_idx=0, per_node_device_num=0, is_node_redundant=False): + + if is_node_redundant: + # 拿出当前节点内设备的信息 + cur_devices_result = cur_layer_result[node_idx * per_node_device_num:(node_idx + 1) * per_node_device_num] + else: + # 拿取所有设备信息 + cur_devices_result = cur_layer_result + + devices_total_weight = [] + for device in cur_devices_result: + devices_total_weight.append((int(device['total_load']), device['device_id'] - 1)) + + # 当迭代次数超过100或负载最大的设备无法进行调整时退出 + exchange_frequency = 100 + while exchange_frequency > 0: + exchange_frequency -= 1 + + # 根据负载从小到大排序 + devices_total_weight.sort(key=lambda x: x[0]) + # 负载最大的设备id + max_weight_device_id = devices_total_weight[-1][1] + + exchange = False + # 按照负载从小到大依次取卡 + for index in range(0, len(devices_total_weight) - 1): + min_weight_device_id = devices_total_weight[index][1] + # 两个节点没有进行过通信 + if min_weight_device_id not in com_between_devices[max_weight_device_id]: + # 找到设备中交换过的专家id,(除了冗余之外通信过的id) + set_cur_com_expert_id = set(com_between_devices[max_weight_device_id].values()) + set_next_com_expert_id = set(com_between_devices[min_weight_device_id].values()) + if num_redundancy_expert != 0: + set_cur_device_expert_id = set(cur_layer_result[max_weight_device_id]['assigned_experts'][1:]) + set_next_device_expert_id = set(cur_layer_result[min_weight_device_id]['assigned_experts'][1:]) + else: + set_cur_device_expert_id = set(cur_layer_result[max_weight_device_id]['assigned_experts']) + set_next_device_expert_id = set(cur_layer_result[min_weight_device_id]['assigned_experts']) + + cur_exchanged_expert_id = set_cur_com_expert_id & set_cur_device_expert_id + next_exchanged_expert_id = set_next_com_expert_id & set_next_device_expert_id + + cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( + cur_layer_result[max_weight_device_id], + cur_layer_result[min_weight_device_id], + cur_exchanged_expert_id, + next_exchanged_expert_id, + ave_workload, + increment, + num_redundancy_expert, + org_placement_table[max_weight_device_id], + org_placement_table[min_weight_device_id]) + + # 有符合条件的专家进行交换 + if cur_exchange_index != -1: + self.exchange_expert(cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices) + + devices_total_weight[-1] = ( + cur_layer_result[max_weight_device_id]['total_load'], max_weight_device_id) + devices_total_weight[index] = ( + cur_layer_result[min_weight_device_id]['total_load'], min_weight_device_id) + exchange = True + break + + if not exchange: + break + + @staticmethod + def exchange_experts(self, layer_result, layer_com_between_devices, num_nodes, device_num, is_node_redundant, + ave_workload, increment, num_redundancy_expert, org_placement_table): + + global_deployment = [] + + if is_node_redundant: + per_node_device_num = device_num // num_nodes + for node_idx in range(num_nodes): + self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, + layer_com_between_devices, num_redundancy_expert, + org_placement_table, node_idx, per_node_device_num, is_node_redundant) + else: + self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert, org_placement_table) + + max_workload = 0 + for box in layer_result: + global_deployment.append(box['assigned_experts']) + if max_workload < box['total_load']: + max_workload = box['total_load'] + + global_deployment = np.array(global_deployment) + + return global_deployment, max_workload + + @staticmethod + def count_elements(self, lst): + count = 0 + for item in lst: + if isinstance(item, list): + count += self.count_elements(self, item) + else: + count += 1 + return count + + def rebalance_experts(self, current_expert_table, expert_workload): + # 输入:当前专家部署信息和对应的负载信息,形状为layer_num, num_npus, experts_per_npu + info = DynamicTable() + info.workload_table = expert_workload.numpy() + info.placement_table = current_expert_table.numpy() + layer_num, num_npus, experts_per_npu = info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + # 负载信息转化为 58 * 256 + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # 计算负载均衡,部署冗余专家 + num_node = num_npus / 8 + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + expert_from_device = np.zeros((layer_num, num_original_expert)) + # 校验专家数量、卡数量、冗余专家数量不能超过卡数量 + if num_original_expert != expert_num: + raise ValueError(f"原始专家数量 {num_original_expert} 必须等于 expert_num {expert_num}") + + if num_npus <= 0: + raise ValueError("NPUs 数量必须大于 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"NPUs 数量 {num_npus} 必须大于或等于冗余专家数量 {num_redundancy_expert}") + + # 每个卡部署的专家数量 一个冗余专家 + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + # 统计更换数据集后的初始58层不均衡度 + layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads) + # 遍历获得每一层的放置策略,考虑计算均衡 + max_heat_per_layer_after = np.zeros([layer_num]) + sum_num = 0 + for layer in range(layer_num): + # 不均衡度小于特定阈值不调整 + if layer_initial_imbalance[layer] < 1.1: + global_deployment[layer] = info.placement_table[layer] + continue + + ave_workload = np.sum(layer_workloads[layer]) / num_npus + for device_id, device in enumerate(info.placement_table[layer]): + for index, expert_id in enumerate(device): + if index != 0: + expert_from_device[layer][expert_id] = device_id + + # 调整冗余专家 + result, max_workload, com_between_devices = self.redundant_expert_deployment(self, layer_workloads[layer], + info.placement_table[layer], + expert_from_device[layer], + num_node, False) + # 交换专家 + global_deployment[layer], new_max_workload = self.exchange_experts(self, result, com_between_devices, + num_node, num_npus, False, ave_workload, + 0.05, num_redundancy_expert, info.placement_table[layer]) + + for device_id in range(num_npus): + com_between_devices[device_id] = {int(key): int(value) for key, value in + com_between_devices[device_id].items()} + sum_num += self.count_elements(self, com_between_devices[device_id]) + + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_load'])['total_load'] + + # 获取层优先级 + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array(global_deployment).tolist() \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/eplb_policy.py b/vllm_ascend/eplb/core/policy/eplb_policy.py new file mode 100644 index 0000000000..1de60c348d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/eplb_policy.py @@ -0,0 +1,42 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 + # 一台机器上,一层最多搬运多少专家 + + ep_worldsize = 64 # 整个集群上所有的专家分布在多少个die上 + num_die_per_host = 8 # 每台机器上有几个die + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + 传入weight并返回相关限制条件下的专家复制和放置 + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass diff --git a/vllm_ascend/eplb/core/policy/mock_load_balance.py b/vllm_ascend/eplb/core/policy/mock_load_balance.py new file mode 100644 index 0000000000..6626d3fb5c --- /dev/null +++ b/vllm_ascend/eplb/core/policy/mock_load_balance.py @@ -0,0 +1,30 @@ +# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import copy +import random +import torch +import torch + +from .eplb_policy import EplbPolicy, DynamicConfig + +random.seed(42) + +class MockLoadBalance(EplbPolicy): + def __init__(self, config: DynamicConfig): + super().__init__(config) + + def rebalance_experts(self, current_expert_table, expert_workload): + new_table = copy.deepcopy(current_expert_table) + num_layers = len(current_expert_table) + num_card = len(current_expert_table[0]) + + for i in range(num_layers): + # 随机选两个卡 + # indices = random.sample(range(num_card), 2) + indices = [3,1] + + # 交换冗余专家 + expert_id_to_exchange = new_table[i][indices[0]][-1].clone() + new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1] + new_table[i][indices[1]][-1] = expert_id_to_exchange + + return 1, [-i for i in range(num_layers)], new_table \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py new file mode 100644 index 0000000000..7ebd048ff9 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -0,0 +1,27 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from .eplb_policy import EplbPolicy, DynamicConfig +from .mock_load_balance import MockLoadBalance +from .dynamic_ep import DynamicEplb +from .dynamic_ep_v2 import DynamicEplbV2 + + + +class PolicyFactory: + @staticmethod + def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: + policy = { + # Constraint applying Dynamic EPLB policy V2: + # If there exists redundant expert: + # only one redundant expert can be placed in one NPU and its physical expert index must be 0 + + # Applying bipartite d2d expert weight update composing + 0:MockLoadBalance, # MockLoadBalance + 1:DynamicEplb, # Dynamic EPLB policy + 2:DynamicEplbV2, # Dynamic EPLB policy V2 + + # Applying greedy d2d expert weight update composing + 3:MockLoadBalance, # MockLoadBalance + 4:DynamicEplb, # Dynamic EPLB policy + 5:DynamicEplbV2, # Dynamic EPLB policy V2 + } + return policy.get(policy_type, MockLoadBalance)(config) diff --git a/vllm_ascend/eplb/core/worker/eplb_worker.py b/vllm_ascend/eplb/core/worker/eplb_worker.py new file mode 100644 index 0000000000..c4aa86a4ad --- /dev/null +++ b/vllm_ascend/eplb/core/worker/eplb_worker.py @@ -0,0 +1,408 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import time +import numpy as np +import networkx as nx +import torch +import torch_npu +import logging +import torch.distributed as dist +from multiprocessing import Process, Queue, Manager +from abc import ABC, abstractmethod +from vllm.logger import logger + +from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory, DynamicConfig +from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils + + +class EplbWorker: + + def __init__(self, shared_dict, policy_type, enable_d2d: bool = True, redundant_enable=0): + self.policy_type = policy_type + self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig()) + self.shared_dict = shared_dict + self.old_expert_maps = None + self.enable_d2d = enable_d2d + self.redundant_enable = redundant_enable + self.rank_id = dist.get_rank() + + def do_update(self): + # put data in to queue + # in process self.policy.generate_policy() + # get epxert table && tensor + + # async stream + # D2D + # H2D + + # Get initial expert_map + if self.old_expert_maps is None: + self.old_expert_maps = self.get_init_expert_maps() + self.num_local_experts = self.old_expert_maps.max() + 1 + + # Get MOE load information + load_info = self.fetch_and_sum_load_info() + if load_info is None: + return + + #根据负载信息,获取更新后的专家表 + old_placement = self.global2local(self.old_expert_maps, self.num_local_experts) + changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement) + + if not torch.is_tensor(new_placement): + new_placement = torch.tensor(new_placement) + self.check_expert_placement(old_placement, new_placement) + new_expert_maps = self.local2global(new_placement) + self.update_expert_map(new_expert_maps) + logger.debug(f"[EPLB Process new_map differs, performing D2D") + + update_info = self.compose_expert_update_info_bipartite(new_expert_maps, self.old_expert_maps)\ + if self.policy_type <= 2 else self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) + self.old_expert_maps = new_expert_maps + logger.info("EPLB Process compute complete") + + packed_update_info = self.pack_update_info(update_info) + + return packed_update_info + + def check_expert_placement(self, old_placement, new_placement): + num_layers = old_placement.shape[0] + num_ranks = old_placement.shape[1] + + for layer_id in range(num_layers): + # check if any logical expert is not placed on any rank + if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel(): + logger.error(f"There exists expert not placed on any rank in layer {layer_id}") + new_placement[layer_id] = old_placement[layer_id] + continue + + for rank_id in range(num_ranks): + new_placement_check = new_placement[layer_id][rank_id] + old_placement_check = old_placement[layer_id][rank_id] + + # check if same logical experts are placed on the same NPU + if new_placement_check.numel() != torch.unique(new_placement_check).numel(): + logger.error(f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + # check if there is any experts movement inside one NPU + expert_not_move = torch.isin(new_placement_check, old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]): + logger.error(f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + def compose_expert_update_info_bipartite(self, updated_expert_maps_org, current_expert_maps_org): + # transform numpy array to torch tensor + updated_expert_maps = updated_expert_maps_org.clone() + current_expert_maps = current_expert_maps_org.clone() + updated_expert_maps = np.array(updated_expert_maps) + current_expert_maps = np.array(current_expert_maps) + + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + updated_expert_maps_this_layer_org = updated_expert_maps_org[layer_id] + + expert_send_info_this_layer = dict() + expert_recv_info_this_layer = dict() + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if (np.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer)).all(): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = np.where((current_expert_maps_this_layer == -1) + & (updated_expert_maps_this_layer != -1)) + + # record src ranks for potential transfer + src_ranks_set = dict() + for idx in range(len(dst_rank_indices)): + expert_id = experts_to_recv[idx].item() + if expert_id not in src_ranks_set: + src_ranks_set[expert_id] = np.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + + # loop until all experts are scheduled + while len(dst_rank_indices) > 0: + # construct bipartite graph + graph_expert_update = nx.Graph() + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + # add src ranks + src_rank_ids = src_ranks_set[expert_id] + graph_expert_update.add_nodes_from(src_rank_ids, bipartite=0) + # add dest rank + graph_expert_update.add_node(str(dst_rank_id), bipartite=1) + # add edges + for src_rank_id in src_rank_ids: + graph_expert_update.add_edge(src_rank_id, str(dst_rank_id)) + + # graph may not be connected + connected_components = list(nx.connected_components(graph_expert_update)) + all_matches = {} + # matching in this loop + for i, component in enumerate(connected_components): + subgraph = graph_expert_update.subgraph(component) + component_matching = nx.bipartite.maximum_matching(subgraph) + all_matches.update(component_matching) + + for src_rank, dst_rank in all_matches.items(): + dst_rank = int(dst_rank) + assert src_rank != dst_rank + if graph_expert_update.nodes[src_rank]['bipartite'] == 0: + # currently not scheduled experts in rank dst_rank + experts_v = experts_to_recv[np.where( + dst_rank_indices == dst_rank)] + # src: src_rank, dest: dst_rank, expert: expert_id + expert_id = np.intersect1d(experts_v, np.where( + current_expert_maps_this_layer[src_rank] != -1))[0] + + # record send/rcv pairs + if src_rank not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank] = [] + if dst_rank not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank] = [] + expert_send_info_this_layer[src_rank].append((dst_rank, expert_id)) + expert_recv_info_this_layer[dst_rank].append((src_rank, expert_id)) + + remove_index = np.where(np.logical_and( + dst_rank_indices == dst_rank, experts_to_recv == expert_id)) + + # update + dst_rank_indices = np.delete( + dst_rank_indices, remove_index) + experts_to_recv = np.delete(experts_to_recv, remove_index) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases + def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps): + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + + expert_send_info_this_layer = dict() + expert_recv_info_this_layer = dict() + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ + & (updated_expert_maps_this_layer != -1)) + + # Parse expert_ids each rank needs to send to other ranks + src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ + & (updated_expert_maps_this_layer == -1)) + + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + if dst_rank_id not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank_id] = [] + + if not torch.isin(torch.tensor(expert_id), experts_to_send).any(): + # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert + candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0] + else: + candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id] + + #TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + src_rank_id = candidate_src_rank_indices[0].item() + if src_rank_id not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank_id] = [] + + expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id)) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + + def calculate_rebalance_experts(self, load_info, old_placement): + """ + 通过 policy 实例的 rebalance_experts 方法计算 new_map。 + """ + if self.old_expert_maps is None: + return False, None, None + + changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info) + return changed, priority, new_map + + def get_init_expert_maps(self): + """ + Read the initial expert_map from shared_dict. + """ + return self.shared_dict.get("expert_maps", None) + + def fetch_and_sum_load_info(self): + """ + Each time the subprocess is awakened, read the latest moe_load + (shape: [num_moe_layers, num_experts_per_layer]) from shared_dict. + """ + return self.shared_dict.get("moe_load", None) + + def update_expert_map(self, expert_maps): + + self.shared_dict["expert_maps"] = expert_maps + + def global2local(self, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def pack_update_info(self, update_info_generator): + """ + Pack a list of update info tuples for efficient IPC. + """ + send_all = [] + recv_all = [] + maps = [] + log2phy_all = [] + layer_ids = [] + + for send_info, recv_info, new_expert_map, layer_id in update_info_generator: + + send_info_this_rank = send_info[self.rank_id] if self.rank_id in send_info else [] + recv_info_this_rank = recv_info[self.rank_id] if self.rank_id in recv_info else [] + send_all.append(send_info_this_rank) + recv_all.append(recv_info_this_rank) + + maps.append(new_expert_map[self.rank_id].numpy().tolist()) + + if self.redundant_enable: + log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map) + log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + else: + log2phy_all.append([]) + + layer_ids.append(layer_id) + + return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + +class EplbProcess: + def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, policy_type: int = 0, enable_d2d: bool = True): + """ + Args: + shared_dict: Cross-process shared dict returned by Manager().dict() + policy_type: Integer passed to PolicyFactory.generate_policy + enable_d2d: Whether to enable D2D loading + """ + self.shared_dict = shared_dict + self.policy_type = policy_type + self.enable_d2d = enable_d2d + self.planner_q = planner_q + self.block_update_q = block_update_q + self.redundant_enable = redundant_enable + + # Create EplbWorker instance + self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d, self.redundant_enable) + + + def worker_process(self, planner_q, block_update_q): + """ + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + """ + while True: + try: + + planner_q.get() + + packed_update_info = self.worker.do_update() + + while True: + if not block_update_q.empty(): + continue + block_update_q.put(packed_update_info) + break + + except Exception as e: + logger.warning(f"[EPLB subprocess Exiting due to error: {e}", exc_info=True) + break + + def _launch_process(self): + """ + Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). + """ + proc = Process( + target=self.worker_process, + args=(self.planner_q,self.block_update_q), + daemon=True + ) + + proc.start() + return proc + diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py new file mode 100644 index 0000000000..02c03c7933 --- /dev/null +++ b/vllm_ascend/eplb/eplb_updator.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import numpy +from typing import Dict, List +import torch.distributed as dist +import vllm.envs as envs +from multiprocessing import Queue, Manager + +from vllm.logger import logger +from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess +from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader +from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils + +class EplbUpdator: + + def __init__(self, expert_map_path): + self.init_eplb(expert_map_path) + + def set_adaptor(self, adaptor): + self.adaptor = adaptor + self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor) + self.num_moe_layers = self.adaptor.num_moe_layers + self.global_expert_num = self.adaptor.global_expert_num + + def init_eplb(self, expert_map_path): + self.num_expert_load_gather = 10 + self.periodic_load_gather = True + self.redundant_enable = (expert_map_path is not None) + self.num_iterations_eplb_update: torch.int64 = 130 + self.expert_map_path = expert_map_path + + try: + if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + except Exception as e: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + + self.expert_map_initialized = False + self.gate_eplb = True + + self.reqs = [] + self.update_info_all = [] + + self.cur_iterations: torch.int64 = 0 + + self.num_wait_worker_iterations: torch.int64 = 20 + + self.planner_block_queue = Queue() + self.block_update_queue = Queue(maxsize=1) + + self.manager = Manager() + self.shared_dict = self.manager.dict({ + # 当前rank_id的专家表[num_layers,num_experts] + "expert_map": None, + # 热度负载信息 [num_layers, world_size, num_experts] + "moe_load": None, + # 所有的专家表[num_layers, world_size, num_experts] + "expert_maps": None, + }) + + self.eplb = EplbProcess( + shared_dict = self.shared_dict, + planner_q = self.planner_block_queue, + block_update_q = self.block_update_queue, + redundant_enable = self.redundant_enable, + policy_type = 1, + enable_d2d = True + ) + + self.eplb_process = self.eplb._launch_process() + + logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})") + + def update_iteration(self): + self.cur_iterations += 1 + if self.cur_iterations == (self.num_iterations_eplb_update +\ + self.num_wait_worker_iterations + self.num_moe_layers): + if not self.gate_eplb: + self.cur_iterations = 0 + + def get_update_info_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update + self.num_wait_worker_iterations) + + def wakeup_eplb_worker_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update - 1) + + def update_expert_weight_flag(self): + weight_update_counter = self.cur_iterations - (self.num_iterations_eplb_update + self.num_wait_worker_iterations) + return (weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers) + + def get_init_expert_map(self): + try: + if not self.expert_map_initialized: + self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path) + self.expert_map_initialized = True + except Exception as e: + logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True) + + def wakeup_eplb_worker(self): + self.planner_block_queue.put(1) + + def forward_before(self): + if self.update_expert_weight_flag(): + (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0) + rank_id = torch.distributed.get_rank() + if self.redundant_enable: + log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) + self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) + updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map)) + #logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}") + self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info, expert_recv_info, + updated_expert_map_this_rank, layer_id + self.adaptor.num_dense_layers) + + # set asynchronous stream for d2d expert weight update + self.reqs = [] + self.eplb_loader.asyn_expert_weight_transfer(self.reqs) + + def take_update_info_from_eplb_process(self): + # Batch after eplb process being triggered, get update info provided by eplb process + if self.get_update_info_flag(): + self.update_info_all = self.block_update_queue.get() + + + def forward_end(self): + if self.wakeup_eplb_worker_flag(): + moe_load = self.compute_and_set_moe_load(is_clear=True) + self.wakeup_eplb_worker() + + if self.update_expert_weight_flag(): + self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable) + + self.update_iteration() + + def compute_and_set_moe_load(self, is_clear=False): + local_load = self.adaptor.get_rank_expert_workload() + + self._gather_buffer = None + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + else: + moe_load = local_load.unsqueeze(1) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + self.adaptor.model.clear_all_moe_loads() + return moe_load + + def warm_up_eplb(self): + + self.get_init_expert_map() + self.compute_and_set_moe_load() + + src_tensor = torch.empty((1,), device=self.device) + self_rank = dist.get_rank() + + comm_op_list = [] + + for dst_rank in range(self.world_size): + if dst_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank) + ) + + for src_rank in range(self.world_size): + if src_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.irecv, src_tensor, src_rank) + ) + if comm_op_list: + reqs = dist.batch_isend_irecv(comm_op_list) + + for req in reqs: + req.wait() + + def unpack_update_batch(self, packed_update_info): + """ + Unpack the IPC batch back into original update_info_list. + """ + send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor = packed_update_info + + maps = stacked_maps.unbind(0) + layer_ids = layer_id_tensor.tolist() + + if self.redundant_enable: + log2phy_list = stacked_log2phy.unbind(0) + else: + log2phy_list = [None] * len(maps) + + _zip = zip + _send = send_all + _recv = recv_all + _maps = maps + _l2p = log2phy_list + _lids = layer_ids + + recovered = [ + (_s, _r, _m, _lp, _lid) + for _s, _r, _m, _lp, _lid + in _zip(_send, _recv, _maps, _l2p, _lids) + ] + return recovered + + def get_expert_load(self) -> tuple: + expert_maps = self.shared_dict["expert_maps"] + moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num] + num_local_experts = expert_maps.max() + 1 + return moe_load, expert_maps, num_local_experts + + def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update}...") + self.num_expert_load_gather = num_expert_load_gather + self.num_iterations_eplb_update = num_iterations + logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update} success...") + + def shutdown(self): + """ + Clean up the EPLB process. + """ + if self.eplb_process.is_alive(): + self.eplb_process.terminate() + self.eplb_process.join() + logger.info("[ModelRunner] EPLB process terminated") diff --git a/vllm_ascend/eplb/tool/eplb_utils.py b/vllm_ascend/eplb/tool/eplb_utils.py new file mode 100644 index 0000000000..156f7a9b9d --- /dev/null +++ b/vllm_ascend/eplb/tool/eplb_utils.py @@ -0,0 +1,114 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import random + +class ExpertMapUtils(): + + @classmethod + def generate_index_dicts(cls, tensor_2d): + dict_list = [] + current_idx = 0 + + for row in tensor_2d: + value_to_index = {} + for i in range(row.size(0)): + value = row[i].item() + value_to_index[value] = current_idx + i + dict_list.append(value_to_index) + current_idx += row.size(0) + + return dict_list + + @classmethod + def generate_log2phy_map(cls, expert_map): + num_local_experts = expert_map.max() + 1 + log2phy_map = expert_map.clone() + num_ranks, num_global_expert = log2phy_map.shape + + row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks,\ + num_global_expert) * num_local_experts + log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] + + for idx in range(num_global_expert): + positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] + negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] + num_rank_holding_expert = positive_rank_idx.size(0) + + if num_rank_holding_expert == 1: + log2phy_map[negative_rank_idx, idx] = torch.full((num_ranks - 1,), + log2phy_map[positive_rank_idx, idx].item(), + dtype=log2phy_map.dtype) + else: + random_list = [random.choice(log2phy_map[positive_rank_idx, idx]) + for _ in range(num_ranks - num_rank_holding_expert)] + log2phy_map[negative_rank_idx, idx] = torch.tensor(random_list,\ + dtype=log2phy_map.dtype) + + return log2phy_map + + @classmethod + def global2local(cls, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + + G, _ = placement.shape + device = placement.device + + pt_local = torch.full(( G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + g_idx, k_idx = valid.nonzero(as_tuple=True) + slot_idx = placement[g_idx, k_idx] + + pt_local[g_idx, slot_idx] = k_idx + + return pt_local + + @classmethod + def global2local_load(self, + workload: torch.Tensor, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + L, G, _ = placement.shape + device = placement.device + + wt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=workload.dtype, + device=device) + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + values = workload[l_idx, g_idx, k_idx] + + wt_local[l_idx, g_idx, slot_idx] = values + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return wt_local, pt_local \ No newline at end of file diff --git a/vllm_ascend/eplb/tool/generate_map.py b/vllm_ascend/eplb/tool/generate_map.py new file mode 100644 index 0000000000..b498e73a06 --- /dev/null +++ b/vllm_ascend/eplb/tool/generate_map.py @@ -0,0 +1,65 @@ +import numpy as np +import json +import argparse + + +def split_and_insert(n, k, m): + ''' + n: expert num + k: card num + m: redundant expert num, make sure m%k==0 + ''' + + A = np.arange(n) + + B = np.random.choice(n, size=m, replace=False) + + groups = np.array_split(A, k) + + for j in range(m // k): + for i in range(k): + groups[i] = np.append(groups[i], B[i + j * k]) + return np.concatenate(groups) + + +def random_generation(n_layer=58, n_expert=256, start_layer_idx=0, device_count=128, n_redundant=128, output_name=""): + expert_data = {} + expert_data["moe_layer_count"] = n_layer + layer_list = [] + for i in range(n_layer): + layer = {"layer_id": start_layer_idx + i, "device_count": device_count} + random_placement = split_and_insert(n_expert, device_count, n_redundant) + device_list = [] + step = random_placement.shape[0] // device_count + for j in range(device_count): + device = {} + device["device_id"] = j + device["device_expert"] = random_placement[j * step: (j + 1) * step].tolist() + device_list.append(device) + layer["device_list"] = device_list + layer_list.append(layer) + + expert_data["layer_list"] = layer_list + json_file_path = output_name + + with open(json_file_path, "w") as f: + json.dump(expert_data, f, indent=4) + + print(f"JSON file generated: {json_file_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="python generate_map.py --n_layers 2 --n_experts 256 --card_num 8 --n_redundant 8 --output expert_map.json") + parser.add_argument("--n_layers", type=int, required=True) + parser.add_argument("--n_experts", type=int, required=True) + parser.add_argument("--card_num", type=int, required=True) + parser.add_argument("--n_redundant", type=int, default=0) + parser.add_argument("--output", type=str, default="expert_map.json") + args = parser.parse_args() + + n_layers = args.n_layers + n_experts = args.n_experts + card_num = args.card_num + n_redundant = args.n_redundant + output = args.output + + random_generation(n_layers, n_experts, 0, card_num, n_redundant, output) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 490cd4ed5e..d85572b32c 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -35,14 +35,19 @@ def register_model(): ModelRegistry.register_model( "DeepseekV2ForCausalLM", "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + else: ModelRegistry.register_model( "DeepseekV2ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") ModelRegistry.register_model( "Qwen3MoeForCausalLM", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 6ab0837e37..000bd39ed5 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -25,38 +25,33 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union import torch import torch.distributed as dist import torch_npu # noqa: F401 -import vllm.envs as envs from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import (ReplicatedLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, - DeepseekV2DecoderLayer, - DeepseekV2MLAAttention) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -64,7 +59,9 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, + CustomDeepseekV2MLP) from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -74,8 +71,9 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import ( + AscendW8A8DynamicLinearMethod, apply_mlp) from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -83,16 +81,48 @@ class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__(hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + prefix=prefix) + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + def _forward_ms_mlp(self, x): current_ms_metadata = get_multistream_comm_context() assert current_ms_metadata is not None gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() + if self.is_dynamic_quant: + x, dynamic_scale = self.act_fn(gate_up) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + else: + x = self.act_fn(gate_up) x, _ = self.down_proj(x) - current_ms_metadata.after_comm_event.record() return x @@ -163,7 +193,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group - + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role = "kv_consumer" self.params_dtype = torch.get_default_dtype() ascend_config = get_ascend_config() @@ -173,39 +206,34 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + attn_metadata = forward_context.attn_metadata + # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = forward_context.in_profile_run - old_hidden_states = hidden_states.clone() + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - hidden_states = self.experts( + experts_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, top_k=CustomDeepseekDBOMoE.top_k, enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor - - if self.n_shared_experts is not None: - shared_output = self.shared_experts(old_hidden_states) + shared_experts=self.shared_experts) - if shared_output is not None: - hidden_states = hidden_states + shared_output + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) return hidden_states @@ -225,199 +253,6 @@ def _forward_ms_op_gate( router_logits, _ = self.gate(hidden_states) return router_logits - def _forward_ms_op_tp_allgather( - self, - hidden_states: torch.Tensor, - chunk_hidden_states: torch.Tensor, - num_tokens: int = 0, - ): - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - current_ms_metadata.after_comm_event.record() - return final_hidden_states - - -class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - else: - hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: - forward_kwargs = {} - if envs.VLLM_USE_V1: - output_shape = hidden_states.shape - output = torch.empty(output_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - forward_kwargs['output'] = output - - output = self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata, - **forward_kwargs) - if envs.VLLM_USE_V1: - output = output.view(-1, output_shape[-1]) - return output - else: - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) - class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): @@ -440,10 +275,7 @@ def __init__( layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx # TODO: enable mla in vllm-ascend - if model_config.use_mla: - attn_cls = CustomDeepseekDBOMLAAttention - else: - attn_cls = DeepseekV2Attention + attn_cls = CustomDeepseekV2MLAAttention self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, @@ -461,6 +293,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + self.tp_size = get_tensor_model_parallel_world_size() + self.dp_size = get_dp_group().world_size + self.tp_group = get_tp_group().device_group + self.global_num_experts = config.n_routed_experts if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -566,7 +402,26 @@ def _forward_ms_layer( shared_outputs = [] router_logits = [] chunk_hidden_states = [] - + chunk_router_logits = [] + topk_weights = [] + topk_ids = [] + num_moe_tokens = [] + original_shapes = [] + expanded_row_idx = [] + scatter_size_list = [] + gather_size_list = [] + local_expert_idx = [] + scatter_sizes = [] + expanded_expert_idx = [] + sorted_local_expert_idx = [] + sorted_idx = [] + + global_num_experts = len( + self.mlp.experts.expert_map + ) if self.mlp.experts.expert_map is not None else self.global_num_experts + ep_group = get_ep_group() + local_num_experts = global_num_experts // ep_group.world_size + fused_moe_state = get_forward_context().fused_moe_state # block 1 : attention # block 2 : attn tp communication # the attn computation of microbatch 1 can be overlapped with the moe @@ -631,88 +486,221 @@ def _forward_ms_layer( # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata[i] is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata[i].num_prefills > 0 - enable_force_load_balance = False - - if self.mlp.tp_size > 1: - num_token, _ = hidden_states[i].shape - padded_num_tokens = (self.mlp.tp_size - num_token % - self.mlp.tp_size) % self.mlp.tp_size - if padded_num_tokens > 0: - hidden_states[i] = nn.functional.pad( - hidden_states[i], (0, 0, 0, padded_num_tokens)) - chunk_hidden_state = torch.tensor_split(hidden_states[i], - self.mlp.tp_size, - dim=0) - chunk_hidden_states.append(chunk_hidden_state) - local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] - else: - local_hidden_states = hidden_states[i] - - router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logit = self.mlp._forward_ms_op_gate(hidden_states[i]) router_logits.append(router_logit) if CustomDeepseekDBOMoE.top_k: real_top_k = CustomDeepseekDBOMoE.top_k else: real_top_k = self.mlp.experts.top_k + if (self.tp_size > 1 + and fused_moe_state != FusedMoEState.AllGather): + if num_tokens[i] < self.tp_size: + hidden_states[i] = nn.functional.pad( + hidden_states[i], + (0, 0, 0, self.tp_size - num_tokens[i])) + router_logits[i] = nn.functional.pad( + router_logits[i], + (0, 0, 0, self.tp_size - num_tokens[i])) + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + chunk_router_logit = torch.tensor_split(router_logits[i], + self.tp_size, + dim=0) + chunk_router_logits.append(chunk_router_logit) + tp_rank = get_tensor_model_parallel_rank() + hidden_states[i] = chunk_hidden_states[i][tp_rank] + router_logits[i] = chunk_router_logits[i][tp_rank] + + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + if attn_metadata[i] is not None: + max_num_tokens_across_dp = attn_metadata[ + i].max_tokens_across_dp + if num_tokens[i] < max_num_tokens_across_dp: + hidden_states[i] = nn.functional.pad( + hidden_states[i], + (0, 0, 0, + max_num_tokens_across_dp - num_tokens[i])) + router_logits[i] = nn.functional.pad( + router_logits[i], + (0, 0, 0, + max_num_tokens_across_dp - num_tokens[i])) + hidden_states[i] = get_dp_group().all_gather( + hidden_states[i], 0) + router_logits[i] = get_dp_group().all_gather( + router_logits[i], 0) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weight, topk_id, _ = torch_npu.npu_moe_gating_top_k( + router_logits[i], + k=real_top_k, # topk当前写8 + bias=self.mlp.experts.e_score_correction_bias, + k_group=self.mlp.experts.topk_group, # fix: 4 + group_count=self.mlp.experts.num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weight, topk_id = self.mlp.experts.select_experts( + hidden_states=hidden_states[i], + router_logits=router_logits[i], + top_k=real_top_k, + use_grouped_topk=self.mlp.experts.use_grouped_topk, + renormalize=self.mlp.experts.renormalize, + topk_group=self.mlp.experts.topk_group, + num_expert_group=self.mlp.experts.num_expert_group, + custom_routing_function=self.mlp.experts. + custom_routing_function, + scoring_func=self.mlp.experts.scoring_func, + e_score_correction_bias=self.mlp.experts. + e_score_correction_bias, + ) + topk_weight = topk_weight.to(hidden_states[i].dtype) + topk_weights.append(topk_weight) + topk_ids.append(topk_id) + original_shape = hidden_states[i].shape + original_shapes.append(original_shape) + if len(original_shapes[i]) == 3: + hidden_states[i] = hidden_states[i].view( + -1, hidden_states[i].shape[-1]) + num_token, _ = hidden_states[i].shape + num_moe_tokens.append(num_token) + device = hidden_states[i].device + + row_idx_len = num_moe_tokens[i] * real_top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(real_top_k, + -1).permute( + 1, 0).contiguous()) + hidden_states[ + i], expanded_row_idx_i, expanded_expert_idx_i = torch_npu.npu_moe_init_routing( + hidden_states[i], + row_idx=row_idx, + expert_idx=topk_ids[i], + active_num=num_moe_tokens[i]) + expanded_row_idx.append(expanded_row_idx_i) + expanded_expert_idx.append(expanded_expert_idx_i) - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( - local_hidden_states, router_logits[i], is_prefill, real_top_k, - enable_force_load_balance) - - # the following kernels will be submitted to the comm stream to overlap the computation of the - # moe computation of next microbatch and the attn computation of next layer context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH], + MSEventKey.MOE_ALL_TO_ALL], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.MOE_ALL_TO_ALL_FINISH], ) context.before_comm_event.record() with torch.npu.stream(ms_metadata.communicate_stream): context.before_comm_event.wait() - if self.mlp.experts.reduce_results and ( - self.mlp.experts.tp_size > 1 - or self.mlp.experts.ep_size > 1): - hidden_states[i] = tensor_model_parallel_all_reduce( - hidden_states[i]) - hidden_states[ - i] = hidden_states[i] * self.mlp.routed_scaling_factor + global_expert_tokens = torch.bincount( + expanded_expert_idx[i], minlength=global_num_experts) + scatter_size = global_expert_tokens.view( + ep_group.world_size, -1).sum(-1) + scatter_sizes.append(scatter_size) + gather_sizes = torch.empty_like(scatter_sizes[i]) + dist.all_to_all_single(gather_sizes, + scatter_sizes[i], + group=ep_group.device_group) + scatter_size_list_i = scatter_sizes[i].cpu().tolist() + gather_size_list_i = gather_sizes.cpu().tolist() + scatter_size_list.append(scatter_size_list_i) + gather_size_list.append(gather_size_list_i) + expanded_expert_idx[ + i] = expanded_expert_idx[i] % local_num_experts + hidden_states[i] = ep_group.all_to_all(hidden_states[i], 0, 0, + scatter_size_list[i], + gather_size_list[i]) + local_expert_idx_i = ep_group.all_to_all( + expanded_expert_idx[i], 0, 0, scatter_size_list[i], + gather_size_list[i]) + local_expert_idx.append(local_expert_idx_i) + + sorted_local_expert_idx_i, sorted_idx_i = torch.sort( + local_expert_idx[i]) + sorted_local_expert_idx.append(sorted_local_expert_idx_i) + sorted_idx.append(sorted_idx_i) context.after_comm_event.record() + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_ALL_TO_ALL_FINISH) + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx[i], local_num_experts).to(torch.int64) + group_list_type = 0 + hidden_states[i] = hidden_states[i][sorted_idx[i]] + hidden_states[i] = apply_mlp( + hidden_states[i], + self.mlp.experts.w13_weight, + self.mlp.experts.w13_weight_scale, #17 + self.mlp.experts.w2_weight, + self.mlp.experts.w2_weight_scale, + expert_tokens, #16 + group_list_type=group_list_type, + w1_scale_bias=None, + w2_scale_bias=None) + + resorted_idx = torch.argsort(sorted_idx[i]) + hidden_states[i] = hidden_states[i][resorted_idx] + hidden_states[i] = ep_group.all_to_all(hidden_states[i], 0, 0, + gather_size_list[i], + scatter_size_list[i]) + + hidden_states[i] = torch_npu.npu_moe_finalize_routing( + hidden_states[i], + skip1=None, + skip2=None, + bias=None, + scales=topk_weights[i], + expanded_src_to_dst_row=expanded_row_idx[i], + export_for_source_row=topk_ids[i], + ) + if len(original_shapes[i]) == 3: + hidden_states[i] = hidden_states[i].view(original_shapes[i]) + + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.FFN_COM_FINISH], after_comm_event=ms_metadata.ms_events[layer_index][i][ MSEventKey.FFN_AR_FINISH], ) - with set_multistream_context(context, i): - if self.mlp.tp_size > 1: - hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( - hidden_states[i], chunk_hidden_states[i], - padded_num_tokens) + context.before_comm_event.record() with torch.npu.stream(ms_metadata.communicate_stream): + context.before_comm_event.wait() + if (self.tp_size > 1 + and fused_moe_state != FusedMoEState.AllGather): + dist.all_gather(list(chunk_hidden_states[i]), + hidden_states[i], self.tp_group) + hidden_states[i] = torch.cat(chunk_hidden_states[i], dim=0) + if num_tokens[i] < self.tp_size: + hidden_states[i] = hidden_states[i][:num_tokens[i]] + elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + hidden_states[ + i] = dist._functional_collectives.reduce_scatter_tensor( + hidden_states[i], + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + hidden_states[i] = hidden_states[i][:num_tokens[i]] + if self.tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) # last if shared_outputs[i] is not None: - hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[ + i] * self.routed_scaling_factor + shared_outputs[i] hidden_states[i] = hidden_states[i].view( num_tokens[i], hidden_dims[i]) - if isinstance(self.mlp, CustomDeepseekDBOMLP - ) and hidden_states[i].dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states[i] *= 1. / self.routed_scaling_factor context.after_comm_event.record() return hidden_states, residual @@ -767,9 +755,7 @@ def _forward_ms_op_post_attn_layernorm( class CustomDeepseekDBOModel(nn.Module): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -835,6 +821,7 @@ def forward( attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -848,10 +835,12 @@ def forward( residual = intermediate_tensors["residual"] num_normal_layers = (self.first_k_dense_replace - if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() - else self.end_layer - self.start_layer) + if VLLM_ASCEND_ENABLE_DBO and not graph_enable + and self.can_run_ms() else self.end_layer - + self.start_layer) - for i in range(self.start_layer, self.start_layer + num_normal_layers): + moe_start_layer = self.start_layer + num_normal_layers + for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, residual, @@ -859,8 +848,7 @@ def forward( self.start_layer] if kv_caches is not None else None, attn_metadata) - moe_start_layer = self.start_layer + num_normal_layers - if moe_start_layer != self.end_layer: + if moe_start_layer < self.end_layer: # if we enable multistream/dbo, process sparse layers here hidden_states, residual = self._forward_ms_layers( positions=positions, @@ -881,34 +869,18 @@ def forward( def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata - # support mla attention and V1 engine at present - if not self.use_mla or not envs.VLLM_USE_V1: - return False # enable prefill overlap - if attn_metadata is None or attn_metadata.num_prefills == 0: - return False - else: - [token_index, seq_index - ] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return False - # check whether the total tokens exceed the threshold - if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + if attn_metadata is None or attn_metadata.num_prefills == 0 or not attn_metadata.enable_dbo_across_dp: return False return True - def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - kv_caches: Optional[List[torch.Tensor]] = None, - is_prefill: bool = False, - ): + def _forward_ms_layers(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False): if moe_start_layer == self.end_layer: return hidden_states, residual @@ -970,8 +942,9 @@ def forward( attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds) + inputs_embeds, graph_enable) return hidden_states diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 979a6099f1..400c7a0acf 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -28,8 +28,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -40,6 +40,20 @@ from .deepseek_v2 import CustomDeepseekV2DecoderLayer +class CustomDeepSeekShareHead(SharedHead): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + nn.Module.__init__(self) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) + + class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): def __init__( @@ -61,7 +75,10 @@ def __init__( self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.shared_head = CustomDeepSeekShareHead(config=config, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "shared_head")) self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, model_config, cache_config, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e96b2e9847..6e215b6b81 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,7 +25,7 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch_npu @@ -33,11 +33,11 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_dp_group, get_pp_group, get_tensor_model_parallel_world_size, get_tp_group) -from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -65,7 +65,6 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod @@ -285,7 +284,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role == "kv_consumer" self.params_dtype = torch.get_default_dtype() @@ -293,23 +295,25 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + attn_metadata = forward_context.attn_metadata + # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + if self.enable_multistream_moe: + router_logits = None + else: + router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( hidden_states=hidden_states, @@ -318,6 +322,7 @@ def forward( top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts, + gate=self.gate if self.enable_multistream_moe else None, ) hidden_states = ( @@ -477,7 +482,8 @@ def forward( hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: + is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model + if self.torchair_graph_enabled and not is_mtp_model: forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = hidden_states.shape @@ -727,13 +733,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.num_dense_layers = self.config.first_k_dense_replace + self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers + self.model = CustomDeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) @@ -755,6 +766,39 @@ def forward( inputs_embeds) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = filter(lambda x: ".module." not in x[0], weights) + # weights = ((name, data) for name, data in weights if ".module." not in name) + loaded_params = super().load_weights(weights) + + return loaded_params + + def get_expert_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_map() + + def get_log2phy_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + + def get_all_expert_map(self, num_moe_layers): + all_loads = [] + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map(3+layer_id) # (num_experts_per_layer,) + all_loads.append(load_tensor) + + return torch.stack(all_loads, dim=0) + + def get_all_moe_loads(self): + all_moe_loads = torch.stack( + [self.model.layers[layer_id + self.num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + return all_moe_loads + + def clear_all_moe_loads(self): + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + self.num_dense_layers].mlp.experts.clear_moe_load() class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): pass diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py index fba58b460e..420839cde6 100644 --- a/vllm_ascend/multistream/base.py +++ b/vllm_ascend/multistream/base.py @@ -14,6 +14,8 @@ class MSEventKey(Enum): MOE_SE_COMM_FINISH = 6 MOE_SE_COMP_FINISH = 7 MOE_GATE_FINISH = 8 + MOE_ALL_TO_ALL = 9 + MOE_ALL_TO_ALL_FINISH = 10 @dataclass diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index b521d3f85f..e451f15f26 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -170,6 +170,8 @@ def make_multistream_metadata_ds( MSEventKey.MOE_SE_COMM_FINISH, MSEventKey.MOE_SE_COMP_FINISH, MSEventKey.MOE_GATE_FINISH, + MSEventKey.MOE_ALL_TO_ALL, + MSEventKey.MOE_ALL_TO_ALL_FINISH, ] return MultiStreamMetadata( calculate_stream=torch.npu.current_stream(), diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 430f57b03a..fd32a18abb 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn( seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) - query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] - query_start_loc_post = deepcopy( - attn_metadata.query_start_loc[seq_index:] - ) - attn_metadata.query_start_loc[seq_index] + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) @@ -223,7 +225,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, prefill=prefill_pre, decode=decode_pre, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) attention_metadata_post = _metadata_cls( num_actual_tokens=attn_metadata.num_actual_tokens - token_index, @@ -240,6 +242,6 @@ def model_input_split_v1_mla_attn( attn_state=attn_state_post, prefill=prefill_post, decode=decode_post, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 8037c9545b..05600aee7a 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Tuple import torch from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -37,7 +37,7 @@ def vanilla_chunked_prefill( scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool = True, -) -> None: +) -> torch.Tensor: num_query_heads = query.shape[1] head_dim = value_cache.shape[3] num_kv_heads = value_cache.shape[2] @@ -138,7 +138,8 @@ def vanilla_chunked_prefill( def vanilla_chunked_prefill_mla( output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv) + kv_cache: Tuple[ + torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) query_lens: torch.Tensor, # (batch_size) context_lens: torch.Tensor, # (batch_size) @@ -152,22 +153,25 @@ def vanilla_chunked_prefill_mla( alibi_slopes: Optional[torch.Tensor], causal: bool = True) -> None: batch_size = block_tables.size(0) + assert len(kv_cache) > 1 assert query_lens.size(0) == batch_size num_heads = query.size(1) - block_size = kv_cache.size(1) - latent_kv_dim = kv_cache.size(3) - rope_dim + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + block_size = nope_cache.size(1) + latent_kv_dim = nope_cache.size(-1) max_num_blocks_per_seq = block_tables.size(1) batch_size = query_lens.size(0) - kv_cache = kv_cache.squeeze() - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] - cache_kv_c_pe = kv_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - latent_kv_dim + rope_dim)[:, :max_context_len, :] - # get kv_c and k_pe + nope_cache = nope_cache.squeeze() + # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe # cached_kv_c: [batch_size, max_context_len, latent_kv] # cached_k_pe: [batch_size, max_context_len, rope_dim] - cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim] - cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:] + cache_kv_c = nope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim)[:, :max_context_len, :] + cache_k_pe = rope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + rope_dim)[:, :max_context_len, :] # get k_rope and v # k_nope: [batch_size, max_context_len, num_heads, nope_dim] # value: [batch_size, max_context_len, num_heads, v_head_dim] @@ -258,8 +262,8 @@ def vanilla_chunked_prefill_mla( attn_output = (attn_output[q_mask].view([-1, num_heads, v_head_dim]).to(output.dtype)) - output = output.view([-1, num_heads, v_head_dim]) - output.copy_(attn_output[:query.size(0) - num_add_query]) + attn_output = attn_output.view_as(output) + output.copy_(attn_output) return attn_output diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 05daf69f79..e0819210d8 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import math import os from typing import Any, Callable, List, Optional, Tuple, Union @@ -26,7 +27,8 @@ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group, get_tp_group +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, @@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, +from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, + get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -118,9 +120,24 @@ def fused_experts_with_mc2( top_k: int, expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None + shared_experts: Optional[Any] = None, + is_torchair: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - global_bs = 0 + quant_mode = 0 + ep_group = get_ep_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + tp_world_size = get_tp_group().world_size + + # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, + # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + tp_world_size) * ep_world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -131,27 +148,20 @@ def fused_experts_with_mc2( "global_bs": global_bs, } - rank = torch.distributed.get_rank() - - quant_mode = 0 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) - - tp_size = get_etp_group().world_size - tp_rank = rank % tp_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) @@ -204,20 +214,22 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, + "global_bs": global_bs, } tp_recv_counts = output[5] stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -847,17 +859,14 @@ def __init__(self, moe: MoEConfig = None): super().__init__(moe=moe) vllm_config = get_current_vllm_config() - self.ep_group = get_ep_group() - self.ep_size = self.ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - self.local_batch_size = self.global_batch_size // self.ep_size self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: - device_group = self.ep_group.device_group + device_group = get_ep_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -933,8 +942,7 @@ def apply( if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) + fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, @@ -945,7 +953,8 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts) + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1046,8 +1055,14 @@ def __init__( self.log2phy = None self.global_redundant_expert_num = 0 + # TODO: if this is not need for dynamic eplb with redundant expert, remove this + # self.log2phy = torch.full((self.ep_size, self.global_num_experts), + # -1, + # dtype=torch.int32) + ascend_config = get_ascend_config() expert_map_path = ascend_config.expert_map_path + self.dynamic_eplb = ascend_config.dynamic_eplb if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance expert_load_balancer = ExpertLoadBalancer(expert_map_path, @@ -1055,17 +1070,15 @@ def __init__( self.local_num_experts, self.expert_map = \ expert_load_balancer.get_rank_placement_map( self.moe_instance_id, - get_ep_group().rank_in_group) + self.ep_rank) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, - get_ep_group().rank_in_group) + self.moe_instance_id, self.ep_rank) self.global_redundant_expert_num = \ expert_load_balancer.get_global_redundant_expert_num() else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ @@ -1095,6 +1108,10 @@ def __init__( local_num_experts = torch.sum(self.expert_map != -1) \ if self.expert_map is not None else num_experts + self.moe_load = None + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, @@ -1108,7 +1125,6 @@ def __init__( in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size - self.ep_group = get_ep_group() # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -1119,7 +1135,8 @@ def forward(self, is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None): + shared_experts: Optional[Any] = None, + gate: Optional[Any] = None): assert self.quant_method is not None if top_k: @@ -1129,8 +1146,21 @@ def forward(self, num_tokens, hidden_size = hidden_states.shape - fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill) + fused_moe_state = get_forward_context().fused_moe_state + # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. + quantized_x_for_share, dynamic_scale_for_share = None, None + from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod + if self.enable_multistream_moe: + assert gate is not None + router_logits, _ = gate(hidden_states) + if isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) + if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) @@ -1154,21 +1184,20 @@ def forward(self, if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + max_num_tokens_across_dp = get_forward_context( + ).max_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) # Matrix multiply. - e_hidden_states = self.quant_method.apply( + e_hidden_states, expert_token_num, group_list_type = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -1188,11 +1217,16 @@ def forward(self, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states + if self.dynamic_eplb: + self.moe_load += expert_token_num if group_list_type else \ + torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: dist.all_gather(list(chunk_hidden_states), e_hidden_states, @@ -1249,3 +1283,17 @@ def _forward_ms_fused_moe_comp( enable_force_load_balance=enable_force_load_balance) return hidden_states + + def update_map(self,new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.log2phy + + def clear_moe_load(self): + self.moe_load.zero_() + + diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 39a4c1cfe8..f55ab8e0cb 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -80,10 +80,7 @@ def native_rope_deepseek_forward(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style @@ -198,8 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len +def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -219,9 +215,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -266,11 +260,10 @@ def deepseek_rope_init_func( super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = max_position_embeddings * scaling_factor + _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") RotaryEmbedding.forward_oot = rope_forward_oot diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index d817f9063e..ae87010359 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -47,7 +47,16 @@ # Related PR (if no, explain why): # Future Plan: # Remove those patch when vllm merged them -# 2. `vllm.config.ParallelConfig.get_next_dp_init_port` +# 2. `vllm.v1.engine.core.DPEngineCoreProc._init_data_parallel` +# Why: +# There is some bug for ASCEND_RT_VISIBLE_DEVICES usage. +# How: +# The ASCEND_RT_VISIBLE_DEVICES related code is dropped. +# Related PR (if no, explain why): +# No, this is a bug for vllm ascend +# Future Plan: +# Remove this patch once ASCEND_RT_VISIBLE_DEVICES bug is fixed. +# 3. `vllm.config.ParallelConfig.get_next_dp_init_port` # Why: # vllm doesn't support get port from environment. # How: @@ -56,7 +65,7 @@ # Need a PR to vllm to support get port from environment. # Future Plan: # Remove those patch when vllm merged them -# 3. `vllm.config.ParallelConfig.ParallelConfig.stateless_init_dp_group` +# 4. `vllm.config.ParallelConfig.ParallelConfig.stateless_init_dp_group` # Why: # vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to # get better performance @@ -65,7 +74,19 @@ # Related PR (if no, explain why): # Need a PR to vllm to support more backend. # Future Plan: -# Remove those patch when vllm support more backend. +# Remove those patch when vllm merged them +# +# ** File: platform/patch_common/patch_scheduler.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.core.sched.scheduler.Scheduler.destroy_model_parallel()` +# Why: +# Vllm transfer the kv blocks data only when this block have already been full filled. However, this behaviour may cause decode node +# exist prefill behaviour. In order to make decode node work as expected, we always transfer all data whether or not the block is filled. +# How: +# The num_computed_token shall always equals to the token number of request during scheduling. +# Related PR (if no, explain why): https://github.com/vllm-project/vllm/pull/17751 (nixl implementation) +# Future Plan: +# No plan, we will maintain this patch util vllm change it behaviour # # * Worker Patch: # =============== @@ -100,18 +121,6 @@ # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. # -# 2. `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and -# `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace` -# Why: -# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change. -# How: -# Use vLLM 0.8.4 method to patch it. -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm/pull/15195 -# - https://github.com/vllm-project/vllm-ascend/pull/395 -# Future Plan: -# Remove it when we identify the reasons clearly. -# # ** File: worker/patch_common/patch_spec_decode_worker.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker` diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index 86515df86e..d971922840 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -17,8 +17,6 @@ # Adapted from vllm/model_executor/models/qwen2_vl.py # This file is a part of the vllm-ascend project. -import vllm -import vllm.distributed import vllm.envs as envs from torch.distributed import ProcessGroup from vllm.config import ParallelConfig @@ -26,25 +24,6 @@ stateless_init_torch_distributed_process_group -def ascend_destroy_model_parallel(): - """Set the groups to none and destroy them.""" - from vllm.distributed.parallel_state import _DP, _PP, _TP - if _TP: - _TP.destroy() - _TP = None - - if _PP: - _PP.destroy() - _PP = None - - if _DP: - _DP.destroy() - _DP = None - from vllm_ascend.distributed.parallel_state import \ - destory_ascend_model_parallel - destory_ascend_model_parallel() - - def parallel_config_get_dp_port(self) -> int: """ We might need to initialize process groups in multiple @@ -78,6 +57,5 @@ def stateless_init_dp_group(self) -> "ProcessGroup": return dp_group -vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port ParallelConfig.stateless_init_dp_group = stateless_init_dp_group diff --git a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py index ca87729540..53ce312676 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py @@ -88,20 +88,4 @@ def sampler_output( return filtered_model_outputs, True -def set_include_gpu_probs_tensor(self) -> None: - # Need include_gpu_probs_tensor for MultiSteoWorker - if hasattr(self.model_runner.model, "sampler"): - self.model_runner.model.sampler.include_gpu_probs_tensor = True - self.model_runner.sampler.include_gpu_probs_tensor = True - - -def set_should_modify_greedy_probs_inplace(self) -> None: - if hasattr(self.model_runner.model, "sampler"): - self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( - True) - self.model_runner.sampler.should_modify_greedy_probs_inplace = True - - MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output) -MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor -MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace diff --git a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py index 66e7aa56b2..d271e65bfc 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py @@ -57,11 +57,6 @@ def create_worker( ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - # TODO(Yizhou): A quick fix, must be refactored ASAP - draft_worker_kwargs["vllm_config"].parallel_config.expert_parallel_size = 1 - draft_worker_kwargs[ - "vllm_config"].parallel_config.expert_tensor_parallel_size = 1 - draft_model_config = draft_worker_kwargs["vllm_config"].model_config draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config @@ -72,6 +67,13 @@ def create_worker( proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: + # TODO(Yizhou): A quick fix, must be refactored ASAP + # ngram need not this fix. + draft_worker_kwargs[ + "vllm_config"].parallel_config.expert_parallel_size = 1 + draft_worker_kwargs[ + "vllm_config"].parallel_config.expert_tensor_parallel_size = 1 + draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b9233da05d..08abc08cdd 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -125,17 +125,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config if parallel_config: - # Default value for expert tensor parallel size - parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size - - # NOTE: When enable_expert_parallel is True, we follow vLLM convention: - # ep_size = world_size, which means expert_tensor_parallel_size must be 1 if parallel_config.enable_expert_parallel: parallel_config.expert_tensor_parallel_size = 1 - # NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size` - # is configured, use ascend_config - elif ascend_config.expert_tensor_parallel_size > 0: - parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size + else: + parallel_config.expert_tensor_parallel_size = parallel_config.world_size_across_dp # Calculate expert parallel size based on world size parallel_config.expert_parallel_size = ( @@ -177,8 +170,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") compilation_config.use_inductor = False - compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + if not compilation_config.full_cuda_graph: + compilation_config.splitting_ops.extend( + ["vllm.unified_ascend_attention_with_output"]) update_aclgraph_sizes(vllm_config) if parallel_config and parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3567dba355..1b06a4294a 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -104,6 +106,12 @@ def get_quant_method(self, layer: torch.nn.Module, return AscendUnquantizedFusedMoEMethod() return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) + elif isinstance(layer, VocabParallelEmbedding): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedEmbeddingMethod() + return AscendEmbeddingMethod(self, prefix, + self.packed_modules_mapping) return None def is_layer_skipped_ascend( @@ -194,6 +202,17 @@ def create_weights( layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, output_size_per_partition, params_dtype) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + setattr(param, "input_dim", 1) + param.input_dim = 1 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) @@ -305,6 +324,10 @@ def create_weights( param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in param_key or "weight_offset_second" in param_key: + setattr(param, "quant_method", + FusedMoeWeightScaleSupported.GROUP.value) + param.quant_method = FusedMoeWeightScaleSupported.GROUP.value def apply( self, @@ -337,3 +360,20 @@ def apply( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) + + +class AscendEmbeddingMethod(AscendLinearMethod): + """Embedding method for Ascend quantization. + + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for Embedding methods. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description, prefix, packed_modules_mapping) + self.quant_method = self.quantizer.build_linear_method() diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index ea1297bf35..d27914139f 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -24,6 +24,8 @@ from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init) +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) from .w8a8 import AscendW8A8LinearMethod from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) @@ -263,6 +265,17 @@ def get_quantizer(cls, f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") +class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW4A8DynamicLinearMethod() + + @staticmethod + def build_moe_method(): + return AscendW4A8DynamicFusedMoEMethod() + + class W8A8Quantizer(VLLMAscendQuantizer): @staticmethod @@ -282,6 +295,7 @@ def build_moe_method(): SUPPORT_ASCEND_QUANTIZER_TYPE = { + "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, } diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py new file mode 100644 index 0000000000..227b6b680a --- /dev/null +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -0,0 +1,377 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import numpy as np +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, + fused_experts_with_mc2) + + +class AscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + self.group_size = get_current_vllm_config( + ).quant_config.quant_description.get("group_size", 256) + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, + per_group_scale: torch.Tensor): + k, n = weight.shape + group_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + ) + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) + + +class AscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + try: + device_group = self.ep_group.device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + config = get_current_vllm_config() + group_size = config.quant_config.quant_description.get( + "group_size", 256) + + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + + fused_moe_state = get_forward_context().fused_moe_state + if fused_moe_state == FusedMoEState.MC2: + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + group_num, k, n = weight.shape + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) + layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_scale_second.data, bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + layer.w13_weight_scale_second.data) + param = torch.nn.Parameter(bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", param) + layer.w2_weight_scale_second.data, bias1 = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + layer.w2_weight_scale_second.data) + param = torch.nn.Parameter(bias1, requires_grad=False) + layer.register_parameter("w2_scale_bias", param) + + layer.w13_weight.data = torch_npu.npu_quantize( + layer.w13_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) + layer.w2_weight.data = torch_npu.npu_quantize( + layer.w2_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index db23cb024d..28925034c1 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -20,6 +20,9 @@ import torch import torch_npu +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset: torch.Tensor): @@ -37,6 +40,8 @@ class AscendW8A8LinearMethod: def __init__(self) -> None: # aclnn quant matmul requires to transpose matrix B, set to true by default. self.transpose_weight = True + ascend_config = get_ascend_config() + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout @staticmethod def get_weight( @@ -77,6 +82,10 @@ def get_perchannel_param( dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, @@ -110,6 +119,9 @@ def process_weights_after_loading(self, layer): requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 372c29bca7..d9738b9b55 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,19 +15,96 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional, Tuple, Union +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator +from vllm.distributed import GroupCoordinator, get_ep_group, get_tp_group +from vllm.forward_context import get_forward_context +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.fused_moe import select_experts -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_ascend_soc_version, + npu_stream_switch, npu_wait_tensor) + + +def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + Returns: + hidden_states: output hidden states after MLP. + """ + + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() + if dynamic_scale is None: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states def apply_mlp(hidden_states: torch.Tensor, @@ -37,7 +114,9 @@ def apply_mlp(hidden_states: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -71,17 +150,31 @@ def apply_mlp(hidden_states: torch.Tensor, else: pertoken_scale = dynamic_scale + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], scale=[w1_scale], + bias=bias1, per_token_scale=[pertoken_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) @@ -93,12 +186,13 @@ def apply_mlp(hidden_states: torch.Tensor, x=[hidden_states], weight=[w2], scale=[w2_scale], + bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] return hidden_states @@ -117,11 +211,33 @@ def fused_experts_with_mc2( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + is_torchair: bool = False, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if log2phy is not None: + if log2phy: topk_ids = log2phy[topk_ids] - global_bs = 0 - moe_expert_num = len(expert_map) + global_redundant_expert_num + quant_mode = 2 + ep_group = get_ep_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + tp_world_size = get_tp_group().world_size + + # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, + # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + tp_world_size) * ep_world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num # hidden_states = hidden_states.bfloat16() kwargs_mc2 = { "x": hidden_states, @@ -130,53 +246,43 @@ def fused_experts_with_mc2( "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": global_bs, - "expert_scales": topk_weights.to(torch.float32), } - rank = torch.distributed.get_rank() - - quant_mode = 2 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) - - world_szie = torch.distributed.get_world_size() - tp_size = world_szie // all_to_all_group_size - tp_rank = rank % tp_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ - 0:7] + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - npu_wait_tensor(shared_gate_up[0], expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) + npu_wait_tensor(quantized_x_for_share, expand_x) + shared_act_out = shared_experts.act_fn( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] # `expand_x` will be disposed in the `apply_mlp` function - down_out_list = apply_mlp(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) + down_out_list = apply_mlp_decode([expand_x], + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) # moeCombine kwargs_mc2 = { @@ -187,8 +293,7 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, - "expand_scales": expand_scales, + "global_bs": global_bs, } tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -196,44 +301,47 @@ def fused_experts_with_mc2( stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - + group_list_type = 1 if shared_experts is None: - return hidden_states + return hidden_states, expert_token_nums, group_list_type else: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act[0], down_out_list) - shared_output, _ = shared_experts.down_proj(shared_act) - return hidden_states, shared_output + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + return hidden_states, shared_output, expert_token_nums, group_list_type # currently expert parallelism implemented with all2all # is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, -): - if log2phy is not None: +def fused_experts_with_all2all(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None): + if log2phy: topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: @@ -311,7 +419,9 @@ def fused_experts_with_all2all( w2, w2_scale, expert_tokens, #16 - group_list_type=group_list_type) + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -343,7 +453,7 @@ def fused_experts_with_all2all( ) if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type def fused_experts(hidden_states: torch.Tensor, @@ -457,7 +567,7 @@ def fused_experts(hidden_states: torch.Tensor, if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type class AscendW8A8DynamicLinearMethod: @@ -466,6 +576,8 @@ class AscendW8A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True + ascend_config = get_ascend_config() + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout @staticmethod def get_weight(input_size: int, output_size: int, @@ -493,6 +605,10 @@ def get_perchannel_param( dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, @@ -527,8 +643,10 @@ def apply( def process_weights_after_loading(self, layer): if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - # cast quantized weight tensors in NZ format (29) for higher inference speed - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -545,6 +663,7 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout try: device_group = self.ep_group.device_group @@ -618,6 +737,8 @@ def apply( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -652,6 +773,16 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -660,14 +791,12 @@ def apply( topk_weights = topk_weights.to(x.dtype) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, @@ -676,7 +805,11 @@ def apply( moe_all_to_all_group_name=self.moe_all_to_all_group_name, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts) + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + **kwargs) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -713,8 +846,16 @@ def process_weights_after_loading(self, layer): 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( diff --git a/vllm_ascend/soc_info.py b/vllm_ascend/soc_info.py new file mode 100644 index 0000000000..ac1317e8e1 --- /dev/null +++ b/vllm_ascend/soc_info.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +import torch_npu + + +@dataclass +class NPUSocInfo: + is_a3: bool = False + + def __post_init__(self): + torch_npu.npu._lazy_init() + self.soc_version = torch_npu._C._npu_get_soc_version() + if self.soc_version in (250, 251, 252, 253, 254, 255): + self.is_a3 = True diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index eeab287906..f7ca0aba2e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -20,12 +20,13 @@ import atexit import math from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch -import torch_npu # noqa: F401 +import torch_npu import torchair # type: ignore[import] # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event @@ -57,6 +58,9 @@ CUSTOM_OP_ENABLED = None +ACL_FORMAT_ND = 2 +ACL_FORMAT_FRACTAL_NZ = 29 + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib @@ -168,6 +172,27 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: original_sizes, compilation_config.cudagraph_capture_sizes = \ compilation_config.cudagraph_capture_sizes, None + if compilation_config.full_cuda_graph: + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + truncated_sizes = [x for x in original_sizes if x <= max_num_seqs] + compilation_config.init_with_cudagraph_sizes(truncated_sizes) + + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + + logger.warning(warning_message) + return + # Calculate parallel configuration factor num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers parallel_config = vllm_config.parallel_config @@ -278,19 +303,58 @@ def npu_wait_tensor(self: torch.Tensor, return _npu_wait_tensor(self, dependency) if enabled else self -# TODO(zzzzwwjj): move this into forward_context -class FusedMoEState(Enum): - AllGather = 0 - All2All = 1 - MC2 = 2 +class AscendSocVersion(Enum): + A2 = 0 + A3 = 1 + MAX = 2 + + +_ascend_soc_version = None -# TODO(zzzzwwjj): add soc_version to choose branch -def get_fused_moe_state(ep_size: int, with_prefill: bool): - if ep_size == 1: - return FusedMoEState.AllGather - # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. - elif ep_size < 16 or with_prefill: - return FusedMoEState.All2All +def init_ascend_soc_version(): + soc_version = torch_npu.npu.get_soc_version() + global _ascend_soc_version + if 220 <= soc_version <= 225: + _ascend_soc_version = AscendSocVersion.A2 + elif 250 <= soc_version <= 255: + _ascend_soc_version = AscendSocVersion.A3 else: - return FusedMoEState.MC2 + _ascend_soc_version = AscendSocVersion.MAX + + +def get_ascend_soc_version(): + global _ascend_soc_version + assert _ascend_soc_version is not None + return _ascend_soc_version + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py index 1306b1e160..bfd513d5fe 100644 --- a/vllm_ascend/worker/draft_model_runner.py +++ b/vllm_ascend/worker/draft_model_runner.py @@ -18,7 +18,6 @@ from typing import List, Optional import torch -from vllm.forward_context import set_forward_context from vllm.logger import logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs @@ -27,6 +26,7 @@ ModelRunnerInputBase, ModelRunnerWrapperBase) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AscendMetadata # A flag to enable debug prints for the updated input tensors @@ -51,12 +51,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): """ def __init__(self, model_runner: ModelRunnerBase): - if hasattr( - model_runner, - "return_hidden_states") and model_runner.return_hidden_states: - raise ValueError( - "return_hidden_states is not supported for TP1DraftModelRunner." - ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -211,6 +205,9 @@ def execute_model( if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") + if model_input.inputs_embeds is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "inputs_embeds") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" @@ -264,14 +261,15 @@ def execute_model( spec_step_idx = kwargs.get("spec_step_idx", step) model_execute_kwargs["spec_step_idx"] = spec_step_idx compute_logits_kwargs["spec_step_idx"] = spec_step_idx - with set_forward_context(model_input.attn_metadata, - self.vllm_config): + with set_ascend_forward_context(model_input.attn_metadata, + self.vllm_config): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions hidden_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -293,6 +291,9 @@ def execute_model( ) outputs.append(output) + if self.return_hidden_states and is_fallback: + output.hidden_states = hidden_states + if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 48c5d4b68f..7846f655d1 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -35,7 +35,6 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import logger from vllm.lora.layers import LoRAMapping @@ -66,6 +65,7 @@ _init_sampling_metadata_from_tensor_dict) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1431,8 +1431,12 @@ def execute_model( model_forward_start.record() if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): + with set_ascend_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + with_prefill=prefill_meta is not None, + in_profile_run=self.in_profile_run): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions if self.torchair_graph_enabled: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 89f30bc43c..68293e0b58 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,7 +17,9 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +import copy import gc +import math import os import time import types @@ -32,13 +34,18 @@ import torch._dynamo.cache_size import torch.distributed as dist import torch.nn as nn +import torchair from torch.distributed import ReduceOp +from torchair import patch_for_hcom from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_dp_group, get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -69,14 +76,21 @@ scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler -from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is +from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +from vllm_ascend.eplb.eplb_updator import EplbUpdator +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -132,6 +146,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config ascend_config = get_ascend_config() @@ -149,12 +164,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs - self.graph_block_tables = np.zeros( - (self.vllm_config.scheduler_config.max_num_seqs, - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) - # Model-related. self.num_attn_layers = self.model_config.get_num_layers_by_block_type( vllm_config.parallel_config, LayerBlockType.attention) @@ -208,8 +217,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up speculative decoding. self.use_spec_decode = False self.spec_attn_mask = None + self.actual_seq_q_lens = [] + self.spec_token_num = 0 + self.decode_token_per_req = 1 if self.speculative_config: self.use_spec_decode = True + self.spec_token_num = self.speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool), @@ -222,6 +236,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.device) # type: ignore elif self.speculative_config.method == 'deepseek_mtp': self.drafter = MtpProposer(self.vllm_config, self) + self.decode_token_per_req = 1 + self.spec_token_num + self.actual_seq_q_lens = [ + len for len in + range(self.decode_token_per_req, self.max_num_tokens + + 1, self.decode_token_per_req) + ] else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -243,6 +263,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.query_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -343,15 +369,19 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes + self.use_ring_mla = ascend_config.chunked_prefill_for_mla if ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() - if len(self.torchair_graph_batch_sizes) == 0: - # TODO(zzzzwwjj): check torchair_graph_batch_sizes init code - self.torchair_graph_batch_sizes = [ - self.scheduler_config.max_num_seqs - ] + self.check_torchair_graph_batch_sizes() + + # graph_block_tables shape: [num_request, cell(max_model_len / block_size)] + self.graph_block_tables = np.zeros( + (self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req, + (self.model_config.max_model_len + self.block_size - 1) // + self.block_size), + dtype=np.int32) torch._dynamo.cache_size.config.cache_size_limit += len( self.torchair_graph_batch_sizes) @@ -362,6 +392,21 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True + self.in_profile_run = False + + # kv role + self.is_kv_producer = False + if vllm_config.kv_transfer_config is not None: + self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer + + #EPLB + self.dynamic_eplb = ascend_config.dynamic_eplb + if self.dynamic_eplb == True: + self.eplb_adaptor = None + self.is_eplb_warmuped = False + self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -420,33 +465,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: generator.manual_seed(sampling_params.seed) else: generator = None - if vllm_version_is("0.9.1"): - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - else: - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - pooling_params=None, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -569,6 +600,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) + if spec_token_ids: + req_index = self.input_batch.num_reqs - 1 + start_index = len(req_state.prompt_token_ids) + len( + req_state.output_token_ids) + end_token_index = start_index + len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + self.input_batch.num_tokens[req_index] = end_token_index # Condense the batched states if there are empty indices. if removed_req_indices: @@ -578,16 +619,45 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: + self, num_tokens: int, with_prefill: bool, enable_dbo: bool + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + if self.dp_size == 1: + return num_tokens, None, with_prefill, enable_dbo + forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], + [num_tokens, with_prefill, not enable_dbo], device="cpu", dtype=torch.int32) dist.all_reduce(forward_metadata, op=ReduceOp.MAX, group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + num_tokens_across_dp = torch.tensor([forward_metadata[0]] * + self.dp_size, + device="cpu", + dtype=torch.int32) + return forward_metadata[0].item(), num_tokens_across_dp, bool( + forward_metadata[1]), not bool(forward_metadata[2]) + + def _check_dbo_is_valid(self, query_lens: torch.Tensor, + attn_state: AscendAttentionState, + num_tokens: int) -> bool: + # do the checks for dp + dbo + if attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + return False + # considering the case that one dp rank may enable dbo while others may not + if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: + return False + # TODO: remove it if token-level microbatch is enabled + [token_index, + seq_index] = compute_split_seq_index(query_lens, attn_state, + num_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + query_lens) or num_tokens < 256: + return False + return True def get_model(self) -> nn.Module: return self.model @@ -776,7 +846,8 @@ def _process_reqs( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor]: + torch.Tensor, int, torch.Tensor, Optional[set[str]], + Optional[set[str]]]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -844,6 +915,7 @@ def _process_reqs( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) + self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) positions = self.positions[:num_input_tokens] @@ -872,6 +944,9 @@ def _process_reqs( # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # SpecDecoding now supports seq_len=1 and seq_len=2 + attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): attn_state = AscendAttentionState.SpecDecoding @@ -881,11 +956,14 @@ def _process_reqs( else: attn_state = AscendAttentionState.PrefillCacheHit - attn_mask = self._make_attention_mask(seq_lens=seq_lens, - query_lens=num_scheduled_tokens, - position=positions, - attn_state=attn_state) - self.attn_mask = attn_mask + # NOTE: when use ring_mla, attn_mask don't need to generate here. + if not self.use_ring_mla or attn_state == AscendAttentionState.PrefillNoCache: + attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=positions, + attn_state=attn_state) + self.attn_mask = attn_mask self.attn_state = attn_state # type: ignore extra_builder_kwargs = {} @@ -896,36 +974,52 @@ def _process_reqs( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] + # Use host tensor, other wise error: tensor.hostData is null common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + query_start_loc=query_start_loc, + seq_lens=self.seq_lens_cpu[:num_reqs]) + self.seq_lens_list = self.seq_lens_np.tolist()[:num_input_tokens] with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + num_tokens_across_dp = None - if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) - extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens - extra_builder_kwargs['with_prefill_across_dp'] = with_prefill - - # Add graph_pad_size here + padded_num_tokens = total_num_scheduled_tokens if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) - graph_pad_size = padded_batch_size - total_num_scheduled_tokens + padded_num_tokens = self.select_torchair_padded_batch_size( + total_num_scheduled_tokens) + (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + padded_num_tokens, with_prefill, enable_dbo) + extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo + + # TODO(zzzzwwjj): this code need to refactor afterwards. + self.with_prefill = with_prefill + # Add num_token_pad_size and num_reqs_pad_size here for torchair graph mode + if self.torchair_graph_enabled and not with_prefill: + num_token_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens + num_reqs_pad_size = ( + padded_num_tokens_across_dp // self.decode_token_per_req - + num_reqs) + assert num_token_pad_size >= 0 and num_reqs_pad_size >= 0 - extra_builder_kwargs['graph_pad_size'] = graph_pad_size + extra_builder_kwargs['num_token_pad_size'] = num_token_pad_size + extra_builder_kwargs['num_reqs_pad_size'] = num_reqs_pad_size + self.num_reqs_pad_size = num_reqs_pad_size + self.extra_builder_kwargs = extra_builder_kwargs if self.vllm_config.model_config.use_mla: attn_metadata = self.attn_metadata_builder.build( # type: ignore @@ -941,6 +1035,7 @@ def _process_reqs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + common_attn_metadata=common_attn_metadata, common_prefix_len=None, **extra_builder_kwargs, ) @@ -956,10 +1051,7 @@ def _process_reqs( # Copy the tensors to the NPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - input_ids = self.input_ids[:num_input_tokens] - # prepare the MRoPE for mllm if using multimodal - num_input_tokens = total_num_scheduled_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if self.is_multimodal_model: @@ -973,51 +1065,56 @@ def _process_reqs( # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids[:total_num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds) + self.inputs_embeds[:total_num_scheduled_tokens].copy_( + inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. + # then the embedding layer is not included in the ACL Graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_batch_size] - positions = self.positions[:padded_batch_size] + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] # Run forward pass - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + # TODO(zzzzwwjj): check param `num_tokens_across_dp` later. + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=padded_num_tokens_across_dp, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill): with ProfileExecuteDuration().capture_async("forward"): + self.maybe_setup_kv_connector(scheduler_output) model_kwargs = {} if self.torchair_graph_enabled: model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata + if envs_ascend.VLLM_ASCEND_ENABLE_DBO and with_prefill: + model_kwargs["graph_enable"] = False # type: ignore if self.torchair_graph_enabled and not with_prefill: compiled_model = self._get_torchair_lazy_compiled_model( - padded_batch_size) + padded_num_tokens_across_dp) hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, - ) + **model_kwargs) else: assert self.model is not None hidden_states = self.model( @@ -1025,9 +1122,11 @@ def _process_reqs( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, - ) + **model_kwargs) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1052,7 +1151,8 @@ def _process_reqs( sample_indices = spec_decode_metadata.logits_indices return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices) + total_num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) def _calc_spec_decode_metadata( self, @@ -1219,16 +1319,27 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: + with ProfileExecuteDuration().capture_async( "prepare input and forward"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + if self.dynamic_eplb: + self.eplb_updator.forward_before() + return self.kv_connector_no_forward(scheduler_output) (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) = (self._process_reqs(scheduler_output, + intermediate_tensors)) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() with ProfileExecuteDuration().capture_async("post process"): logits = self.model.compute_logits(hidden_states[sample_indices], @@ -1319,25 +1430,19 @@ def execute_model( hidden_states, attn_metadata, ) - if vllm_version_is("0.9.1"): - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - ) - else: - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - pooler_output=[], - ) + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) durations = ProfileExecuteDuration().pop_captured_sync() if durations: @@ -1349,8 +1454,55 @@ def execute_model( logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + return model_runner_output + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # TODO(zzzzwwjj): Check whether `set_ascend_forward_context` has influence with kv_connector or not. + with set_ascend_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finsihed_sending, finished_recving = ( + self.get_finished_kv_transfer(scheduler_output)) + # For the case of no forward caused by receiving remote kv, + # one round of dummy inference is necessary + # to prevent hang over the collective calls. + if not finsihed_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finsihed_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfer( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def _profile_multimodal(self) -> None: # TODO: handle encoder-decoder models once we support them. # NOTE: Currently model is profiled with a single non-text @@ -1438,15 +1590,34 @@ def _profile_multimodal(self) -> None: def _dummy_run( self, num_tokens: int, - is_compile: bool = False, - with_prefill: bool = True, + skip_attn: bool = True, + with_prefill: bool = False, + is_torchair_compile: bool = False, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + is_profile_run: bool = False, ) -> torch.Tensor: + if self.torchair_graph_enabled and not with_prefill: + num_tokens = self.select_torchair_padded_batch_size(num_tokens) + + # For kv producer, with prefill always true + if self.is_kv_producer: + with_prefill = True + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + num_tokens, with_prefill, False) + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + num_reqs = math.ceil(num_tokens / self.decode_token_per_req) + if with_prefill: + num_reqs = min(num_tokens, max_num_reqs) + else: + num_reqs = (num_tokens + self.decode_token_per_req - + 1) // self.decode_token_per_req min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs @@ -1454,6 +1625,26 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + # NOTE: If torchair graph mode and not with_prefill, + # we can't skip_attn, it will cause graph recompile. + if self.torchair_graph_enabled and not with_prefill: + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_reqs, num_actual_tokens=1) + elif skip_attn: + attn_metadata = None + else: + attn_metadata = self.attn_metadata_builder.build_dummy_metadata( + num_actual_tokens=num_tokens, + num_reqs=num_reqs, + num_scheduled_tokens=num_scheduled_tokens, + attn_state=attn_state, + ) + + + if not is_torchair_compile and not is_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1483,14 +1674,17 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run): + model_kwargs = {} if self.torchair_graph_enabled and not with_prefill: - attn_metadata = self.attn_metadata_builder.build_dummy( - num_reqs=num_tokens, num_actual_tokens=1) # Only mark static while compiling - if is_compile: + if is_torchair_compile: torch._dynamo.mark_static(input_ids) torch._dynamo.mark_static(positions) torch._dynamo.mark_static( @@ -1505,21 +1699,43 @@ def _dummy_run( torch._dynamo.mark_static(kv[1]) compiled_model = self._get_torchair_lazy_compiled_model( num_tokens) + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + model_kwargs["graph_enable"] = True # type: ignore hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=None, - kv_caches=self.kv_caches, - attn_metadata=attn_metadata, + **model_kwargs, ) else: + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + model_kwargs["graph_enable"] = False # type: ignore hidden_states = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - return hidden_states + inputs_embeds=inputs_embeds, + **model_kwargs) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": + assert isinstance(self.drafter, MtpProposer) + self.drafter.dummy_run(num_reqs, with_prefill=with_prefill) + if is_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not is_torchair_compile and not is_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + return hidden_states + + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False def profile_run(self) -> None: # FIXME Profile with multimodal encoder & encoder cache. @@ -1547,7 +1763,10 @@ def profile_run(self) -> None: # TODO: call maybe_profile_with_lora() # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens) + with self.set_in_profile_run(): + hidden_states = self._dummy_run(self.max_num_tokens, + with_prefill=True, + is_profile_run=True) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] @@ -1560,6 +1779,20 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def do_get_expert_load(self) -> tuple: + return self.eplb_updator.get_expert_load() + + def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + return self.eplb_updator.update_expert_load_statistical_period(num_expert_load_gather, num_iterations) + + def eplb_warmup(self): + #EPLB + if self.dynamic_eplb and not self.is_eplb_warmuped: + self.is_eplb_warmuped = True + self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.warm_up_eplb() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -1578,9 +1811,9 @@ def load_model(self) -> None: m.consumed_memory / float(2**30)) def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.max_num_reqs: + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: raise ValueError( - f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}" + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" ) compiled_model = self.torchair_compiled_models.get( @@ -1590,9 +1823,6 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): if compiled_model: return compiled_model - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - patch_for_hcom() config = torchair.CompilerConfig() config.experimental_config.frozen_parameter = True @@ -1640,9 +1870,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - import torch_npu kv_caches: Dict[str, torch.Tensor] = {} + def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, @@ -1653,17 +1888,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[self.cache_config.block_size], ) - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "NPU.") - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + if not vllm_version_is("0.9.0"): + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[ + kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: - tensor_size = kv_cache_sizes[layer_name] + if vllm_version_is("0.9.0"): + tensor_size = kv_cache_config.tensors[layer_name].size + else: + tensor_size = kv_cache_sizes[layer_name] assert tensor_size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_size // kv_cache_spec.page_size_bytes @@ -1675,6 +1915,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks + alignment = 2 * 1024 * 1024 # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): @@ -1682,29 +1923,51 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.torchair_graph_enabled: - layer_kv_cache_nope = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - layer_kv_cache_pe = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, - ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - kv_caches[layer_name] = (layer_kv_cache_nope, - layer_kv_cache_pe) - torch_npu.npu_format_cast(kv_caches[layer_name][0], 2) - torch_npu.npu_format_cast(kv_caches[layer_name][1], 2) + if self.model_config.is_deepseek_mla: + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + nope_cache_shape = (num_blocks, block_size, + num_kv_heads, nope_dim) + rope_cache_shape = (num_blocks, block_size, + num_kv_heads, rope_dim) + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = align_memory( + nope_cache, alignment)[:nope_allocate_shape].view( + nope_cache_shape) + rope_cache = align_memory( + rope_cache, alignment)[:rope_allocate_shape].view( + rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) else: - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - torch_npu.npu_format_cast(kv_caches[layer_name], 2) + num_caches = kv_cache_shape[0] + kv_cache_list = [] + for i in range(num_caches): + cache_shape = kv_cache_shape[1:] + cache_size = math.prod(cache_shape) + cache_size_aligned = cache_size + alignment + kv_cache = torch.zeros(cache_size_aligned, + dtype=dtype, + device=self.device) + kv_cache = align_memory( + kv_cache, + alignment)[:cache_size].view(cache_shape) + kv_cache_list.append(kv_cache) + kv_caches[layer_name] = kv_cache_list + # torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: # TODO: add new branches when introducing more types of # KV cache specs. @@ -1715,6 +1978,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -1773,24 +2039,25 @@ def capture_model(self) -> None: reversed(torchair_graph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - is_compile=True, - with_prefill=False) - self._dummy_run(num_tokens, - is_compile=True, - with_prefill=False) + # NOTE: when in torchair graph and not with_prefill, + # we don't need to set `skip_attn=False` + self._dummy_run(num_tokens, is_torchair_compile=True) + self._dummy_run(num_tokens, is_torchair_compile=True) logger.info("Batchsize %d is compiled successfully: %d/%d.", num_tokens, idx + 1, graph_num) elif self.use_aclgraph: # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. + # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph + # TODO: Make sure passing attn_state to _dummy_run in the future for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) else: logger.info("Skipping NPU graph capture for eager mode.") return @@ -1885,6 +2152,7 @@ def _generate_mtp_token_ids( cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, + force_one_token=True, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] @@ -1916,8 +2184,43 @@ def init_torchair_graph_batch_sizes(self): start_graph_batch_size *= 2 def select_torchair_padded_batch_size(self, batch_size: int): - selected_batch_size = self.max_num_reqs for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size < selected_batch_size: - selected_batch_size = padded_batch_size - return selected_batch_size + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) + + def check_torchair_graph_batch_sizes(self): + # return graph_batch_sizes according to the number of tokens + # first pad according to the number of requests + if len(self.torchair_graph_batch_sizes) == 0: + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + else: + self.torchair_graph_batch_sizes = sorted( + self.torchair_graph_batch_sizes) + while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: + self.torchair_graph_batch_sizes.pop() + if len(self.torchair_graph_batch_sizes) == 0: + logger.warning( + "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" + ) + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: + self.torchair_graph_batch_sizes.append(self.max_num_reqs) + + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` + tp_size = self.parallel_config.tensor_parallel_size + if self.parallel_config.enable_expert_parallel: + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = (graph_batch_size + tp_size - + 1) // tp_size * tp_size + # `graph_batch_size` need to be divisible by `self.decode_token_per_req` + cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + self.torchair_graph_batch_sizes = new_graph_batch_sizes diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index ba8406fa0a..04a7d617b5 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -1,15 +1,19 @@ import torch +import vllm.envs as envs_vllm from vllm.attention.layer import Attention from vllm.config import (VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) from vllm.v1.sample.metadata import SamplingMetadata -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.utils import ProfileExecuteDuration # FIXME(woosuk): The logic here is duplicated with the main sampling code. @@ -61,13 +65,26 @@ def __init__( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size self.runner = runner + # persistent buffers for graph + self.input_ids = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int32, + device=self.runner.device) + self.positions = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int64, + device=self.runner.device) + self.hidden_states = torch.zeros( + (self.runner.max_num_tokens, self.runner.hidden_size), + dtype=self.runner.dtype, + device=self.runner.device) + self.is_mtp_torchair_ready = False @staticmethod def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + force_one_token: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: # cu_target_query_lens: [0, a, a + b, a + b + c] # num_rejected_tokens: [n1, n2, n3] @@ -76,32 +93,39 @@ def prepare_inputs( # token_indices: [0, 1, ..., a - n1 - 1, # a, a + 1, ..., a + b - n2 - 1, # a + b, a + b + 1, ..., a + b + c - n3 - 1] - # [0, a, a + b, a + b + c] -> [a, b, c] query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens + if force_one_token: + # enable force_one_token means we only focus on the last token position of each request + # token_indices: [batch_size] + cu_num_tokens = torch.arange(cu_target_query_lens.size(0), + device=cu_target_query_lens.device, + dtype=torch.int32) + relative_index = query_len_per_req - num_rejected_tokens - 1 + token_indices = cu_target_query_lens[:-1] + relative_index + else: + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - BLOCK_SIZE = 1024 - prepare_input_kernel( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) + BLOCK_SIZE = 1024 + prepare_input_kernel( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) return cu_num_tokens, token_indices def propose( @@ -126,13 +150,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 - input_ids = torch.empty_like(target_token_ids) # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - input_ids[:-1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - input_ids[last_token_indices] = next_token_ids + self.input_ids[last_token_indices] = next_token_ids query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() @@ -152,6 +175,23 @@ def propose( # input_batch=self.runner.input_batch, # scheduler_output=self.runner.scheduler_output, # ) + extra_builder_kwargs = self.runner.extra_builder_kwargs + + is_running_torchair = self.runner.torchair_graph_enabled and \ + not self.runner.with_prefill and self.is_mtp_torchair_ready + + if is_running_torchair: + if num_tokens == 1: + self.runner.attn_state = AscendAttentionState.DecodeOnly + num_reqs_pad_size = self.runner.num_reqs_pad_size + extra_builder_kwargs['num_reqs_pad_size'] = num_reqs_pad_size + # Assume num token per request is one + extra_builder_kwargs['num_token_pad_size'] = num_reqs_pad_size + num_input_tokens = self.runner.num_reqs_pad_size + else: + extra_builder_kwargs['num_token_pad_size'] = -1 + extra_builder_kwargs['num_reqs_pad_size'] = 0 + num_input_tokens = num_tokens attn_metadata = self.runner.attn_metadata_builder.build( num_reqs=batch_size, @@ -159,14 +199,52 @@ def propose( max_query_len=max_query_len, common_prefix_len=0, common_attn_metadata=common_attn_metadata, - ) - - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - input_ids=input_ids, - positions=target_positions, - previous_hidden_states=target_hidden_states, - ) + is_mtp_model=True, + **extra_builder_kwargs) + + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + # Assuming force_one_token is on, so each perfill request query_lens is 1 + if attn_metadata.prefill is not None: + attn_metadata.prefill.query_lens[:] = 1 + + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + if self.runner.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if is_running_torchair: + torch._dynamo.mark_static(self.input_ids) + torch._dynamo.mark_static(self.positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + for kv in self.runner.kv_caches: + assert isinstance(kv, + tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:]) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -202,6 +280,49 @@ def load_model(self) -> None: self.model)) process_weights_after_loading(self.model, draft_model_config, target_device) + if self.runner.torchair_graph_enabled and self.is_mtp_torchair_ready: + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + if not self.runner.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + else: + self.torchair_compiled_model = torchair.inference.cache_compile( + self.model.forward, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + ) -> None: + if self.runner.torchair_graph_enabled and not with_prefill: + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + else: + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + with set_ascend_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + self.model(input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + previous_hidden_states=self.hidden_states[:num_tokens], + attn_metadata=attn_metadata) # TODO Using torch instead of triton may result in poor performance diff --git a/vllm_ascend/worker/pooling_model_runner.py b/vllm_ascend/worker/pooling_model_runner.py index e1262fb0a2..5047a0f106 100644 --- a/vllm_ascend/worker/pooling_model_runner.py +++ b/vllm_ascend/worker/pooling_model_runner.py @@ -21,13 +21,13 @@ import torch from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.worker.model_runner import (ModelInputForNPU, ModelInputForNPUBuilder, NPUModelRunnerBase) @@ -142,8 +142,8 @@ def execute_model( if model_input.token_types is not None: cross_enc_kwargs["token_type_ids"] = model_input.token_types - with set_forward_context(model_input.attn_metadata, self.vllm_config, - virtual_engine): + with set_ascend_forward_context(model_input.attn_metadata, + self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index e78cc3f1cf..80f7c4a78d 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -49,9 +49,8 @@ from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator -from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import init_ascend_soc_version, try_register_lib from vllm_ascend.worker.model_runner import NPUModelRunner from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner @@ -218,6 +217,7 @@ def init_device(self) -> None: else: raise RuntimeError( f"Not support device type: {self.device_config.device}") + init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, @@ -545,11 +545,6 @@ def _init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - init_ascend_model_parallel( - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, - ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6fe84a4580..b062e5cf9e 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -40,9 +40,8 @@ from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator -from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import init_ascend_soc_version, try_register_lib from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -75,6 +74,9 @@ def __init__( is_driver_worker=is_driver_worker) # Try to import mindie_turbo to accelerate vLLM inference. + local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + world_size = self.vllm_config.parallel_config.world_size + self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank try_register_lib( "mindie_turbo", "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." @@ -125,6 +127,7 @@ def init_device(self): info = f"Not support device type: {self.device_config.device}" logger.error(info) raise RuntimeError(info) + init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed. @@ -192,6 +195,7 @@ def load_model(self) -> None: self.model_runner.load_model() def compile_or_warm_up_model(self) -> None: + self.model_runner.eplb_warmup() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ @@ -201,12 +205,18 @@ def compile_or_warm_up_model(self) -> None: for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) + def get_expert_load(self) -> tuple: + return self.model_runner.do_get_expert_load() + def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations) + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -245,22 +255,10 @@ def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - runner = self.model_runner - max_num_tokens = 1 - with_prefill = False - if runner.dp_size > 1: - max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) - if runner.torchair_graph_enabled and not with_prefill: - max_num_tokens = runner.select_torchair_padded_batch_size( - max_num_tokens) - runner._dummy_run(max_num_tokens, - is_compile=False, - with_prefill=with_prefill) + self.model_runner._dummy_run(1) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" - parallel_config = self.vllm_config.parallel_config set_custom_all_reduce( not self.parallel_config.disable_custom_all_reduce) init_distributed_environment(self.parallel_config.world_size, @@ -269,11 +267,6 @@ def _init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - init_ascend_model_parallel( - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, - ) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):