Skip to content

GPU support - Design feedback requested #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@ num = "0.4"
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
wgpu = { version = "25.0.2", optional = true }
pollster = { version = "0.4.0", optional = true }
bytemuck = { version = "1.23.0", optional = true }
lazy_static = { version = "1.5.0", optional = true }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }

[features]
default = []
default = ["gpu"]
serde = ["dep:serde", "dep:typetag"]
ndarray-bindings = ["dep:ndarray"]
datasets = ["dep:rand_distr", "std_rand", "serde"]
std_rand = ["rand/std_rng", "rand/std"]
# used by wasm32-unknown-unknown for in-browser usage
js = ["getrandom/js"]
gpu = ["wgpu", "pollster", "bytemuck", "lazy_static"]

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", optional = true }
Expand Down
37 changes: 37 additions & 0 deletions src/error/gpu_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

use std::fmt;

#[derive(Debug)]
pub enum GpuError {
NoAdapter(String),
NoShader,
NoDevice(String),
InvalidWorkgroupSize,
MutexLock(String),
WorkerConversion,
ParamsBufferNotFound,
Generic(String)
}

impl std::error::Error for GpuError {}
impl fmt::Display for GpuError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoAdapter(err) => write!(f, "Unable to create GPU adapter, error: {}", err),
Self::NoShader => write!(f, "GPU adapter does not support computer shaders."),
Self::NoDevice(err) => write!(f, "Unable to create device on GPU, error: {}", err),
Self::InvalidWorkgroupSize => write!(f, "Workgroup size must be 64, 128, 256, 512 or 1024"),
Self::MutexLock(msg) => write!(f, "Unable to lock mutex: {}", msg),
Self::WorkerConversion => write!(f, "Unable to convert into GpuWorker"),
Self::ParamsBufferNotFound => write!(f, "Unable to update params buffer, as there doesn't appear to be a params buffer in this worker!"),
Self::Generic(msg) => write!(f, "{}", msg),
}
}
}

impl From<std::io::Error> for GpuError {
fn from(err: std::io::Error) -> Self {
GpuError::Generic(err.to_string())
}
}

6 changes: 6 additions & 0 deletions src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
use std::error::Error;
use std::fmt;

#[cfg(feature = "gpu")]
pub use self::gpu_error::GpuError;

#[cfg(feature = "gpu")]
mod gpu_error;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down
46 changes: 46 additions & 0 deletions src/gpu/adapter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@


use wgpu::util::DeviceExt;
use crate::error::GpuError;

#[derive(Clone)]
pub struct GpuAdapter {
adapter: wgpu::Adapter,
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub max_workgroup_size: u32
}

impl GpuAdapter {
pub fn new() -> Result<Self, GpuError> {

let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions::default()))
.map_err(|e| GpuError::NoAdapter(e.to_string()) )?;
println!("Running on Adapter: {:#?}", adapter.get_info());

// Ensure adapter supports a compute shader
let downlevel_capabilities = adapter.get_downlevel_capabilities();
if !downlevel_capabilities.flags.contains(wgpu::DownlevelFlags::COMPUTE_SHADERS) {
return Err(GpuError::NoShader);
}

// Create device and queue
let (device, queue) = pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: wgpu::MemoryHints::MemoryUsage,
trace: wgpu::Trace::Off,
}))
.map_err(|e| GpuError::NoDevice(e.to_string()))?;

// Get limits
let limits = device.limits();
let max_workgroup_size: u32 = limits.max_compute_invocations_per_workgroup;

Ok( Self { adapter, device, queue, max_workgroup_size })
}
}


5 changes: 5 additions & 0 deletions src/gpu/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@



pub enum GpuAlgorithm {

93 changes: 93 additions & 0 deletions src/gpu/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

use wgpu::util::DeviceExt;
use crate::numbers::basenum::Number;
use super::{GpuAdapter, GpuParams, GpuMatrix};

#[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub enum GpuBuffer {
Samples,
Targets,
Weights,
TempStorage,
Params,
Download
}

impl GpuBuffer {
pub fn included_in_bind_group(&self) -> bool {
match self {
GpuBuffer::Download => false,
_ => true
}
}

pub fn is_read_only(&self) -> bool {
match self {
GpuBuffer::Weights => false,
GpuBuffer::TempStorage => false,
_ => true
}
}
}


pub fn create_samples(adapter: &GpuAdapter, matrix: &GpuMatrix) -> wgpu::Buffer {
adapter.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Samples"),
contents: bytemuck::cast_slice(&matrix.data),
usage: wgpu::BufferUsages::STORAGE,
})
}

pub fn create_targets<T>(adapter: &GpuAdapter, targets: &Vec<T>) -> wgpu::Buffer
where T: Number + Ord
{

let u_targets = targets.iter().filter_map(|&val| val.to_u32()).collect::<Vec<u32>>();
adapter.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Targets"),
contents: bytemuck::cast_slice(&u_targets),
usage: wgpu::BufferUsages::STORAGE,
})
}

pub fn create_weights(adapter: &GpuAdapter, num_features: usize) -> wgpu::Buffer {
let zeros = vec![0.0f32; num_features];
adapter.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Weights"),
contents: bytemuck::cast_slice(&zeros),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
})
}


