diff --git a/mergekit/common.py b/mergekit/common.py index e1ddb76c..4245a699 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -381,3 +381,43 @@ def get_auto_cls(arch_name: str) -> AutoClassProtocol: ) auto_cls = transformers.AutoModelForCausalLM return auto_cls + + +def get_torch_accelerator_module(accelerator_name: Optional[str] = None): + if accelerator_name is not None: + accelerator_type = torch.device(accelerator_name).type + return getattr(torch, accelerator_type) + else: + return ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + + +def get_torch_accelerator_count(accelerator_name: Optional[str] = None): + torch_accelerator_module = torch.cuda + if accelerator_name is not None: + accelerator = torch.device(accelerator_name) + # if user passes the device index in `accelerator_name`, then 1 + if accelerator.index != None: + return 1 + torch_accelerator_module = getattr(torch, accelerator.type) + else: + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + return torch_accelerator_module.device_count() + + +def get_torch_accelerator_type(accelerator_name: Optional[str] = None): + if accelerator_name is not None: + return torch.device(accelerator_name).type + else: + return ( + torch.accelerator.current_accelerator().type + if hasattr(torch, "accelerator") + else "cuda" + ) diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index 09b205e9..e5f9cd4b 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -26,6 +26,7 @@ from mergekit.architecture import arch_info_for_config +from mergekit.common import get_torch_accelerator_module, get_torch_accelerator_type from mergekit.config import MergeConfiguration from mergekit.evo.config import EvolMergeConfiguration from mergekit.evo.genome import InvalidGenotypeError, ModelGenome @@ -90,7 +91,10 @@ def evaluate_genotype( genotype: torch.Tensor, ) -> dict: gc.collect() - torch.cuda.empty_cache() + torch_accelerator_module = get_torch_accelerator_module( + self.merge_options.device + ) + torch_accelerator_module.empty_cache() LOG.info("Merging model") merged_path = merge_model( genotype, self.genome, self.model_storage_path, self.merge_options @@ -190,7 +194,7 @@ def _maybe_init_model(self, config: MergeConfiguration): **model_kwargs, ) .bfloat16() - .cuda() + .to(self.merge_options.device) .eval() .requires_grad_(False) ) @@ -226,16 +230,17 @@ def _maybe_init_model(self, config: MergeConfiguration): max_model_len = 8192 LOG.warning(f"Clipping sequence length to {max_model_len}") + accelerator_type = get_torch_accelerator_type(self.merge_options.device) mem_util = ( - 0.7 if self.merge_options.cuda else 0.9 - ) # reduce memory usage if we're also using cuda for the merge + 0.7 if accelerator_type in ["cuda", "xpu"] else 0.9 + ) # reduce memory usage if we're also using accelerator for the merge self.model = lm_eval.models.vllm_causallms.VLLM( pretrained=tempdir, batch_size=self.batch_size or "auto", max_model_len=max_model_len, gpu_memory_utilization=mem_util, dtype="bfloat16", - device="cuda", + device=self.merge_options.device, trust_remote_code=self.merge_options.trust_remote_code, ) else: @@ -292,10 +297,19 @@ def evaluate(self, genotype: torch.Tensor) -> dict: ".up_proj.": (".gate_up_proj.", 1), } + accelerator_type = get_torch_accelerator_type(self.merge_options.device) executor = Executor( tasks, - math_device="cuda" if self.merge_options.cuda else "cpu", - storage_device="cuda" if self.merge_options.cuda else "cpu", + math_device=( + self.merge_options.device + if accelerator_type in ["cuda", "xpu"] + else "cpu" + ), + storage_device=( + self.merge_options.device + if accelerator_type in ["cuda", "xpu"] + else "cpu" + ), ) for tensor_task, value in executor.run(quiet=True): assert isinstance(tensor_task, ReturnTensor) diff --git a/mergekit/evo/strategy.py b/mergekit/evo/strategy.py index b5f982f1..1ac9c0d9 100644 --- a/mergekit/evo/strategy.py +++ b/mergekit/evo/strategy.py @@ -15,6 +15,7 @@ import torch import transformers +from mergekit.common import get_torch_accelerator_count from mergekit.evo.actors import InMemoryMergeEvaluator, OnDiskMergeEvaluator from mergekit.evo.config import EvolMergeConfiguration from mergekit.evo.genome import ModelGenome @@ -37,7 +38,9 @@ def __init__( self.config = config self.genome = genome self.merge_options = merge_options - self.num_gpus = num_gpus or torch.cuda.device_count() + self.num_gpus = num_gpus or get_torch_accelerator_count( + self.merge_options.device + ) self.batch_size = batch_size self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path) self.model_storage_path = model_storage_path @@ -118,7 +121,9 @@ def __init__( self.genome = genome self.merge_options = merge_options self.vllm = vllm - self.num_gpus = num_gpus or torch.cuda.device_count() + self.num_gpus = num_gpus or get_torch_accelerator_count( + self.merge_options.device + ) self.input_queue = [] self.batch_size = batch_size self.task_manager = task_manager diff --git a/mergekit/graph.py b/mergekit/graph.py index bfde9c0b..cbdded1e 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -529,7 +529,7 @@ def _move_tensors( self, value: Any, device: torch.device, non_blocking: Optional[bool] = None ) -> Any: if non_blocking is None: - non_blocking = device.type == "cuda" + non_blocking = device.type in ["cuda", "xpu"] if isinstance(value, torch.Tensor): if value.device == device: return value diff --git a/mergekit/merge.py b/mergekit/merge.py index b4535de1..e978939b 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -77,8 +77,8 @@ def run_merge( else: exec = Executor( targets=targets, - math_device="cuda" if options.cuda else "cpu", - storage_device="cuda" if options.low_cpu_memory else "cpu", + math_device=options.device, + storage_device=options.device if options.low_cpu_memory else "cpu", ) tokenizer = None diff --git a/mergekit/multigpu_executor.py b/mergekit/multigpu_executor.py index 179b1cae..ba965a8e 100644 --- a/mergekit/multigpu_executor.py +++ b/mergekit/multigpu_executor.py @@ -21,6 +21,11 @@ import torch import tqdm +from .common import ( + get_torch_accelerator_count, + get_torch_accelerator_module, + get_torch_accelerator_type, +) from .graph import ( Executor, Task, @@ -75,9 +80,10 @@ def __init__( self.results: Dict[TaskHandle, Any] = {} self.storage_device = storage_device + self.accelerator_type = get_torch_accelerator_type() if num_gpus is None: - num_gpus = torch.cuda.device_count() - LOG.info(f"Using {num_gpus} GPUs for parallel execution") + num_gpus = get_torch_accelerator_count() + LOG.info(f"Using {num_gpus} {self.accelerator_type} for parallel execution") self.universe = TaskUniverse(targets) self.targets = set([self.universe.get_handle(t) for t in targets]) @@ -309,12 +315,14 @@ def _assign_islands_to_gpus( continue # don't need to sort, inner executor will handle island_tasks = [TaskHandle(self.universe, idx) for idx in island] - # assign to GPU with fewest tasks (load balancing) + # assign to accelerator with fewest tasks (load balancing) device_idx = min( range(num_gpus), - key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])), + key=lambda i: len( + assignments.get(torch.device(f"{self.accelerator_type}:{i}"), []) + ), ) - device = torch.device(f"cuda:{device_idx}") + device = torch.device(f"{self.accelerator_type}:{device_idx}") assignments[device] = assignments.get(device, []) + island_tasks return assignments @@ -339,9 +347,16 @@ def _device_worker( quiet: Whether to suppress progress bar output """ LOG.debug(f"Device {device} starting") + torch_accelerator_module = get_torch_accelerator_module(self.accelerator_type) with torch.device(device): - stream = torch.cuda.Stream(device=device) - with torch.cuda.stream(stream): + stream = ( + torch.Stream(device=device) + if self.accelerator_type == "xpu" + else torch.cuda.Stream(device=device) + ) + with ( + stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream) + ): exec = Executor( targets=task_list, math_device=device, @@ -358,5 +373,5 @@ def _device_worker( ): result = None self.task_completion_queue.put((task_handle._index, result)) - torch.cuda.synchronize(device=device) + torch_accelerator_module.synchronize(device=device) LOG.debug(f"Device {device} done") diff --git a/mergekit/options.py b/mergekit/options.py index ed9469b5..5b3907c0 100644 --- a/mergekit/options.py +++ b/mergekit/options.py @@ -21,6 +21,7 @@ class MergeOptions(BaseModel, frozen=True): lora_merge_cache: Optional[str] = None lora_merge_dtype: Optional[str] = None cuda: bool = False + device: Optional[str] = "cpu" low_cpu_memory: bool = False out_shard_size: int = parse_kmb("5B") copy_tokenizer: bool = True @@ -62,6 +63,28 @@ def handle_gpu_rich(cls, value): value["multi_gpu"] = True return value + @model_validator(mode="before") + def handle_device_setting(cls, value): + if not isinstance(value, dict): + return value + + # Set device to "cuda" if cuda is True and device is still at default + if value.get("cuda"): + value["device"] = "cuda" + + if value.get("device") is None: + value["device"] = "cpu" + + # Detect device automatically if `device` is set to "auto" + if value.get("device") == "auto": + if torch.cuda.is_available(): + value["device"] = "cuda" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + value["device"] = "xpu" + else: + value["device"] = "cpu" + return value + OPTION_HELP = { "allow_crimes": "Allow mixing architectures", @@ -69,7 +92,8 @@ def handle_gpu_rich(cls, value): "lora_merge_cache": "Path to store merged LORA models", "lora_merge_dtype": "Override dtype when applying LoRAs", "cuda": "Perform matrix arithmetic on GPU", - "low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM", + "device": "Perform matrix arithmetic on specified device", + "low_cpu_memory": "Store results and intermediate values on accelerator. Useful if VRAM > RAM", "out_shard_size": "Number of parameters per output shard [default: 5B]", "copy_tokenizer": "Copy a tokenizer to the output", "clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer", @@ -79,7 +103,7 @@ def handle_gpu_rich(cls, value): "write_model_card": "Output README.md containing details of the merge", "safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.", "quiet": "Suppress progress bars and other non-essential output", - "read_to_gpu": "Read model weights directly to GPU", + "read_to_gpu": "Read model weights directly to accelerator", "multi_gpu": "Use multi-gpu parallel graph execution engine", "num_threads": "Number of threads to use for parallel CPU operations", "verbosity": "Verbose logging (repeat for more verbosity)", @@ -96,6 +120,7 @@ def handle_gpu_rich(cls, value): "safe_serialization": "Output Settings", "lazy_unpickle": "Performance", "cuda": "Performance", + "device": "Performance", "low_cpu_memory": "Performance", "read_to_gpu": "Performance", "multi_gpu": "Performance", @@ -127,8 +152,12 @@ def wrapper(*args, **kwargs): if field_name in kwargs: arg_dict[field_name] = kwargs.pop(field_name) - kwargs["merge_options"] = MergeOptions(**arg_dict) - f(*args, **kwargs) + try: + kwargs["merge_options"] = MergeOptions(**arg_dict) + except Exception as e: + print(f"Error creating MergeOptions with args: {arg_dict}") + raise + return f(*args, **kwargs) for field_name, info in reversed(MergeOptions.model_fields.items()): origin = typing.get_origin(info.annotation) diff --git a/mergekit/plan.py b/mergekit/plan.py index 973bc69c..830b169a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -204,7 +204,7 @@ def plan_tensor( gather_tensors = GatherTensors( weight_info=ImmutableMap(data=dict(zip(models, weights_in))), dtype=self.config.dtype, - device="cuda" if self.options.read_to_gpu else None, + device=self.options.device if self.options.read_to_gpu else None, ) tensor_input_task = gather_tensors diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index 6aa97467..574e12e0 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -151,8 +151,10 @@ def main( else: executor = Executor( tasks, - math_device="cuda" if merge_options.cuda else "cpu", - storage_device="cuda" if merge_options.low_cpu_memory else "cpu", + math_device=merge_options.device, + storage_device=( + merge_options.device if merge_options.low_cpu_memory else "cpu" + ), ) module_real_ranks = {} diff --git a/mergekit/scripts/merge_raw_pytorch.py b/mergekit/scripts/merge_raw_pytorch.py index 9461c6d6..37a4f92b 100644 --- a/mergekit/scripts/merge_raw_pytorch.py +++ b/mergekit/scripts/merge_raw_pytorch.py @@ -248,9 +248,9 @@ def main( executor = Executor( tasks, - math_device="cuda" if merge_options.cuda else "cpu", + math_device=merge_options.device, storage_device=( - "cuda" if (merge_options.cuda and merge_options.low_cpu_memory) else "cpu" + merge_options.device if merge_options.low_cpu_memory else "cpu" ), ) executor.execute() diff --git a/mergekit/scripts/multimerge.py b/mergekit/scripts/multimerge.py index 99fd198a..c87348f0 100644 --- a/mergekit/scripts/multimerge.py +++ b/mergekit/scripts/multimerge.py @@ -117,7 +117,7 @@ def main( executor = Executor( tasks, math_device="cpu", storage_device="cpu" - ) # inner executors will handle cuda + ) # inner executors will handle accelerator executor.execute(desc="Merging models") diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py index 791a0640..04811109 100644 --- a/mergekit/scripts/tokensurgeon.py +++ b/mergekit/scripts/tokensurgeon.py @@ -88,7 +88,7 @@ def main( cache = LoaderCache() cache.setup(options=merge_options) - device = "cuda" if merge_options.cuda else "cpu" + device = merge_options.device arch_info, donor_cfg = validate_architecture(model, donor, merge_options) embed_info, lm_head_info = get_embedding_info(model, merge_options)