Skip to content

Commit dc508b2

Browse files
authored
Merge pull request #6 from bnicenboim/dev
max length
2 parents aad2ef1 + 4fc110c commit dc508b2

File tree

11 files changed

+38
-145
lines changed

11 files changed

+38
-145
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Authors@R: c(
1010
person("Chris", "Emmerly", role = "ctb"),
1111
person("Giovanni", "Cassani", role = "ctb"))
1212
Description: Access to word predictability using large language (transformer) models.
13-
URL: <https://bruno.nicenboim.me/pangoling>, <https://github.com/bnicenboim/pangoling>
13+
URL: https://bruno.nicenboim.me/pangoling, https://github.com/bnicenboim/pangoling
1414
BugReports: https://github.com/bnicenboim/pangoling/issues
1515
License: MIT + file LICENSE
1616
Encoding: UTF-8

R/tr_utils.R

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,28 +220,27 @@ create_tensor_lst <- function(texts,
220220
!is.null(tkzr$special_tokens_map$eos_token)) {
221221
tkzr$pad_token <- tkzr$eos_token
222222
}
223-
max_length <- tkzr$model_max_length
224223
# If I runt the following line, some models such as
225224
# 'flax-community/gpt-2-spanish' give a weird error of
226225
# 'GPT2TokenizerFast' object has no attribute 'is_fast'
227226
# max_length <- tkzr$model_max_length
228227
# thus the ugly hack
229-
## max_length <- chr_match(utils::capture.output(tkzr),
230-
## pattern = "model_max_len=([0-9]*)") |>
231-
## c() |>
232-
## (\(x) x[[2]])()
233-
if (is.null(max_length) || is.na(max_length) || max_length < 1) {
234-
message_verbose("Unknown maximum length of input. This might cause a problem for long inputs exceeding the maximum length.")
235-
max_length <- Inf
236-
}
228+
# max_length <- chr_match(utils::capture.output(tkzr),
229+
# pattern = "model_max_len=([0-9]*)") |>
230+
# c() |>
231+
# (\(x) x[[2]])()
232+
# if (is.null(max_length) || is.na(max_length) || max_length < 1) {
233+
# message_verbose("Unknown maximum length of input. This might cause a problem for long inputs exceeding the maximum length.")
234+
# max_length <- Inf
235+
# }
237236
lapply(texts, function(text) {
238237
tensor <- encode(text,
239238
tkzr,
240239
add_special_tokens = add_special_tokens,
241240
stride = as.integer(stride),
242-
truncation = is.finite(max_length),
243-
return_overflowing_tokens = is.finite(max_length),
244-
padding = is.finite(max_length)
241+
truncation = TRUE, #is.finite(max_length),
242+
return_overflowing_tokens = TRUE, #is.finite(max_length),
243+
padding = TRUE #is.finite(max_length)
245244
)
246245
tensor
247246
})

R/zzz.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ torch <- NULL
3030
lang_model <<- memoise::memoise(lang_model)
3131
transformer_vocab <<- memoise::memoise(transformer_vocab)
3232

33+
# avoid notes:
34+
utils::globalVariables(c("mask_n"))
35+
3336
invisible()
3437
}
3538

36-
## avoid notes:
37-
utils::globalVariables(c("mask_n"))

README.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ df_sent
7070

7171
> Nicenboim B (2023). _pangoling: Access to
7272
> language model predictions in R_. R package
73-
> version 0.0.0.9000,
73+
> version `r packageVersion("pangoling")`,
7474
> <https://github.com/bnicenboim/pangoling>.
7575
7676
## Code of conduct

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ df_sent
122122
## How to cite
123123

124124
> Nicenboim B (2023). *pangoling: Access to language model predictions
125-
> in R*. R package version 0.0.0.9000,
125+
> in R*. R package version 0.0.0.9001,
126126
> <https://github.com/bnicenboim/pangoling>.
127127
128128
## Code of conduct

man/pangoling-package.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# This file is part of the standard setup for testthat.
2+
# It is recommended that you do not modify it.
3+
#
4+
# Where should you do additional test configuration?
5+
# Learn more about the roles of various files in:
6+
# * https://r-pkgs.org/tests.html
7+
# * https://testthat.r-lib.org/reference/test_package.html#special-files
8+
19
library(testthat)
210
library(pangoling)
311

tests/testthat/_snaps/tr_causal.md

Lines changed: 0 additions & 66 deletions
This file was deleted.

