Skip to content

Commit 322169a

Browse files
committed
Update models API
1 parent 1807f98 commit 322169a

File tree

6 files changed

+11
-11
lines changed

6 files changed

+11
-11
lines changed

ext/flan_t5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class FlanT5(BaseLM):
88

99
def __init__(self, model_name="google/flan-t5-base", temp=0.1, device='cuda', max_length=None, use_bf16=False, **kwargs):
10-
super(FlanT5, self).__init__(name=model_name)
10+
super(FlanT5, self).__init__(name=model_name, **kwargs)
1111
self.__device = device
1212
self.__max_length = 512 if max_length is None else max_length
1313
self.__model = T5ForConditionalGeneration.from_pretrained(

ext/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
class Gemma(BaseLM):
88

99
def __init__(self, model_name="google/gemma-7b-it", temp=0.1, device='cuda',
10-
max_length=None, api_token=None, use_bf16=False):
11-
super(Gemma, self).__init__(name=model_name)
10+
max_length=None, api_token=None, use_bf16=False, **kwargs):
11+
super(Gemma, self).__init__(name=model_name, **kwargs)
1212
self.__device = device
1313
self.__max_length = 1024 if max_length is None else max_length
1414
self.__model = AutoModelForCausalLM.from_pretrained(

ext/llama32.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
class Llama32(BaseLM):
88

99
def __init__(self, model_name="meta-llama/Llama-3.2-3B-Instruct", api_token=None,
10-
temp=0.1, device='cuda', max_length=256, use_bf16=False):
11-
super(Llama32, self).__init__(name=model_name)
10+
temp=0.1, device='cuda', max_length=256, use_bf16=False, **kwargs):
11+
super(Llama32, self).__init__(name=model_name, **kwargs)
1212

1313
if use_bf16:
1414
print("Warning: Experimental mode with bf-16!")

ext/microsoft_phi_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class MicrosoftPhi2(BaseLM):
1111
""" https://huggingface.co/microsoft/phi-2
1212
"""
1313

14-
def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False):
15-
super(MicrosoftPhi2, self).__init__(model_name)
14+
def __init__(self, model_name="microsoft/phi-2", device='cuda', max_length=None, use_bf16=False, **kwargs):
15+
super(MicrosoftPhi2, self).__init__(model_name, **kwargs)
1616

1717
# Default parameters.
1818
kwargs = {"device_map": device, "trust_remote_code": True}

ext/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
class Mistral(BaseLM):
88

99
def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.1", temp=0.1, device='cuda', max_length=None,
10-
use_bf16=False):
10+
use_bf16=False, **kwargs):
1111
assert(isinstance(max_length, int) or max_length is None)
12-
super(Mistral, self).__init__(name=model_name)
12+
super(Mistral, self).__init__(name=model_name, **kwargs)
1313

1414
if use_bf16:
1515
print("Warning: Experimental mode with bf-16!")

ext/openai_gpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
class OpenAIGPT(BaseLM):
66

77
def __init__(self, api_key, model_name="gpt-4-1106-preview", temp=0.1, max_tokens=None, assistant_prompt=None,
8-
freq_penalty=0.0, kwargs=None):
8+
freq_penalty=0.0, **kwargs):
99
assert(isinstance(assistant_prompt, str) or assistant_prompt is None)
10-
super(OpenAIGPT, self).__init__(name=model_name)
10+
super(OpenAIGPT, self).__init__(name=model_name, **kwargs)
1111

1212
# dynamic import of the OpenAI library.
1313
OpenAI = auto_import("openai._client.OpenAI", is_class=False)

0 commit comments

Comments
 (0)