Skip to content

Commit 755888e

Browse files
jeremymanningclaude
andcommitted
Fix test failures and improve type annotations for mypy compliance
- Fix test_notebook_magic.py: Handle widgets module import when ipywidgets unavailable - Add Optional type annotations for function parameters with None defaults - Fix Path object assignment issues in config.py by using separate variables - Add null checks for optional config fields (module_loads, environment_variables) - Add type annotations for dictionaries (futures, active_jobs) - Add runtime check for LocalExecutor context manager usage - Change cloud provider return types to Dict[str, Any] for mixed data types - Install types-PyYAML and types-paramiko for better type checking - Resolve Magics class redefinition conflicts in notebook_magic.py All tests pass (279/279) and code formatting is clean. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 214e7d3 commit 755888e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+27908
-41
lines changed

clustrix/config.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,24 @@ def __post_init__(self):
8888

8989
def save_to_file(self, config_path: str) -> None:
9090
"""Save this configuration instance to a file."""
91-
config_path = Path(config_path)
91+
config_path_obj = Path(config_path)
9292
config_data = asdict(self)
9393

94-
with open(config_path, "w") as f:
95-
if config_path.suffix.lower() in [".yml", ".yaml"]:
94+
with open(config_path_obj, "w") as f:
95+
if config_path_obj.suffix.lower() in [".yml", ".yaml"]:
9696
yaml.dump(config_data, f, default_flow_style=False)
9797
else:
9898
json.dump(config_data, f, indent=2)
9999

100100
@classmethod
101101
def load_from_file(cls, config_path: str) -> "ClusterConfig":
102102
"""Load configuration from a file and return a new instance."""
103-
config_path = Path(config_path)
104-
if not config_path.exists():
103+
config_path_obj = Path(config_path)
104+
if not config_path_obj.exists():
105105
raise FileNotFoundError(f"Configuration file not found: {config_path}")
106106

107-
with open(config_path, "r") as f:
108-
if config_path.suffix.lower() in [".yml", ".yaml"]:
107+
with open(config_path_obj, "r") as f:
108+
if config_path_obj.suffix.lower() in [".yml", ".yaml"]:
109109
config_data = yaml.safe_load(f)
110110
else:
111111
config_data = json.load(f)
@@ -143,12 +143,12 @@ def load_config(config_path: str) -> None:
143143
"""
144144
global _config
145145

146-
config_path = Path(config_path)
147-
if not config_path.exists():
146+
config_path_obj = Path(config_path)
147+
if not config_path_obj.exists():
148148
raise FileNotFoundError(f"Configuration file not found: {config_path}")
149149

150-
with open(config_path, "r") as f:
151-
if config_path.suffix.lower() in [".yml", ".yaml"]:
150+
with open(config_path_obj, "r") as f:
151+
if config_path_obj.suffix.lower() in [".yml", ".yaml"]:
152152
config_data = yaml.safe_load(f)
153153
else:
154154
config_data = json.load(f)
@@ -163,11 +163,11 @@ def save_config(config_path: str) -> None:
163163
Args:
164164
config_path: Path where to save configuration
165165
"""
166-
config_path = Path(config_path)
166+
config_path_obj = Path(config_path)
167167
config_data = asdict(_config)
168168

169-
with open(config_path, "w") as f:
170-
if config_path.suffix.lower() in [".yml", ".yaml"]:
169+
with open(config_path_obj, "w") as f:
170+
if config_path_obj.suffix.lower() in [".yml", ".yaml"]:
171171
yaml.dump(config_data, f, default_flow_style=False)
172172
else:
173173
json.dump(config_data, f, indent=2)

