Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 93f2de8

Browse files
ATen/Dlpack functions -> aten/utils*h
1 parent 0258d2d commit 93f2de8

File tree

4 files changed

+88
-40
lines changed

4 files changed

+88
-40
lines changed

include/tc/aten/aten_compiler.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
#include <ATen/ATen.h>
2323
#include <ATen/DLConvertor.h>
2424

25+
#include "tc/aten/utils.h"
2526
#include "tc/core/cuda/cuda.h"
2627
#include "tc/core/cuda/cuda_tc_executor.h"
2728
#include "tc/core/execution_engine.h"
2829
#include "tc/lang/parser.h"
2930

3031
namespace tc {
31-
3232
/// This provides the basic interface for writing ATen style tensor operations
3333
/// based on Tensor Comprehensions.
3434

@@ -74,13 +74,4 @@ class ATenCompilationUnit {
7474
private:
7575
std::unique_ptr<ExecutionEngine<CudaTcExecutor>> executionEngine_;
7676
};
77-
78-
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
79-
toDlpackTensors(const std::vector<at::Tensor>& tensors);
80-
81-
std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
82-
toConstDlpackTensors(const std::vector<at::Tensor>& tensors);
83-
84-
void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors);
85-
8677
} // namespace tc

include/tc/aten/utils-inl.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <string>
19+
#include <vector>
20+
21+
#include <ATen/ATen.h>
22+
#include <ATen/DLConvertor.h>
23+
namespace tc {
24+
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
25+
toDlpackTensors(const std::vector<at::Tensor>& tensors) {
26+
std::vector<DLTensor*> dlTensors;
27+
std::vector<DLManagedTensor*> dlMTensors;
28+
for (auto tensor : tensors) {
29+
auto dlMTensor = at::toDLPack(tensor);
30+
dlTensors.push_back(&(dlMTensor->dl_tensor));
31+
dlMTensors.push_back(dlMTensor);
32+
}
33+
return make_pair(dlTensors, dlMTensors);
34+
}
35+
36+
std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
37+
toConstDlpackTensors(const std::vector<at::Tensor>& tensors) {
38+
std::vector<const DLTensor*> dlTensors;
39+
std::vector<DLManagedTensor*> dlMTensors;
40+
for (auto tensor : tensors) {
41+
auto dlMTensor = at::toDLPack(tensor);
42+
dlTensors.push_back(&(dlMTensor->dl_tensor));
43+
dlMTensors.push_back(dlMTensor);
44+
}
45+
return make_pair(dlTensors, dlMTensors);
46+
}
47+
48+
void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors) {
49+
for (auto& tensor : tensors) {
50+
tensor->deleter(tensor);
51+
}
52+
}
53+
} // namespace tc

include/tc/aten/utils.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <string>
19+
#include <vector>
20+
21+
#include <ATen/ATen.h>
22+
#include <ATen/DLConvertor.h>
23+
24+
namespace tc {
25+
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
26+
toDlpackTensors(const std::vector<at::Tensor>& tensors);
27+
28+
std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
29+
toConstDlpackTensors(const std::vector<at::Tensor>& tensors);
30+
31+
void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors);
32+
} // namespace tc
33+
34+
#include "tc/aten/utils-inl.h"

src/aten/aten_compiler.cc

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ void ATenCompilationUnit::define(const std::string& language) {
3232
executionEngine_->define(language);
3333
}
3434

35-
void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors) {
36-
for (auto& tensor : tensors) {
37-
tensor->deleter(tensor);
38-
}
39-
}
40-
4135
namespace {
4236

4337
// given the tensor shape and DLType, allocate storage for the tensor output
@@ -73,30 +67,6 @@ void prepareOutputs(
7367

7468
} // namespace
7569

76-
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
77-
toDlpackTensors(const std::vector<at::Tensor>& tensors) {
78-
std::vector<DLTensor*> dlTensors;
79-
std::vector<DLManagedTensor*> dlMTensors;
80-
for (auto tensor : tensors) {
81-
auto dlMTensor = at::toDLPack(tensor);
82-
dlTensors.push_back(&(dlMTensor->dl_tensor));
83-
dlMTensors.push_back(dlMTensor);
84-
}
85-
return make_pair(dlTensors, dlMTensors);
86-
}
87-
88-
std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
89-
toConstDlpackTensors(const std::vector<at::Tensor>& tensors) {
90-
std::vector<const DLTensor*> dlTensors;
91-
std::vector<DLManagedTensor*> dlMTensors;
92-
for (auto tensor : tensors) {
93-
auto dlMTensor = at::toDLPack(tensor);
94-
dlTensors.push_back(&(dlMTensor->dl_tensor));
95-
dlMTensors.push_back(dlMTensor);
96-
}
97-
return make_pair(dlTensors, dlMTensors);
98-
}
99-
10070
size_t ATenCompilationUnit::compile(
10171
const std::string& name,
10272
const std::vector<at::Tensor>& inputs,

0 commit comments

Comments
 (0)