Skip to content

Commit 2ad20c3

Browse files
authored
Merge pull request #523 from madsmtm/restrict-this-declare-class
Restrict `this` parameters to `Self`-like types in `declare_class!`
2 parents 53cbfc6 + 8c6b293 commit 2ad20c3

File tree

14 files changed

+1309
-1461
lines changed

14 files changed

+1309
-1461
lines changed

crates/objc2/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
8181
* Allow cloning `Id<AnyObject>`.
8282
* **BREAKING**: Restrict message sending to `&mut` references to things that
8383
implement `IsAllowedMutable`.
84+
* Disallow the ability to use non-`Self`-like types as the receiver in
85+
`declare_class!`.
8486

8587
### Removed
8688
* **BREAKING**: Removed `ProtocolType` implementation for `NSObject`.

crates/objc2/src/__macro_helpers/declare_class.rs

Lines changed: 247 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
1+
#[cfg(all(debug_assertions, feature = "verify"))]
2+
use alloc::vec::Vec;
3+
use core::marker::PhantomData;
4+
#[cfg(all(debug_assertions, feature = "verify"))]
5+
use std::collections::HashSet;
6+
7+
#[cfg(all(debug_assertions, feature = "verify"))]
8+
use crate::runtime::{AnyProtocol, MethodDescription};
9+
110
use objc2_encode::Encoding;
211

12+
use crate::declare::{ClassBuilder, IvarType};
313
use crate::encode::Encode;
414
use crate::rc::{Allocated, Id};
15+
use crate::runtime::{AnyClass, MethodImplementation, Sel};
516
use crate::runtime::{AnyObject, MessageReceiver};
6-
use crate::{ClassType, Message};
17+
use crate::{ClassType, Message, ProtocolType};
718

