Skip to content

Commit a16d2f0

Browse files
authored
feat: --gpu-segments (#204)
1 parent 68ade91 commit a16d2f0

File tree

6 files changed

+82
-29
lines changed

6 files changed

+82
-29
lines changed

src/app.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
4343
args.maxSeqLen = 0;
4444
args.netTurbo = true;
4545
args.gpuIndex = -1;
46+
args.gpuSegmentFrom = -1;
47+
args.gpuSegmentTo = -1;
48+
4649
int i = 1;
4750
if (requireMode && argc > 1) {
4851
args.mode = argv[1];
@@ -79,15 +82,15 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
7982

8083
for (int s = 0; s < count; s++) {
8184
char *v = argv[i + 1 + s];
82-
char *sep = std::strstr(v, ":");
83-
if (sep == NULL) {
85+
char *separator = std::strstr(v, ":");
86+
if (separator == NULL) {
8487
throw std::runtime_error("Invalid worker address: " + std::string(v));
8588
}
86-
int hostLen = sep - v;
89+
int hostLen = separator - v;
8790
args.workerHosts[s] = new char[hostLen + 1];
8891
std::memcpy(args.workerHosts[s], v, hostLen);
8992
args.workerHosts[s][hostLen] = '\0';
90-
args.workerPorts[s] = std::atoi(sep + 1);
93+
args.workerPorts[s] = std::atoi(separator + 1);
9194
}
9295

9396
i += count - 1;
@@ -109,6 +112,12 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
109112
args.maxSeqLen = (unsigned int)atoi(value);
110113
} else if (std::strcmp(name, "--gpu-index") == 0) {
111114
args.gpuIndex = atoi(value);
115+
} else if (std::strcmp(name, "--gpu-segments") == 0) {
116+
char *separator = std::strstr(value, ":");
117+
if (separator == NULL)
118+
throw std::runtime_error("GPU segments expected in the format <from>:<to>");
119+
args.gpuSegmentFrom = atoi(value);
120+
args.gpuSegmentTo = atoi(separator + 1);
112121
} else if (std::strcmp(name, "--net-turbo") == 0) {
113122
args.netTurbo = atoi(value) == 1;
114123
} else {
@@ -128,23 +137,32 @@ AppCliArgs::~AppCliArgs() {
128137
delete[] workerPorts;
129138
}
130139

131-
static NnDevice *createDevice(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
140+
static std::vector<NnExecutorDevice> resolveDevices(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
141+
std::vector<NnExecutorDevice> devices;
142+
132143
if (args->gpuIndex >= 0) {
133144
#if defined(DLLAMA_VULKAN)
134-
return new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution);
145+
devices.push_back(NnExecutorDevice(
146+
new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution),
147+
args->gpuSegmentFrom,
148+
args->gpuSegmentTo
149+
));
135150
#else
136151
throw std::runtime_error("This build does not support GPU");
137152
#endif
138153
}
139-
return new NnCpuDevice(netConfig, nodeConfig, netExecution);
154+
155+
if (args->gpuIndex < 0 || (args->gpuSegmentFrom >= 0 && args->gpuSegmentTo >= 0)) {
156+
devices.push_back(NnExecutorDevice(new NnCpuDevice(netConfig, nodeConfig, netExecution), -1, -1));
157+
}
158+
return devices;
140159
}
141160

142-
RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
161+
RootLlmInference::RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
143162
this->header = net->header;
144163
this->tokenPipe = (float *)execution->pipes[net->tokenPipeIndex];
145164
this->positionPipe = (float *)execution->pipes[net->positionPipeIndex];
146165
this->logitsPipe = (float *)execution->pipes[net->logitsPipeIndex];
147-
this->device = device;
148166
this->execution = execution;
149167
this->executor = executor;
150168
this->network = network; // May be nullptr!
@@ -245,13 +263,13 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
245263
configWriter.writeToWorkers(&net.netConfig, net.nodeConfigs);
246264
}
247265

248-
std::unique_ptr<NnDevice> device(createDevice(args, &net.netConfig, rootNodeConfig, &execution));
249-
NnExecutor executor(&net.netConfig, rootNodeConfig, device.get(), &execution, synchronizer.get(), args->benchmark);
266+
std::vector<NnExecutorDevice> devices = resolveDevices(args, &net.netConfig, rootNodeConfig, &execution);
267+
NnExecutor executor(&net.netConfig, rootNodeConfig, &devices, &execution, synchronizer.get(), args->benchmark);
250268

251269
NnRootWeightLoader weightLoader(&executor, network, nNodes);
252270
loadLlmNetWeight(args->modelPath, &net, &weightLoader);
253271

254-
RootLlmInference inference(&net, device.get(), &execution, &executor, network);
272+
RootLlmInference inference(&net, &execution, &executor, network);
255273

256274
if (network != nullptr) {
257275
network->resetStats();
@@ -290,10 +308,9 @@ void runWorkerApp(AppCliArgs *args) {
290308

291309
NnNetExecution execution(args->nThreads, &netConfig);
292310

293-
std::unique_ptr<NnDevice> device(createDevice(args, &netConfig, &nodeConfig, &execution));
294-
311+
std::vector<NnExecutorDevice> devices = resolveDevices(args, &netConfig, &nodeConfig, &execution);
295312
NnNetworkNodeSynchronizer synchronizer(network, &execution, &netConfig, &nodeConfig);
296-
NnExecutor executor(&netConfig, &nodeConfig, device.get(), &execution, &synchronizer, false);
313+
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
297314

298315
NnWorkerWeightReader weightReader(&executor, network);
299316
weightReader.read();

src/app.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class AppCliArgs {
3131
NnUint maxSeqLen;
3232
bool netTurbo;
3333
int gpuIndex;
34+
int gpuSegmentFrom;
35+
int gpuSegmentTo;
3436

3537
// worker
3638
NnUint port;
@@ -51,13 +53,12 @@ class RootLlmInference {
5153
float *tokenPipe;
5254
float *positionPipe;
5355
LlmHeader *header;
54-
NnDevice *device;
5556
NnNetExecution *execution;
5657
NnExecutor *executor;
5758
NnNetwork *network;
5859
LlmControlPacket controlPacket;
5960
public:
60-
RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
61+
RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
6162
void setBatchSize(NnUint batchSize);
6263
void setPosition(NnUint position);
6364
void setToken(NnUint batchIndex, NnUint token);

src/nn/nn-cpu-test.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,13 @@ int main() {
6464
for (NnUint i = 0; i < DIM; i++)
6565
rmsNormWeight[i] = 0.5 + i / (float)DIM;
6666

67-
NnCpuDevice device(&netConfig, &nodeConfig, &execution);
67+
NnCpuDevice *device = new NnCpuDevice(&netConfig, &nodeConfig, &execution);
68+
std::vector<NnExecutorDevice> devices;
69+
devices.push_back(NnExecutorDevice(device, -1, -1));
70+
6871
NnFakeNodeSynchronizer synchronizer;
69-
float *rms = (float *)device.buffers[0];
70-
NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer, false);
72+
float *rms = (float *)device->buffers[0];
73+
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
7174
executor.loadWeight("rms_norm", 0, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);
7275

7376
execution.setBatchSize(2);

src/nn/nn-executor.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,41 @@ void NnNetExecution::setBatchSize(NnUint batchSize) {
3333
this->batchSize = batchSize;
3434
}
3535

36-
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
36+
NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo) {
37+
this->device = std::unique_ptr<NnDevice>(device);
38+
this->segmentFrom = segmentFrom;
39+
this->segmentTo = segmentTo;
40+
}
41+
42+
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
3743
: segments(nodeConfig->nSegments), steps()
3844
{
39-
NnUint maxNThreads = device->maxNThreads();
45+
NnUint maxNThreads = 0;
46+
for (NnExecutorDevice &d : *devices) {
47+
if (d.device->maxNThreads() > maxNThreads)
48+
maxNThreads = d.device->maxNThreads();
49+
}
4050
if (netExecution->nThreads > maxNThreads)
41-
throw std::invalid_argument("This device supports max " + std::to_string(maxNThreads) + " threads");
51+
throw std::invalid_argument("This configuration supports max " + std::to_string(maxNThreads) + " threads");
52+
4253
this->netExecution = netExecution;
4354
this->nodeConfig = nodeConfig;
4455

4556
bool useSynchronizer = netConfig->nNodes > 1;
4657
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
58+
NnDevice *device = nullptr;
59+
for (NnExecutorDevice &d : *devices) {
60+
if (
61+
(d.segmentFrom == -1 && d.segmentTo == -1) ||
62+
(segmentIndex >= d.segmentFrom && segmentIndex <= d.segmentTo)
63+
) {
64+
device = d.device.get();
65+
break;
66+
}
67+
}
68+
if (device == nullptr)
69+
throw std::invalid_argument("Cannot locate device for segment " + std::to_string(segmentIndex));
70+
4771
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
4872
if (segmentConfig->nOps > 0) {
4973
NnDeviceSegment *segment = device->createSegment(segmentIndex);
@@ -60,7 +84,6 @@ NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevic
6084

6185
context.nThreads = netExecution->nThreads;
6286
context.synchronizer = synchronizer;
63-
context.device = device;
6487
context.nSteps = (NnUint)steps.size();
6588
context.steps = steps.data();
6689
if (benchmark)

src/nn/nn-executor.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ enum NnExecutorStepType {
5151

5252
#define N_STEP_TYPES STEP_SYNC_NODES + 1
5353

54+
class NnExecutorDevice {
55+
public:
56+
std::unique_ptr<NnDevice> device;
57+
int segmentFrom;
58+
int segmentTo;
59+
NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo);
60+
};
61+
5462
typedef struct {
5563
NnExecutorStepType type;
5664
NnDeviceSegment *segment;
@@ -63,7 +71,6 @@ typedef struct {
6371
NnUint nSteps;
6472
NnExecutorStep *steps;
6573
NnNodeSynchronizer *synchronizer;
66-
NnDevice *device;
6774
std::atomic_uint currentStepIndex;
6875
std::atomic_uint doneThreadCount;
6976
NnUint batchSize;
@@ -86,7 +93,7 @@ class NnExecutor {
8693
NnExecutorThread *threads;
8794
NnExecutorContext context;
8895
public:
89-
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark);
96+
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark);
9097
~NnExecutor();
9198
void loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight);
9299
void forward();

src/nn/nn-vulkan-test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ void execute(
3636
NnNetExecution execution(1, &netConfig);
3737

3838
NnUint gpuIndex = 0;
39-
NnVulkanDevice device(gpuIndex, &netConfig, &nodeConfig, &execution);
39+
std::vector<NnExecutorDevice> devices;
40+
NnVulkanDevice *device = new NnVulkanDevice(gpuIndex, &netConfig, &nodeConfig, &execution);
41+
devices.push_back(NnExecutorDevice(device, -1, -1));
4042
NnFakeNodeSynchronizer synchronizer;
41-
NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer, false);
43+
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
4244

43-
execute(&executor, &execution, &device);
45+
execute(&executor, &execution, device);
4446
}
4547

4648
void testRmsNorm_F32_F32_F32() {

0 commit comments

Comments
 (0)