Skip to content

Commit 8d67718

Browse files
authored
Improve jwcrypto (#13715)
1 parent 9234985 commit 8d67718

File tree

5 files changed

+220
-94
lines changed

5 files changed

+220
-94
lines changed

stubs/jwcrypto/jwcrypto/common.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from _typeshed import Incomplete
2-
from collections.abc import Iterator, MutableMapping
1+
from collections.abc import Callable, Iterator, MutableMapping
32
from typing import Any, NamedTuple
43

4+
from jwcrypto.jwe import JWE
5+
from jwcrypto.jws import JWS
6+
57
def base64url_encode(payload: str | bytes) -> str: ...
68
def base64url_decode(payload: str) -> bytes: ...
79
def json_encode(string: str | bytes) -> str: ...
@@ -36,11 +38,11 @@ class JWSEHeaderParameter(NamedTuple):
3638
description: str
3739
mustprotect: bool
3840
supported: bool
39-
check_fn: Incomplete | None
41+
check_fn: Callable[[JWS | JWE], bool] | None
4042

4143
class JWSEHeaderRegistry(MutableMapping[str, JWSEHeaderParameter]):
42-
def __init__(self, init_registry: Incomplete | None = None) -> None: ...
43-
def check_header(self, h: str, value) -> bool: ...
44+
def __init__(self, init_registry: dict[str, JWSEHeaderParameter] | None = None) -> None: ...
45+
def check_header(self, h: str, value: JWS | JWE) -> bool: ...
4446
def __getitem__(self, key: str) -> JWSEHeaderParameter: ...
4547
def __iter__(self) -> Iterator[str]: ...
4648
def __delitem__(self, key: str) -> None: ...

stubs/jwcrypto/jwcrypto/jwe.pyi

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from _typeshed import Incomplete
22
from collections.abc import Mapping, Sequence
3+
from typing import Any
4+
from typing_extensions import Self
35

46
from jwcrypto import common
5-
from jwcrypto.common import JWException, JWSEHeaderParameter
7+
from jwcrypto.common import JWException, JWSEHeaderParameter, JWSEHeaderRegistry
68
from jwcrypto.jwk import JWK, JWKSet
79

810
default_max_compressed_size: int
@@ -18,34 +20,34 @@ InvalidJWEKeyType = common.InvalidJWEKeyType
1820
InvalidJWEOperation = common.InvalidJWEOperation
1921

2022
class JWE:
21-
objects: Incomplete
22-
plaintext: Incomplete
23-
header_registry: Incomplete
23+
objects: dict[str, Any]
24+
plaintext: bytes | None
25+
header_registry: JWSEHeaderRegistry
2426
cek: Incomplete
25-
decryptlog: Incomplete
27+
decryptlog: list[str] | None
2628
def __init__(
2729
self,
28-
plaintext: bytes | None = None,
30+
plaintext: str | bytes | None = None,
2931
protected: str | None = None,
3032
unprotected: str | None = None,
3133
aad: bytes | None = None,
32-
algs: Incomplete | None = None,
34+
algs: list[str] | None = None,
3335
recipient: str | None = None,
34-
header: Incomplete | None = None,
35-
header_registry: Incomplete | None = None,
36+
header: str | None = None,
37+
header_registry: Mapping[str, JWSEHeaderParameter] | None = None,
3638
) -> None: ...
3739
@property
38-
def allowed_algs(self): ...
40+
def allowed_algs(self) -> list[str]: ...
3941
@allowed_algs.setter
40-
def allowed_algs(self, algs) -> None: ...
41-
def add_recipient(self, key, header: Incomplete | None = None) -> None: ...
42-
def serialize(self, compact: bool = False): ...
42+
def allowed_algs(self, algs: list[str]) -> None: ...
43+
def add_recipient(self, key: JWK, header: dict[str, Any] | str | None = None) -> None: ...
44+
def serialize(self, compact: bool = False) -> str: ...
4345
def decrypt(self, key: JWK | JWKSet) -> None: ...
4446
def deserialize(self, raw_jwe: str | bytes, key: JWK | JWKSet | None = None) -> None: ...
4547
@property
46-
def payload(self): ...
48+
def payload(self) -> bytes: ...
4749
@property
4850
def jose_header(self) -> dict[Incomplete, Incomplete]: ...
4951
@classmethod
50-
def from_jose_token(cls, token: str | bytes) -> JWE: ...
52+
def from_jose_token(cls, token: str | bytes) -> Self: ...
5153
def __eq__(self, other: object) -> bool: ...

stubs/jwcrypto/jwcrypto/jwk.pyi

Lines changed: 163 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from _typeshed import Incomplete
2-
from collections.abc import Sequence
1+
from collections.abc import Callable, Sequence
32
from enum import Enum
4-
from typing import Any, NamedTuple
3+
from typing import Any, Literal, NamedTuple, TypeVar, overload
4+
from typing_extensions import Self, deprecated
55

6+
from cryptography.hazmat.primitives import hashes
7+
from cryptography.hazmat.primitives.asymmetric import ec, rsa
68
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey as Ed448PrivateKey, Ed448PublicKey as Ed448PublicKey
79
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
810
Ed25519PrivateKey as Ed25519PrivateKey,
@@ -15,6 +17,8 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import (
1517
)
1618
from jwcrypto.common import JWException
1719

20+
_T = TypeVar("_T")
21+
1822
class UnimplementedOKPCurveKey:
1923
@classmethod
2024
def generate(cls) -> None: ...
@@ -24,9 +28,25 @@ class UnimplementedOKPCurveKey:
2428
def from_private_bytes(cls, *args) -> None: ...
2529

2630
ImplementedOkpCurves: Sequence[str]
27-
priv_bytes: Incomplete
31+
priv_bytes: Callable[[bytes], X25519PrivateKey] | None
32+
33+
class _Ed25519_CURVE(NamedTuple):
34+
pubkey: UnimplementedOKPCurveKey
35+
privkey: UnimplementedOKPCurveKey
36+
37+
class _Ed448_CURVE(NamedTuple):
38+
pubkey: UnimplementedOKPCurveKey
39+
privkey: UnimplementedOKPCurveKey
2840

29-
JWKTypesRegistry: Incomplete
41+
class _X25519_CURVE(NamedTuple):
42+
pubkey: UnimplementedOKPCurveKey
43+
privkey: UnimplementedOKPCurveKey
44+
45+
class _X448_CURVE(NamedTuple):
46+
pubkey: UnimplementedOKPCurveKey
47+
privkey: UnimplementedOKPCurveKey
48+
49+
JWKTypesRegistry: dict[str, str]
3050

3151
class ParmType(Enum):
3252
name = "A string with a name" # pyright: ignore[reportAssignmentType]
@@ -35,81 +55,170 @@ class ParmType(Enum):
3555
unsupported = "Unsupported Parameter"
3656

3757
class JWKParameter(NamedTuple):
38-
description: Incomplete
39-
public: Incomplete
40-
required: Incomplete
41-
type: Incomplete
42-
43-
JWKValuesRegistry: Incomplete
44-
JWKParamsRegistry: Incomplete
45-
JWKEllipticCurveRegistry: Incomplete
46-
JWKUseRegistry: Incomplete
47-
JWKOperationsRegistry: Incomplete
48-
JWKpycaCurveMap: Incomplete
49-
IANANamedInformationHashAlgorithmRegistry: Incomplete
58+
description: str
59+
public: bool
60+
required: bool | None
61+
type: ParmType | None
62+
63+
JWKValuesRegistry: dict[str, dict[str, JWKParameter]]
64+
JWKParamsRegistry: dict[str, JWKParameter]
65+
JWKEllipticCurveRegistry: dict[str, str]
66+
JWKUseRegistry: dict[str, str]
67+
JWKOperationsRegistry: dict[str, str]
68+
JWKpycaCurveMap: dict[str, str]
69+
IANANamedInformationHashAlgorithmRegistry: dict[
70+
str,
71+
hashes.SHA256
72+
| hashes.SHA384
73+
| hashes.SHA512
74+
| hashes.SHA3_224
75+
| hashes.SHA3_256
76+
| hashes.SHA3_384
77+
| hashes.SHA3_512
78+
| hashes.BLAKE2s
79+
| hashes.BLAKE2b
80+
| None,
81+
]
5082

5183
class InvalidJWKType(JWException):
52-
value: Incomplete
53-
def __init__(self, value: Incomplete | None = None) -> None: ...
84+
value: str | None
85+
def __init__(self, value: str | None = None) -> None: ...
5486

5587
class InvalidJWKUsage(JWException):
56-
value: Incomplete
57-
use: Incomplete
58-
def __init__(self, use, value) -> None: ...
88+
value: str
89+
use: str
90+
def __init__(self, use: str, value: str) -> None: ...
5991

6092
class InvalidJWKOperation(JWException):
61-
op: Incomplete
62-
values: Incomplete
63-
def __init__(self, operation, values) -> None: ...
93+
op: str
94+
values: Sequence[str]
95+
def __init__(self, operation: str, values: Sequence[str]) -> None: ...
6496

6597
class InvalidJWKValue(JWException): ...
6698

6799
class JWK(dict[str, Any]):
68100
def __init__(self, **kwargs) -> None: ...
69101
@classmethod
70-
def generate(cls, **kwargs): ...
102+
def generate(cls, **kwargs) -> Self: ...
71103
def generate_key(self, **params) -> None: ...
72104
def import_key(self, **kwargs) -> None: ...
73105
@classmethod
74-
def from_json(cls, key): ...
75-
def export(self, private_key: bool = True, as_dict: bool = False): ...
76-
def export_public(self, as_dict: bool = False): ...
77-
def export_private(self, as_dict: bool = False): ...
78-
def export_symmetric(self, as_dict: bool = False): ...
79-
def public(self): ...
106+
def from_json(cls, key) -> Self: ...
107+
@overload
108+
def export(self, private_key: bool = True, as_dict: Literal[False] = False) -> str: ...
109+
@overload
110+
def export(self, private_key: bool, as_dict: Literal[True]) -> dict[str, Any]: ...
111+
@overload
112+
def export(self, *, as_dict: Literal[True]) -> dict[str, Any]: ...
113+
@overload
114+
def export_public(self, as_dict: Literal[False] = False) -> str: ...
115+
@overload
116+
def export_public(self, as_dict: Literal[True]) -> dict[str, Any]: ...
117+
@overload
118+
def export_public(self, as_dict: bool = False) -> str | dict[str, Any]: ...
119+
@overload
120+
def export_private(self, as_dict: Literal[False] = False) -> str: ...
121+
@overload
122+
def export_private(self, as_dict: Literal[True]) -> dict[str, Any]: ...
123+
@overload
124+
def export_private(self, as_dict: bool = False) -> str | dict[str, Any]: ...
125+
@overload
126+
def export_symmetric(self, as_dict: Literal[False] = False) -> str: ...
127+
@overload
128+
def export_symmetric(self, as_dict: Literal[True]) -> dict[str, Any]: ...
129+
@overload
130+
def export_symmetric(self, as_dict: bool = False) -> str | dict[str, Any]: ...
131+
def public(self) -> Self: ...
80132
@property
81133
def has_public(self) -> bool: ...
82134
@property
83135
def has_private(self) -> bool: ...
84136
@property
85137
def is_symmetric(self) -> bool: ...
86138
@property
87-
def key_type(self): ...
139+
@deprecated("")
140+
def key_type(self) -> str | None: ...
88141
@property
89-
def key_id(self): ...
142+
@deprecated("")
143+
def key_id(self) -> str | None: ...
90144
@property
91-
def key_curve(self): ...
92-
def get_curve(self, arg): ...
93-
def get_op_key(self, operation: Incomplete | None = None, arg: Incomplete | None = None): ...
94-
def import_from_pyca(self, key) -> None: ...
95-
def import_from_pem(self, data, password: Incomplete | None = None, kid: Incomplete | None = None) -> None: ...
96-
def export_to_pem(self, private_key: bool = False, password: bool = False): ...
145+
@deprecated("")
146+
def key_curve(self) -> str | None: ...
147+
@deprecated("")
148+
def get_curve(
149+
self, arg: str
150+
) -> (
151+
ec.SECP256R1
152+
| ec.SECP384R1
153+
| ec.SECP521R1
154+
| ec.SECP256K1
155+
| ec.BrainpoolP256R1
156+
| ec.BrainpoolP384R1
157+
| ec.BrainpoolP512R1
158+
| _Ed25519_CURVE
159+
| _Ed448_CURVE
160+
| _X25519_CURVE
161+
| _X448_CURVE
162+
): ...
163+
def get_op_key(
164+
self, operation: str | None = None, arg: str | None = None
165+
) -> str | rsa.RSAPrivateKey | rsa.RSAPublicKey | ec.EllipticCurvePrivateKey | ec.EllipticCurvePublicKey | None: ...
166+
def import_from_pyca(
167+
self,
168+
key: (
169+
rsa.RSAPrivateKey
170+
| rsa.RSAPublicKey
171+
| ec.EllipticCurvePrivateKey
172+
| ec.EllipticCurvePublicKey
173+
| Ed25519PrivateKey
174+
| Ed448PrivateKey
175+
| X25519PrivateKey
176+
| Ed25519PublicKey
177+
| Ed448PublicKey
178+
| X25519PublicKey
179+
),
180+
) -> None: ...
181+
def import_from_pem(self, data: bytes, password: bytes | None = None, kid: str | None = None) -> None: ...
182+
def export_to_pem(self, private_key: bool = False, password: bool = False) -> bytes: ...
97183
@classmethod
98-
def from_pyca(cls, key): ...
184+
def from_pyca(
185+
cls,
186+
key: (
187+
rsa.RSAPrivateKey
188+
| rsa.RSAPublicKey
189+
| ec.EllipticCurvePrivateKey
190+
| ec.EllipticCurvePublicKey
191+
| Ed25519PrivateKey
192+
| Ed448PrivateKey
193+
| X25519PrivateKey
194+
| Ed25519PublicKey
195+
| Ed448PublicKey
196+
| X25519PublicKey
197+
),
198+
) -> Self: ...
99199
@classmethod
100-
def from_pem(cls, data, password: Incomplete | None = None): ...
101-
def thumbprint(self, hashalg=...): ...
102-
def thumbprint_uri(self, hname: str = "sha-256"): ...
200+
def from_pem(cls, data: bytes, password: bytes | None = None) -> Self: ...
201+
def thumbprint(self, hashalg: hashes.HashAlgorithm = ...) -> str: ...
202+
def thumbprint_uri(self, hname: str = "sha-256") -> str: ...
103203
@classmethod
104-
def from_password(cls, password): ...
105-
def setdefault(self, key: str, default: Incomplete | None = None): ...
204+
def from_password(cls, password: str) -> Self: ...
205+
def setdefault(self, key: str, default: _T | None = None) -> _T: ...
106206

107-
class JWKSet(dict[str, Any]):
108-
def add(self, elem) -> None: ...
109-
def export(self, private_keys: bool = True, as_dict: bool = False): ...
110-
def import_keyset(self, keyset) -> None: ...
207+
class JWKSet(dict[Literal["keys"], set[JWK]]):
208+
@overload
209+
def __setitem__(self, key: Literal["keys"], val: JWK) -> None: ...
210+
@overload
211+
def __setitem__(self, key: str, val: Any) -> None: ...
212+
def add(self, elem: JWK) -> None: ...
213+
@overload
214+
def export(self, private_keys: bool = True, as_dict: Literal[False] = False) -> str: ...
215+
@overload
216+
def export(self, private_keys: bool, as_dict: Literal[True]) -> dict[str, Any]: ...
217+
@overload
218+
def export(self, *, as_dict: Literal[True]) -> dict[str, Any]: ...
219+
def import_keyset(self, keyset: str | bytes) -> None: ...
111220
@classmethod
112-
def from_json(cls, keyset): ...
113-
def get_key(self, kid): ...
114-
def get_keys(self, kid): ...
115-
def setdefault(self, key: str, default: Incomplete | None = None): ...
221+
def from_json(cls, keyset: str | bytes) -> Self: ...
222+
def get_key(self, kid: str) -> JWK | None: ...
223+
def get_keys(self, kid: str) -> set[JWK]: ...
224+
def setdefault(self, key: str, default: _T | None = None) -> _T: ...

0 commit comments

Comments
 (0)