819
use super::{CopyOrMutCopy, Init, MaybeUnwrap, New, Other};
920
use crate::mutability;
@@ -52,7 +63,7 @@ where
5263
// restrict it here to only be when the selector is `init`.
5364
//
5465
// Additionally, the receiver and return type must have the same generic
55-
// generic parameter `T`.
66+
// parameter `T`.
5667
impl<Ret, T> MessageRecieveId<Allocated<T>, Ret> for Init
5768
where
5869
T: Message,
@@ -190,3 +201,237 @@ where
190201
{
191202
// Noop
192203
}
204+
205+
#[derive(Debug)]
206+
pub struct ClassBuilderHelper<T: ?Sized> {
207+
builder: ClassBuilder,
208+
p: PhantomData<T>,
209+
}
210+
211+
#[track_caller]
212+
fn failed_declaring_class(name: &str) -> ! {
213+
panic!("could not create new class {name}. Perhaps a class with that name already exists?")
214+
}
215+
216+
impl<T: ?Sized + ClassType> ClassBuilderHelper<T> {
217+
#[inline]
218+
#[track_caller]
219+
#[allow(clippy::new_without_default)]
220+
pub fn new() -> Self
221+
where
222+
T::Super: ClassType,
223+
{
224+
let builder = match ClassBuilder::new(T::NAME, <T::Super as ClassType>::class()) {
225+
Some(builder) => builder,
226+
None => failed_declaring_class(T::NAME),
227+
};
228+
229+
Self {
230+
builder,
231+
p: PhantomData,
232+
}
233+
}
234+
235+
#[inline]
236+
pub fn add_protocol_methods<P>(&mut self) -> ClassProtocolMethodsBuilder<'_, T>
237+
where
238+
P: ?Sized + ProtocolType,
239+
{
240+
let protocol = P::protocol();
241+
242+
if let Some(protocol) = protocol {
243+
self.builder.add_protocol(protocol);
244+
}
245+
246+
#[cfg(all(debug_assertions, feature = "verify"))]
247+
{
248+
ClassProtocolMethodsBuilder {
249+
builder: self,
250+
protocol,
251+
required_instance_methods: protocol
252+
.map(|p| p.method_descriptions(true))
253+
.unwrap_or_default(),
254+
optional_instance_methods: protocol
255+
.map(|p| p.method_descriptions(false))
256+
.unwrap_or_default(),
257+
registered_instance_methods: HashSet::new(),
258+
required_class_methods: protocol
259+
.map(|p| p.class_method_descriptions(true))
260+
.unwrap_or_default(),
261+
optional_class_methods: protocol
262+
.map(|p| p.class_method_descriptions(false))
263+
.unwrap_or_default(),
264+
registered_class_methods: HashSet::new(),
265+
}
266+
}
267+
268+
#[cfg(not(all(debug_assertions, feature = "verify")))]
269+
{
270+
ClassProtocolMethodsBuilder { builder: self }
271+
}
272+
}
273+
274+
// Addition: This restricts to callee `T`
275+
#[inline]
276+
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
277+
where
278+
F: MethodImplementation<Callee = T>,
279+
{
280+
// SAFETY: Checked by caller
281+
unsafe { self.builder.add_method(sel, func) }
282+
}
283+
284+
#[inline]
285+
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
286+
where
287+
F: MethodImplementation<Callee = AnyClass>,
288+
{
289+
// SAFETY: Checked by caller
290+
unsafe { self.builder.add_class_method(sel, func) }
291+
}
292+
293+
#[inline]
294+
pub fn add_static_ivar<I: IvarType>(&mut self) {
295+
self.builder.add_static_ivar::<I>()
296+
}
297+
298+
#[inline]
299+
pub fn register(self) -> &'static AnyClass {
300+
self.builder.register()
301+
}
302+
}
303+
304+
/// Helper for ensuring that:
305+
/// - Only methods on the protocol are overriden.
306+
/// - TODO: The methods have the correct signature.
307+
/// - All required methods are overridden.
308+
#[derive(Debug)]
309+
pub struct ClassProtocolMethodsBuilder<'a, T: ?Sized> {
310+
builder: &'a mut ClassBuilderHelper<T>,
311+
#[cfg(all(debug_assertions, feature = "verify"))]
312+
protocol: Option<&'static AnyProtocol>,
313+
#[cfg(all(debug_assertions, feature = "verify"))]
314+
required_instance_methods: Vec<MethodDescription>,
315+
#[cfg(all(debug_assertions, feature = "verify"))]
316+
optional_instance_methods: Vec<MethodDescription>,
317+
#[cfg(all(debug_assertions, feature = "verify"))]
318+
registered_instance_methods: HashSet<Sel>,
319+
#[cfg(all(debug_assertions, feature = "verify"))]
320+
required_class_methods: Vec<MethodDescription>,
321+
#[cfg(all(debug_assertions, feature = "verify"))]
322+
optional_class_methods: Vec<MethodDescription>,
323+
#[cfg(all(debug_assertions, feature = "verify"))]
324+
registered_class_methods: HashSet<Sel>,
325+
}
326+
327+
impl<T: ?Sized + ClassType> ClassProtocolMethodsBuilder<'_, T> {
328+
// Addition: This restricts to callee `T`
329+
#[inline]
330+
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
331+
where
332+
F: MethodImplementation<Callee = T>,
333+
{
334+
#[cfg(all(debug_assertions, feature = "verify"))]
335+
if let Some(protocol) = self.protocol {
336+
let _types = self
337+
.required_instance_methods
338+
.iter()
339+
.chain(&self.optional_instance_methods)
340+
.find(|desc| desc.sel == sel)
341+
.map(|desc| desc.types)
342+
.unwrap_or_else(|| {
343+
panic!(
344+
"failed overriding protocol method -[{protocol} {sel}]: method not found"
345+
)
346+
});
347+
}
348+
349+
// SAFETY: Checked by caller
350+
unsafe { self.builder.add_method(sel, func) };
351+
352+
#[cfg(all(debug_assertions, feature = "verify"))]
353+
if !self.registered_instance_methods.insert(sel) {
354+
unreachable!("already added")
355+
}
356+
}
357+
358+
#[inline]
359+
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
360+
where
361+
F: MethodImplementation<Callee = AnyClass>,
362+
{
363+
#[cfg(all(debug_assertions, feature = "verify"))]
364+
if let Some(protocol) = self.protocol {
365+
let _types = self
366+
.required_class_methods
367+
.iter()
368+
.chain(&self.optional_class_methods)
369+
.find(|desc| desc.sel == sel)
370+
.map(|desc| desc.types)
371+
.unwrap_or_else(|| {
372+
panic!(
373+
"failed overriding protocol method +[{protocol} {sel}]: method not found"
374+
)
375+
});
376+
}
377+
378+
// SAFETY: Checked by caller
379+
unsafe { self.builder.add_class_method(sel, func) };
380+
381+
#[cfg(all(debug_assertions, feature = "verify"))]
382+
if !self.registered_class_methods.insert(sel) {
383+
unreachable!("already added")
384+
}
385+
}
386+
387+
#[cfg(all(debug_assertions, feature = "verify"))]
388+
pub fn finish(self) {
389+
let superclass = self.builder.builder.superclass();
390+
391+
if let Some(protocol) = self.protocol {
392+
for desc in &self.required_instance_methods {
393+
if self.registered_instance_methods.contains(&desc.sel) {
394+
continue;
395+
}
396+
397+
// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
398+
if superclass
399+
.and_then(|superclass| superclass.instance_method(desc.sel))
400+
.is_some()
401+
{
402+
continue;
403+
}
404+
405+
panic!(
406+
"must implement required protocol method -[{protocol} {}]",
407+
desc.sel
408+
)
409+
}
410+
}
411+
412+
if let Some(protocol) = self.protocol {
413+
for desc in &self.required_class_methods {
414+
if self.registered_class_methods.contains(&desc.sel) {
415+
continue;
416+
}
417+
418+
// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
419+
if superclass
420+
.and_then(|superclass| superclass.class_method(desc.sel))
421+
.is_some()
422+
{
423+
continue;
424+
}
425+
426+
panic!(
427+
"must implement required protocol method +[{protocol} {}]",
428+
desc.sel
429+
);
430+
}
431+
}
432+
}
433+
434+
#[inline]
435+
#[cfg(not(all(debug_assertions, feature = "verify")))]
436+
pub fn finish(self) {}
437+
}

0 commit comments

Comments
 (0)