Skip to content

Commit b8d7aa6

Browse files
cyyevergante
authored andcommitted
Fix Optional type annotation (huggingface#36841)
* Fix annotation * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
1 parent 15b8768 commit b8d7aa6

21 files changed

+65
-57
lines changed

src/transformers/agents/agents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def tools(self) -> Dict[str, Tool]:
217217
"""Get all tools currently in the toolbox"""
218218
return self._tools
219219

220-
def show_tool_descriptions(self, tool_description_template: str = None) -> str:
220+
def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
221221
"""
222222
Returns the description of all tools in the toolbox
223223
@@ -891,7 +891,7 @@ def direct_run(self, task: str):
891891

892892
return final_answer
893893

894-
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
894+
def planning_step(self, task, is_first_step: bool = False, iteration: Optional[int] = None):
895895
"""
896896
Used periodically by the agent to plan the next steps to reach the objective.
897897

src/transformers/audio_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int =
11251125
return frames
11261126

11271127

1128-
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
1128+
def stft(frames: np.array, windowing_function: np.array, fft_window_size: Optional[int] = None):
11291129
"""
11301130
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
11311131
as `torch.stft`.

src/transformers/cache_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,8 +1183,8 @@ class StaticCache(Cache):
11831183
def __init__(
11841184
self,
11851185
config: PretrainedConfig,
1186-
batch_size: int = None,
1187-
max_cache_len: int = None,
1186+
batch_size: Optional[int] = None,
1187+
max_cache_len: Optional[int] = None,
11881188
device: torch.device = None,
11891189
dtype: torch.dtype = torch.float32,
11901190
max_batch_size: Optional[int] = None,
@@ -1367,8 +1367,8 @@ class SlidingWindowCache(StaticCache):
13671367
def __init__(
13681368
self,
13691369
config: PretrainedConfig,
1370-
batch_size: int = None,
1371-
max_cache_len: int = None,
1370+
batch_size: Optional[int] = None,
1371+
max_cache_len: Optional[int] = None,
13721372
device: torch.device = None,
13731373
dtype: torch.dtype = torch.float32,
13741374
max_batch_size: Optional[int] = None,
@@ -1674,8 +1674,8 @@ class HybridCache(Cache):
16741674
def __init__(
16751675
self,
16761676
config: PretrainedConfig,
1677-
batch_size: int = None,
1678-
max_cache_len: int = None,
1677+
batch_size: Optional[int] = None,
1678+
max_cache_len: Optional[int] = None,
16791679
device: Union[torch.device, str] = None,
16801680
dtype: torch.dtype = torch.float32,
16811681
max_batch_size: Optional[int] = None,
@@ -1877,7 +1877,7 @@ class MambaCache:
18771877
def __init__(
18781878
self,
18791879
config: PretrainedConfig,
1880-
batch_size: int = None,
1880+
batch_size: Optional[int] = None,
18811881
dtype: torch.dtype = torch.float16,
18821882
device: Optional[Union[torch.device, str]] = None,
18831883
max_batch_size: Optional[int] = None,

src/transformers/data/processors/squad.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from functools import partial
1818
from multiprocessing import Pool, cpu_count
19+
from typing import Optional
1920

2021
import numpy as np
2122
from tqdm import tqdm
@@ -800,8 +801,8 @@ def __init__(
800801
start_position,
801802
end_position,
802803
is_impossible,
803-
qas_id: str = None,
804-
encoding: BatchEncoding = None,
804+
qas_id: Optional[str] = None,
805+
encoding: Optional[BatchEncoding] = None,
805806
):
806807
self.input_ids = input_ids
807808
self.attention_mask = attention_mask

src/transformers/generation/candidate_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,9 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
914914

915915
def __init__(
916916
self,
917-
eos_token_id: torch.Tensor = None,
917+
eos_token_id: Optional[torch.Tensor] = None,
918918
num_output_tokens: int = 10,
919-
max_matching_ngram_size: int = None,
919+
max_matching_ngram_size: Optional[int] = None,
920920
max_length: int = 20,
921921
):
922922
self.num_output_tokens = num_output_tokens

src/transformers/generation/flax_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode
171171
def _prepare_decoder_input_ids_for_generation(
172172
self,
173173
batch_size: int,
174-
decoder_start_token_id: int = None,
175-
bos_token_id: int = None,
174+
decoder_start_token_id: Optional[int] = None,
175+
bos_token_id: Optional[int] = None,
176176
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
177177
) -> jnp.ndarray:
178178
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
@@ -183,7 +183,9 @@ def _prepare_decoder_input_ids_for_generation(
183183
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
184184
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
185185

186-
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
186+
def _get_decoder_start_token_id(
187+
self, decoder_start_token_id: Optional[int] = None, bos_token_id: Optional[int] = None
188+
) -> int:
187189
# retrieve decoder_start_token_id for encoder-decoder models
188190
# fall back to bos_token_id if necessary
189191
decoder_start_token_id = (

src/transformers/generation/tf_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,8 @@ def _prepare_decoder_input_ids_for_generation(
10771077
batch_size: int,
10781078
model_input_name: str,
10791079
model_kwargs: Dict[str, tf.Tensor],
1080-
decoder_start_token_id: int = None,
1081-
bos_token_id: int = None,
1080+
decoder_start_token_id: Optional[int] = None,
1081+
bos_token_id: Optional[int] = None,
10821082
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
10831083
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
10841084
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
@@ -1111,7 +1111,9 @@ def _prepare_decoder_input_ids_for_generation(
11111111

11121112
return decoder_input_ids, model_kwargs
11131113

1114-
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
1114+
def _get_decoder_start_token_id(
1115+
self, decoder_start_token_id: Optional[int] = None, bos_token_id: Optional[int] = None
1116+
) -> int:
11151117
# retrieve decoder_start_token_id for encoder-decoder models
11161118
# fall back to bos_token_id if necessary
11171119
decoder_start_token_id = (

src/transformers/generation/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class GenerateDecoderOnlyOutput(ModelOutput):
157157
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
158158
"""
159159

160-
sequences: torch.LongTensor = None
160+
sequences: torch.LongTensor
161161
scores: Optional[Tuple[torch.FloatTensor]] = None
162162
logits: Optional[Tuple[torch.FloatTensor]] = None
163163
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -202,7 +202,7 @@ class GenerateEncoderDecoderOutput(ModelOutput):
202202
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
203203
"""
204204

205-
sequences: torch.LongTensor = None
205+
sequences: torch.LongTensor
206206
scores: Optional[Tuple[torch.FloatTensor]] = None
207207
logits: Optional[Tuple[torch.FloatTensor]] = None
208208
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@@ -247,7 +247,7 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
247247
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
248248
"""
249249

250-
sequences: torch.LongTensor = None
250+
sequences: torch.LongTensor
251251
sequences_scores: Optional[torch.FloatTensor] = None
252252
scores: Optional[Tuple[torch.FloatTensor]] = None
253253
logits: Optional[Tuple[torch.FloatTensor]] = None
@@ -301,7 +301,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
301301
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
302302
"""
303303

304-
sequences: torch.LongTensor = None
304+
sequences: torch.LongTensor
305305
sequences_scores: Optional[torch.FloatTensor] = None
306306
scores: Optional[Tuple[torch.FloatTensor]] = None
307307
logits: Optional[Tuple[torch.FloatTensor]] = None
@@ -699,7 +699,7 @@ def _prepare_decoder_input_ids_for_generation(
699699
model_input_name: str,
700700
model_kwargs: Dict[str, torch.Tensor],
701701
decoder_start_token_id: torch.Tensor,
702-
device: torch.device = None,
702+
device: Optional[torch.device] = None,
703703
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
704704
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
705705
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
@@ -923,7 +923,7 @@ def _get_logits_processor(
923923
encoder_input_ids: torch.LongTensor,
924924
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
925925
logits_processor: Optional[LogitsProcessorList],
926-
device: str = None,
926+
device: Optional[str] = None,
927927
model_kwargs: Optional[Dict[str, Any]] = None,
928928
negative_prompt_ids: Optional[torch.Tensor] = None,
929929
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
@@ -4833,7 +4833,7 @@ def _ranking_fast(
48334833
return selected_idx
48344834

48354835

4836-
def _split(data, full_batch_size: int, split_size: int = None):
4836+
def _split(data, full_batch_size: int, split_size: int):
48374837
"""
48384838
Takes care of three cases:
48394839
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim

src/transformers/generation/watermarking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class BayesianDetectorConfig(PretrainedConfig):
257257
Prior probability P(w) that a text is watermarked.
258258
"""
259259

260-
def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs):
260+
def __init__(self, watermarking_depth: Optional[int] = None, base_rate: float = 0.5, **kwargs):
261261
self.watermarking_depth = watermarking_depth
262262
self.base_rate = base_rate
263263
# These can be set later to store information about this detector.

src/transformers/hf_argparser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def make_choice_type_function(choices: list) -> Callable[[str], Any]:
6363

6464
def HfArg(
6565
*,
66-
aliases: Union[str, list[str]] = None,
67-
help: str = None,
66+
aliases: Optional[Union[str, list[str]]] = None,
67+
help: Optional[str] = None,
6868
default: Any = dataclasses.MISSING,
6969
default_factory: Callable[[], Any] = dataclasses.MISSING,
70-
metadata: dict = None,
70+
metadata: Optional[dict] = None,
7171
**kwargs,
7272
) -> dataclasses.Field:
7373
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.

0 commit comments

Comments
 (0)