11
11
#include < folly/coro/Collect.h>
12
12
#include < algorithm>
13
13
#include " common/time/Time.h"
14
+ #ifdef FBGEMM_USE_GPU
14
15
#include " kv_db_cuda_utils.h"
16
+ #endif
15
17
#include " torch/csrc/autograd/record_function_ops.h"
16
18
#ifdef FBGEMM_FBCODE
17
19
#include < folly/stop_watch.h>
@@ -411,6 +413,7 @@ void EmbeddingKVDB::get_cuda(
411
413
const at::Tensor& indices,
412
414
const at::Tensor& weights,
413
415
const at::Tensor& count) {
416
+ #ifdef FBGEMM_USE_GPU
414
417
auto rec = torch::autograd::profiler::record_function_enter_new (
415
418
" ## EmbeddingKVDB::get_cuda ##" );
416
419
check_tensor_type_consistency (indices, weights);
@@ -424,6 +427,7 @@ void EmbeddingKVDB::get_cuda(
424
427
functor,
425
428
0 ));
426
429
rec->record .end ();
430
+ #endif
427
431
}
428
432
429
433
void EmbeddingKVDB::set_cuda (
@@ -432,6 +436,7 @@ void EmbeddingKVDB::set_cuda(
432
436
const at::Tensor& count,
433
437
const int64_t timestep,
434
438
const bool is_bwd) {
439
+ #ifdef FBGEMM_USE_GPU
435
440
auto rec = torch::autograd::profiler::record_function_enter_new (
436
441
" ## EmbeddingKVDB::set_cuda ##" );
437
442
check_tensor_type_consistency (indices, weights);
@@ -447,14 +452,15 @@ void EmbeddingKVDB::set_cuda(
447
452
functor,
448
453
0 ));
449
454
rec->record .end ();
455
+ #endif
450
456
}
451
457
452
458
void EmbeddingKVDB::stream_cuda (
453
459
const at::Tensor& indices,
454
460
const at::Tensor& weights,
455
461
const at::Tensor& count,
456
462
bool blocking_tensor_copy) {
457
- #ifdef FBGEMM_FBCODE
463
+ #ifdef FBGEMM_USE_GPU
458
464
auto rec = torch::autograd::profiler::record_function_enter_new (
459
465
" ## EmbeddingKVDB::stream_cuda ##" );
460
466
check_tensor_type_consistency (indices, weights);
@@ -472,7 +478,7 @@ void EmbeddingKVDB::stream_cuda(
472
478
}
473
479
474
480
void EmbeddingKVDB::stream_sync_cuda () {
475
- #ifdef FBGEMM_FBCODE
481
+ #ifdef FBGEMM_USE_GPU
476
482
auto rec = torch::autograd::profiler::record_function_enter_new (
477
483
" ## EmbeddingKVDB::stream_sync_cuda ##" );
478
484
// take reference to self to avoid lifetime issues.
0 commit comments