Skip to content

Commit be027c4

Browse files
committed
add stream.all_reduce API and ProcessGroupStream
1 parent de436f0 commit be027c4

18 files changed

+691
-14
lines changed

paddle/fluid/distributed/collective/CMakeLists.txt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ cc_library(
22
processgroup
33
SRCS ProcessGroup.cc
44
DEPS dense_tensor)
5+
cc_library(
6+
processgroup_stream
7+
SRCS ProcessGroupStream.cc
8+
DEPS dense_tensor)
59
cc_library(
610
eager_reducer
711
SRCS reducer.cc
8-
DEPS eager_api processgroup phi_api string_helper)
12+
DEPS eager_api processgroup processgroup_stream phi_api string_helper)
913

1014
if(WITH_DISTRIBUTE)
1115
cc_library(
@@ -18,7 +22,12 @@ if(WITH_NCCL OR WITH_RCCL)
1822
cc_library(
1923
processgroup_nccl
2024
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
21-
DEPS processgroup place enforce collective_helper device_context
25+
DEPS processgroup
26+
processgroup_stream
27+
place
28+
enforce
29+
collective_helper
30+
device_context
2231
dense_tensor)
2332
if(WITH_DISTRIBUTE AND WITH_PSCORE)
2433
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)

paddle/fluid/distributed/collective/ProcessGroup.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@ namespace paddle {
1818
namespace distributed {
1919

2020
ProcessGroup::Task::Task(int rank,
21-
const std::vector<phi::DenseTensor>& inputTensors,
21+
const std::vector<phi::DenseTensor>& inputs,
2222
CommType comm_type)
2323
: rank_(rank), comm_type_(comm_type) {}
2424

25+
ProcessGroup::Task::Task(int rank,
26+
const std::vector<phi::DenseTensor>& inputs,
27+
CommType comm_type,
28+
bool sync_op)
29+
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
30+
2531
ProcessGroup::Task::~Task() = default;
2632

2733
bool ProcessGroup::Task::IsCompleted() {

paddle/fluid/distributed/collective/ProcessGroup.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,27 @@ class ProcessGroup {
5555
class Task {
5656
public:
5757
Task(int rank,
58-
const std::vector<phi::DenseTensor>& inputTensors,
59-
CommType opType = CommType::UNKNOWN);
58+
const std::vector<phi::DenseTensor>& inputs,
59+
CommType comm_type);
60+
Task(int rank,
61+
const std::vector<phi::DenseTensor>& inputs,
62+
CommType comm_type,
63+
bool sync_op);
6064

6165
virtual ~Task();
6266
virtual bool IsCompleted();
6367
virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
6468
virtual void Synchronize();
69+
bool IsSync() const { return sync_op_; }
6570

6671
protected:
6772
const int rank_;
68-
CommType comm_type_;
73+
CommType comm_type_{CommType::UNKNOWN};
6974
std::mutex mutex_;
70-
bool is_completed_ = false;
75+
bool is_completed_{false};
76+
77+
private:
78+
bool sync_op_{true};
7179
};
7280

7381
explicit ProcessGroup(int rank,
@@ -82,6 +90,7 @@ class ProcessGroup {
8290

8391
virtual const std::string GetBackendName() const = 0;
8492

93+
// TODO(liyurui): This API will be moved later
8594
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
8695
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
8796
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
@@ -90,6 +99,16 @@ class ProcessGroup {
9099
"ProcessGroup%s does not support allreduce", GetBackendName()));
91100
}
92101

102+
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
103+
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
104+
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
105+
const AllreduceOptions&,
106+
bool) {
107+
PADDLE_THROW(platform::errors::InvalidArgument(
108+
"ProcessGroup%s does not support allreduce with sync_op flag",
109+
GetBackendName()));
110+
}
111+
93112
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
94113
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
95114
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT

paddle/fluid/distributed/collective/ProcessGroupNCCL.cc

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,20 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
5555
int rank,
5656
CommType CommType,
5757
const std::vector<phi::DenseTensor>& inputs)
58-
: Task(rank, inputs, CommType), places_(places) {
58+
: TaskStream(rank, inputs, CommType), places_(places) {
59+
control_events_.resize(places.size());
60+
ncclComms_.resize(places.size());
61+
}
62+
63+
ProcessGroupNCCL::NCCLTask::NCCLTask(
64+
const std::vector<Place>& places,
65+
int rank,
66+
CommType comm_type,
67+
const std::vector<phi::DenseTensor>& inputs,
68+
bool sync_op,
69+
bool use_calc_stream)
70+
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
71+
places_(places) {
5972
control_events_.resize(places.size());
6073
ncclComms_.resize(places.size());
6174
}
@@ -116,6 +129,13 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
116129

117130
// TODO(sheniang03): Add timeout for wait, now timeout unused
118131
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
132+
// Warning here when use calc stream but also invoke waiting explicitly.
133+
if (UseCalcStream()) {
134+
VLOG(3) << "Warning: The communication is on calc stream, wait here is "
135+
"useless.";
136+
return true;
137+
}
138+
119139
SynchronizeStreams();
120140
if (FLAGS_nccl_blocking_wait) {
121141
// NOTE(shenliang03): It will block host for sync
@@ -146,7 +166,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
146166
int size,
147167
const platform::Place& place,
148168
int gid)
149-
: ProcessGroup(rank, size, place, gid), store_(store) {
169+
: ProcessGroupStream(rank, size, place, gid), store_(store) {
150170
platform::SetDeviceId(place_.device);
151171
}
152172

@@ -223,6 +243,81 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
223243
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
224244
}
225245

246+
template <typename Fn>
247+
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
248+
std::vector<phi::DenseTensor>& inputs,
249+
std::vector<phi::DenseTensor>& outputs,
250+
Fn fn,
251+
CommType comm_type,
252+
bool sync_op,
253+
bool use_calc_stream) {
254+
const auto& places = GetPlaceList(inputs);
255+
const auto& key = GetKeyFromPlaces(places);
256+
257+
{
258+
std::lock_guard<std::mutex> lock(mutex_);
259+
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
260+
CreateNCCLManagerCache(key, places);
261+
}
262+
}
263+
264+
auto& nccl_comms = places_to_ncclcomm_[key];
265+
266+
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
267+
268+
auto task = std::make_shared<ProcessGroupNCCL::NCCLTask>(
269+
places, rank_, comm_type, inputs, sync_op, use_calc_stream);
270+
271+
platform::CUDADeviceGuard cuda_guard;
272+
273+
{
274+
platform::NCCLGroupGuard nccl_guard;
275+
for (size_t i = 0; i < inputs.size(); ++i) {
276+
cuda_guard.SetDevice(places[i]);
277+
278+
gpuStream_t nccl_stream;
279+
if (use_calc_stream) {
280+
nccl_stream =
281+
static_cast<phi::GPUContext*>(
282+
platform::DeviceContextPool::Instance().Get(places[i]))
283+
->stream();
284+
} else {
285+
nccl_stream = places_to_ctx_[key][i]->stream();
286+
}
287+
288+
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream);
289+
}
290+
}
291+
292+
if (FLAGS_use_stream_safe_cuda_allocator) {
293+
for (size_t i = 0; i < inputs.size(); ++i) {
294+
cuda_guard.SetDevice(places[i]);
295+
296+
gpuStream_t nccl_stream;
297+
if (use_calc_stream) {
298+
nccl_stream =
299+
static_cast<phi::GPUContext*>(
300+
platform::DeviceContextPool::Instance().Get(places[i]))
301+
->stream();
302+
} else {
303+
nccl_stream = places_to_ctx_[key][i]->stream();
304+
}
305+
306+
memory::RecordStream(inputs[i].Holder(), nccl_stream);
307+
}
308+
}
309+
310+
// Adding stream event dependency only when use comm stream
311+
if (!use_calc_stream) {
312+
for (size_t i = 0; i < inputs.size(); ++i) {
313+
cuda_guard.SetDevice(places[i]);
314+
task->control_events_[i].Record(*places_to_ctx_[key][i]);
315+
}
316+
}
317+
318+
return task;
319+
}
320+
226321
template <typename Fn>
227322
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
228323
std::vector<phi::DenseTensor>& inputs,
@@ -386,6 +481,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
386481
CommType::ALLREDUCE);
387482
}
388483

