Skip to content

Add GPU Acceleration for tabular model using pytorch MPS (Metal Performance Shaders) for macOS #62

@radurogojanumai

Description

@radurogojanumai

Description

Currently, mostlyai-engine does not leverage GPU acceleration on macOS. While CUDA-based acceleration is available for Linux, macOS users are limited to CPU-bound operations. To improve performance for macOS users, we should add support for PyTorch’s Metal Performance Shaders (MPS) backend, which enables GPU acceleration on Apple Silicon (eg. M1/M2/M3 or future Macs).

Proposed Solution

  1. Enable PyTorch MPS backend detection
  • Check if PyTorch is installed with MPS support (torch.backends.mps.is_available()).

  • If MPS is available, ensure models and tensors are correctly moved to the MPS device

  1. Update Installation & Dependencies
  • Ensure torch>=1.13 is installed, as MPS support is available from this version onwards.

  • Add documentation for macOS users on installing PyTorch with MPS support.

  1. Modify Training & Inference Pipelines
  • Adapt existing PyTorch calls to dynamically select the best available backend (mps, cuda, or cpu).

  • Ensure compatibility with QLoRA and bitsandbytes (fallback to CPU if MPS does not support certain operations).

  1. Performance Benchmarking & Validation
  • Compare training/inference speeds using MPS vs. CPU.

  • Identify any limitations or unsupported operations within MPS that may require fallbacks.

Questions

  1. Should we introduce an extra[mps] option for macOS users to explicitly enable MPS-related dependencies?
    Answer: We would want to keep a simple set of extras (eg. [gpu] for both Linux + CUDA and Darwin + MPS)
  2. How well does bitsandbytes integrate with Darwin?
    Answer: We'll ensure the required version has the necessary wheels.

Acceptance Criteria

  • Mac users can utilize MPS acceleration via PyTorch without modifying code manually.
  • Performance improvements over CPU-only execution are verified.
  • No breaking changes for Linux users.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions