Skip to content

Commit c78cc1c

Browse files
authored
Arguments (#190)
* starting Arguments * arguments validator WIP * working on validate_args_pair * arguments validator compiling * starting tests * WIP JSON args mapping * fix argument mapping for json * fix tests, add comment about build_args * tweaking argument errors * missing tuple args * fix tuple tests * using new errors * mapping positional arguments * fix too_long and too_short errors * linting * converting kwarg errors * improve coverage * improve coverage * add always_validate_kwargs * allow *args with zero arguments * fix extra args with args mapping * error on repeat arguments * validator decorator and more coverage * skip positional only args tests for <3.10 * fix for python <3.10 * separate build_validate_generic_mapping * Arguments, take 2 (#203) * WIP moving to lookup item for args * Arguments take 2 * adding var_args and var_kwargs handling * new arguments logic working * fixing tests * Non-default argument follows default arguments * test aliases * test kwargs * fix unexpected_positional_argument * improve coverage * PyArgs use tuple, not list * coverage and tweaks
1 parent 4faedfe commit c78cc1c

27 files changed

+1662
-126
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ node_modules/
2929
package-lock.json
3030
/pytest-speed/
3131
/src/self_schema.py
32+
/worktree/

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ format:
6363

6464
.PHONY: lint-python
6565
lint-python:
66-
flake8 --max-complexity 10 --max-line-length 120 --ignore E203,W503 pydantic_core tests
66+
flake8 --max-line-length 120 pydantic_core tests
6767
$(isort) --check-only --df
6868
$(black) --check --diff
6969

generate_self_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_schema(obj):
4848
return 'any'
4949

5050
origin = get_origin(obj)
51-
assert origin is not None, f'origin cannot be None, obj={obj}'
51+
assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py'
5252
if origin is Union:
5353
return union_schema(obj)
5454
elif obj is Callable or origin is Callable:

pydantic_core/_types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ class Config(TypedDict, total=False):
3434
# settings related to typed_dicts only
3535
typed_dict_extra_behavior: Literal['allow', 'forbid', 'ignore']
3636
typed_dict_full: bool # default: True
37-
typed_dict_populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
3837
# used on typed-dicts and tagged union keys
3938
from_attributes: bool
4039
revalidate_models: bool
40+
# used on typed-dicts and arguments
41+
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
4142
# fields related to string fields only
4243
str_max_length: int
4344
str_min_length: int
@@ -277,6 +278,24 @@ class CallableSchema(TypedDict):
277278
type: Literal['callable']
278279

279280

281+
class ArgumentInfo(TypedDict):
282+
name: str
283+
mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only']
284+
schema: Schema
285+
default: NotRequired[Any]
286+
default_factory: NotRequired[Callable[[], Any]]
287+
alias: NotRequired[Union[str, List[Union[str, int]], List[List[Union[str, int]]]]]
288+
289+
290+
class ArgumentsSchema(TypedDict, total=False):
291+
type: Required[Literal['arguments']]
292+
arguments_schema: Required[List[ArgumentInfo]]
293+
populate_by_name: bool
294+
var_args_schema: Schema
295+
var_kwargs_schema: Schema
296+
ref: str
297+
298+
280299
# pydantic allows types to be defined via a simple string instead of dict with just `type`, e.g.
281300
# 'int' is equivalent to {'type': 'int'}, this only applies to schema types which do not have other required fields
282301
BareType = Literal[
@@ -332,4 +351,5 @@ class CallableSchema(TypedDict):
332351
TimedeltaSchema,
333352
IsInstanceSchema,
334353
CallableSchema,
354+
ArgumentsSchema,
335355
]

src/build_tools.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,20 @@ where
8181
}
8282
}
8383

84+
pub fn schema_or_config_same<'py, T>(
85+
schema: &'py PyDict,
86+
config: Option<&'py PyDict>,
87+
key: &PyString,
88+
) -> PyResult<Option<T>>
89+
where
90+
T: FromPyObject<'py>,
91+
{
92+
schema_or_config(schema, config, key, key)
93+
}
94+
8495
pub fn is_strict(schema: &PyDict, config: Option<&PyDict>) -> PyResult<bool> {
8596
let py = schema.py();
86-
let k = intern!(py, "strict");
87-
Ok(schema_or_config(schema, config, k, k)?.unwrap_or(false))
97+
Ok(schema_or_config_same(schema, config, intern!(py, "strict"))?.unwrap_or(false))
8898
}
8999

90100
// we could perhaps do clever things here to store each schema error, or have different types for the top

