|
4 | 4 | import difflib
|
5 | 5 | import json
|
6 | 6 | import pathlib
|
| 7 | +import re |
7 | 8 | from json import dumps, loads
|
8 | 9 | from typing import TYPE_CHECKING, cast
|
9 | 10 | from urllib.parse import urlparse
|
10 | 11 |
|
11 | 12 | import yaml
|
12 | 13 | from django.urls import Resolver404, resolve
|
13 | 14 | from django.utils.functional import cached_property
|
14 |
| -from openapi_spec_validator import openapi_v2_spec_validator, openapi_v3_spec_validator |
| 15 | +from openapi_spec_validator import openapi_v2_spec_validator, openapi_v30_spec_validator, openapi_v31_spec_validator |
15 | 16 | from prance.util.resolver import RefResolver
|
16 | 17 | from rest_framework.schemas.generators import BaseSchemaGenerator, EndpointEnumerator
|
17 | 18 | from rest_framework.settings import api_settings
|
@@ -98,7 +99,14 @@ def normalize_schema_paths(self, schema: dict) -> dict[str, dict]:
|
98 | 99 | @staticmethod
|
99 | 100 | def validate_schema(schema: dict):
|
100 | 101 | if "openapi" in schema:
|
101 |
| - validator = openapi_v3_spec_validator |
| 102 | + openapi_version_pattern = re.compile(r"^(\d)\.(\d+)") |
| 103 | + result = openapi_version_pattern.findall(schema["openapi"]) |
| 104 | + if result: |
| 105 | + major, minor = result[0] |
| 106 | + if (major, minor) == ("3", "0"): |
| 107 | + validator = openapi_v30_spec_validator |
| 108 | + elif (major, minor) == ("3", "1"): |
| 109 | + validator = openapi_v31_spec_validator |
102 | 110 | else:
|
103 | 111 | validator = openapi_v2_spec_validator
|
104 | 112 | validator.validate(schema)
|
|
0 commit comments