Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def parts(length):
}

def confirm(message: str):
alwaysYes = sys.argv.count('-y') > 0
if alwaysYes:
return True
result = input(f'❓ {message} ("Y" if yes): ').upper()
return result == 'Y' or result == 'YES'

Expand Down Expand Up @@ -126,8 +129,10 @@ def printUsage():
print('Usage: python download-model.py <model>')
print()
print('Options:')
print(' <model> The name of the model to download')
print(' --run Run the model after download')
print(' <model> The name of the model to download')
print(' -skip-run Do not run the model after download')
print(' -skip-script Do not create a script to run the model')
print(' -y Skip confirmation prompts')
print()
print('Available models:')
for model in MODELS:
Expand All @@ -144,7 +149,6 @@ def printUsage():
if modelName not in MODELS:
print(f'Model is not supported: {modelName}')
exit(1)
runAfterDownload = sys.argv.count('--run') > 0

model = MODELS[modelName]
(modelPath, tokenizerPath) = download(modelName, model)
Expand All @@ -165,12 +169,15 @@ def printUsage():
print()
print('--- copy end -----')

runFilePath = writeRunFile(modelName, command)
print(f'🌻 Created {runFilePath} script to easy run')
skipRun = sys.argv.count('-skip-run') > 0
skipScript = sys.argv.count('-skip-script') > 0

if (not runAfterDownload):
runAfterDownload = confirm('Do you want to run Distributed Llama?')
if (runAfterDownload):
if (not os.path.isfile('dllama')):
os.system('make dllama')
os.system(command)
if (not skipScript):
runFilePath = writeRunFile(modelName, command)
print(f'🌻 Created {runFilePath} script to easy run')

if (not skipRun):
if (confirm('Do you want to run Distributed Llama?')):
if (not os.path.isfile('dllama')):
os.system('make dllama')
os.system(command)
11 changes: 7 additions & 4 deletions src/nn/vulkan/cast-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ void main() {
const uint workGroupIndex = gl_WorkGroupID.z;

const BatchInfo info = infos[batchIndex];
sharedDim = info.inputSizeX / nWorkGroups;
const uint dimOffset = sharedDim * workGroupIndex;
sharedXOffset = info.inputOffset + dimOffset;
sharedYOffset = info.outputOffset + dimOffset;

const uint slice = info.inputSizeX / nWorkGroups;
const uint rest = info.inputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedXOffset = info.inputOffset + offset;
sharedYOffset = info.outputOffset + offset;
}

