Skip to content

Commit bb410f4

Browse files
authored
Merge pull request #284 from alan-turing-institute/fix_single_logical_bug
Fix_single_logical_bug
2 parents c727639 + 866aa95 commit bb410f4

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: distr6
22
Title: The Complete R6 Probability Distributions Interface
3-
Version: 1.6.13
3+
Version: 1.6.14
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# distr6 1.6.14
2+
3+
* Fix bug when extracting a single distribution with a logical vector from `MatDist`
4+
15
# distr6 1.6.13
26

37
* Fix reordering bug when extracting vector distributions

R/SDistribution_Matdist.R

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Matdist <- R6Class("Matdist",
7272
support = Set$new(1, class = "numeric")^"n",
7373
type = Reals$new()^"n"
7474
)
75+
private$.ndists <- nrow(gprm(self, "pdf"))
7576
invisible(self)
7677
},
7778

@@ -80,7 +81,7 @@ Matdist <- R6Class("Matdist",
8081
#' @param n `(integer(1))` \cr
8182
#' Ignored.
8283
strprint = function(n = 2) {
83-
"Matdist()"
84+
sprintf("Matdist(%s)", private$.ndists)
8485
},
8586

8687
# stats
@@ -128,7 +129,7 @@ Matdist <- R6Class("Matdist",
128129
"*" %=% gprm(self, c("x", "pdf"))
129130
mean <- self$mean()
130131

131-
vnapply(seq(nrow(pdf)), function(i) {
132+
vnapply(seq_len(private$.ndists), function(i) {
132133
if (mean[[i]] == Inf) {
133134
Inf
134135
} else {
@@ -149,7 +150,7 @@ Matdist <- R6Class("Matdist",
149150
mean <- self$mean()
150151
sd <- self$stdev()
151152

152-
vnapply(seq(nrow(pdf)), function(i) {
153+
vnapply(seq_len(private$.ndists), function(i) {
153154
if (mean[[i]] == Inf) {
154155
Inf
155156
} else {
@@ -171,7 +172,7 @@ Matdist <- R6Class("Matdist",
171172
mean <- self$mean()
172173
sd <- self$stdev()
173174

174-
kurt <- vnapply(seq(nrow(pdf)), function(i) {
175+
kurt <- vnapply(seq_len(private$.ndists), function(i) {
175176
if (mean[[i]] == Inf) {
176177
Inf
177178
} else {
@@ -209,8 +210,8 @@ Matdist <- R6Class("Matdist",
209210
if (length(t) == 1) {
210211
mgf <- apply(pdf, 1, function(.y) sum(exp(x * t) * .y))
211212
} else {
212-
stopifnot(length(z) == nrow(pdf))
213-
mgf <- vnapply(seq(nrow(pdf)),
213+
stopifnot(length(z) == private$.ndists)
214+
mgf <- vnapply(seq_len(private$.ndists),
214215
function(i) sum(exp(x * t[[i]]) * pdf[i, ]))
215216
}
216217

@@ -228,8 +229,8 @@ Matdist <- R6Class("Matdist",
228229
if (length(t) == 1) {
229230
cf <- apply(pdf, 1, function(.y) sum(exp(x * t * 1i) * .y))
230231
} else {
231-
stopifnot(length(z) == nrow(pdf))
232-
cf <- vnapply(seq(nrow(pdf)),
232+
stopifnot(length(z) == private$.ndists)
233+
cf <- vnapply(seq_len(private$.ndists),
233234
function(i) sum(exp(x * t[[i]] * 1i) * pdf[i, ]))
234235
}
235236

@@ -247,8 +248,8 @@ Matdist <- R6Class("Matdist",
247248
if (length(z) == 1) {
248249
pgf <- apply(pdf, 1, function(.y) sum((z^x) * .y))
249250
} else {
250-
stopifnot(length(z) == nrow(pdf))
251-
pgf <- vnapply(seq(nrow(pdf)),
251+
stopifnot(length(z) == private$.ndists)
252+
pgf <- vnapply(seq_len(private$.ndists),
252253
function(i) sum((z[[i]]^x) * pdf[i, ]))
253254
}
254255

@@ -271,7 +272,7 @@ Matdist <- R6Class("Matdist",
271272
.pdf = function(x, log = FALSE) {
272273
"pdf, data" %=% gprm(self, c("pdf", "x"))
273274
out <- t(C_Vec_WeightedDiscretePdf(
274-
x, matrix(data, ncol(pdf), nrow(pdf)),
275+
x, matrix(data, ncol(pdf), private$.ndists),
275276
t(pdf)
276277
))
277278
if (log) {
@@ -306,7 +307,8 @@ Matdist <- R6Class("Matdist",
306307

307308
# traits
308309
.traits = list(valueSupport = "discrete", variateForm = "univariate"),
309-
.improper = FALSE
310+
.improper = FALSE,
311+
.ndists = 0
310312
)
311313
)
312314

@@ -392,7 +394,12 @@ c.Matdist <- function(...) {
392394
#' m[1:2]
393395
#' @export
394396
"[.Matdist" <- function(md, i) {
395-
if (length(i) == 1) {
397+
if (is.logical(i)) {
398+
i <- which(i)
399+
}
400+
if (length(i) == 0) {
401+
stop("Can't create an empty distribution.")
402+
} else if (length(i) == 1) {
396403
pdf <- gprm(md, "pdf")[i, ]
397404
dstr("WeightedDiscrete", x = as.numeric(names(pdf)), pdf = pdf,
398405
decorators = md$decorators)

tests/testthat/test-sdistribution-Matdist.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ test_that("c.Matdist", {
8686
expect_true(all(r >= 1))
8787
})
8888

89-
test_that("c.Matdist", {
89+
test_that("[.Matdist", {
9090
set.seed(1)
9191
m <- as.Distribution(
9292
t(apply(matrix(runif(200), 20, 10, FALSE,
@@ -95,9 +95,16 @@ test_that("c.Matdist", {
9595
fun = "pdf"
9696
)
9797

98+
expect_equal(m$strprint(), "Matdist(20)")
99+
100+
expect_error(m[logical(20)], "empty")
101+
98102
m1 <- m[1]
99103
m12 <- m[1:2]
104+
100105
expect_distribution(m1, "WeightedDiscrete")
106+
expect_distribution(m[!logical(20)], "Matdist")
107+
expect_distribution(m[c(TRUE, logical(19))], "WeightedDiscrete")
101108
expect_distribution(m12, "Matdist")
102109

103110
expect_equal(unname(m$cdf(0:25)[, 1]), unname(m1$cdf(0:25)))

0 commit comments

Comments
 (0)