Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rvar indexing and casting improvements #247

Merged
merged 15 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/covr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ jobs:

- uses: actions/checkout@v2

- uses: r-lib/actions/setup-r@v1
- uses: r-lib/actions/setup-r@v2

- uses: r-lib/actions/setup-pandoc@v1
- uses: r-lib/actions/setup-pandoc@v2

- name: Query dependencies
run: |
Expand All @@ -50,4 +50,4 @@ jobs:

- name: Test coverage
run: covr::codecov()
shell: Rscript {0}
shell: Rscript {0}
4 changes: 2 additions & 2 deletions .github/workflows/rcmdcheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ jobs:

- uses: actions/checkout@v2

- uses: r-lib/actions/setup-r@v1
- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}

- uses: r-lib/actions/setup-pandoc@v1
- uses: r-lib/actions/setup-pandoc@v2

- name: Query dependencies
run: |
Expand Down
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: posterior
Title: Tools for Working with Posterior Distributions
Version: 1.2.2
Version: 1.2.2.9000
Date: 2022-06-09
Authors@R: c(person("Paul-Christian", "Bürkner", email = "[email protected]", role = c("aut", "cre")),
person("Jonah", "Gabry", email = "[email protected]", role = c("aut")),
Expand All @@ -22,9 +22,10 @@ Description: Provides useful tools for both users and developers of packages
(d) Provide lightweight implementations of state of the art posterior
inference diagnostics. References: Vehtari et al. (2021)
<doi:10.1214/20-BA1221>.
Depends:
Depends:
R (>= 3.2.0)
Imports:
methods,
abind,
checkmate,
rlang (>= 0.4.7),
Expand Down Expand Up @@ -52,5 +53,5 @@ LazyData: false
URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/
BugReports: https://github.com/stan-dev/posterior/issues
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2
RoxygenNote: 7.2.1
VignetteBuilder: knitr
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,15 @@ export(chain_ids)
export(default_convergence_measures)
export(default_mcse_measures)
export(default_summary_measures)
export(diag)
export(draw_ids)
export(draws_array)
export(draws_df)
export(draws_list)
export(draws_matrix)
export(draws_of)
export(draws_rvars)
export(drop)
export(ess_basic)
export(ess_bulk)
export(ess_mean)
Expand Down Expand Up @@ -396,11 +398,17 @@ export(variables)
export(variance)
export(weight_draws)
export(z_scale)
exportMethods(diag)
exportMethods(drop)
import(checkmate)
import(stats)
importFrom(abind,abind)
importFrom(distributional,cdf)
importFrom(distributional,variance)
importFrom(methods,callNextMethod)
importFrom(methods,setGeneric)
importFrom(methods,setMethod)
importFrom(methods,setOldClass)
importFrom(pillar,format_glimpse)
importFrom(pillar,new_pillar_shaft_simple)
importFrom(pillar,pillar_shaft)
Expand Down
17 changes: 17 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# posterior 1.2.2.9000

### Enhancements

* Add an implementation of `drop()` for `rvar`s.

### Bug Fixes

* Support remaining modes of `diag()` for `rvar`s (#246).
* Better parsing for named indices in `as_draws_rvars()`, including nested use
of `[`, like `x[y[1],2]` (#243).
* Allow 0-length `rvar`s with `ndraws() > 1` (#242).
* Ensure 0-length `rvar`s can be cast to `draws` formats (#242).
* Don't treat length-1 `rvar`s with more than 1 dimension as scalars when
casting to other formats (#248).


# posterior 1.2.2

### Enhancements
Expand Down
2 changes: 1 addition & 1 deletion R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ check_draws_object <- function(x) {
# use the 'unique' naming strategy of tibble
# @param nvariables number of variables
default_variables <- function(nvariables) {
paste0("...", seq_len(nvariables))
sprintf("...%s", seq_len(nvariables))
}

# validate draws vectors per variable
Expand Down
16 changes: 12 additions & 4 deletions R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ as_draws_rvars.rvar <- function(x, ...) {
#' @rdname draws_rvars
#' @export
as_draws_rvars.draws_matrix <- function(x, ...) {
.variables <- variables(x)
if (ndraws(x) == 0) {
return(empty_draws_rvars(variables(x)))
return(empty_draws_rvars(.variables))
}

# split x[y,z] names into base name and indices
vars_indices <- strsplit(variables(x), "(\\[|\\])")
vars <- sapply(vars_indices, `[[`, 1)
#
# ----- base name -> vars_indices[[i]][[2]]
# ||||| lazy-matched (.*? not .*) so that indices match as much as they can
# |||||
# ||||| ---- optional indices -> vars_indices[[i]][[3]]
# ||||| ||||
matches <- regexec("^(.*?)(?:\\[(.*)\\])?$", .variables)
vars_indices <- regmatches(.variables, matches)
vars <- vapply(vars_indices, `[[`, i = 2, character(1))

# pull out each var into its own rvar
var_names <- unique(vars)
Expand All @@ -78,7 +86,7 @@ as_draws_rvars.draws_matrix <- function(x, ...) {

# first, pull out the list of indices into a data frame
# where each column is an index variable
indices <- sapply(vars_indices[var_i], `[[`, 2)
indices <- vapply(vars_indices[var_i], `[[`, i = 3, character(1))
indices <- as.data.frame(do.call(rbind, strsplit(indices, ",")),
stringsAsFactors = FALSE)
unique_indices <- vector("list", length(indices))
Expand Down
39 changes: 24 additions & 15 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = 1L, with_c

#' @importFrom vctrs new_vctr
new_rvar <- function(x = double(), .nchains = 1L) {
if (length(x) == 0) {
if (is.null(x)) {
x <- double()
}

Expand All @@ -126,6 +126,9 @@ new_rvar <- function(x = double(), .nchains = 1L) {
)
}

#' @importFrom methods setOldClass
setOldClass(c("rvar", "vctrs_vctr", "list"))


# manipulating raw draws array --------------------------------------------

Expand Down Expand Up @@ -288,7 +291,7 @@ anyDuplicated.rvar <- function(x, incomparables = FALSE, MARGIN = 1, ...) {
# then return the corresponding margin for draws_of(x)
check_rvar_margin <- function(x, MARGIN) {
if (!(1 <= MARGIN && MARGIN <= length(dim(x)))) {
stop_no_call("MARGIN = ", MARGIN, " is invalid for dim = ", paste0(dim(x), collapse = ","))
stop_no_call("MARGIN = ", MARGIN, " is invalid for length(dim(x)) = ", length(dim(x)))
}
MARGIN + 1
}
Expand Down Expand Up @@ -325,13 +328,16 @@ all.equal.rvar <- function(target, current, ...) {
check_rvar_yank_index = function(x, i, ...) {
index <- dots_list(i, ..., .preserve_empty = TRUE, .ignore_empty = "none")

if (any(lengths(index)) > 1) {
index_lengths <- lengths(index)
if (any(index_lengths == 0)) {
stop_no_call("Cannot select zero elements with `[[` in an rvar.")
} else if (any(index_lengths > 1)) {
stop_no_call("Cannot select more than one element per index with `[[` in an rvar.")
} else if (any(sapply(index, function(x) is_missing(x) || is.na(x)))) {
} else if (any(vapply(index, function(x) is_missing(x) || is.na(x), logical(1)))) {
stop_no_call("Missing indices not allowed with `[[` in an rvar.")
} else if (any(sapply(index, is.logical))) {
} else if (any(vapply(index, is.logical, logical(1)))) {
stop_no_call("Logical indices not allowed with `[[` in an rvar.")
} else if (any(sapply(index, function(x) x < 0))) {
} else if (any(vapply(index, function(x) x < 0, logical(1)))) {
stop_no_call("subscript out of bounds")
}

Expand Down Expand Up @@ -550,8 +556,8 @@ broadcast_draws <- function(x, .ndraws, keep_constants = FALSE) {
flatten_array = function(x, x_name = NULL) {
# determine new dimension names in the form x,y,z
# start with numeric names
dimname_lists = lapply(dim(x), seq_len)
.dimnames = dimnames(x)
dimname_lists <- lapply(dim(x), seq_len)
.dimnames <- dimnames(x)
if (!is.null(.dimnames)) {
# where character names are provided, use those instead of the numeric names
dimname_lists = lapply(seq_along(dimname_lists), function(i) .dimnames[[i]] %||% dimname_lists[[i]])
Expand All @@ -560,18 +566,20 @@ flatten_array = function(x, x_name = NULL) {
dimname_grid <- expand.grid(dimname_lists)
new_names <- apply(dimname_grid, 1, paste0, collapse = ",")

dim(x) <- prod(dim(x))
.length <- length(x)
old_dim <- dim(x)
dim(x) <- .length

# update variable names
if (is.null(x_name)) {
# no base name for x provided, just use index names
names(x) <- new_names
} else if (length(x) > 1) {
} else if (.length == 1 && (isTRUE(old_dim == 1) || length(old_dim) == 0)) {
# scalar, use the provided base name
names(x) <- x_name
} else if (.length >= 1) {
# rename the variables with their indices in brackets
names(x) <- paste0(x_name, "[", new_names %||% seq_along(x), "]")
} else {
# just one variable, use the provided base name
names(x) <- x_name
}

x
Expand Down Expand Up @@ -634,10 +642,11 @@ drop_chain_dim <- function(x) {
#' @noRd
cleanup_draw_dims <- function(x) {
if (length(x) == 0) {
# canonical NULL rvar is 1 draw of nothing
# canonical NULL rvar is at least 1 draw of nothing
# this ensures that (e.g.) extending a null rvar
# with x[1] = something works.
dim(x) <- c(1, 0)
ndraws <- max(NROW(x), 1)
dim(x) <- c(ndraws, 0)
}
else if (length(dim(x)) <= 1) {
# 1d vectors get treated as a single variable
Expand Down
2 changes: 1 addition & 1 deletion R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ as_rvar <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
if (!is_rvar(out)) {
out <- vec_cast(out, new_rvar())
}
if (!length(out)) {
if (is.null(out)) {
out <- rvar()
}

Expand Down
49 changes: 49 additions & 0 deletions R/rvar-dim.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,52 @@ names.rvar <- function(x) {
dimnames(draws_of(x))[2] <- list(value)
x
}

#' @importFrom methods setGeneric
#' @export
setGeneric("drop")

#' Drop redundant dimensions
#'
#' Delete the dimensions of an [`rvar`] which are of size one. See [`base::drop()`]
#'
#' @param x (rvar) an [`rvar`].
#'
#' @return
#' An [`rvar`] with the same length as `x`, but where any entry equal to `1`
#' in `dim(x)` has been removed. The exception is if `dim(x) == 1`, in which
#' case `dim(drop(x)) == 1` as well (this is because [`rvar`]s, unlike [`numeric`]s,
#' never have `NULL` dimensions).
#'
#' @examples
#' # Sigma is a 3x3 covariance matrix
#' Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma
#' Sigma
#'
#' Sigma[1, ]
#'
#' drop(Sigma[1, ])
#'
#' # equivalently ...
#' Sigma[1, drop = TRUE]
#'
#' @importFrom methods setMethod
#' @export
setMethod("drop", signature(x = "rvar"), function(x) {
.dim <- dim(x)

if (length(.dim) > 1) {
# with exactly 1 dimension left we don't want to drop anything
# (otherwise names get lost), so only do this with > 1 dimension
keep_dim <- .dim != 1
.dimnames <- dimnames(x)
dim(x) <- .dim[keep_dim]
# for comparison / testing, ensure if no dimnames have names that we
# actually have those names be NULL (rather than just empty strings)
new_dimnames <- .dimnames[keep_dim]
if (all(names(new_dimnames) == "")) names(new_dimnames) <- NULL
dimnames(x) <- new_dimnames
}

x
})
64 changes: 64 additions & 0 deletions R/rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,70 @@ chol.rvar <- function(x, ...) {
new_rvar(result, .nchains = nchains(x))
}

#' @importFrom methods setGeneric
#' @export
setGeneric("diag")

#' Matrix diagonals (including for random variables)
#'
#' Extract the diagonal of a matrix or construct a matrix, including random
#' matrices (2-dimensional [`rvar`]s). Makes [`base::diag()`] generic.
#'
#' @inheritParams base::diag
#' @param x (numeric,rvar) a matrix, vector, 1D array, missing, or a 1- or
#' 2-dimensional [`rvar`].
#'
#' @details
#' Makes [`base::diag()`] into a generic function. See that function's documentation
#' for usage with [`numeric`]s and for usage of [`diag<-`], which is also supported
#' by [`rvar`].
#'
#' @return
#'
#' For [`rvar`]s, has two modes:
#'
#' 1. `x` is a matrix-like [`rvar`]: it returns the diagonal as a vector-like [`rvar`]
#' 2. `x` is a vector-like [`rvar`]: it returns a matrix-like [`rvar`] with `x` as
#' the diagonal and zero for off-diagonal entries.
#'
#' @seealso [`base::diag()`]
#'
#' @examples
#'
#' # Sigma is a 3x3 covariance matrix
#' Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma
#' Sigma
#'
#' diag(Sigma)
#'
#' diag(Sigma) <- 1:3
#' Sigma
#'
#' diag(as_rvar(1:3))
#'
#' @importFrom methods setMethod callNextMethod
#' @export
setMethod("diag", signature(x = "rvar"), function(x = 1, nrow, ncol, names = TRUE) {
if (length(dim(x)) > 1) {
# base implementation of diag() works on rvars except when x is a vector
callNextMethod()
} else {
if (missing(nrow)) {
nrow <- length(x)
}
if (missing(ncol)) {
ncol <- nrow
}
out <- as_rvar(matrix(rep(0, nrow * ncol), nrow = nrow, ncol = ncol))
n <- min(nrow, ncol)
x <- rep_len(x, n)
i <- seq_len(n)
out[cbind(i, i)] <- x
out
}
})


# transpose and permutation -----------------------------------------------

#' @export
Expand Down
Loading