Skip to content

Commit 4cda2b4

Browse files
authored
Function call (#218)
* capacity to call functions * capacity to call functions * renaming to new-class and call-function * cleanup coverage * more tests and fix coverage * add typing tests * rename call-function -> call
1 parent 3811283 commit 4cda2b4

20 files changed

+370
-99
lines changed

pydantic_core/_types.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class LiteralSchema(TypedDict):
109109
ref: NotRequired[str]
110110

111111

112-
class ModelClassSchema(TypedDict):
113-
type: Literal['model-class']
112+
class NewClassSchema(TypedDict):
113+
type: Literal['new-class']
114114
class_type: type
115115
schema: Schema
116116
strict: NotRequired[bool]
@@ -279,13 +279,13 @@ class CallableSchema(TypedDict):
279279
type: Literal['callable']
280280

281281

282-
class Parameter(TypedDict):
283-
name: str
284-
mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only']
285-
schema: Schema
286-
default: NotRequired[Any]
287-
default_factory: NotRequired[Callable[[], Any]]
288-
alias: NotRequired[Union[str, List[Union[str, int]], List[List[Union[str, int]]]]]
282+
class Parameter(TypedDict, total=False):
283+
name: Required[str]
284+
mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] # default positional_or_keyword
285+
schema: Required[Schema]
286+
default: Any
287+
default_factory: Callable[[], Any]
288+
alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
289289

290290

291291
class ArgumentsSchema(TypedDict, total=False):
@@ -297,6 +297,14 @@ class ArgumentsSchema(TypedDict, total=False):
297297
ref: str
298298

299299

300+
class CallSchema(TypedDict):
301+
type: Literal['call']
302+
function: Callable[..., Any]
303+
arguments_schema: Schema
304+
return_schema: NotRequired[Schema]
305+
ref: NotRequired[str]
306+
307+
300308
# pydantic allows types to be defined via a simple string instead of dict with just `type`, e.g.
301309
# 'int' is equivalent to {'type': 'int'}, this only applies to schema types which do not have other required fields
302310
BareType = Literal[
@@ -335,7 +343,7 @@ class ArgumentsSchema(TypedDict, total=False):
335343
ListSchema,
336344
LiteralSchema,
337345
TypedDictSchema,
338-
ModelClassSchema,
346+
NewClassSchema,
339347
NoneSchema,
340348
NullableSchema,
341349
RecursiveReferenceSchema,
@@ -353,4 +361,5 @@ class ArgumentsSchema(TypedDict, total=False):
353361
IsInstanceSchema,
354362
CallableSchema,
355363
ArgumentsSchema,
364+
CallSchema,
356365
]

src/errors/line_error.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,6 @@ impl<'a> ValLineError<'a> {
114114
self.kind = kind;
115115
self
116116
}
117-
118-
pub fn into_new<'b>(self, py: Python) -> ValLineError<'b> {
119-
ValLineError {
120-
kind: self.kind,
121-
location: self.location,
122-
input_value: self.input_value.to_object(py).into(),
123-
}
124-
}
125117
}
126118

127119
#[cfg_attr(debug_assertions, derive(Debug))]

src/errors/validation_exception.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ impl ValidationError {
4141
}
4242

