Skip to content

[FOLLOWUP] Use base test to avoid patch everwhere #1634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@

Run `pytest tests/ops/test_fused_moe.py`.
"""
# fused moe ops test will hit the infer_schema error, we need add the patch
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa

import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down
8 changes: 2 additions & 6 deletions tests/ut/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@

from vllm_ascend.utils import adapt_patch

# fused moe ops test will hit the infer_schema error, we need add the patch
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa


class TestBase(unittest.TestCase):

def setUp(self):
def __init__(self, *args, **kwargs):
# adapt patch by default.
adapt_patch(True)
adapt_patch()
super().setUp()
super(TestBase, self).__init__(*args, **kwargs)
4 changes: 2 additions & 2 deletions tests/ut/distributed/kv_transfer/test_simple_buffer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
import zlib
from unittest.mock import MagicMock

import torch

from tests.ut.base import TestBase
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
int32_hash)

Expand All @@ -17,7 +17,7 @@ def __init__(self):
self.deallocate_buffer = MagicMock()


class TestSimpleBuffer(unittest.TestCase):
class TestSimpleBuffer(TestBase):

def setUp(self):
self.pipe = MockSimplePipe()
Expand Down
4 changes: 2 additions & 2 deletions tests/ut/distributed/kv_transfer/test_simple_connector.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import unittest
from unittest.mock import MagicMock, patch

import torch
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from tests.ut.base import TestBase
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_connector import \
SimpleConnector
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe


class TestSimpleConnector(unittest.TestCase):
class TestSimpleConnector(TestBase):

def setUp(self):
self.mock_pipe = MagicMock(spec=SimplePipe)
Expand Down
4 changes: 2 additions & 2 deletions tests/ut/distributed/kv_transfer/test_simple_pipe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import unittest
from unittest.mock import MagicMock, patch

import torch

from tests.ut.base import TestBase
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe


class TestSimplePipe(unittest.TestCase):
class TestSimplePipe(TestBase):

@classmethod
def _create_mock_config(self):
Expand Down
11 changes: 5 additions & 6 deletions tests/ut/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
import unittest
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -12,7 +11,7 @@
yarn_get_mscale)


class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
class TestCustomRotaryEmbeddingEnabled(TestBase):

def setUp(self):
# Common setup for tests
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_custom_rotary_embedding_enabled(self):
self.assertFalse(result)


class TestRopeForwardOot(unittest.TestCase):
class TestRopeForwardOot(TestBase):

def setUp(self):
# Common setup for tests
Expand Down Expand Up @@ -262,7 +261,7 @@ def test_native_rope_deepseek_forward_non_neox_style(
assert k_pe.shape == key.shape


class TestRotateHalf(unittest.TestCase):
class TestRotateHalf(TestBase):

def test_rotate_half_even_dim(self):
# Test with even dimension
Expand All @@ -272,7 +271,7 @@ def test_rotate_half_even_dim(self):
self.assertTrue(torch.allclose(result, expected))


class TestYarnFindCorrectionDim(unittest.TestCase):
class TestYarnFindCorrectionDim(TestBase):

def test_basic_case(self):
# Test with standard values
Expand All @@ -293,7 +292,7 @@ def test_basic_case(self):
self.assertTrue(torch.allclose(result, expected))


class TestYarnGetMscale(unittest.TestCase):
class TestYarnGetMscale(TestBase):

def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
Expand Down
3 changes: 1 addition & 2 deletions tests/ut/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import math
import os
import unittest
from threading import Lock
from unittest import mock

Expand Down Expand Up @@ -302,7 +301,7 @@ def test_torchair_cache_dir(self):
"Delete kv cache bytes cache dir failed")


class TestProfileExecuteDuration(unittest.TestCase):
class TestProfileExecuteDuration(TestBase):

def setUp(self):
utils.ProfileExecuteDuration._instance = None
Expand Down
5 changes: 2 additions & 3 deletions tests/ut/worker/test_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import unittest

import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import MultiGroupBlockTable

from tests.ut.base import TestBase
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch


Expand All @@ -24,7 +23,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
)


class TestInputBatch(unittest.TestCase):
class TestInputBatch(TestBase):

def setUp(self):
self.max_num_reqs = 10
Expand Down
3 changes: 2 additions & 1 deletion tests/ut/worker/test_pooling_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import unittest

Check failure on line 1 in tests/ut/worker/test_pooling_model_runner.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Ruff (F401)

tests/ut/worker/test_pooling_model_runner.py:1:8: F401 `unittest` imported but unused
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -7,11 +7,12 @@
from vllm.pooling_params import PoolingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata

from tests.ut.base import TestBase
from vllm_ascend.worker.pooling_model_runner import (
ModelInputForNPUWithPoolingMetadata, NPUPoolingModelRunner)


class TestPoolingModelRunner(unittest.TestCase):
class TestPoolingModelRunner(TestBase):
"""Unit tests for the NPUPoolingModelRunner class."""

def _create_model_runner(self, model: str, *args,
Expand Down
Loading