484+
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
485+
std::vector<phi::DenseTensor>& in_tensors,
486+
std::vector<phi::DenseTensor>& out_tensors,
487+
const AllreduceOptions& opts,
488+
bool sync_op,
489+
bool use_calc_stream) {
490+
PADDLE_ENFORCE_EQ(
491+
CheckTensorsInCudaPlace(in_tensors),
492+
true,
493+
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
494+
return Collective(
495+
in_tensors,
496+
out_tensors,
497+
[&](const phi::DenseTensor& input,
498+
phi::DenseTensor& output,
499+
ncclComm_t comm,
500+
const gpuStream_t& stream) {
501+
return platform::dynload::ncclAllReduce(
502+
input.data(),
503+
output.data(),
504+
input.numel(),
505+
platform::ToNCCLDataType(input.type()),
506+
ToNCCLRedType(opts.reduce_op),
507+
comm,
508+
stream);
509+
},
510+
CommType::ALLREDUCE,
511+
sync_op,
512+
use_calc_stream);
513+
}
514+
389515
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
390516
std::vector<phi::DenseTensor>& in_tensors,
391517
std::vector<phi::DenseTensor>& out_tensors,
@@ -432,7 +558,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
432558
new paddle::experimental::DefaultAllocator(place));
433559
barrierTensors.emplace_back(allocator.get(), meta);
434560
}
435-
auto task = ProcessGroupNCCL::AllReduce(barrierTensors, barrierTensors);
561+
auto task = ProcessGroupNCCL::AllReduce(
562+
barrierTensors, barrierTensors, AllreduceOptions());
436563
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
437564
nccl_task->barrierTensors_ = std::move(barrierTensors);
438565
return task;

