Skip to content

Commit 3bc311e

Browse files
TensorFlow Datasets Teamcopybara-github
authored andcommitted
Add some extra flags to TokenTextEncoder for increased flexibility.
PiperOrigin-RevId: 258428299
1 parent b9c4488 commit 3bc311e

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

tensorflow_datasets/core/features/text/text_encoder.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def __init__(self,
230230
oov_buckets=1,
231231
oov_token="UNK",
232232
lowercase=False,
233-
tokenizer=None):
233+
tokenizer=None,
234+
strip_vocab=True,
235+
decode_token_separator=" "):
234236
"""Constructs a TokenTextEncoder.
235237
236238
To load from a file saved with `TokenTextEncoder.save_to_file`, use
@@ -244,8 +246,14 @@ def __init__(self,
244246
lowercase: `bool`, whether to make all text and tokens lowercase.
245247
tokenizer: `Tokenizer`, responsible for converting incoming text into a
246248
list of tokens.
249+
strip_vocab: `bool`, whether to strip whitespace from the beginning and
250+
end of elements of `vocab_list`.
251+
decode_token_separator: `str`, the string used to separate tokens when
252+
decoding.
247253
"""
248-
self._vocab_list = [tf.compat.as_text(el).strip() for el in vocab_list]
254+
self._vocab_list = [tf.compat.as_text(el) for el in vocab_list]
255+
if strip_vocab:
256+
self._vocab_list = [el.strip() for el in self._vocab_list]
249257
self._lowercase = lowercase
250258
if self._lowercase:
251259
self._vocab_list = [t.lower() for t in self._vocab_list]
@@ -261,6 +269,8 @@ def __init__(self,
261269
self._tokenizer = (tokenizer or Tokenizer(reserved_tokens=reserved_tokens))
262270
self._user_defined_tokenizer = tokenizer
263271

272+
self._decode_token_separator = decode_token_separator
273+
264274
def encode(self, s):
265275
s = tf.compat.as_text(s)
266276
if self.lowercase:
@@ -286,7 +296,7 @@ def decode(self, ids):
286296
tokens.append(self._vocab_list[int_id])
287297
else:
288298
tokens.append(self._oov_token)
289-
return " ".join(tokens)
299+
return self._decode_token_separator.join(tokens)
290300

291301
@property
292302
def vocab_size(self):

0 commit comments

Comments
 (0)