Skip to content

Commit c4df151

Browse files
committed
experimental swa flag
1 parent 499283c commit c4df151

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct load_model_inputs
7070
const int quant_k = 0;
7171
const int quant_v = 0;
7272
const bool check_slowness = false;
73+
const bool swa_support = false;
7374
const bool quiet = false;
7475
const int debugmode = 0;
7576
};

gpttype_adapter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
19271927
kcpp_data->use_smartcontext = inputs.use_smartcontext;
19281928
kcpp_data->use_contextshift = inputs.use_contextshift;
19291929
kcpp_data->use_fastforward = inputs.use_fastforward;
1930-
kcpp_data->swa_full = (inputs.use_fastforward || inputs.use_contextshift)?true:false;
1930+
kcpp_data->swa_full = !inputs.swa_support;//(inputs.use_fastforward || inputs.use_contextshift)?true:false;
1931+
if(!kcpp_data->swa_full)
1932+
{
1933+
printf("\n!!!!!!!!!!!!!!!!!!!\nExperimental FLAG - SWA SUPPORT IS ENABLED!\n!!!!!!!!!!!!!!!!!!!\n");
1934+
}
19311935
debugmode = inputs.debugmode;
19321936
draft_ctx = nullptr;
19331937
guidance_ctx = nullptr;

koboldcpp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class load_model_inputs(ctypes.Structure):
192192
("quant_k", ctypes.c_int),
193193
("quant_v", ctypes.c_int),
194194
("check_slowness", ctypes.c_bool),
195+
("swa_support", ctypes.c_bool),
195196
("quiet", ctypes.c_bool),
196197
("debugmode", ctypes.c_int)]
197198

@@ -1248,6 +1249,7 @@ def load_model(model_filename):
12481249
inputs.override_kv = args.overridekv.encode("UTF-8") if args.overridekv else "".encode("UTF-8")
12491250
inputs.override_tensors = args.overridetensors.encode("UTF-8") if args.overridetensors else "".encode("UTF-8")
12501251
inputs.check_slowness = (not args.highpriority and os.name == 'nt' and 'Intel' in platform.processor())
1252+
inputs.swa_support = args.experiment_swa
12511253
inputs = set_backend_props(inputs)
12521254
ret = handle.load_model(inputs)
12531255
return ret
@@ -6907,6 +6909,9 @@ def range_checker(arg: str):
69076909
admingroup.add_argument("--adminpassword", metavar=('[password]'), help="Require a password to access admin functions. You are strongly advised to use one for publically accessible instances!", default=None)
69086910
admingroup.add_argument("--admindir", metavar=('[directory]'), help="Specify a directory to look for .kcpps configs in, which can be used to swap models.", default="")
69096911

6912+
experimentgroup = parser.add_argument_group('Experimental Commands, can change or break any time!')
6913+
experimentgroup.add_argument("--experiment_swa", help="Enables SWA mode. There are no safety checks.", action='store_true')
6914+
69106915
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
69116916
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
69126917
deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+')

0 commit comments

Comments
 (0)