diff --git a/training/utils/dist_checkpoint_utils.py b/training/utils/dist_checkpoint_utils.py index 702d632..5ab45a7 100644 --- a/training/utils/dist_checkpoint_utils.py +++ b/training/utils/dist_checkpoint_utils.py @@ -22,43 +22,44 @@ def load_checkpoint(pipe, args): try: with open(os.path.join(checkpoint_step_path, 'meta.json')) as f: meta = json.load(f) - except: - print('failed to load meta.') + except FileNotFoundError: + print(f"Checkpoint metadata file not found at {os.path.join(checkpoint_step_path, 'meta.json')}") + return # Or handle appropriately + except Exception as e: + print(f"Failed to load meta.json: {e}") + # Decide if you want to return or raise pipe.global_step = latest_step + model_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt') try: pipe.model.model.load_state_dict( - torch.load( - os.path.join( - checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt' - ), map_location=torch.device('cpu') - ) + torch.load(model_path, map_location=torch.device('cpu')) ) - except: - print('failed to load model params.') + except FileNotFoundError: + print(f"Model checkpoint file not found: {model_path}") + except Exception as e: + print(f"Failed to load model params from {model_path}: {e}") + optimizer_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt') try: pipe.optimizer.load_state_dict( - torch.load( - os.path.join( - checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt' - ), map_location=torch.device('cpu') - ) + torch.load(optimizer_path, map_location=torch.device('cpu')) ) - except: - print('failed to load optim states.') + except FileNotFoundError: + print(f"Optimizer checkpoint file not found: {optimizer_path}") + except Exception as e: + print(f"Failed to load optim states from {optimizer_path}: {e}") + scheduler_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt') try: pipe.scheduler.load_state_dict( - torch.load( - os.path.join( - checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt' - ) - ) + torch.load(scheduler_path) ) - except: - print('failed to load scheduler states.') + except FileNotFoundError: + print(f"Scheduler checkpoint file not found: {scheduler_path}") + except Exception as e: + print(f"Failed to load scheduler states from {scheduler_path}: {e}") def save_checkpoint(pipe, args) -> str: @@ -109,29 +110,28 @@ def save_stream_dataloader_state_dict(dataloader, pipe, args): latest_step = pipe.global_step checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}") - os.system(f"mkdir -p {checkpoint_step_path}") + os.makedirs(checkpoint_step_path, exist_ok=True) - torch.save( - dataloader.dataset.state_dict(), - os.path.join( - checkpoint_step_path, f'dataset_state_dict.pt' + dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + try: + torch.save( + dataloader.dataset.state_dict(), + dataset_state_dict_path ) - ) + except Exception as e: + print(f"Failed to save dataset state_dict to {dataset_state_dict_path}: {e}") def load_stream_dataloader_state_dict(dataloader, pipe, args): latest_step = pipe.global_step - checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}") + # checkpoint_step_path is already defined in load_checkpoint, but if this function can be called independently: + checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}") + dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') try: - state_dict = torch.load( - os.path.join( - checkpoint_step_path, f'dataset_state_dict.pt' - ) - ) - - dataloader.data.load_state_dict(state_dict) - + state_dict = torch.load(dataset_state_dict_path) + dataloader.dataset.load_state_dict(state_dict) # Corrected: dataloader.dataset.load_state_dict + except FileNotFoundError: + print(f"Dataset state_dict file not found: {dataset_state_dict_path}") except Exception as e: - - print('failed to load dataset state_dict.') \ No newline at end of file + print(f"Failed to load dataset state_dict from {dataset_state_dict_path}: {e}") \ No newline at end of file diff --git a/training/utils/logging_utils.py b/training/utils/logging_utils.py index fc7420d..3c5dfed 100644 --- a/training/utils/logging_utils.py +++ b/training/utils/logging_utils.py @@ -24,8 +24,9 @@ def init_train_logger(args): if train_log_backend == 'print': pass elif train_log_backend == 'loguru': - os.system("mkdir -p logs") - loguru.logger.add("logs/file_{time}.log") + log_file_path = getattr(args, 'log_file_path', "logs/file_{time}.log") + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + loguru.logger.add(log_file_path) elif train_log_backend == 'wandb': assert _has_wandb @@ -34,11 +35,13 @@ def init_train_logger(args): import re args.project_name = "test-" + \ re.sub('[^a-zA-Z0-9 \n\.]', '_', args.task_name) - - wandb.init( - project=args.project_name, - config=args, - ) + try: + wandb.init( + project=args.project_name, + config=args, + ) + except Exception as e: + print(f"Error initializing wandb: {e}") else: raise Exception('Unknown logging backend.') @@ -52,6 +55,4 @@ def train_log(x, *args, **kargs): elif train_log_backend == 'wandb': wandb.log(x, *args, **kargs) else: - raise Exception('Unknown logging backend.') - - \ No newline at end of file + raise Exception('Unknown logging backend.') \ No newline at end of file diff --git a/training/utils/test_dist_checkpoint_utils.py b/training/utils/test_dist_checkpoint_utils.py new file mode 100644 index 0000000..48d653a --- /dev/null +++ b/training/utils/test_dist_checkpoint_utils.py @@ -0,0 +1,230 @@ +import unittest +from unittest.mock import patch, MagicMock, mock_open, call +import os +import shutil +import json +import torch # Keep torch for torch.device and potentially other utilities +import sys + +# Import the functions to be tested +from training.utils.dist_checkpoint_utils import ( + load_checkpoint, + save_checkpoint, + save_stream_dataloader_state_dict, + load_stream_dataloader_state_dict +) + +# Mock get_pipeline_parallel_rank as it's used to construct file names +@patch('training.utils.dist_checkpoint_utils.get_pipeline_parallel_rank', MagicMock(return_value=0)) +class TestDistCheckpointUtils(unittest.TestCase): + + def setUp(self): + self.test_checkpoint_dir = "test_checkpoints" + # Ensure a clean state for each test + if os.path.exists(self.test_checkpoint_dir): + shutil.rmtree(self.test_checkpoint_dir) + os.makedirs(self.test_checkpoint_dir, exist_ok=True) + + self.args = MagicMock() + self.args.checkpoint_path = self.test_checkpoint_dir + + self.pipe = MagicMock() + self.pipe.global_step = 10 + self.pipe.model = MagicMock() + self.pipe.model.model = MagicMock() # Mocking model.model attribute for state_dict + self.pipe.optimizer = MagicMock() + self.pipe.optimizer.state_dict = MagicMock(return_value={"opt_param": 1}) + self.pipe.scheduler = MagicMock() + self.pipe.scheduler.state_dict = MagicMock(return_value={"sched_param": 1}) + self.pipe.model.model.state_dict = MagicMock(return_value={"model_param": 1}) + + # Suppress print statements from the module + self.original_stdout = sys.stdout + sys.stdout = MagicMock() + + def tearDown(self): + if os.path.exists(self.test_checkpoint_dir): + shutil.rmtree(self.test_checkpoint_dir) + sys.stdout = self.original_stdout # Restore stdout + + def _create_dummy_checkpoint_files(self, step, create_meta=True, create_model=True, create_optimizer=True, create_scheduler=True): + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}") + os.makedirs(checkpoint_step_path, exist_ok=True) + + if create_meta: + with open(os.path.join(checkpoint_step_path, 'meta.json'), 'w') as f: + json.dump({'step': step}, f) + + if create_model: + torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt')) + if create_optimizer: + torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_optimizer.pt')) + if create_scheduler: + torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_scheduler.pt')) + + with open(os.path.join(self.test_checkpoint_dir, 'latest'), 'w') as f: + f.write(str(step)) + return checkpoint_step_path + + @patch('torch.load') + def test_load_checkpoint_success(self, mock_torch_load): + step = 10 + self._create_dummy_checkpoint_files(step) + mock_torch_load.return_value = {"dummy_state": "value"} + + load_checkpoint(self.pipe, self.args) + + self.assertEqual(self.pipe.global_step, step) + self.pipe.model.model.load_state_dict.assert_called_once_with({"dummy_state": "value"}) + self.pipe.optimizer.load_state_dict.assert_called_once_with({"dummy_state": "value"}) + self.pipe.scheduler.load_state_dict.assert_called_once_with({"dummy_state": "value"}) + + def test_load_checkpoint_no_latest_file(self): + # No 'latest' file + load_checkpoint(self.pipe, self.args) + self.assertTrue(any("no checkpoint available, skipping" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + def test_load_checkpoint_meta_file_not_found(self): + step = 5 + self._create_dummy_checkpoint_files(step, create_meta=False) + + load_checkpoint(self.pipe, self.args) + expected_msg = f"Checkpoint metadata file not found at {os.path.join(self.test_checkpoint_dir, f'checkpoint_{step}', 'meta.json')}" + self.assertTrue(any(expected_msg in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + @patch('torch.load', side_effect=FileNotFoundError("File not found")) + def test_load_checkpoint_model_file_not_found(self, mock_torch_load): + step = 15 + self._create_dummy_checkpoint_files(step, create_model=False) # Model file won't be there but mock matters more + + # We need meta.json to proceed to loading attempts + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}") + with open(os.path.join(checkpoint_step_path, 'meta.json'), 'w') as f: + json.dump({'step': step}, f) + with open(os.path.join(self.test_checkpoint_dir, 'latest'), 'w') as f: + f.write(str(step)) + + load_checkpoint(self.pipe, self.args) + model_path = os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt') + self.assertTrue(any(f"Model checkpoint file not found: {model_path}" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + @patch('torch.load', side_effect=RuntimeError("Torch load error")) + def test_load_checkpoint_torch_load_generic_error(self, mock_torch_load): + step = 20 + self._create_dummy_checkpoint_files(step) # All files are there + + load_checkpoint(self.pipe, self.args) + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}") + model_path = os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt') + self.assertTrue(any(f"Failed to load model params from {model_path}: Torch load error" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + + @patch('torch.save') + @patch('os.makedirs') + def test_save_checkpoint_directory_creation_and_save(self, mock_os_makedirs, mock_torch_save): + self.pipe.global_step = 25 + + save_checkpoint(self.pipe, self.args) + + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_25") + mock_os_makedirs.assert_called_once_with(checkpoint_step_path, exist_ok=True) + + self.assertEqual(mock_torch_save.call_count, 3) # model, optimizer, scheduler + + # Check meta.json content + meta_path = os.path.join(checkpoint_step_path, 'meta.json') + self.assertTrue(os.path.exists(meta_path)) + with open(meta_path, 'r') as f: + meta = json.load(f) + self.assertEqual(meta['step'], 25) + + # Check latest file content + latest_path = os.path.join(self.test_checkpoint_dir, 'latest') + self.assertTrue(os.path.exists(latest_path)) + with open(latest_path, 'r') as f: + latest_step_str = f.read() + self.assertEqual(latest_step_str, "25") + + + @patch('torch.save') + @patch('os.makedirs') + def test_save_stream_dataloader_state_dict_creation_and_save(self, mock_os_makedirs, mock_torch_save): + self.pipe.global_step = 30 + mock_dataloader = MagicMock() + mock_dataloader.dataset.state_dict.return_value = {"dataset_state": "some_state"} + + save_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args) + + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_30") + mock_os_makedirs.assert_called_once_with(checkpoint_step_path, exist_ok=True) + + dataset_state_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + mock_torch_save.assert_called_once_with({"dataset_state": "some_state"}, dataset_state_path) + + @patch('torch.save', side_effect=Exception("Failed to save dataset")) + @patch('os.makedirs') + def test_save_stream_dataloader_state_dict_save_error(self, mock_os_makedirs, mock_torch_save): + self.pipe.global_step = 35 + mock_dataloader = MagicMock() + mock_dataloader.dataset.state_dict.return_value = {"dataset_state": "some_state"} + + save_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args) + + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_35") + dataset_state_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + self.assertTrue(any(f"Failed to save dataset state_dict to {dataset_state_path}: Failed to save dataset" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + + @patch('torch.load') + def test_load_stream_dataloader_state_dict_success(self, mock_torch_load): + self.pipe.global_step = 40 + # We need to ensure the checkpoint directory for this step exists for path construction + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}") + os.makedirs(checkpoint_step_path, exist_ok=True) + # Dummy file for torch.load to be "successful" + dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + with open(dataset_state_dict_path, 'w') as f: f.write("dummy data") + + + mock_dataloader = MagicMock() + mock_dataloader.dataset = MagicMock() # Ensure dataset attribute exists + + mock_torch_load.return_value = {"loaded_state": "value"} + + load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args) + + mock_torch_load.assert_called_once_with(dataset_state_dict_path) + mock_dataloader.dataset.load_state_dict.assert_called_once_with({"loaded_state": "value"}) + + @patch('torch.load', side_effect=FileNotFoundError("Dataset state file not found")) + def test_load_stream_dataloader_state_dict_file_not_found(self, mock_torch_load): + self.pipe.global_step = 45 + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}") + # No need to create the dummy file, as we are testing FileNotFoundError + + mock_dataloader = MagicMock() + mock_dataloader.dataset = MagicMock() + + load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args) + + dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + self.assertTrue(any(f"Dataset state_dict file not found: {dataset_state_dict_path}" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + @patch('torch.load', side_effect=RuntimeError("Torch load error for dataset")) + def test_load_stream_dataloader_state_dict_generic_error(self, mock_torch_load): + self.pipe.global_step = 50 + checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}") + os.makedirs(checkpoint_step_path, exist_ok=True) + dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt') + with open(dataset_state_dict_path, 'w') as f: f.write("dummy data") # File exists + + mock_dataloader = MagicMock() + mock_dataloader.dataset = MagicMock() + + load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args) + + self.assertTrue(any(f"Failed to load dataset state_dict from {dataset_state_dict_path}: Torch load error for dataset" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/training/utils/test_logging_utils.py b/training/utils/test_logging_utils.py new file mode 100644 index 0000000..c0fdb32 --- /dev/null +++ b/training/utils/test_logging_utils.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, mock_open +import os +import shutil +import sys + +# Import the functions to be tested +from training.utils.logging_utils import init_train_logger, train_log, _has_wandb, _has_loguru + +# Mock a simple args class +class MockArgs: + def __init__(self, train_log_backend=None, project_name=None, task_name=None, log_file_path=None): + self.train_log_backend = train_log_backend + self.project_name = project_name + self.task_name = task_name + if log_file_path is not None: + self.log_file_path = log_file_path + +class TestLoggingUtils(unittest.TestCase): + + def setUp(self): + # This setup can be used to clean up created directories if any, + # but for mocking, it's often not strictly necessary. + self.test_logs_dir = "test_logs_dir" + # Ensure no leftover directories from previous failed tests + if os.path.exists(self.test_logs_dir): + shutil.rmtree(self.test_logs_dir) + if os.path.exists("logs"): + shutil.rmtree("logs") + + # It's good practice to store original values and restore them + self.original_stdout = sys.stdout + self.original_stderr = sys.stderr + sys.stdout = MagicMock() # Suppress print statements + sys.stderr = MagicMock() + + + def tearDown(self): + # Clean up any created directories + if os.path.exists(self.test_logs_dir): + shutil.rmtree(self.test_logs_dir) + if os.path.exists("logs"): + shutil.rmtree("logs") + sys.stdout = self.original_stdout # Restore stdout + sys.stderr = self.original_stderr # Restore stderr + + + @patch('training.utils.logging_utils.wandb') + def test_wandb_init_error_handling(self, mock_wandb): + if not _has_wandb: + self.skipTest("wandb is not installed, skipping test") + + mock_wandb.init.side_effect = Exception("Wandb init failed") + args = MockArgs(train_log_backend='wandb', project_name='test_project', task_name='test_task') + + init_train_logger(args) + + mock_wandb.init.assert_called_once_with(project='test_project', config=args) + # Check if the error message was printed (to our mocked sys.stdout or a logger if that's how it's handled) + # Here we assume it prints to stdout/stderr. If it uses logging module, that needs to be mocked. + self.assertTrue(any("Error initializing wandb: Wandb init failed" in call_args[0][0] for call_args in sys.stdout.write.call_args_list)) + + + @patch('training.utils.logging_utils.loguru.logger') + @patch('training.utils.logging_utils.os.makedirs') + def test_loguru_configurable_log_file_path(self, mock_makedirs, mock_loguru_logger): + if not _has_loguru: + self.skipTest("loguru is not installed, skipping test") + + custom_path = os.path.join(self.test_logs_dir, "custom_log_file_{time}.log") + args = MockArgs(train_log_backend='loguru', log_file_path=custom_path) + + init_train_logger(args) + + expected_dir = os.path.dirname(custom_path) + mock_makedirs.assert_called_once_with(expected_dir, exist_ok=True) + mock_loguru_logger.add.assert_called_once_with(custom_path) + + @patch('training.utils.logging_utils.loguru.logger') + @patch('training.utils.logging_utils.os.makedirs') + def test_loguru_default_log_file_path(self, mock_makedirs, mock_loguru_logger): + if not _has_loguru: + self.skipTest("loguru is not installed, skipping test") + + args = MockArgs(train_log_backend='loguru') # No log_file_path provided + + init_train_logger(args) + + # Default path is "logs/file_{time}.log" + # The directory part is "logs" + mock_makedirs.assert_called_once_with("logs", exist_ok=True) + mock_loguru_logger.add.assert_called_once_with("logs/file_{time}.log") + + @patch('training.utils.logging_utils.loguru.logger') + @patch('training.utils.logging_utils.os.makedirs') + def test_loguru_directory_creation(self, mock_makedirs, mock_loguru_logger): + if not _has_loguru: + self.skipTest("loguru is not installed, skipping test") + + custom_log_dir = "my_custom_logs" + custom_path = os.path.join(custom_log_dir, "another_log_{time}.log") + args = MockArgs(train_log_backend='loguru', log_file_path=custom_path) + + init_train_logger(args) + + mock_makedirs.assert_called_once_with(custom_log_dir, exist_ok=True) + mock_loguru_logger.add.assert_called_once_with(custom_path) + + # Clean up the directory if it was actually created by mistake (mock should prevent this) + if os.path.exists(custom_log_dir): + shutil.rmtree(custom_log_dir) + + @patch('training.utils.logging_utils.wandb') + def test_train_log_wandb_called(self, mock_wandb): + if not _has_wandb: + self.skipTest("wandb is not installed, skipping test") + + args = MockArgs(train_log_backend='wandb', project_name='test_p', task_name='test_t') + init_train_logger(args) # Initialize backend + + log_data = {"accuracy": 0.95} + train_log(log_data, step=100) + + mock_wandb.log.assert_called_once_with(log_data, step=100) + + @patch('training.utils.logging_utils.loguru.logger') + @patch('training.utils.logging_utils.os.makedirs') # Keep makedirs mocked + def test_train_log_loguru_called(self, mock_makedirs, mock_loguru_logger): + if not _has_loguru: + self.skipTest("loguru is not installed, skipping test") + + args = MockArgs(train_log_backend='loguru') + init_train_logger(args) # Initialize backend + + log_message = "Epoch 5 completed." + train_log(log_message) + + mock_loguru_logger.info.assert_called_once_with(log_message) + + @patch('builtins.print') # Mock print for 'print' backend + def test_train_log_print_called(self, mock_print): + args = MockArgs(train_log_backend='print') + init_train_logger(args) # Initialize backend + + log_message = "Test print message." + train_log(log_message) + + mock_print.assert_called_once_with(log_message) + + def test_unknown_backend_init(self): + args = MockArgs(train_log_backend='unknown_backend') + with self.assertRaisesRegex(Exception, 'Unknown logging backend.'): + init_train_logger(args) + + def test_unknown_backend_log(self): + # First init with a valid backend, then try to force an unknown one for train_log + # This is a bit artificial as init_train_logger sets a global. + # A better way would be to directly set train_log_backend global for test if possible. + args = MockArgs(train_log_backend='print') + init_train_logger(args) + + # Directly manipulate the global for testing this edge case + import training.utils.logging_utils as lu + original_backend = lu.train_log_backend + lu.train_log_backend = 'another_unknown_backend' + + with self.assertRaisesRegex(Exception, 'Unknown logging backend.'): + train_log("test message") + + lu.train_log_backend = original_backend # Restore + + +if __name__ == '__main__': + # This allows running the tests directly from the script + # For integration with a test runner, this might not be necessary + # or might be handled differently. + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(TestLoggingUtils) + runner = unittest.TextTestRunner() + runner.run(suite) + + # If wandb or loguru are not installed, their tests will be skipped. + # We can print a notice. + if not _has_wandb: + print("\nSkipped wandb tests as wandb is not installed.") + if not _has_loguru: + print("\nSkipped loguru tests as loguru is not installed.") \ No newline at end of file