4343
// used to convert a validation error back to ValError for wrap functions
44-
impl<'a> From<ValidationError> for ValError<'a> {
45-
fn from(val_error: ValidationError) -> Self {
46-
val_error
47-
.line_errors
44+
impl<'a> IntoPy<ValError<'a>> for ValidationError {
45+
fn into_py(self, py: Python) -> ValError<'a> {
46+
self.line_errors
4847
.into_iter()
49-
.map(|e| e.into())
48+
.map(|e| e.into_py(py))
5049
.collect::<Vec<_>>()
5150
.into()
5251
}
@@ -130,12 +129,12 @@ impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
130129
}
131130

132131
/// opposite of above, used to extract line errors from a validation error for wrap functions
133-
impl<'a> From<PyLineError> for ValLineError<'a> {
134-
fn from(py_line_error: PyLineError) -> Self {
135-
Self {
136-
kind: py_line_error.kind,
137-
location: py_line_error.location,
138-
input_value: py_line_error.input_value.into(),
132+
impl<'a> IntoPy<ValLineError<'a>> for PyLineError {
133+
fn into_py(self, _py: Python) -> ValLineError<'a> {
134+
ValLineError {
135+
kind: self.kind,
136+
location: self.location,
137+
input_value: self.input_value.into(),
139138
}
140139
}
141140
}

src/input/return_enums.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ impl<'a> IntoPy<PyObject> for EitherString<'a> {
246246
}
247247

248248
pub fn py_string_str(py_str: &PyString) -> ValResult<&str> {
249-
py_str.to_str().map_err(|_| {
250-
ValError::new_custom_input(ErrorKind::StrUnicode, InputValue::PyObject(py_str.into_py(py_str.py())))
251-
})
249+
py_str
250+
.to_str()
251+
.map_err(|_| ValError::new_custom_input(ErrorKind::StrUnicode, InputValue::PyAny(py_str as &PyAny)))
252252
}
253253

254254
#[cfg_attr(debug_assertions, derive(Debug))]

src/validators/arguments.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ impl BuildValidator for ArgumentsValidator {
5353
let arg: &PyDict = arg.cast_as()?;
5454

5555
let name: String = arg.get_as_req(intern!(py, "name"))?;
56-
let mode: &str = arg.get_as_req(intern!(py, "mode"))?;
56+
let mode = arg
57+
.get_as::<&str>(intern!(py, "mode"))?
58+
.unwrap_or("positional_or_keyword");
5759
let positional = mode == "positional_only" || mode == "positional_or_keyword";
5860
if positional {
5961
positional_params_count = arg_index + 1;

src/validators/call.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use pyo3::exceptions::PyTypeError;
2+
use pyo3::intern;
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyTuple};
5+
6+
use crate::build_tools::SchemaDict;
7+
use crate::errors::ValResult;
8+
use crate::input::Input;
9+
use crate::recursion_guard::RecursionGuard;
10+
11+
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
12+
13+
#[derive(Debug, Clone)]
14+
pub struct CallValidator {
15+
function: PyObject,
16+
arguments_validator: Box<CombinedValidator>,
17+
return_validator: Option<Box<CombinedValidator>>,
18+
name: String,
19+
}
20+
21+
impl BuildValidator for CallValidator {
22+
const EXPECTED_TYPE: &'static str = "call";
23+
24+
fn build(
25+
schema: &PyDict,
26+
config: Option<&PyDict>,
27+
build_context: &mut BuildContext,
28+
) -> PyResult<CombinedValidator> {
29+
let py = schema.py();
30+
31+
let arguments_schema: &PyAny = schema.get_as_req(intern!(py, "arguments_schema"))?;
32+
let arguments_validator = Box::new(build_validator(arguments_schema, config, build_context)?);
33+
34+
let return_schema = schema.get_item(intern!(py, "return_schema"));
35+
let return_validator = match return_schema {
36+
Some(return_schema) => Some(Box::new(build_validator(return_schema, config, build_context)?)),
37+
None => None,
38+
};
39+
let function: &PyAny = schema.get_as_req(intern!(py, "function"))?;
40+
let function_name: &str = function.getattr(intern!(py, "__name__"))?.extract()?;
41+
let name = format!("{}[{}]", Self::EXPECTED_TYPE, function_name);
42+
43+
Ok(Self {
44+
function: function.to_object(py),
45+
arguments_validator,
46+
return_validator,
47+
name,
48+
}
49+
.into())
50+
}
51+
}
52+
53+
impl Validator for CallValidator {
54+
fn validate<'s, 'data>(
55+
&'s self,
56+
py: Python<'data>,
57+
input: &'data impl Input<'data>,
58+
extra: &Extra,
59+
slots: &'data [CombinedValidator],
60+
recursion_guard: &'s mut RecursionGuard,
61+
) -> ValResult<'data, PyObject> {
62+
let args = self
63+
.arguments_validator
64+
.validate(py, input, extra, slots, recursion_guard)
65+
.map_err(|e| e.with_outer_location("arguments".into()))?;
66+
67+
let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) {
68+
self.function.call(py, args, Some(kwargs))?
69+
} else if let Ok(kwargs) = args.cast_as::<PyDict>(py) {
70+
self.function.call(py, (), Some(kwargs))?
71+
} else {
72+
let msg = "Arguments validator should return a tuple of (args, kwargs) or a dict of kwargs";
73+
return Err(PyTypeError::new_err(msg).into());
74+
};
75+
76+
if let Some(return_validator) = &self.return_validator {
77+
return_validator
78+
.validate(py, return_value.into_ref(py), extra, slots, recursion_guard)
79+
.map_err(|e| e.with_outer_location("return-value".into()))
80+
} else {
81+
Ok(return_value.to_object(py))
82+
}
83+
}
84+
85+
fn get_name(&self) -> &str {
86+
&self.name
87+
}
88+
}

src/validators/function.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,9 @@ impl Validator for FunctionBeforeValidator {
8787
.func
8888
.call(py, (input.to_object(py),), kwargs)
8989
.map_err(|e| convert_err(py, e, input))?;
90-
// maybe there's some way to get the PyAny here and explicitly tell rust it should have lifespan 'a?
91-
let new_input: &PyAny = value.as_ref(py);
92-
match self.validator.validate(py, new_input, extra, slots, recursion_guard) {
93-
Ok(v) => Ok(v),
94-
Err(ValError::InternalErr(err)) => Err(ValError::InternalErr(err)),
95-
Err(ValError::LineErrors(line_errors)) => {
96-
// we have to be explicit about clone line errors to a new lifetime since new_input doesn't have
97-
// the 'data lifetime
98-
Err(ValError::LineErrors(
99-
line_errors
100-
.into_iter()
101-
.map(|line_error| line_error.into_new(py))
102-
.collect(),
103-
))
104-
}
105-
}
90+
91+
self.validator
92+
.validate(py, value.into_ref(py), extra, slots, recursion_guard)
10693
}
10794

10895
fn get_name(&self) -> &str {
@@ -308,7 +295,7 @@ fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> Val
308295
if let Ok(pydantic_value_error) = err.value(py).extract::<PydanticValueError>() {
309296
pydantic_value_error.into_val_error(input)
310297
} else if let Ok(validation_error) = err.value(py).extract::<ValidationError>() {
311-
validation_error.into()
298+
validation_error.into_py(py)
312299
} else {
313300
py_err_string!(err.value(py), ValueError, input)
314301
}

src/validators/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod any;
1717
mod arguments;
1818
mod bool;
1919
mod bytes;
20+
mod call;
2021
mod callable;
2122
mod date;
2223
mod datetime;
@@ -28,7 +29,7 @@ mod int;
2829
mod is_instance;
2930
mod list;
3031
mod literal;
31-
mod model_class;
32+
mod new_class;
3233
mod none;
3334
mod nullable;
3435
mod recursive;
@@ -314,7 +315,7 @@ pub fn build_validator<'a>(
314315
// nullables
315316
nullable::NullableValidator,
316317
// model classes
317-
model_class::ModelClassValidator,
318+
new_class::NewClassValidator,
318319
// strings
319320
string::StrValidator,
320321
// integers
@@ -335,6 +336,8 @@ pub fn build_validator<'a>(
335336
none::NoneValidator,
336337
// functions - before, after, plain & wrap
337338
function::FunctionBuilder,
339+
// function call - validation around a function call
340+
call::CallValidator,
338341
// recursive (self-referencing) models
339342
recursive::RecursiveRefValidator,
340343
// literals
@@ -408,7 +411,7 @@ pub enum CombinedValidator {
408411
// nullables
409412
Nullable(nullable::NullableValidator),
410413
// model classes
411-
ModelClass(model_class::ModelClassValidator),
414+
ModelClass(new_class::NewClassValidator),
412415
// strings
413416
Str(string::StrValidator),
414417
StrConstrained(string::StrConstrainedValidator),
@@ -436,6 +439,8 @@ pub enum CombinedValidator {
436439
FunctionAfter(function::FunctionAfterValidator),
437440
FunctionPlain(function::FunctionPlainValidator),
438441
FunctionWrap(function::FunctionWrapValidator),
442+
// function call - validation around a function call
443+
FunctionCall(call::CallValidator),
439444
// recursive (self-referencing) models
440445
Recursive(recursive::RecursiveContainerValidator),
441446
RecursiveRef(recursive::RecursiveRefValidator),

src/validators/model_class.rs renamed to src/validators/new_class.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::recursion_guard::RecursionGuard;
1616
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1717

1818
#[derive(Debug, Clone)]
19-
pub struct ModelClassValidator {
19+
pub struct NewClassValidator {
2020
strict: bool,
2121
revalidate: bool,
2222
validator: Box<CombinedValidator>,
@@ -25,8 +25,8 @@ pub struct ModelClassValidator {
2525
expect_fields_set: bool,
2626
}
2727

28-
impl BuildValidator for ModelClassValidator {
29-
const EXPECTED_TYPE: &'static str = "model-class";
28+
impl BuildValidator for NewClassValidator {
29+
const EXPECTED_TYPE: &'static str = "new-class";
3030

3131
fn build(
3232
schema: &PyDict,
@@ -59,7 +59,7 @@ impl BuildValidator for ModelClassValidator {
5959
}
6060
}
6161

62-
impl Validator for ModelClassValidator {
62+
impl Validator for NewClassValidator {
6363
fn validate<'s, 'data>(
6464
&'s self,
6565
py: Python<'data>,
@@ -106,7 +106,7 @@ impl Validator for ModelClassValidator {
106106
}
107107
}
108108

109-
impl ModelClassValidator {
109+
impl NewClassValidator {
110110
fn create_class(&self, py: Python, model_dict: &PyAny, fields_set: Option<&PyAny>) -> PyResult<PyObject> {
111111
// based on the following but with the second argument of new_func set to an empty tuple as required
112112
// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77

tests/benchmarks/complete_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def wrap_function(input_value, *, validator, **kwargs):
1010
return f'Input {validator(input_value)} Changed'
1111

1212
return {
13-
'type': 'model-class',
13+
'type': 'new-class',
1414
'class_type': MyModel,
1515
'config': {'strict': strict},
1616
'schema': {

0 commit comments

Comments
 (0)