clustrix/cost_providers/aws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def estimate_batch_cost(
260260

261261
def get_region_pricing_comparison(
262262
self, instance_type: str
263-
) -> Dict[str, Dict[str, float]]:
263+
) -> Dict[str, Dict[str, Any]]:
264264
"""Compare pricing across AWS regions (simplified)."""
265265
# Regional pricing multipliers (approximate)
266266
region_multipliers = {

clustrix/cost_providers/azure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def estimate_batch_cost(
269269

270270
def get_region_pricing_comparison(
271271
self, instance_type: str
272-
) -> Dict[str, Dict[str, float]]:
272+
) -> Dict[str, Dict[str, Any]]:
273273
"""Compare pricing across Azure regions (simplified)."""
274274
# Regional pricing multipliers (approximate)
275275
region_multipliers = {

clustrix/cost_providers/gcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def estimate_sustained_use_discount(self, hours_per_month: float) -> Dict[str, A
259259

260260
def get_region_pricing_comparison(
261261
self, instance_type: str
262-
) -> Dict[str, Dict[str, float]]:
262+
) -> Dict[str, Dict[str, Any]]:
263263
"""Compare pricing across GCP regions (simplified)."""
264264
# Regional pricing multipliers (approximate)
265265
region_multipliers = {

clustrix/cost_providers/lambda_cloud.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
from datetime import datetime
7-
from typing import Dict, List, Any
7+
from typing import Dict, List, Any, Optional
88

99
from ..cost_monitoring import BaseCostMonitor, ResourceUsage, CostEstimate
1010

@@ -128,7 +128,7 @@ def get_pricing_info(self) -> Dict[str, float]:
128128
return self.pricing.copy()
129129

130130
def get_instance_recommendations(
131-
self, resource_usage: ResourceUsage, current_instance: str = None
131+
self, resource_usage: ResourceUsage, current_instance: Optional[str] = None
132132
) -> List[str]:
133133
"""Get instance type recommendations based on current usage."""
134134
recommendations = []
@@ -242,7 +242,7 @@ def get_performance_metrics(self) -> Dict[str, Any]:
242242

243243
def estimate_monthly_cost(
244244
self, instance_type: str, hours_per_day: float = 8
245-
) -> Dict[str, float]:
245+
) -> Dict[str, Any]:
246246
"""Estimate monthly costs for different usage patterns."""
247247
instance_type = instance_type.lower().replace("-", "_")
248248
hourly_rate = self.pricing.get(instance_type, self.pricing["default"])

clustrix/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, config: ClusterConfig):
2020
self.config = config
2121
self.ssh_client = None
2222
self.sftp_client = None
23-
self.active_jobs = {}
23+
self.active_jobs: Dict[str, Any] = {}
2424

2525
# Connection will be established on-demand
2626

clustrix/local_executor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, max_workers: Optional[int] = None, use_threads: bool = False)
2424
max_workers: Maximum number of worker processes/threads
2525
use_threads: If True, use ThreadPoolExecutor, else ProcessPoolExecutor
2626
"""
27-
self.max_workers = max_workers or os.cpu_count()
27+
self.max_workers = max_workers or os.cpu_count() or 4
2828
self.use_threads = use_threads
2929
self._executor = None
3030

@@ -115,10 +115,13 @@ def _execute_parallel_chunks(
115115
timeout: Optional[float],
116116
) -> List[Any]:
117117
"""Execute chunks in parallel using the executor."""
118-
futures = {}
118+
futures: Dict[Any, int] = {}
119119
results = [None] * len(work_chunks)
120120

121121
# Submit all tasks
122+
if self._executor is None:
123+
raise RuntimeError("LocalExecutor must be used as a context manager")
124+
122125
for i, chunk in enumerate(work_chunks):
123126
args = chunk.get("args", ())
124127
kwargs = chunk.get("kwargs", {})
@@ -183,7 +186,7 @@ def execute_loop_parallel(
183186
loop_var: str,
184187
iterable: Union[range, List, tuple],
185188
func_args: tuple = (),
186-
func_kwargs: dict = None,
189+
func_kwargs: Optional[Dict[Any, Any]] = None,
187190
chunk_size: Optional[int] = None,
188191
) -> List[Any]:
189192
"""
@@ -419,7 +422,7 @@ def create_local_executor(
419422
use_threads: Optional[bool] = None,
420423
func: Optional[Callable] = None,
421424
args: tuple = (),
422-
kwargs: dict = None,
425+
kwargs: Optional[Dict[Any, Any]] = None,
423426
) -> LocalExecutor:
424427
"""
425428
Create a LocalExecutor with appropriate settings.

clustrix/loop_analysis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def to_dict(self) -> Dict[str, Any]:
7474
class SafeRangeEvaluator(ast.NodeVisitor):
7575
"""Safely evaluate range expressions without using eval()."""
7676

77-
def __init__(self, local_vars: Dict[str, Any] = None):
77+
def __init__(self, local_vars: Optional[Dict[str, Any]] = None):
7878
self.local_vars = local_vars or {}
7979
self.result = None
8080
self.safe = True
@@ -192,8 +192,8 @@ def has_dependencies(self) -> bool:
192192
class LoopDetector(ast.NodeVisitor):
193193
"""Enhanced loop detection with dependency analysis."""
194194

195-
def __init__(self, local_vars: Dict[str, Any] = None):
196-
self.loops = []
195+
def __init__(self, local_vars: Optional[Dict[str, Any]] = None):
196+
self.loops: List[LoopInfo] = []
197197
self.current_level = 0
198198
self.local_vars = local_vars or {}
199199

@@ -299,7 +299,7 @@ def _analyze_while_loop(self, node) -> Optional[LoopInfo]:
299299

300300

301301
def detect_loops_in_function(
302-
func: Callable, args: tuple = (), kwargs: dict = None
302+
func: Callable, args: tuple = (), kwargs: Optional[Dict[Any, Any]] = None
303303
) -> List[LoopInfo]:
304304
"""
305305
Detect and analyze loops in a function.
@@ -320,7 +320,7 @@ def detect_loops_in_function(
320320
tree = ast.parse(source)
321321

322322
# Build local variables context
323-
local_vars = {}
323+
local_vars: Dict[str, Any] = {}
324324

325325
# Add function arguments to context
326326
try:
@@ -353,7 +353,7 @@ def detect_loops_in_function(
353353
def find_parallelizable_loops(
354354
func: Callable,
355355
args: tuple = (),
356-
kwargs: dict = None,
356+
kwargs: Optional[Dict[Any, Any]] = None,
357357
max_nesting_level: int = 1,
358358
) -> List[LoopInfo]:
359359
"""

clustrix/notebook_magic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
IPYTHON_AVAILABLE = False
2323

2424
# Create placeholder classes for non-notebook environments
25-
class Magics:
25+
class Magics: # type: ignore
2626
pass
2727

2828
def magics_class(cls):

clustrix/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def setup_remote_environment(
311311
ssh_client,
312312
work_dir: str,
313313
requirements: Dict[str, str],
314-
config: ClusterConfig = None,
314+
config: Optional[ClusterConfig] = None,
315315
):
316316
"""Setup environment on remote cluster via SSH."""
317317

@@ -394,14 +394,17 @@ def _create_slurm_script(
394394
script_lines.append(f"#SBATCH --partition={job_config['partition']}")
395395

396396
# Add environment setup
397-
for module in config.module_loads:
398-
script_lines.append(f"module load {module}")
397+
if config.module_loads:
398+
for module in config.module_loads:
399+
script_lines.append(f"module load {module}")
399400

400-
for var, value in config.environment_variables.items():
401-
script_lines.append(f"export {var}={value}")
401+
if config.environment_variables:
402+
for var, value in config.environment_variables.items():
403+
script_lines.append(f"export {var}={value}")
402404

403-
for cmd in config.pre_execution_commands:
404-
script_lines.append(cmd)
405+
if config.pre_execution_commands:
406+
for cmd in config.pre_execution_commands:
407+
script_lines.append(cmd)
405408

406409
# Add execution commands
407410
script_lines.extend(

0 commit comments

Comments
 (0)