14
14
import sys
15
15
import threading
16
16
import types
17
+ import warnings
17
18
from contextlib import ExitStack , contextmanager
18
19
from typing import (
19
20
Any ,
29
30
Set ,
30
31
Tuple ,
31
32
TypeVar ,
33
+ no_type_check ,
32
34
)
33
35
34
36
from typing_extensions import ParamSpec
@@ -107,6 +109,33 @@ def restrict_built_in(name: str, orig: Any, *args, **kwargs):
107
109
)
108
110
)
109
111
112
+ # Need to unwrap params for isinstance and issubclass. We have
113
+ # chosen to do it this way instead of customize __instancecheck__
114
+ # and __subclasscheck__ because we may have proxied the second
115
+ # parameter which does not have a way to override. It is unfortunate
116
+ # we have to change these globals for everybody.
117
+ def unwrap_second_param (orig : Any , a : Any , b : Any ) -> Any :
118
+ a = RestrictionContext .unwrap_if_proxied (a )
119
+ b = RestrictionContext .unwrap_if_proxied (b )
120
+ return orig (a , b )
121
+
122
+ thread_local_is_inst = _get_thread_local_builtin ("isinstance" )
123
+ self .restricted_builtins .append (
124
+ (
125
+ "isinstance" ,
126
+ thread_local_is_inst ,
127
+ functools .partial (unwrap_second_param , thread_local_is_inst .orig ),
128
+ )
129
+ )
130
+ thread_local_is_sub = _get_thread_local_builtin ("issubclass" )
131
+ self .restricted_builtins .append (
132
+ (
133
+ "issubclass" ,
134
+ thread_local_is_sub ,
135
+ functools .partial (unwrap_second_param , thread_local_is_sub .orig ),
136
+ )
137
+ )
138
+
110
139
@contextmanager
111
140
def applied (self ) -> Iterator [None ]:
112
141
"""Context manager to apply this restrictive import.
@@ -153,17 +182,21 @@ def _import(
153
182
fromlist : Sequence [str ] = (),
154
183
level : int = 0 ,
155
184
) -> types .ModuleType :
185
+ # We have to resolve the full name, it can be relative at different
186
+ # levels
187
+ full_name = _resolve_module_name (name , globals , level )
188
+
156
189
# Check module restrictions and passthrough modules
157
- if name not in sys .modules :
190
+ if full_name not in sys .modules :
158
191
# Make sure not an entirely invalid module
159
- self ._assert_valid_module (name )
192
+ self ._assert_valid_module (full_name )
160
193
161
194
# Check if passthrough
162
- passthrough_mod = self ._maybe_passthrough_module (name )
195
+ passthrough_mod = self ._maybe_passthrough_module (full_name )
163
196
if passthrough_mod :
164
197
# Load all parents. Usually Python does this for us, but not on
165
198
# passthrough.
166
- parent , _ , child = name .rpartition ("." )
199
+ parent , _ , child = full_name .rpartition ("." )
167
200
if parent and parent not in sys .modules :
168
201
_trace (
169
202
"Importing parent module %s before passing through %s" ,
@@ -174,17 +207,17 @@ def _import(
174
207
# Set the passthrough on the parent
175
208
setattr (sys .modules [parent ], child , passthrough_mod )
176
209
# Set the passthrough on sys.modules and on the parent
177
- sys .modules [name ] = passthrough_mod
210
+ sys .modules [full_name ] = passthrough_mod
178
211
# Put it on the parent
179
212
if parent :
180
- setattr (sys .modules [parent ], child , sys .modules [name ])
213
+ setattr (sys .modules [parent ], child , sys .modules [full_name ])
181
214
182
215
# If the module is __temporal_main__ and not already in sys.modules,
183
216
# we load it from whatever file __main__ was originally in
184
- if name == "__temporal_main__" :
217
+ if full_name == "__temporal_main__" :
185
218
orig_mod = _thread_local_sys_modules .orig ["__main__" ]
186
219
new_spec = importlib .util .spec_from_file_location (
187
- name , orig_mod .__file__
220
+ full_name , orig_mod .__file__
188
221
)
189
222
if not new_spec :
190
223
raise ImportError (
@@ -195,7 +228,7 @@ def _import(
195
228
f"Spec for __main__ file at { orig_mod .__file__ } has no loader"
196
229
)
197
230
new_mod = importlib .util .module_from_spec (new_spec )
198
- sys .modules [name ] = new_mod
231
+ sys .modules [full_name ] = new_mod
199
232
new_spec .loader .exec_module (new_mod )
200
233
201
234
mod = importlib .__import__ (name , globals , locals , fromlist , level )
@@ -219,10 +252,20 @@ def _assert_valid_module(self, name: str) -> None:
219
252
raise RestrictedWorkflowAccessError (name )
220
253
221
254
def _maybe_passthrough_module (self , name : str ) -> Optional [types .ModuleType ]:
222
- if not self .restrictions .passthrough_modules .match_access (
223
- self .restriction_context , * name .split ("." )
255
+ # If imports not passed through and name not in passthrough modules,
256
+ # check parents
257
+ if (
258
+ not temporalio .workflow .unsafe .is_imports_passed_through ()
259
+ and name not in self .restrictions .passthrough_modules
224
260
):
225
- return None
261
+ end_dot = - 1
262
+ while True :
263
+ end_dot = name .find ("." , end_dot + 1 )
264
+ if end_dot == - 1 :
265
+ return None
266
+ elif name [:end_dot ] in self .restrictions .passthrough_modules :
267
+ break
268
+ # Do the pass through
226
269
with self ._unapplied ():
227
270
_trace ("Passing module %s through from host" , name )
228
271
global _trace_depth
@@ -409,3 +452,50 @@ def _get_thread_local_builtin(name: str) -> _ThreadLocalCallable:
409
452
ret = _ThreadLocalCallable (getattr (builtins , name ))
410
453
_thread_local_builtins [name ] = ret
411
454
return ret
455
+
456
+
457
+ def _resolve_module_name (
458
+ name : str , globals : Optional [Mapping [str , object ]], level : int
459
+ ) -> str :
460
+ if level == 0 :
461
+ return name
462
+ # Calc the package from globals
463
+ package = _calc___package__ (globals or {})
464
+ # Logic taken from importlib._resolve_name
465
+ bits = package .rsplit ("." , level - 1 )
466
+ if len (bits ) < level :
467
+ raise ImportError ("Attempted relative import beyond top-level package" )
468
+ base = bits [0 ]
469
+ return f"{ base } .{ name } " if name else base
470
+
471
+
472
+ # Copied from importlib._calc__package__
473
+ @no_type_check
474
+ def _calc___package__ (globals : Mapping [str , object ]) -> str :
475
+ """Calculate what __package__ should be.
476
+ __package__ is not guaranteed to be defined or could be set to None
477
+ to represent that its proper value is unknown.
478
+ """
479
+ package = globals .get ("__package__" )
480
+ spec = globals .get ("__spec__" )
481
+ if package is not None :
482
+ if spec is not None and package != spec .parent :
483
+ warnings .warn (
484
+ "__package__ != __spec__.parent " f"({ package !r} != { spec .parent !r} )" ,
485
+ DeprecationWarning ,
486
+ stacklevel = 3 ,
487
+ )
488
+ return package
489
+ elif spec is not None :
490
+ return spec .parent
491
+ else :
492
+ warnings .warn (
493
+ "can't resolve package from __spec__ or __package__, "
494
+ "falling back on __name__ and __path__" ,
495
+ ImportWarning ,
496
+ stacklevel = 3 ,
497
+ )
498
+ package = globals ["__name__" ]
499
+ if "__path__" not in globals :
500
+ package = package .rpartition ("." )[0 ]
501
+ return package
0 commit comments