tests/testthat/_snaps/tr_masked.md

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/testthat/test-tr_causal.R

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ test_that("gpt2 get prob work", {
2525
cont <-
2626
causal_next_tokens_tbl("The apple doesn't fall far from the")
2727
expect_equal(sum(exp(cont$lp)),1,tolerance = .0001)
28-
expect_snapshot(cont)
2928
expect_equal(cont[1]$token, "Ġtree")
3029
prov_words <- strsplit(prov, " ")[[1]]
3130
sent2_words <- strsplit(sent2, " ")[[1]]
@@ -36,7 +35,6 @@ test_that("gpt2 get prob work", {
3635
expect_equal(names(lp_sent2), sent2_words)
3736
lp_sent3 <- causal_lp(x = sent3_words)
3837
expect_equal(names(lp_sent3), sent3_words)
39-
expect_snapshot(lp_prov)
4038
expect_equal(cont$lp[1], unname(lp_prov[[8]]), tolerance = .0001)
4139
lp_prov_mat <- causal_lp_mats(x = prov_words)
4240
mat <- lp_prov_mat[[1]]
@@ -56,16 +54,13 @@ test_that("gpt2 get prob work", {
5654
expect_equal(rownames(lp_prov_mat[[1]]), transformer_vocab())
5755
expect_equal(sum(exp(mat[, 2])), 1, tolerance = .0001) # sums to one
5856

59-
lp_prov2 <-
60-
causal_lp(x = strsplit(paste0(prov, "."), " ")[[1]])
61-
expect_snapshot(lp_prov2)
6257
# regex
63-
lp_prov3 <-
58+
lp_prov2 <-
6459
causal_lp(
6560
x = strsplit(paste0(prov, "."), " ")[[1]],
6661
ignore_regex = "[[:punct:]]"
6762
)
68-
expect_equal(unname(lp_prov), unname(lp_prov3), tolerance = 0.001)
63+
expect_equal(unname(lp_prov), unname(lp_prov2), tolerance = 0.001)
6964

7065
##
7166
sent <- "This is it, is it?"
@@ -102,15 +97,14 @@ test_that("can handle extra parameters", {
10297
word_1_prob <- causal_next_tokens_tbl("<|endoftext|>")
10398
prob1 <- word_1_prob[token == "This"]$lp
10499
names(prob1) <- "This"
105-
expect_snapshot(probs)
106-
expect_equal(probs[1], prob1)
100+
expect_equal(probs[1], prob1, tolerance = 0.0001)
107101

108102
probs_F <- causal_lp(x = c("This", "is", "it"), add_special_tokens = FALSE)
109103
expect_true(is.na(probs_F[1]))
110104
word_2_prob <- causal_next_tokens_tbl("This")
111105
prob2 <- word_2_prob[token == "Ġis"]$lp
112106
names(prob2) <- "is"
113-
expect_equal(probs_F[2], prob2)
107+
expect_equal(probs_F[2], prob2, tolerance = .0001)
114108
})
115109

116110

@@ -129,18 +123,16 @@ if (0) {
129123
})
130124
}
131125

132-
test_that("other models using get prob work", {
126+
test_that("other models using get prob don't fail", {
133127
skip_if_no_python_stuff()
134128
tokenize("El bebé de cigüeña.", model = "flax-community/gpt-2-spanish")
135129

136-
expect_snapshot(
137-
causal_lp(x = c("El", "bebé", "de", "cigüeña."), model = "flax-community/gpt-2-spanish")
138-
)
130+
expect_no_error(causal_lp(x = c("El", "bebé", "de", "cigüeña."),
131+
model = "flax-community/gpt-2-spanish"))
139132

140-
lp_provd <-
133+
expect_no_error(
141134
causal_lp(
142135
x = strsplit(paste0(prov, "."), " ")[[1]],
143136
model = "distilgpt2"
144-
)
145-
expect_snapshot(lp_provd)
137+
))
146138
})

tests/testthat/test-tr_masked.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ test_that("bert masked works", {
1717
masked_tokens_tbl("The apple doesn't fall far from the [MASK].",
1818
model = "google/bert_uncased_L-2_H-128_A-2"
1919
)
20-
21-
expect_snapshot(mask_1)
20+
expect_equal(colnames(mask_1),c("masked_sentence", "token", "lp", "mask_n"))
21+
expect_equal(sum(exp(mask_1$lp)),1, tolerance = 0.0001)
2222
mask_2 <-
2323
masked_tokens_tbl("The apple doesn't fall far from [MASK] [MASK].",
2424
model = "google/bert_uncased_L-2_H-128_A-2"
2525
)
26-
expect_snapshot(mask_2)
2726
mask_2_ <-
2827
masked_tokens_tbl(
2928
"[CLS] The apple doesn't fall far from [MASK] [MASK]. [SEP]",

0 commit comments

Comments
 (0)