src/errors/kinds.rs

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub enum ErrorKind {
3131
Missing,
3232
#[strum(message = "Extra values are not permitted")]
3333
ExtraForbidden,
34-
#[strum(message = "Model keys must be strings")]
34+
#[strum(message = "Keys must be strings")]
3535
InvalidKey,
3636
#[strum(message = "Error extracting attribute: {error}")]
3737
GetAttributeError {
@@ -71,13 +71,19 @@ pub enum ErrorKind {
7171
},
7272
// ---------------------
7373
// generic length errors - used for everything with a length except strings and bytes which need custom messages
74-
#[strum(message = "Input must have at least {min_length} items")]
74+
#[strum(
75+
message = "Input must have at least {min_length} item{expected_plural}, got {input_length} item{input_plural}"
76+
)]
7577
TooShort {
7678
min_length: usize,
79+
input_length: usize,
7780
},
78-
#[strum(message = "Input must have at most {max_length} items")]
81+
#[strum(
82+
message = "Input must have at most {max_length} item{expected_plural}, got {input_length} item{input_plural}"
83+
)]
7984
TooLong {
8085
max_length: usize,
86+
input_length: usize,
8187
},
8288
// ---------------------
8389
// string errors
@@ -297,6 +303,20 @@ pub enum ErrorKind {
297303
UnionTagNotFound {
298304
discriminator: String,
299305
},
306+
// ---------------------
307+
// argument errors
308+
#[strum(message = "Arguments must be a tuple of (positional arguments, keyword arguments) or a plain dict")]
309+
ArgumentsType,
310+
#[strum(message = "Unexpected keyword argument")]
311+
UnexpectedKeywordArgument,
312+
#[strum(message = "Missing required keyword argument")]
313+
MissingKeywordArgument,
314+
#[strum(message = "Unexpected positional argument")]
315+
UnexpectedPositionalArgument,
316+
#[strum(message = "Missing required positional argument")]
317+
MissingPositionalArgument,
318+
#[strum(message = "Got multiple values for argument")]
319+
MultipleArgumentValues,
300320
}
301321

302322
macro_rules! render {
@@ -331,6 +351,14 @@ macro_rules! py_dict {
331351
}};
332352
}
333353

