Skip to content

Commit be3ae3c

Browse files
committed
A new Type variant
1 parent 2c6c3b5 commit be3ae3c

File tree

10 files changed

+298
-55
lines changed

10 files changed

+298
-55
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 177 additions & 15 deletions
Large diffs are not rendered by default.

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
591591
}
592592
_ => {
593593
let known_instance = new_positive
594-
.into_instance()
594+
.into_nominal_instance()
595595
.and_then(|instance| instance.class().known(db));
596596

597597
if known_instance == Some(KnownClass::Object) {
@@ -705,7 +705,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
705705
let contains_bool = || {
706706
self.positive
707707
.iter()
708-
.filter_map(|ty| ty.into_instance())
708+
.filter_map(|ty| ty.into_nominal_instance())
709709
.filter_map(|instance| instance.class().known(db))
710710
.any(KnownClass::is_bool)
711711
};

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ impl<'db> ClassType<'db> {
216216
}
217217
}
218218

219+
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
220+
self.class_literal(db).0.is_protocol(db)
221+
}
222+
219223
pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
220224
&self.class(db).name
221225
}
@@ -1144,6 +1148,7 @@ impl<'db> ClassLiteralType<'db> {
11441148
Parameters::new([Parameter::positional_or_keyword(Name::new_static("other"))
11451149
// TODO: could be `Self`.
11461150
.with_annotated_type(Type::instance(
1151+
db,
11471152
self.apply_optional_specialization(db, specialization),
11481153
))]),
11491154
Some(KnownClass::Bool.to_instance(db)),
@@ -2152,7 +2157,7 @@ impl<'db> KnownClass {
21522157
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> {
21532158
self.to_class_literal(db)
21542159
.into_class_type()
2155-
.map(Type::instance)
2160+
.map(|class| Type::instance(db, class))
21562161
.unwrap_or_else(Type::unknown)
21572162
}
21582163

crates/red_knot_python_semantic/src/types/class_base.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ impl<'db> ClassBase<'db> {
106106
| Type::SubclassOf(_)
107107
| Type::TypeVar(_)
108108
| Type::BoundSuper(_)
109+
| Type::ProtocolInstance(_)
109110
| Type::AlwaysFalsy
110111
| Type::AlwaysTruthy => None,
111112
Type::KnownInstance(known_instance) => match known_instance {

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::types::{
1717
use crate::Db;
1818
use rustc_hash::FxHashMap;
1919

20+
use super::instance::Protocol;
2021
use super::CallableType;
2122

2223
impl<'db> Type<'db> {
@@ -78,9 +79,28 @@ impl Display for DisplayRepresentation<'_> {
7879
(_, Some(KnownClass::NoneType)) => f.write_str("None"),
7980
(_, Some(KnownClass::NoDefaultType)) => f.write_str("NoDefault"),
8081
(ClassType::NonGeneric(class), _) => f.write_str(&class.class(self.db).name),
81-
(ClassType::Generic(alias), _) => write!(f, "{}", alias.display(self.db)),
82+
(ClassType::Generic(alias), _) => alias.display(self.db).fmt(f),
8283
}
8384
}
85+
Type::ProtocolInstance(protocol) => match protocol.inner() {
86+
Protocol::FromClass(ClassType::NonGeneric(class)) => {
87+
f.write_str(&class.class(self.db).name)
88+
}
89+
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
90+
Protocol::Synthesized(synthetic) => {
91+
f.write_str("<Protocol with members ")?;
92+
let member_list = synthetic.members(self.db);
93+
let num_members = member_list.len();
94+
for (i, member) in member_list.iter().enumerate() {
95+
let is_last = i == num_members - 1;
96+
write!(f, "'{member}'")?;
97+
if !is_last {
98+
f.write_str(", ")?;
99+
}
100+
}
101+
f.write_char('>')
102+
}
103+
},
84104
Type::PropertyInstance(_) => f.write_str("property"),
85105
Type::ModuleLiteral(module) => {
86106
write!(f, "<module '{}'>", module.module(self.db).name())

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2544,6 +2544,7 @@ impl<'db> TypeInferenceBuilder<'db> {
25442544
Type::Dynamic(..) | Type::Never => true,
25452545

25462546
Type::NominalInstance(..)
2547+
| Type::ProtocolInstance(_)
25472548
| Type::BooleanLiteral(..)
25482549
| Type::IntLiteral(..)
25492550
| Type::StringLiteral(..)
@@ -5011,6 +5012,7 @@ impl<'db> TypeInferenceBuilder<'db> {
50115012
| Type::GenericAlias(_)
50125013
| Type::SubclassOf(_)
50135014
| Type::NominalInstance(_)
5015+
| Type::ProtocolInstance(_)
50145016
| Type::KnownInstance(_)
50155017
| Type::PropertyInstance(_)
50165018
| Type::Union(_)
@@ -5291,6 +5293,7 @@ impl<'db> TypeInferenceBuilder<'db> {
52915293
| Type::GenericAlias(_)
52925294
| Type::SubclassOf(_)
52935295
| Type::NominalInstance(_)
5296+
| Type::ProtocolInstance(_)
52945297
| Type::KnownInstance(_)
52955298
| Type::PropertyInstance(_)
52965299
| Type::Intersection(_)
@@ -5316,6 +5319,7 @@ impl<'db> TypeInferenceBuilder<'db> {
53165319
| Type::GenericAlias(_)
53175320
| Type::SubclassOf(_)
53185321
| Type::NominalInstance(_)
5322+
| Type::ProtocolInstance(_)
53195323
| Type::KnownInstance(_)
53205324
| Type::PropertyInstance(_)
53215325
| Type::Intersection(_)

crates/red_knot_python_semantic/src/types/instance.rs

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ use super::{ClassType, KnownClass, SubclassOfType, Type};
66
use crate::{Db, FxOrderSet};
77

88
impl<'db> Type<'db> {
9-
pub(crate) const fn instance(class: ClassType<'db>) -> Self {
10-
Self::NominalInstance(NominalInstanceType { class })
9+
pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self {
10+
if class.is_protocol(db) {
11+
Self::ProtocolInstance(ProtocolInstanceType(Protocol::FromClass(class)))
12+
} else {
13+
Self::NominalInstance(NominalInstanceType { class })
14+
}
1115
}
1216

13-
pub(crate) const fn into_instance(self) -> Option<NominalInstanceType<'db>> {
17+
pub(crate) const fn into_nominal_instance(self) -> Option<NominalInstanceType<'db>> {
1418
match self {
1519
Type::NominalInstance(instance_type) => Some(instance_type),
1620
_ => None,
@@ -95,30 +99,26 @@ impl<'db> From<NominalInstanceType<'db>> for Type<'db> {
9599
}
96100
}
97101

98-
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, salsa::Supertype)]
99-
pub enum ProtocolInstanceType<'db> {
100-
FromClass(ClassType<'db>),
101-
Synthesized(SynthesizedProtocolType<'db>),
102-
}
102+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, salsa::Update)]
103+
pub struct ProtocolInstanceType<'db>(
104+
// Keep the inner field here private,
105+
// so that the only way of constructing `ProtocolInstanceType` instances
106+
// is through the `Type::instance` constructor function.
107+
Protocol<'db>,
108+
);
103109

104-
#[salsa::tracked]
105110
impl<'db> ProtocolInstanceType<'db> {
106-
#[salsa::tracked(return_ref)]
107-
fn protocol_members(self, db: &'db dyn Db) -> FxOrderSet<Name> {
108-
match self {
109-
Self::FromClass(class) => class
110-
.class_literal(db)
111-
.0
112-
.into_protocol_class(db)
113-
.expect("Protocol class literal should be a protocol class")
114-
.protocol_members(db),
115-
Self::Synthesized(synthesized) => synthesized.members(db),
116-
}
111+
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db FxOrderSet<Name> {
112+
self.0.protocol_members(db)
113+
}
114+
115+
pub(super) fn inner(self) -> Protocol<'db> {
116+
self.0
117117
}
118118

119119
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
120-
match self {
121-
Self::FromClass(class) => SubclassOfType::from(db, class),
120+
match self.0 {
121+
Protocol::FromClass(class) => SubclassOfType::from(db, class),
122122

123123
// TODO: we can and should do better here.
124124
//
@@ -133,19 +133,34 @@ impl<'db> ProtocolInstanceType<'db> {
133133
// reveal_type(type(x)) # mypy: "type[def (builtins.int) -> builtins.str]"
134134
// reveal_type(type(x).__call__) # mypy: "def (*args: Any, **kwds: Any) -> Any"
135135
// ```
136-
Self::Synthesized(_) => KnownClass::Type.to_instance(db),
136+
Protocol::Synthesized(_) => KnownClass::Type.to_instance(db),
137137
}
138138
}
139139

140-
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
141-
match self {
142-
Self::FromClass(_) => {
143-
Self::Synthesized(SynthesizedProtocolType::new(db, self.protocol_members(db)))
144-
}
145-
Self::Synthesized(_) => self,
140+
pub(super) fn normalized(self, db: &'db dyn Db) -> Type<'db> {
141+
let members = self.protocol_members(db);
142+
let object = KnownClass::Object.to_instance(db);
143+
if members
144+
.iter()
145+
.all(|member| !object.member(db, member).symbol.is_unbound())
146+
{
147+
return object;
148+
}
149+
match self.0 {
150+
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
151+
SynthesizedProtocolType::new(db, self.protocol_members(db)),
152+
))),
153+
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
146154
}
147155
}
148156

157+
/// TODO: should iterate over the types of the members
158+
/// and check if any of them contain `Todo` types
159+
#[expect(clippy::unused_self)]
160+
pub(super) fn contains_todo(self) -> bool {
161+
false
162+
}
163+
149164
/// TODO: should not be considered fully static if any members do not have fully static types
150165
#[expect(clippy::unused_self)]
151166
pub(super) fn is_fully_static(self) -> bool {
@@ -181,7 +196,35 @@ impl<'db> ProtocolInstanceType<'db> {
181196
}
182197
}
183198

199+
/// Private inner enum to represent the two kinds of protocol types.
200+
/// This is not exposed publicly, so that the only way of constructing `Protocol` instances
201+
/// is through the [`Type::instance`] constructor function.
202+
#[derive(
203+
Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, salsa::Supertype, PartialOrd, Ord,
204+
)]
205+
pub(super) enum Protocol<'db> {
206+
FromClass(ClassType<'db>),
207+
Synthesized(SynthesizedProtocolType<'db>),
208+
}
209+
210+
#[salsa::tracked]
211+
impl<'db> Protocol<'db> {
212+
#[salsa::tracked(return_ref)]
213+
fn protocol_members(self, db: &'db dyn Db) -> FxOrderSet<Name> {
214+
match self {
215+
Self::FromClass(class) => class
216+
.class_literal(db)
217+
.0
218+
.into_protocol_class(db)
219+
.expect("Protocol class literal should be a protocol class")
220+
.protocol_members(db),
221+
Self::Synthesized(synthesized) => synthesized.members(db).clone(),
222+
}
223+
}
224+
}
225+
184226
#[salsa::interned(debug)]
185-
pub struct SynthesizedProtocolType<'db> {
186-
members: FxOrderSet<Name>,
227+
pub(super) struct SynthesizedProtocolType<'db> {
228+
#[return_ref]
229+
pub(super) members: FxOrderSet<Name>,
187230
}

crates/red_knot_python_semantic/src/types/narrow.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl KnownConstraintFunction {
159159
/// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type.
160160
fn generate_constraint<'db>(self, db: &'db dyn Db, classinfo: Type<'db>) -> Option<Type<'db>> {
161161
let constraint_fn = |class| match self {
162-
KnownConstraintFunction::IsInstance => Type::instance(class),
162+
KnownConstraintFunction::IsInstance => Type::instance(db, class),
163163
KnownConstraintFunction::IsSubclass => SubclassOfType::from(db, class),
164164
};
165165

@@ -684,7 +684,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
684684
let symbol = self.expect_expr_name_symbol(id);
685685
constraints.insert(
686686
symbol,
687-
Type::instance(rhs_class.unknown_specialization(self.db)),
687+
Type::instance(self.db, rhs_class.unknown_specialization(self.db)),
688688
);
689689
}
690690
}

crates/red_knot_python_semantic/src/types/subclass_of.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ impl<'db> SubclassOfType<'db> {
9494
}
9595
}
9696

97-
pub(crate) fn to_instance(self) -> Type<'db> {
97+
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> {
9898
match self.subclass_of {
99-
SubclassOfInner::Class(class) => Type::instance(class),
99+
SubclassOfInner::Class(class) => Type::instance(db, class),
100100
SubclassOfInner::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type),
101101
}
102102
}

crates/red_knot_python_semantic/src/types/type_ordering.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,21 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
126126

127127
(Type::SubclassOf(_), _) => Ordering::Less,
128128
(_, Type::SubclassOf(_)) => Ordering::Greater,
129+
129130
(Type::NominalInstance(left), Type::NominalInstance(right)) => {
130131
left.class().cmp(&right.class())
131132
}
132-
133133
(Type::NominalInstance(_), _) => Ordering::Less,
134134
(_, Type::NominalInstance(_)) => Ordering::Greater,
135135

136+
(Type::ProtocolInstance(left_proto), Type::ProtocolInstance(right_proto)) => {
137+
debug_assert_eq!(*left, left_proto.normalized(db));
138+
debug_assert_eq!(*right, right_proto.normalized(db));
139+
left_proto.cmp(right_proto)
140+
}
141+
(Type::ProtocolInstance(_), _) => Ordering::Less,
142+
(_, Type::ProtocolInstance(_)) => Ordering::Greater,
143+
136144
(Type::TypeVar(left), Type::TypeVar(right)) => left.cmp(right),
137145
(Type::TypeVar(_), _) => Ordering::Less,
138146
(_, Type::TypeVar(_)) => Ordering::Greater,

0 commit comments

Comments
 (0)