diff --git a/generalimport/fake_module.py b/generalimport/fake_module.py index ef3f620..37a911f 100644 --- a/generalimport/fake_module.py +++ b/generalimport/fake_module.py @@ -1,5 +1,7 @@ from typing import Optional import sys +from functools import partialmethod +from generalimport import MissingOptionalDependency from generalimport.exception import MissingOptionalDependency, MissingDependencyException @@ -43,6 +45,22 @@ def __getattr__(self, item): fakemodule.error_func() return FakeModule(spec=self.__spec__, trigger=item) + def __mro_entries__(self, *a, **k): + """ + This prevents the creation of subclasses from triggering `generalimport`. + + The classes so generated will trigger generalimport as soon as they're instantiated. + """ + class FakeBaseClass: + + def __new__(fake_cls, *args, **kwargs): + self.error_func("__new__") + + def __init__(fake_self, *args, **kwargs): + self.error_func("__init__") + + return (FakeBaseClass, ) + # Binary __ilshift__ = __invert__ = __irshift__ = __ixor__ = __lshift__ = __rlshift__ = __rrshift__ = __rshift__ = error_func diff --git a/generalimport/test/test_usage/test_object.py b/generalimport/test/test_usage/test_object.py index 7f5f7d0..f8ff715 100644 --- a/generalimport/test/test_usage/test_object.py +++ b/generalimport/test/test_usage/test_object.py @@ -1,26 +1,48 @@ -import sys -from unittest import skip - -import generalimport as gi from generalimport import * from generalimport.test.funcs import ImportTestCase class Test(ImportTestCase): - def test_init_subclass(self): - """ This one is caught by __call__. """ + def test_subclass_class_returning_self(self): generalimport("fakepackage") import fakepackage + class SubClass(fakepackage.BaseClass): + pass + + foo = SubClass.bar # Won't error if SubClass is a FakeModule + with self.assertRaises(MissingDependencyException): - class X(fakepackage): - pass + foo *= 2 + + def test_subclass_module(self): + generalimport("fakepackage") + import fakepackage + class X(fakepackage): + pass + with self.assertRaises(MissingDependencyException): + X() + def test_subclass_class(self): + generalimport("fakepackage") + import fakepackage + class SubClass(fakepackage.BaseClass): + def __init__(self): + raise ValueError("'generalimport' should fail earlier with MissingDependencyException") + self.assertRaises(MissingDependencyException, SubClass) + def test_subclass_class_direct_new_call(self): + generalimport("fakepackage") + import fakepackage + class SubClass(fakepackage.BaseClass): + def __init__(self): + raise ValueError("'generalimport' should fail earlier with MissingDependencyException") + with self.assertRaises(MissingDependencyException): + SubClass.__new__(SubClass)