Skip to content

extend mergekit to make it work on xpu #580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,43 @@ def get_auto_cls(arch_name: str) -> AutoClassProtocol:
)
auto_cls = transformers.AutoModelForCausalLM
return auto_cls


def get_torch_accelerator_module(accelerator_name: Optional[str] = None):
if accelerator_name is not None:
accelerator_type = torch.device(accelerator_name).type
return getattr(torch, accelerator_type)
else:
return (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)


def get_torch_accelerator_count(accelerator_name: Optional[str] = None):
torch_accelerator_module = torch.cuda
if accelerator_name is not None:
accelerator = torch.device(accelerator_name)
# if user passes the device index in `accelerator_name`, then 1
if accelerator.index != None:
return 1
torch_accelerator_module = getattr(torch, accelerator.type)
else:
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
return torch_accelerator_module.device_count()


def get_torch_accelerator_type(accelerator_name: Optional[str] = None):
if accelerator_name is not None:
return torch.device(accelerator_name).type
else:
return (
torch.accelerator.current_accelerator().type
if hasattr(torch, "accelerator")
else "cuda"
)
28 changes: 21 additions & 7 deletions mergekit/evo/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


from mergekit.architecture import arch_info_for_config
from mergekit.common import get_torch_accelerator_module, get_torch_accelerator_type
from mergekit.config import MergeConfiguration
from mergekit.evo.config import EvolMergeConfiguration
from mergekit.evo.genome import InvalidGenotypeError, ModelGenome
Expand Down Expand Up @@ -90,7 +91,10 @@ def evaluate_genotype(
genotype: torch.Tensor,
) -> dict:
gc.collect()
torch.cuda.empty_cache()
torch_accelerator_module = get_torch_accelerator_module(
self.merge_options.device
)
torch_accelerator_module.empty_cache()
LOG.info("Merging model")
merged_path = merge_model(
genotype, self.genome, self.model_storage_path, self.merge_options
Expand Down Expand Up @@ -190,7 +194,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
**model_kwargs,
)
.bfloat16()
.cuda()
.to(self.merge_options.device)
.eval()
.requires_grad_(False)
)
Expand Down Expand Up @@ -226,16 +230,17 @@ def _maybe_init_model(self, config: MergeConfiguration):
max_model_len = 8192
LOG.warning(f"Clipping sequence length to {max_model_len}")

accelerator_type = get_torch_accelerator_type(self.merge_options.device)
mem_util = (
0.7 if self.merge_options.cuda else 0.9
) # reduce memory usage if we're also using cuda for the merge
0.7 if accelerator_type in ["cuda", "xpu"] else 0.9
) # reduce memory usage if we're also using accelerator for the merge
self.model = lm_eval.models.vllm_causallms.VLLM(
pretrained=tempdir,
batch_size=self.batch_size or "auto",
max_model_len=max_model_len,
gpu_memory_utilization=mem_util,
dtype="bfloat16",
device="cuda",
device=self.merge_options.device,
trust_remote_code=self.merge_options.trust_remote_code,
)
else:
Expand Down Expand Up @@ -292,10 +297,19 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
".up_proj.": (".gate_up_proj.", 1),
}

accelerator_type = get_torch_accelerator_type(self.merge_options.device)
executor = Executor(
tasks,
math_device="cuda" if self.merge_options.cuda else "cpu",
storage_device="cuda" if self.merge_options.cuda else "cpu",
math_device=(
self.merge_options.device
if accelerator_type in ["cuda", "xpu"]
else "cpu"
),
storage_device=(
self.merge_options.device
if accelerator_type in ["cuda", "xpu"]
else "cpu"
),
)
for tensor_task, value in executor.run(quiet=True):
assert isinstance(tensor_task, ReturnTensor)
Expand Down
9 changes: 7 additions & 2 deletions mergekit/evo/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import transformers

