@@ -43,6 +43,9 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
43
43
args.maxSeqLen = 0 ;
44
44
args.netTurbo = true ;
45
45
args.gpuIndex = -1 ;
46
+ args.gpuSegmentFrom = -1 ;
47
+ args.gpuSegmentTo = -1 ;
48
+
46
49
int i = 1 ;
47
50
if (requireMode && argc > 1 ) {
48
51
args.mode = argv[1 ];
@@ -79,15 +82,15 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
79
82
80
83
for (int s = 0 ; s < count; s++) {
81
84
char *v = argv[i + 1 + s];
82
- char *sep = std::strstr (v, " :" );
83
- if (sep == NULL ) {
85
+ char *separator = std::strstr (v, " :" );
86
+ if (separator == NULL ) {
84
87
throw std::runtime_error (" Invalid worker address: " + std::string (v));
85
88
}
86
- int hostLen = sep - v;
89
+ int hostLen = separator - v;
87
90
args.workerHosts [s] = new char [hostLen + 1 ];
88
91
std::memcpy (args.workerHosts [s], v, hostLen);
89
92
args.workerHosts [s][hostLen] = ' \0 ' ;
90
- args.workerPorts [s] = std::atoi (sep + 1 );
93
+ args.workerPorts [s] = std::atoi (separator + 1 );
91
94
}
92
95
93
96
i += count - 1 ;
@@ -109,6 +112,12 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
109
112
args.maxSeqLen = (unsigned int )atoi (value);
110
113
} else if (std::strcmp (name, " --gpu-index" ) == 0 ) {
111
114
args.gpuIndex = atoi (value);
115
+ } else if (std::strcmp (name, " --gpu-segments" ) == 0 ) {
116
+ char *separator = std::strstr (value, " :" );
117
+ if (separator == NULL )
118
+ throw std::runtime_error (" GPU segments expected in the format <from>:<to>" );
119
+ args.gpuSegmentFrom = atoi (value);
120
+ args.gpuSegmentTo = atoi (separator + 1 );
112
121
} else if (std::strcmp (name, " --net-turbo" ) == 0 ) {
113
122
args.netTurbo = atoi (value) == 1 ;
114
123
} else {
@@ -128,23 +137,32 @@ AppCliArgs::~AppCliArgs() {
128
137
delete[] workerPorts;
129
138
}
130
139
131
- static NnDevice *createDevice (AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
140
+ static std::vector<NnExecutorDevice> resolveDevices (AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
141
+ std::vector<NnExecutorDevice> devices;
142
+
132
143
if (args->gpuIndex >= 0 ) {
133
144
#if defined(DLLAMA_VULKAN)
134
- return new NnVulkanDevice (args->gpuIndex , netConfig, nodeConfig, netExecution);
145
+ devices.push_back (NnExecutorDevice (
146
+ new NnVulkanDevice (args->gpuIndex , netConfig, nodeConfig, netExecution),
147
+ args->gpuSegmentFrom ,
148
+ args->gpuSegmentTo
149
+ ));
135
150
#else
136
151
throw std::runtime_error (" This build does not support GPU" );
137
152
#endif
138
153
}
139
- return new NnCpuDevice (netConfig, nodeConfig, netExecution);
154
+
155
+ if (args->gpuIndex < 0 || (args->gpuSegmentFrom >= 0 && args->gpuSegmentTo >= 0 )) {
156
+ devices.push_back (NnExecutorDevice (new NnCpuDevice (netConfig, nodeConfig, netExecution), -1 , -1 ));
157
+ }
158
+ return devices;
140
159
}
141
160
142
- RootLlmInference::RootLlmInference (LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
161
+ RootLlmInference::RootLlmInference (LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
143
162
this ->header = net->header ;
144
163
this ->tokenPipe = (float *)execution->pipes [net->tokenPipeIndex ];
145
164
this ->positionPipe = (float *)execution->pipes [net->positionPipeIndex ];
146
165
this ->logitsPipe = (float *)execution->pipes [net->logitsPipeIndex ];
147
- this ->device = device;
148
166
this ->execution = execution;
149
167
this ->executor = executor;
150
168
this ->network = network; // May be nullptr!
@@ -245,13 +263,13 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
245
263
configWriter.writeToWorkers (&net.netConfig , net.nodeConfigs );
246
264
}
247
265
248
- std::unique_ptr<NnDevice> device ( createDevice ( args, &net.netConfig , rootNodeConfig, &execution) );
249
- NnExecutor executor (&net.netConfig , rootNodeConfig, device. get () , &execution, synchronizer.get (), args->benchmark );
266
+ std::vector<NnExecutorDevice> devices = resolveDevices ( args, &net.netConfig , rootNodeConfig, &execution);
267
+ NnExecutor executor (&net.netConfig , rootNodeConfig, &devices , &execution, synchronizer.get (), args->benchmark );
250
268
251
269
NnRootWeightLoader weightLoader (&executor, network, nNodes);
252
270
loadLlmNetWeight (args->modelPath , &net, &weightLoader);
253
271
254
- RootLlmInference inference (&net, device. get (), &execution, &executor, network);
272
+ RootLlmInference inference (&net, &execution, &executor, network);
255
273
256
274
if (network != nullptr ) {
257
275
network->resetStats ();
@@ -290,10 +308,9 @@ void runWorkerApp(AppCliArgs *args) {
290
308
291
309
NnNetExecution execution (args->nThreads , &netConfig);
292
310
293
- std::unique_ptr<NnDevice> device (createDevice (args, &netConfig, &nodeConfig, &execution));
294
-
311
+ std::vector<NnExecutorDevice> devices = resolveDevices (args, &netConfig, &nodeConfig, &execution);
295
312
NnNetworkNodeSynchronizer synchronizer (network, &execution, &netConfig, &nodeConfig);
296
- NnExecutor executor (&netConfig, &nodeConfig, device. get () , &execution, &synchronizer, false );
313
+ NnExecutor executor (&netConfig, &nodeConfig, &devices , &execution, &synchronizer, false );
297
314
298
315
NnWorkerWeightReader weightReader (&executor, network);
299
316
weightReader.read ();
0 commit comments