paddle/fluid/distributed/collective/ProcessGroupNCCL.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include <unordered_map>
2222
#include <vector>
2323

24-
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
24+
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
2525
#include "paddle/fluid/distributed/store/store.h"
2626
#include "paddle/fluid/platform/cuda_device_guard.h"
2727
#include "paddle/fluid/platform/device_context.h"
@@ -46,16 +46,23 @@ namespace distributed {
4646

4747
using Place = paddle::platform::Place;
4848

49-
class ProcessGroupNCCL : public ProcessGroup {
49+
class ProcessGroupNCCL : public ProcessGroupStream {
5050
public:
51-
class NCCLTask : public ProcessGroup::Task,
51+
class NCCLTask : public ProcessGroupStream::TaskStream,
5252
public std::enable_shared_from_this<NCCLTask> {
5353
public:
5454
NCCLTask(const std::vector<Place>& places,
5555
int rank,
5656
CommType CommType,
5757
const std::vector<phi::DenseTensor>& inputs);
5858

59+
NCCLTask(const std::vector<Place>& places,
60+
int rank,
61+
CommType comm_type,
62+
const std::vector<phi::DenseTensor>& inputs,
63+
bool is_sync,
64+
bool use_calc_stream);
65+
5966
bool IsCompleted();
6067

6168
void SynchronizeStreams();
@@ -89,6 +96,14 @@ class ProcessGroupNCCL : public ProcessGroup {
8996
return std::string(NCCL_BACKEND_NAME);
9097
}
9198

99+
std::shared_ptr<ProcessGroup::Task> AllReduce(
100+
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
101+
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
102+
const AllreduceOptions& options,
103+
bool sync_op,
104+
bool use_calc_stream) override;
105+
106+
// TODO(liyurui): This API will be moved later
92107
std::shared_ptr<ProcessGroup::Task> AllReduce(
93108
std::vector<phi::DenseTensor>& in_tensors,
94109
std::vector<phi::DenseTensor>& out_tensors,
@@ -194,6 +209,15 @@ class ProcessGroupNCCL : public ProcessGroup {
194209
Fn fn,
195210
CommType op_type);
196211

212+
template <typename Fn>
213+
std::shared_ptr<ProcessGroupStream::Task> Collective(
214+
std::vector<phi::DenseTensor>& inputs, // NOLINT
215+
std::vector<phi::DenseTensor>& outputs, // NOLINT
216+
Fn fn,
217+
CommType comm_type,
218+
bool sync_op,
219+
bool use_calc_stream);
220+
197221
template <typename Fn>
198222
void Collective(const phi::DenseTensor*,
199223
phi::DenseTensor*,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
16+
17+
namespace paddle {
18+
namespace distributed {
19+
20+
ProcessGroupStream::ProcessGroupStream(int rank,
21+
int size,
22+
const platform::Place& place,
23+
int gid)
24+
: ProcessGroup(rank, size, place, gid) {}
25+
26+
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
27+
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
28+
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
29+
const AllreduceOptions& options,
30+
bool sync_op) {
31+
return AllReduce(input_tensors,
32+
output_tensors,
33+
options,
34+
sync_op,
35+
/*use_calc_stream*/ false);
36+
}
37+
38+
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
39+
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
40+
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
41+
const AllreduceOptions& options,
42+
bool sync_op,
43+
bool use_calc_stream) {
44+
PADDLE_THROW(platform::errors::InvalidArgument(
45+
"ProcessGroup%s does not support do allreduce", GetBackendName()));
46+
}
47+
48+
} // namespace distributed
49+
} // namespace paddle

0 commit comments

Comments
 (0)