pub fn create_temp_storage(adapter: &GpuAdapter, buffer_size: u64) -> wgpu::Buffer {
adapter.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Temp Storage"),
size: buffer_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}

pub fn create_params(adapter: &GpuAdapter, data: &Vec<u32>) -> wgpu::Buffer {

adapter.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Params"),
contents: bytemuck::cast_slice(&data.as_slice()),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
})
}

pub fn create_download(adapter: &GpuAdapter, buffer_size: u64) -> wgpu::Buffer {
adapter.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Download"),
size: buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
}




83 changes: 83 additions & 0 deletions src/gpu/layout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

use super::{GpuAdapter, GpuBuffer};

#[derive(Default, Copy, Clone, Eq, PartialEq, Hash)]
pub enum GpuLayout {
#[default]
Supervised, // Samples + Targets + Weights + TempStorage + Params
//Clustering, // Samples + Centroids + Assignments + Params
//Decomposition, // Samples + Vectors + Values + TempStorage + Params
}

#[derive(Clone)]
pub struct GpuResourceLayout {
pub bind_group_layout: wgpu::BindGroupLayout,
pub pipeline_layout: wgpu::PipelineLayout,
}

impl GpuLayout {
pub fn create_resource_layout(&self, adapter: &GpuAdapter) -> GpuResourceLayout {

// Bind group layout
let bind_group_layout = self.create_bind_group_layout(&adapter);

// Pipeline layout
let pipeline_layout = adapter.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});

GpuResourceLayout {
bind_group_layout,
pipeline_layout
}
}

pub fn get_buffer_templates(&self) -> Vec<GpuBuffer> {
match self {
GpuLayout::Supervised => vec![GpuBuffer::Samples, GpuBuffer::Targets, GpuBuffer::Weights, GpuBuffer::TempStorage, GpuBuffer::Params],
}
}

pub fn get_buffer_index(&self, buffer: GpuBuffer) -> Option<usize> {
let templates = self.get_buffer_templates();
templates.iter().position(|&buf| buf == buffer)
}

fn create_bind_group_layout(&self, adapter: &GpuAdapter) -> wgpu::BindGroupLayout {

let templates = self.get_buffer_templates();
let mut layout_entries: Vec<wgpu::BindGroupLayoutEntry> = Vec::new();

let mut binding_num = 0;
for template in templates.iter() {

if !template.included_in_bind_group() {
continue;
}

layout_entries.push( wgpu::BindGroupLayoutEntry {
binding: binding_num,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: template.is_read_only() },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});

binding_num += 1;
}

let layout = adapter.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &layout_entries.as_slice()
});

layout
}
}


52 changes: 52 additions & 0 deletions src/gpu/matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

use super::STATION;

#[derive(Default)]
pub struct GpuMatrix {
pub rows: usize,
pub cols: usize,
pub data: Vec<f32>
}

impl GpuMatrix {
pub fn get_workgroup_size(&self) -> usize {
let mut size: usize = match self.cols {
c if c <= 64 => 64,
c if c <= 128 => 128,
c if c <= 256 => 256,
c if c <= 512 => 512,
_ => 1024
};

if let Ok(adapter) = STATION.get_adapter() {
if size > adapter.max_workgroup_size as usize {
size = adapter.max_workgroup_size as usize;
}
}

size
}
}

impl From<Vec<Vec<f32>>> for GpuMatrix {
fn from(data: Vec<Vec<f32>>) -> Self {
Self {
rows: data.len(),
cols: data[0].len(),
data: data.into_iter().flatten().collect()
}
}
}

impl From<&Vec<Vec<f32>>> for GpuMatrix {
fn from(data: &Vec<Vec<f32>>) -> Self {
Self {
rows: data.len(),
cols: data[0].len(),
data: data.clone().into_iter().flatten().collect()
}
}
}



38 changes: 38 additions & 0 deletions src/gpu/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

use crate::error::GpuError;

use lazy_static::lazy_static;

lazy_static! {
pub static ref STATION: GpuStation = GpuStation::new();
}

pub use self::adapter::GpuAdapter;
pub use self::buffer::GpuBuffer;
pub use self::layout::{GpuLayout, GpuResourceLayout};
pub use self::matrix::GpuMatrix;
pub use self::params::GpuParams;
pub use self::station::{GpuStation, GpuWorkgroup};
pub use self::worker::GpuWorker;

mod adapter;
pub mod buffer;
mod layout;
mod matrix;
pub mod models;
mod params;
mod station;
mod worker;

#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub enum GpuAlgorithm {
LogisticRegressionGradientDescentBinaryClassification,
LogisticRegressionGradientDescentMultiClassification,
}
pub trait GpuModule {
fn get_params(&self, matrix: &GpuMatrix, num_classes: usize) -> Result<GpuParams, GpuError>;
fn get_wgsl_code(&self, matrix: &GpuMatrix, params: &GpuParams) -> String;
fn get_params_buffer_data(&self, params: &GpuParams) -> Vec<u32>;
}


Loading