Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions backends/cuda-ref/ceed-cuda-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
#include <string.h>
#include "ceed-cuda-ref.h"


//------------------------------------------------------------------------------
// Check if host/device sync is needed
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync host to device
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
//------------------------------------------------------------------------------
// Sync arrays
//------------------------------------------------------------------------------
static inline int CeedVectorSync_Cuda(const CeedVector vec,
CeedMemType mem_type) {
static int CeedVectorSyncArray_Cuda(const CeedVector vec,
CeedMemType mem_type) {
int ierr;
// Check whether device/host sync is needed
bool need_sync = false;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync);
CeedChkBackend(ierr);
if (!need_sync)
return CEED_ERROR_SUCCESS;

switch (mem_type) {
case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec);
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec);
Expand Down Expand Up @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec,
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Check if is any array of given type
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Set array from host
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

// Sync array to requested mem_type
bool need_sync = false;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
if (need_sync) {
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
}
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool need_sync = false, has_array_of_type = true;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type);
CeedChkBackend(ierr);
if (need_sync) {
// Sync array to requested mem_type
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
}
// Sync array to requested mem_type
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
(int (*)())(CeedVectorSetValue_Cuda));
CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
CeedVectorSyncArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
CeedVectorGetArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
Expand Down
76 changes: 39 additions & 37 deletions backends/hip-ref/ceed-hip-ref-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
#include <string.h>
#include "ceed-hip-ref.h"


//------------------------------------------------------------------------------
// Check if host/device sync is needed
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync host to device
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
//------------------------------------------------------------------------------
// Sync arrays
//------------------------------------------------------------------------------
static inline int CeedVectorSync_Hip(const CeedVector vec,
CeedMemType mem_type) {
static int CeedVectorSyncArray_Hip(const CeedVector vec,
CeedMemType mem_type) {
int ierr;
// Check whether device/host sync is needed
bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync);
CeedChkBackend(ierr);
if (!need_sync)
return CEED_ERROR_SUCCESS;

switch (mem_type) {
case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
Expand Down Expand Up @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync array of given type
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Set array from host
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -363,11 +372,7 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

// Sync array to requested mem_type
bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
if (need_sync) {
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
}
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -398,13 +403,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
CeedChkBackend(ierr);
if (need_sync) {
// Sync array to requested mem_type
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
}
// Sync array to requested mem_type
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -758,6 +758,8 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
(int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
CeedVectorSyncArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
CeedVectorGetArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
Expand Down
1 change: 1 addition & 0 deletions interface/ceed.c
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ int CeedInit(const char *resource, Ceed *ceed) {
CEED_FTABLE_ENTRY(CeedVector, SetArray),
CEED_FTABLE_ENTRY(CeedVector, TakeArray),
CEED_FTABLE_ENTRY(CeedVector, SetValue),
CEED_FTABLE_ENTRY(CeedVector, SyncArray),
CEED_FTABLE_ENTRY(CeedVector, GetArray),
CEED_FTABLE_ENTRY(CeedVector, GetArrayRead),
CEED_FTABLE_ENTRY(CeedVector, GetArrayWrite),
Expand Down