from mergekit.common import get_torch_accelerator_count
from mergekit.evo.actors import InMemoryMergeEvaluator, OnDiskMergeEvaluator
from mergekit.evo.config import EvolMergeConfiguration
from mergekit.evo.genome import ModelGenome
Expand All @@ -37,7 +38,9 @@ def __init__(
self.config = config
self.genome = genome
self.merge_options = merge_options
self.num_gpus = num_gpus or torch.cuda.device_count()
self.num_gpus = num_gpus or get_torch_accelerator_count(
self.merge_options.device
)
self.batch_size = batch_size
self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path)
self.model_storage_path = model_storage_path
Expand Down Expand Up @@ -118,7 +121,9 @@ def __init__(
self.genome = genome
self.merge_options = merge_options
self.vllm = vllm
self.num_gpus = num_gpus or torch.cuda.device_count()
self.num_gpus = num_gpus or get_torch_accelerator_count(
self.merge_options.device
)
self.input_queue = []
self.batch_size = batch_size
self.task_manager = task_manager
Expand Down
2 changes: 1 addition & 1 deletion mergekit/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _move_tensors(
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
) -> Any:
if non_blocking is None:
non_blocking = device.type == "cuda"
non_blocking = device.type in ["cuda", "xpu"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if isinstance(value, torch.Tensor):
if value.device == device:
return value
Expand Down
4 changes: 2 additions & 2 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def run_merge(
else:
exec = Executor(
targets=targets,
math_device="cuda" if options.cuda else "cpu",
storage_device="cuda" if options.low_cpu_memory else "cpu",
math_device=options.device,
storage_device=options.device if options.low_cpu_memory else "cpu",
)

tokenizer = None
Expand Down
31 changes: 23 additions & 8 deletions mergekit/multigpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import torch
import tqdm

from .common import (
get_torch_accelerator_count,
get_torch_accelerator_module,
get_torch_accelerator_type,
)
from .graph import (
Executor,
Task,
Expand Down Expand Up @@ -75,9 +80,10 @@ def __init__(
self.results: Dict[TaskHandle, Any] = {}
self.storage_device = storage_device

self.accelerator_type = get_torch_accelerator_type()
if num_gpus is None:
num_gpus = torch.cuda.device_count()
LOG.info(f"Using {num_gpus} GPUs for parallel execution")
num_gpus = get_torch_accelerator_count()
LOG.info(f"Using {num_gpus} {self.accelerator_type} for parallel execution")

self.universe = TaskUniverse(targets)
self.targets = set([self.universe.get_handle(t) for t in targets])
Expand Down Expand Up @@ -309,12 +315,14 @@ def _assign_islands_to_gpus(
continue
# don't need to sort, inner executor will handle
island_tasks = [TaskHandle(self.universe, idx) for idx in island]
# assign to GPU with fewest tasks (load balancing)
# assign to accelerator with fewest tasks (load balancing)
device_idx = min(
range(num_gpus),
key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])),
key=lambda i: len(
assignments.get(torch.device(f"{self.accelerator_type}:{i}"), [])
),
)
device = torch.device(f"cuda:{device_idx}")
device = torch.device(f"{self.accelerator_type}:{device_idx}")
assignments[device] = assignments.get(device, []) + island_tasks
return assignments

Expand All @@ -339,9 +347,16 @@ def _device_worker(
quiet: Whether to suppress progress bar output
"""
LOG.debug(f"Device {device} starting")
torch_accelerator_module = get_torch_accelerator_module(self.accelerator_type)
with torch.device(device):
stream = torch.cuda.Stream(device=device)
with torch.cuda.stream(stream):
stream = (
torch.Stream(device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just use torch.Stream in all cases?

Copy link
Contributor Author

@yao-matrix yao-matrix May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can. The only concern is torch.Stream is only available from 2.5, but mergekit's torch dependency is >= 2.0 now. So if users installs torch 2.4, it will crash. Pls let me know your insights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, fair! This is good for now then. I'll look at it if I bump the minimum torch version in the future.

if self.accelerator_type == "xpu"
else torch.cuda.Stream(device=device)
)
with (
stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream)
):
exec = Executor(
targets=task_list,
math_device=device,
Expand All @@ -358,5 +373,5 @@ def _device_worker(
):
result = None
self.task_completion_queue.put((task_handle._index, result))
torch.cuda.synchronize(device=device)
torch_accelerator_module.synchronize(device=device)
LOG.debug(f"Device {device} done")
37 changes: 33 additions & 4 deletions mergekit/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MergeOptions(BaseModel, frozen=True):
lora_merge_cache: Optional[str] = None
lora_merge_dtype: Optional[str] = None
cuda: bool = False
device: Optional[str] = "cpu"
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
copy_tokenizer: bool = True
Expand Down Expand Up @@ -62,14 +63,37 @@ def handle_gpu_rich(cls, value):
value["multi_gpu"] = True
return value

@model_validator(mode="before")
def handle_device_setting(cls, value):
if not isinstance(value, dict):
return value

# Set device to "cuda" if cuda is True and device is still at default
if value.get("cuda"):
value["device"] = "cuda"

if value.get("device") is None:
value["device"] = "cpu"

# Detect device automatically if `device` is set to "auto"
if value.get("device") == "auto":
if torch.cuda.is_available():
value["device"] = "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
value["device"] = "xpu"
else:
value["device"] = "cpu"
return value


OPTION_HELP = {
"allow_crimes": "Allow mixing architectures",
"transformers_cache": "Override storage path for downloaded models",
"lora_merge_cache": "Path to store merged LORA models",
"lora_merge_dtype": "Override dtype when applying LoRAs",
"cuda": "Perform matrix arithmetic on GPU",
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM",
"device": "Perform matrix arithmetic on specified device",
"low_cpu_memory": "Store results and intermediate values on accelerator. Useful if VRAM > RAM",
"out_shard_size": "Number of parameters per output shard [default: 5B]",
"copy_tokenizer": "Copy a tokenizer to the output",
"clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer",
Expand All @@ -79,7 +103,7 @@ def handle_gpu_rich(cls, value):
"write_model_card": "Output README.md containing details of the merge",
"safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.",
"quiet": "Suppress progress bars and other non-essential output",
"read_to_gpu": "Read model weights directly to GPU",
"read_to_gpu": "Read model weights directly to accelerator",
"multi_gpu": "Use multi-gpu parallel graph execution engine",
"num_threads": "Number of threads to use for parallel CPU operations",
"verbosity": "Verbose logging (repeat for more verbosity)",
Expand All @@ -96,6 +120,7 @@ def handle_gpu_rich(cls, value):
"safe_serialization": "Output Settings",
"lazy_unpickle": "Performance",
"cuda": "Performance",
"device": "Performance",
"low_cpu_memory": "Performance",
"read_to_gpu": "Performance",
"multi_gpu": "Performance",
Expand Down Expand Up @@ -127,8 +152,12 @@ def wrapper(*args, **kwargs):
if field_name in kwargs:
arg_dict[field_name] = kwargs.pop(field_name)

kwargs["merge_options"] = MergeOptions(**arg_dict)
f(*args, **kwargs)
try:
kwargs["merge_options"] = MergeOptions(**arg_dict)
except Exception as e:
print(f"Error creating MergeOptions with args: {arg_dict}")
raise
return f(*args, **kwargs)

for field_name, info in reversed(MergeOptions.model_fields.items()):
origin = typing.get_origin(info.annotation)
Expand Down
2 changes: 1 addition & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def plan_tensor(
gather_tensors = GatherTensors(
weight_info=ImmutableMap(data=dict(zip(models, weights_in))),
dtype=self.config.dtype,
device="cuda" if self.options.read_to_gpu else None,
device=self.options.device if self.options.read_to_gpu else None,
)

tensor_input_task = gather_tensors
Expand Down
6 changes: 4 additions & 2 deletions mergekit/scripts/extract_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def main(
else:
executor = Executor(
tasks,
math_device="cuda" if merge_options.cuda else "cpu",
storage_device="cuda" if merge_options.low_cpu_memory else "cpu",
math_device=merge_options.device,
storage_device=(
merge_options.device if merge_options.low_cpu_memory else "cpu"
),
)

module_real_ranks = {}
Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/merge_raw_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def main(

executor = Executor(
tasks,
math_device="cuda" if merge_options.cuda else "cpu",
math_device=merge_options.device,
storage_device=(
"cuda" if (merge_options.cuda and merge_options.low_cpu_memory) else "cpu"
merge_options.device if merge_options.low_cpu_memory else "cpu"
),
)
executor.execute()
2 changes: 1 addition & 1 deletion mergekit/scripts/multimerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(

executor = Executor(
tasks, math_device="cpu", storage_device="cpu"
) # inner executors will handle cuda
) # inner executors will handle accelerator
executor.execute(desc="Merging models")


Expand Down
2 changes: 1 addition & 1 deletion mergekit/scripts/tokensurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main(
cache = LoaderCache()
cache.setup(options=merge_options)

device = "cuda" if merge_options.cuda else "cpu"
device = merge_options.device

arch_info, donor_cfg = validate_architecture(model, donor, merge_options)
embed_info, lm_head_info = get_embedding_info(model, merge_options)
Expand Down
Loading