barrier();
Expand Down
8 changes: 4 additions & 4 deletions src/nn/vulkan/cast-forward-f32-q80.comp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ void main() {

const BatchInfo info = infos[batchIndex];

const uint ySlice = info.outputSizeX / nWorkGroups;
const uint yRest = info.outputSizeX % nWorkGroups;
sharedYStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
sharedYEnd = sharedYStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
const uint slice = info.outputSizeX / nWorkGroups;
const uint rest = info.outputSizeX % nWorkGroups;
sharedYStart = workGroupIndex * slice + min(rest, workGroupIndex);
sharedYEnd = sharedYStart + slice + (workGroupIndex < rest ? 1 : 0);
sharedXOffset = info.inputOffset;
sharedYOffset = info.outputOffset;
}
Expand Down
2 changes: 1 addition & 1 deletion src/nn/vulkan/inv-rms-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void main() {
const uint offset = sharedInfo.inputOffset;
const uint slice = colDim / N_THREADS;
const uint rest = colDim % N_THREADS;
const uint start = offset + threadIndex * slice + (threadIndex < rest ? threadIndex : rest);
const uint start = offset + threadIndex * slice + min(rest, threadIndex);
const uint end = start + slice + (threadIndex < rest ? 1 : 0);

for (uint col = 0; col < nColumns; col++) {
Expand Down
10 changes: 7 additions & 3 deletions src/nn/vulkan/merge-add-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ void main() {
const uint workGroupIndex = gl_WorkGroupID.z;

const BatchInfo info = infos[batchIndex];
sharedDim = info.outputSizeX / nWorkGroups;
const uint slice = info.outputSizeX / nWorkGroups;
const uint rest = info.outputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);

sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedOutputSizeX = info.outputSizeX;
sharedParts = info.inputSizeX / info.outputSizeX;
sharedXOffset = info.inputOffset + sharedDim * workGroupIndex;
sharedYOffset = info.outputOffset + sharedDim * workGroupIndex;
sharedXOffset = info.inputOffset + offset;
sharedYOffset = info.outputOffset + offset;
}

barrier();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/vulkan/merge-add-forward-q80-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void main() {
const uint xSlice = xJump / nWorkGroups;
const uint xRest = xJump % nWorkGroups;

sharedXStart = workGroupIndex * xSlice + (workGroupIndex < xRest ? workGroupIndex : xRest);
sharedXStart = workGroupIndex * xSlice + min(xRest, workGroupIndex);
sharedXEnd = sharedXStart + xSlice + (workGroupIndex < xRest ? 1 : 0);
sharedNParts = nParts;
sharedXJump = xJump;
Expand Down
10 changes: 7 additions & 3 deletions src/nn/vulkan/mul-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ void main() {
const uint workGroupIndex = gl_WorkGroupID.z;

const BatchInfo info = infos[batchIndex];
sharedDim = info.inputSizeX / nWorkGroups;
sharedXyOffset = info.inputOffset + sharedDim * workGroupIndex;
sharedMOffset = info.inputSizeX * batchIndex + sharedDim * workGroupIndex;
const uint slice = info.inputSizeX / nWorkGroups;
const uint rest = info.inputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);

sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedXyOffset = info.inputOffset + offset;
sharedMOffset = info.inputSizeX * batchIndex + offset;
}

barrier();
Expand Down
7 changes: 5 additions & 2 deletions src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ void main() {
const uint workGroupIndex = gl_WorkGroupID.z;

const BatchInfo info = infos[batchIndex];
sharedDim = info.inputSizeX / nWorkGroups;
sharedDimOffset = sharedDim * workGroupIndex;
const uint slice = info.inputSizeX / nWorkGroups;
const uint rest = info.inputSizeX % nWorkGroups;

sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedDimOffset = slice * workGroupIndex + min(rest, workGroupIndex);
sharedColDim = info.inputSizeX / nColumns;
sharedXOffset = info.inputOffset + sharedDimOffset;
sharedYOffset = info.outputOffset + sharedDimOffset;
Expand Down
13 changes: 8 additions & 5 deletions src/nn/vulkan/shift-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ void main() {
const uint nWorkGroups = gl_NumWorkGroups.z;
const uint batchIndex = gl_WorkGroupID.y;
const uint workGroupIndex = gl_WorkGroupID.z;

const uint index = uint(indexes[batchIndex]);

BatchInfo info = infos[batchIndex];
sharedDim = info.inputSizeX / nWorkGroups;
const uint dimOffset = sharedDim * workGroupIndex;
sharedXOffset = info.inputOffset + dimOffset;
sharedYOffset = index * info.inputSizeX + dimOffset;
const uint slice = info.inputSizeX / nWorkGroups;
const uint rest = info.inputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);

sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedXOffset = info.inputOffset + offset;
sharedYOffset = index * info.inputSizeX + offset;
}

barrier();
Expand Down
10 changes: 7 additions & 3 deletions src/nn/vulkan/silu-forward-f32-f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ void main() {
const uint workGroupIndex = gl_WorkGroupID.z;

const BatchInfo info = infos[batchIndex];
sharedDim = info.inputSizeX / nWorkGroups;
sharedXOffset = info.inputOffset + sharedDim * workGroupIndex;
sharedYOffset = info.outputOffset + sharedDim * workGroupIndex;
const uint slice = info.inputSizeX / nWorkGroups;
const uint rest = info.inputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);

sharedDim = slice + (workGroupIndex < rest ? 1 : 0);
sharedXOffset = info.inputOffset + offset;
sharedYOffset = info.outputOffset + offset;
}

barrier();
Expand Down
Loading