diff --git a/train.py b/train.py index de57850510..4c68f9af16 100644 --- a/train.py +++ b/train.py @@ -69,7 +69,7 @@ # DDP settings backend = 'nccl' # 'nccl', 'gloo', etc. # system -device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks +device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler compile = True # use PyTorch 2.0 to compile the model to be faster # -----------------------------------------------------------------------------