354+
fn plural_s(value: &usize) -> &'static str {
355+
if *value == 1 {
356+
""
357+
} else {
358+
"s"
359+
}
360+
}
361+
334362
impl ErrorKind {
335363
pub fn kind(&self) -> String {
336364
match self {
@@ -348,8 +376,22 @@ impl ErrorKind {
348376
Self::GreaterThanEqual { ge } => render!(self, ge),
349377
Self::LessThan { lt } => render!(self, lt),
350378
Self::LessThanEqual { le } => render!(self, le),
351-
Self::TooShort { min_length } => to_string_render!(self, min_length),
352-
Self::TooLong { max_length } => to_string_render!(self, max_length),
379+
Self::TooShort {
380+
min_length,
381+
input_length,
382+
} => {
383+
let expected_plural = plural_s(min_length);
384+
let input_plural = plural_s(input_length);
385+
to_string_render!(self, min_length, input_length, expected_plural, input_plural)
386+
}
387+
Self::TooLong {
388+
max_length,
389+
input_length,
390+
} => {
391+
let expected_plural = plural_s(max_length);
392+
let input_plural = plural_s(input_length);
393+
to_string_render!(self, max_length, input_length, expected_plural, input_plural)
394+
}
353395
Self::StrTooShort { min_length } => to_string_render!(self, min_length),
354396
Self::StrTooLong { max_length } => to_string_render!(self, max_length),
355397
Self::StrPatternMismatch { pattern } => render!(self, pattern),
@@ -398,8 +440,14 @@ impl ErrorKind {
398440
Self::GreaterThanEqual { ge } => py_dict!(py, ge),
399441
Self::LessThan { lt } => py_dict!(py, lt),
400442
Self::LessThanEqual { le } => py_dict!(py, le),
401-
Self::TooShort { min_length } => py_dict!(py, min_length),
402-
Self::TooLong { max_length } => py_dict!(py, max_length),
443+
Self::TooShort {
444+
min_length,
445+
input_length,
446+
} => py_dict!(py, min_length, input_length),
447+
Self::TooLong {
448+
max_length,
449+
input_length,
450+
} => py_dict!(py, max_length, input_length),
403451
Self::StrTooShort { min_length } => py_dict!(py, min_length),
404452
Self::StrTooLong { max_length } => py_dict!(py, max_length),
405453
Self::StrPatternMismatch { pattern } => py_dict!(py, pattern),

src/input/input_abstract.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::input::datetime::EitherTime;
88

99
use super::datetime::{EitherDate, EitherDateTime, EitherTimedelta};
1010
use super::return_enums::{EitherBytes, EitherString};
11-
use super::{GenericListLike, GenericMapping};
11+
use super::{GenericArguments, GenericListLike, GenericMapping};
1212

1313
/// all types have three methods: `validate_*`, `strict_*`, `lax_*`
1414
/// the convention is to either implement:
@@ -42,6 +42,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4242
false
4343
}
4444

45+
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>>;
46+
4547
fn validate_str(&'a self, strict: bool) -> ValResult<EitherString<'a>> {
4648
if strict {
4749
self.strict_str()

src/input/input_json.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use super::datetime::{
55
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
66
};
77
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
8-
use super::{EitherBytes, EitherString, EitherTimedelta, GenericListLike, GenericMapping, Input, JsonInput};
8+
use super::{
9+
EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericListLike, GenericMapping, Input, JsonArgs,
10+
JsonInput,
11+
};
912

1013
impl<'a> Input<'a> for JsonInput {
1114
/// This is required by since JSON object keys are always strings, I don't think it can be called
@@ -26,6 +29,30 @@ impl<'a> Input<'a> for JsonInput {
2629
matches!(self, JsonInput::Null)
2730
}
2831

32+
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
33+
match self {
34+
JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()),
35+
JsonInput::Array(array) => {
36+
if array.len() != 2 {
37+
Err(ValError::new(ErrorKind::ArgumentsType, self))
38+
} else {
39+
let args = match unsafe { array.get_unchecked(0) } {
40+
JsonInput::Null => None,
41+
JsonInput::Array(args) => Some(args.as_slice()),
42+
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
43+
};
44+
let kwargs = match unsafe { array.get_unchecked(1) } {
45+
JsonInput::Null => None,
46+
JsonInput::Object(kwargs) => Some(kwargs),
47+
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
48+
};
49+
Ok(JsonArgs::new(args, kwargs).into())
50+
}
51+
}
52+
_ => Err(ValError::new(ErrorKind::ArgumentsType, self)),
53+
}
54+
}
55+
2956
fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
3057
match self {
3158
JsonInput::String(s) => Ok(s.as_str().into()),
@@ -245,6 +272,11 @@ impl<'a> Input<'a> for String {
245272
false
246273
}
247274

275+
#[cfg_attr(has_no_coverage, no_coverage)]
276+
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
277+
Err(ValError::new(ErrorKind::ArgumentsType, self))
278+
}
279+
248280
fn validate_str(&'a self, _strict: bool) -> ValResult<EitherString<'a>> {
249281
Ok(self.as_str().into())
250282
}

src/input/input_python.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ use super::datetime::{
1616
EitherTime,
1717
};
1818
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
19-
use super::{repr_string, EitherBytes, EitherString, EitherTimedelta, GenericListLike, GenericMapping, Input};
19+
use super::{
20+
repr_string, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericListLike, GenericMapping, Input,
21+
PyArgs,
22+
};
2023

2124
impl<'a> Input<'a> for PyAny {
2225
fn as_loc_item(&self) -> LocItem {
@@ -60,6 +63,32 @@ impl<'a> Input<'a> for PyAny {
6063
self.is_callable()
6164
}
6265

66+
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
67+
if let Ok(kwargs) = self.cast_as::<PyDict>() {
68+
Ok(PyArgs::new(None, Some(kwargs)).into())
69+
} else if let Ok((args, kwargs)) = self.extract::<(&PyAny, &PyAny)>() {
70+
let args = if let Ok(tuple) = args.cast_as::<PyTuple>() {
71+
Some(tuple)
72+
} else if args.is_none() {
73+
None
74+
} else if let Ok(list) = args.cast_as::<PyList>() {
75+
Some(PyTuple::new(self.py(), list.iter().collect::<Vec<_>>()))
76+
} else {
77+
return Err(ValError::new(ErrorKind::ArgumentsType, self));
78+
};
79+
let kwargs = if let Ok(dict) = kwargs.cast_as::<PyDict>() {
80+
Some(dict)
81+
} else if kwargs.is_none() {
82+
None
83+
} else {
84+
return Err(ValError::new(ErrorKind::ArgumentsType, self));
85+
};
86+
Ok(PyArgs::new(args, kwargs).into())
87+
} else {
88+
Err(ValError::new(ErrorKind::ArgumentsType, self))
89+
}
90+
}
91+
6392
fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
6493
if let Ok(py_str) = self.cast_as::<PyString>() {
6594
Ok(py_str.into())

src/input/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ mod shared;
1111
pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1212
pub use input_abstract::Input;
1313
pub use parse_json::{JsonInput, JsonObject};
14-
pub use return_enums::{EitherBytes, EitherString, GenericListLike, GenericMapping};
14+
pub use return_enums::{
15+
EitherBytes, EitherString, GenericArguments, GenericListLike, GenericMapping, JsonArgs, PyArgs,
16+
};
1517

1618
pub fn repr_string(v: &PyAny) -> PyResult<String> {
1719
v.repr()?.extract()

0 commit comments

Comments
 (0)