|
4 | 4 |
|
5 | 5 | import concurrent.futures
|
6 | 6 | from collections.abc import Callable, Mapping, MutableMapping
|
7 |
| -from dataclasses import dataclass, field |
| 7 | +from dataclasses import dataclass |
8 | 8 | from datetime import timedelta
|
9 | 9 | from typing import (
|
10 | 10 | Any,
|
@@ -301,38 +301,38 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
|
301 | 301 | headers: Optional[Mapping[str, str]]
|
302 | 302 | output_type: Optional[Type[OutputT]] = None
|
303 | 303 |
|
304 |
| - _operation_name: str = field(init=False, repr=False) |
305 |
| - _input_type: Optional[Type[InputT]] = field(init=False, repr=False) |
306 |
| - |
307 | 304 | def __post_init__(self) -> None:
|
308 | 305 | if isinstance(self.operation, nexusrpc.Operation):
|
309 |
| - self._operation_name = self.operation.name |
310 |
| - self._input_type = self.operation.input_type |
311 | 306 | self.output_type = self.operation.output_type
|
312 |
| - elif isinstance(self.operation, str): |
313 |
| - self._operation_name = self.operation |
314 |
| - self._input_type = None |
315 | 307 | elif callable(self.operation):
|
316 | 308 | _, op = temporalio.nexus._util.get_operation_factory(self.operation)
|
317 | 309 | if isinstance(op, nexusrpc.Operation):
|
318 |
| - self._operation_name = op.name |
319 |
| - self._input_type = op.input_type |
320 | 310 | self.output_type = op.output_type
|
321 | 311 | else:
|
322 | 312 | raise ValueError(
|
323 | 313 | f"Operation callable is not a Nexus operation: {self.operation}"
|
324 | 314 | )
|
| 315 | + elif isinstance(self.operation, str): |
| 316 | + pass |
325 | 317 | else:
|
326 | 318 | raise ValueError(f"Operation is not a Nexus operation: {self.operation}")
|
327 | 319 |
|
328 | 320 | @property
|
329 | 321 | def operation_name(self) -> str:
|
330 |
| - return self._operation_name |
331 |
| - |
332 |
| - # TODO(nexus-preview) contravariant type in output |
333 |
| - @property |
334 |
| - def input_type(self) -> Optional[Type[InputT]]: |
335 |
| - return self._input_type |
| 322 | + if isinstance(self.operation, nexusrpc.Operation): |
| 323 | + return self.operation.name |
| 324 | + elif isinstance(self.operation, str): |
| 325 | + return self.operation |
| 326 | + elif callable(self.operation): |
| 327 | + _, op = temporalio.nexus._util.get_operation_factory(self.operation) |
| 328 | + if isinstance(op, nexusrpc.Operation): |
| 329 | + return op.name |
| 330 | + else: |
| 331 | + raise ValueError( |
| 332 | + f"Operation callable is not a Nexus operation: {self.operation}" |
| 333 | + ) |
| 334 | + else: |
| 335 | + raise ValueError(f"Operation is not a Nexus operation: {self.operation}") |
336 | 336 |
|
337 | 337 |
|
338 | 338 | @dataclass
|
|
0 commit comments