dwww Home | Show directory contents | Find package

.runThisTest <- Sys.getenv("RunAllinsightTests") == "yes"

if (.runThisTest && requiet("brms")) {
  # Model fitting -----------------------------------------------------------

  m1 <- suppressWarnings(insight::download_model("brms_mixed_6"))
  m2 <- insight::download_model("brms_mv_4")
  m3 <- insight::download_model("brms_2")
  m4 <- insight::download_model("brms_zi_3")
  m5 <- insight::download_model("brms_mv_5")
  m6 <- insight::download_model("brms_corr_re1")
  m7 <- suppressWarnings(insight::download_model("brms_mixed_8"))
  m8 <- insight::download_model("brms_ordinal_1")

  # Tests -------------------------------------------------------------------
  test_that("get_predicted.brmsfit: ordinal dv", {
    skip_if_not_installed("bayestestR")

    pred1 <- get_predicted(m8, ci = 0.95)
    pred2 <- get_predicted(m8, ci_method = "hdi", ci = 0.95)
    expect_true(inherits(pred1, "get_predicted"))
    expect_true(inherits(pred1, "data.frame"))
    expect_true(all(c("Row", "Response") %in% colnames(pred1)))

    # ci_method changes intervals but not se or predicted
    pred1 <- data.frame(pred1)
    pred2 <- data.frame(pred2)
    expect_equal(pred1$Row, pred2$Row)
    expect_equal(pred1$Response, pred2$Response)
    expect_equal(pred1$Predicted, pred2$Predicted)
    expect_equal(pred1$SE, pred2$SE)
    expect_false(mean(pred1$CI_low == pred2$CI_low) > 0.1) # most CI bounds are different
    expect_false(mean(pred1$CI_high == pred2$CI_high) > 0.1) # most CI bounds are different

    # compare to manual predictions
    pred3 <- get_predicted(m8, centrality_function = stats::median, ci = 0.95)
    manual <- rstantools::posterior_epred(m8)
    manual <- apply(manual[, , 1], 2, median)
    expect_equal(pred3$Predicted[1:32], manual)
    manual <- rstantools::posterior_epred(m8)
    manual <- apply(manual[, , 1], 2, mean)
    expect_equal(pred1$Predicted[1:32], manual)
  })

  test_that("find_statistic", {
    expect_null(find_statistic(m1))
    expect_null(find_statistic(m2))
    expect_null(find_statistic(m3))
    expect_null(find_statistic(m4))
    expect_null(find_statistic(m5))
  })

  test_that("n_parameters", {
    expect_equal(n_parameters(m1), 65)
    expect_equal(n_parameters(m1, effects = "fixed"), 5)
  })

  test_that("model_info", {
    expect_true(model_info(m3)$is_trial)
    expect_true(model_info(m5)[[1]]$is_zero_inflated)
    expect_true(model_info(m5)[[1]]$is_bayesian)
  })

  test_that("clean_names", {
    expect_identical(
      clean_names(m1),
      c("count", "Age", "Base", "Trt", "patient")
    )
    expect_identical(
      clean_names(m2),
      c(
        "Sepal.Length",
        "Sepal.Width",
        "Petal.Length",
        "Species"
      )
    )
    expect_identical(clean_names(m3), c("r", "n", "treat", "c2"))
    expect_identical(
      clean_names(m4),
      c("count", "child", "camper", "persons")
    )
    expect_identical(
      clean_names(m5),
      c(
        "count",
        "count2",
        "child",
        "camper",
        "persons",
        "livebait"
      )
    )
  })


  test_that("find_predictors", {
    expect_identical(find_predictors(m1), list(conditional = c("Age", "Base", "Trt")))
    expect_identical(
      find_predictors(m1, flatten = TRUE),
      c("Age", "Base", "Trt")
    )
    expect_identical(
      find_predictors(m1, effects = "all", component = "all"),
      list(
        conditional = c("Age", "Base", "Trt"),
        random = "patient"
      )
    )
    expect_identical(
      find_predictors(
        m1,
        effects = "all",
        component = "all",
        flatten = TRUE
      ),
      c("Age", "Base", "Trt", "patient")
    )

    expect_identical(
      find_predictors(m2),
      list(
        SepalLength = list(conditional = c(
          "Petal.Length", "Sepal.Width", "Species"
        )),
        SepalWidth = list(conditional = "Species")
      )
    )

    expect_identical(
      find_predictors(m2, flatten = TRUE),
      c("Petal.Length", "Sepal.Width", "Species")
    )
    expect_identical(find_predictors(m3), list(conditional = c("treat", "c2")))
    expect_identical(
      find_predictors(m4),
      list(
        conditional = c("child", "camper"),
        zero_inflated = c("child", "camper")
      )
    )
    expect_identical(
      find_predictors(m4, effects = "random"),
      list(random = "persons", zero_inflated_random = "persons")
    )
    expect_identical(find_predictors(m4, flatten = TRUE), c("child", "camper"))

    expect_identical(
      find_predictors(m5),
      list(
        count = list(
          conditional = c("child", "camper"),
          zero_inflated = "camper"
        ),
        count2 = list(
          conditional = c("child", "livebait"),
          zero_inflated = "child"
        )
      )
    )
  })

  test_that("find_response", {
    expect_equal(find_response(m1, combine = TRUE), "count")
    expect_equal(
      find_response(m2, combine = TRUE),
      c(SepalLength = "Sepal.Length", SepalWidth = "Sepal.Width")
    )
    expect_equal(find_response(m3, combine = TRUE), c("r", "n"))
    expect_equal(find_response(m1, combine = FALSE), "count")
    expect_equal(
      find_response(m2, combine = FALSE),
      c(SepalLength = "Sepal.Length", SepalWidth = "Sepal.Width")
    )
    expect_equal(find_response(m3, combine = FALSE), c("r", "n"))
    expect_equal(find_response(m4, combine = FALSE), "count")
    expect_equal(
      find_response(m5, combine = TRUE),
      c(count = "count", count2 = "count2")
    )
  })

  test_that("get_response", {
    expect_length(get_response(m1), 236)
    expect_equal(ncol(get_response(m2)), 2)
    expect_equal(
      colnames(get_response(m2)),
      c("Sepal.Length", "Sepal.Width")
    )
    expect_equal(ncol(get_response(m3)), 2)
    expect_equal(colnames(get_response(m3)), c("r", "n"))
    expect_length(get_response(m4), 250)
    expect_equal(colnames(get_response(m5)), c("count", "count2"))
  })

  test_that("find_variables", {
    expect_identical(
      find_variables(m1),
      list(
        response = "count",
        conditional = c("Age", "Base", "Trt"),
        random = "patient"
      )
    )
    expect_identical(
      find_variables(m6),
      list(
        response = "y",
        conditional = "x",
        random = "id",
        sigma = "x",
        sigma_random = "id"
      )
    )
    expect_identical(
      find_variables(m1, effects = "fixed"),
      list(
        response = "count",
        conditional = c("Age", "Base", "Trt")
      )
    )
    expect_null(find_variables(m1, component = "zi"))

    expect_identical(
      find_variables(m2),
      list(
        response = c(SepalLength = "Sepal.Length", SepalWidth = "Sepal.Width"),
        SepalLength = list(conditional = c(
          "Petal.Length", "Sepal.Width", "Species"
        )),
        SepalWidth = list(conditional = "Species")
      )
    )

    expect_identical(
      find_variables(m2, flatten = TRUE),
      c(
        "Sepal.Length",
        "Sepal.Width",
        "Petal.Length",
        "Species"
      )
    )
    expect_identical(find_variables(m3), list(
      response = c("r", "n"),
      conditional = c("treat", "c2")
    ))

    expect_identical(
      find_variables(m4),
      list(
        response = "count",
        conditional = c("child", "camper"),
        random = "persons",
        zero_inflated = c("child", "camper"),
        zero_inflated_random = "persons"
      )
    )

    expect_identical(
      find_variables(m4, flatten = TRUE),
      c("count", "child", "camper", "persons")
    )
  })

  test_that("n_obs", {
    expect_equal(n_obs(m1), 236)
    expect_equal(n_obs(m2), 150)
    expect_equal(n_obs(m3), 10)
    expect_equal(n_obs(m4), 250)
    expect_equal(n_obs(m5), 250)
  })


  test_that("find_random", {
    expect_equal(find_random(m5), list(
      count = list(
        random = "persons",
        zero_inflated_random = "persons"
      ),
      count2 = list(
        random = "persons",
        zero_inflated_random = "persons"
      )
    ))
    expect_equal(find_random(m5, flatten = TRUE), "persons")
    expect_equal(find_random(m6, flatten = TRUE), "id")
  })


  test_that("get_random", {
    zinb <- get_data(m4)
    expect_equal(get_random(m4), zinb[, "persons", drop = FALSE])
  })


  test_that("get_data", {
    d <- get_data(m6)
    expect_equal(nrow(d), 200)
    expect_equal(ncol(d), 3)
  })


  test_that("find_paramaters", {
    expect_equal(
      find_parameters(m1),
      list(
        conditional = c(
          "b_Intercept",
          "b_Age",
          "b_Base",
          "b_Trt1",
          "b_Base:Trt1"
        ),
        random = c(sprintf("r_patient[%i,Intercept]", 1:59), "sd_patient__Intercept")
      )
    )

    expect_equal(
      find_parameters(m2),
      structure(
        list(
          SepalLength = list(
            conditional = c(
              "b_SepalLength_Intercept",
              "b_SepalLength_Petal.Length",
              "b_SepalLength_Sepal.Width",
              "b_SepalLength_Speciesversicolor",
              "b_SepalLength_Speciesvirginica"
            ),
            sigma = "sigma_SepalLength"
          ),
          SepalWidth = list(
            conditional = c(
              "b_SepalWidth_Intercept",
              "b_SepalWidth_Speciesversicolor",
              "b_SepalWidth_Speciesvirginica"
            ),
            sigma = "sigma_SepalWidth"
          )
        ),
        "is_mv" = "1"
      )
    )

    expect_equal(
      find_parameters(m4),
      list(
        conditional = c("b_Intercept", "b_child", "b_camper"),
        random = c(sprintf("r_persons[%i,Intercept]", 1:4), "sd_persons__Intercept"),
        zero_inflated = c("b_zi_Intercept", "b_zi_child", "b_zi_camper"),
        zero_inflated_random = c(sprintf("r_persons__zi[%i,Intercept]", 1:4), "sd_persons__zi_Intercept")
      )
    )

    expect_equal(
      find_parameters(m5, effects = "all"),
      structure(
        list(
          count = list(
            conditional = c("b_count_Intercept", "b_count_child", "b_count_camper"),
            random = c(sprintf("r_persons__count[%i,Intercept]", 1:4), "sd_persons__count_Intercept"),
            zero_inflated = c("b_zi_count_Intercept", "b_zi_count_camper"),
            zero_inflated_random = c(sprintf("r_persons__zi_count[%i,Intercept]", 1:4), "sd_persons__zi_count_Intercept")
          ),
          count2 = list(
            conditional = c(
              "b_count2_Intercept",
              "b_count2_child",
              "b_count2_livebait"
            ),
            random = c(sprintf("r_persons__count2[%i,Intercept]", 1:4), "sd_persons__count2_Intercept"),
            zero_inflated = c("b_zi_count2_Intercept", "b_zi_count2_child"),
            zero_inflated_random = c(sprintf("r_persons__zi_count2[%i,Intercept]", 1:4), "sd_persons__zi_count2_Intercept")
          )
        ),
        "is_mv" = "1"
      )
    )
  })

  test_that("find_paramaters", {
    expect_equal(
      colnames(get_parameters(m4)),
      c(
        "b_Intercept",
        "b_child",
        "b_camper",
        "b_zi_Intercept",
        "b_zi_child",
        "b_zi_camper"
      )
    )
    expect_equal(
      colnames(get_parameters(m4, component = "zi")),
      c("b_zi_Intercept", "b_zi_child", "b_zi_camper")
    )
    expect_equal(
      colnames(get_parameters(m4, effects = "all")),
      c(
        "b_Intercept", "b_child", "b_camper", "r_persons[1,Intercept]",
        "r_persons[2,Intercept]", "r_persons[3,Intercept]", "r_persons[4,Intercept]",
        "sd_persons__Intercept", "b_zi_Intercept", "b_zi_child", "b_zi_camper",
        "r_persons__zi[1,Intercept]", "r_persons__zi[2,Intercept]", "r_persons__zi[3,Intercept]",
        "r_persons__zi[4,Intercept]", "sd_persons__zi_Intercept"
      )
    )
    expect_equal(
      colnames(get_parameters(m4, effects = "random", component = "conditional")),
      c(
        "r_persons[1,Intercept]", "r_persons[2,Intercept]", "r_persons[3,Intercept]",
        "r_persons[4,Intercept]", "sd_persons__Intercept"
      )
    )
    expect_equal(
      colnames(get_parameters(m5, effects = "random", component = "conditional")),
      c(
        "r_persons__count[1,Intercept]", "r_persons__count[2,Intercept]",
        "r_persons__count[3,Intercept]", "r_persons__count[4,Intercept]",
        "sd_persons__count_Intercept", "r_persons__count2[1,Intercept]",
        "r_persons__count2[2,Intercept]", "r_persons__count2[3,Intercept]",
        "r_persons__count2[4,Intercept]", "sd_persons__count2_Intercept"
      )
    )

    expect_equal(
      colnames(get_parameters(m5, effects = "all", component = "all")),
      c(
        "b_count_Intercept", "b_count_child", "b_count_camper", "r_persons__count[1,Intercept]",
        "r_persons__count[2,Intercept]", "r_persons__count[3,Intercept]",
        "r_persons__count[4,Intercept]", "sd_persons__count_Intercept",
        "b_zi_count_Intercept", "b_zi_count_camper", "r_persons__zi_count[1,Intercept]",
        "r_persons__zi_count[2,Intercept]", "r_persons__zi_count[3,Intercept]",
        "r_persons__zi_count[4,Intercept]", "sd_persons__zi_count_Intercept",
        "b_count2_Intercept", "b_count2_child", "b_count2_livebait",
        "r_persons__count2[1,Intercept]", "r_persons__count2[2,Intercept]",
        "r_persons__count2[3,Intercept]", "r_persons__count2[4,Intercept]",
        "sd_persons__count2_Intercept", "b_zi_count2_Intercept", "b_zi_count2_child",
        "r_persons__zi_count2[1,Intercept]", "r_persons__zi_count2[2,Intercept]",
        "r_persons__zi_count2[3,Intercept]", "r_persons__zi_count2[4,Intercept]",
        "sd_persons__zi_count2_Intercept"
      )
    )
  })

  test_that("linkfun", {
    expect_false(is.null(link_function(m1)))
    expect_length(link_function(m2), 2)
    expect_false(is.null(link_function(m3)))
    expect_false(is.null(link_function(m4)))
    expect_length(link_function(m5), 2)
  })

  test_that("linkinv", {
    expect_false(is.null(link_inverse(m1)))
    expect_length(link_inverse(m2), 2)
    expect_false(is.null(link_inverse(m3)))
    expect_false(is.null(link_inverse(m4)))
    expect_length(link_inverse(m2), 2)
  })

  test_that("is_multivariate", {
    expect_false(is_multivariate(m1))
    expect_true(is_multivariate(m2))
    expect_false(is_multivariate(m3))
    expect_false(is_multivariate(m4))
    expect_true(is_multivariate(m5))
  })

  test_that("find_terms", {
    expect_equal(
      find_terms(m2),
      list(
        SepalLength = list(
          response = "Sepal.Length",
          conditional = c("Petal.Length", "Sepal.Width", "Species")
        ),
        SepalWidth = list(
          response = "Sepal.Width",
          conditional = "Species"
        )
      )
    )
  })

  test_that("find_algorithm", {
    expect_equal(
      find_algorithm(m1),
      list(
        algorithm = "sampling",
        chains = 1,
        iterations = 500,
        warmup = 250
      )
    )
  })


  test_that("get_priors", {
    expect_equal(
      get_priors(m7),
      data.frame(
        Parameter = c(
          "b_Intercept", "b_Age", "b_Base", "b_Trt1", "b_Base:Trt1",
          "sd_patient__Intercept", "sd_patient__Age",
          "cor_patient__Intercept__Age"
        ),
        Distribution = c(
          "student_t", "student_t", "student_t",
          "student_t", "student_t", "cauchy", "cauchy", "lkj"
        ),
        Location = c(1.4, 0, 0, 0, 0, NA, NA, 1),
        Scale = c(2.5, 10, 10, 10, 10, NA, NA, NA),
        df = c(3, 5, 5, 5, 5, NA, NA, NA),
        stringsAsFactors = FALSE
      ),
      ignore_attr = TRUE
    )
    expect_equal(
      get_priors(m3),
      data.frame(
        Parameter = c("b_Intercept", "b_treat1", "b_c2", "b_treat1:c2"),
        Distribution = c("student_t", "uniform", "uniform", "uniform"),
        Location = c(0, NA, NA, NA),
        Scale = c(2.5, NA, NA, NA),
        df = c(3, NA, NA, NA),
        stringsAsFactors = FALSE
      ),
      ignore_attr = TRUE
    )
  })

  test_that("Issue #645", {
    # apparently BH is required to fit these brms models
    skip_if_not_installed("BH")
    # sink() writing permission fail on some Windows CI machines
    skip_on_os("windows")

    void <- suppressMessages(suppressWarnings(capture.output(
      mod <- brm(
        silent = 2,
        data = mtcars,
        family = cumulative(probit),
        formula = bf(
          cyl ~ 1 + mpg + drat + gearnl,
          gearnl ~ 0 + (1 | gear),
          nl = TRUE
        )
      )
    )))

    p <- find_predictors(mod, flatten = TRUE)
    d <- get_data(mod)
    expect_true("gear" %in% p)
    expect_true("gear" %in% colnames(d))
  })

  test_that("clean_parameters", {
    expect_equal(
      clean_parameters(m4),
      structure(
        list(
          Parameter = c(
            "b_Intercept",
            "b_child",
            "b_camper",
            "r_persons[1,Intercept]",
            "r_persons[2,Intercept]",
            "r_persons[3,Intercept]",
            "r_persons[4,Intercept]",
            "sd_persons__Intercept",
            "b_zi_Intercept",
            "b_zi_child",
            "b_zi_camper",
            "r_persons__zi[1,Intercept]",
            "r_persons__zi[2,Intercept]",
            "r_persons__zi[3,Intercept]",
            "r_persons__zi[4,Intercept]",
            "sd_persons__zi_Intercept"
          ),
          Effects = c(
            "fixed",
            "fixed",
            "fixed",
            "random",
            "random",
            "random",
            "random",
            "random",
            "fixed",
            "fixed",
            "fixed",
            "random",
            "random",
            "random",
            "random",
            "random"
          ),
          Component = c(
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated"
          ),
          Group = c(
            "",
            "",
            "",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "SD/Cor: persons",
            "",
            "",
            "",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "SD/Cor: persons"
          ),
          Cleaned_Parameter = c(
            "(Intercept)",
            "child",
            "camper",
            "persons.1",
            "persons.2",
            "persons.3",
            "persons.4",
            "(Intercept)",
            "(Intercept)",
            "child",
            "camper",
            "persons.1",
            "persons.2",
            "persons.3",
            "persons.4",
            "(Intercept)"
          )
        ),
        class = c("clean_parameters", "data.frame"),
        row.names = c(
          NA,
          -16L
        )
      ),
      ignore_attr = TRUE
    )

    expect_equal(
      clean_parameters(m5),
      structure(
        list(
          Parameter = c(
            "b_count_Intercept",
            "b_count_child",
            "b_count_camper",
            "b_count2_Intercept",
            "b_count2_child",
            "b_count2_livebait",
            "r_persons__count[1,Intercept]",
            "r_persons__count[2,Intercept]",
            "r_persons__count[3,Intercept]",
            "r_persons__count[4,Intercept]",
            "sd_persons__count_Intercept",
            "r_persons__count2[1,Intercept]",
            "r_persons__count2[2,Intercept]",
            "r_persons__count2[3,Intercept]",
            "r_persons__count2[4,Intercept]",
            "sd_persons__count2_Intercept",
            "b_zi_count_Intercept",
            "b_zi_count_camper",
            "b_zi_count2_Intercept",
            "b_zi_count2_child",
            "r_persons__zi_count[1,Intercept]",
            "r_persons__zi_count[2,Intercept]",
            "r_persons__zi_count[3,Intercept]",
            "r_persons__zi_count[4,Intercept]",
            "sd_persons__zi_count_Intercept",
            "r_persons__zi_count2[1,Intercept]",
            "r_persons__zi_count2[2,Intercept]",
            "r_persons__zi_count2[3,Intercept]",
            "r_persons__zi_count2[4,Intercept]",
            "sd_persons__zi_count2_Intercept"
          ),
          Effects = c(
            "fixed",
            "fixed",
            "fixed",
            "fixed",
            "fixed",
            "fixed",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "fixed",
            "fixed",
            "fixed",
            "fixed",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random",
            "random"
          ),
          Component = c(
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "conditional",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated",
            "zero_inflated"
          ),
          Group = c(
            "",
            "",
            "",
            "",
            "",
            "",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "SD/Cor: persons",
            "Intercept: persons2",
            "Intercept: persons2",
            "Intercept: persons2",
            "Intercept: persons2",
            "SD/Cor: persons",
            "",
            "",
            "",
            "",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "Intercept: persons",
            "SD/Cor: persons",
            "Intercept: persons2",
            "Intercept: persons2",
            "Intercept: persons2",
            "Intercept: persons2",
            "SD/Cor: persons"
          ),
          Response = c(
            "count",
            "count",
            "count",
            "count2",
            "count2",
            "count2",
            "count",
            "count",
            "count",
            "count",
            "count",
            "count2",
            "count2",
            "count2",
            "count2",
            "count2",
            "count",
            "count",
            "count2",
            "count2",
            "count",
            "count",
            "count",
            "count",
            "count",
            "count2",
            "count2",
            "count2",
            "count2",
            "count2"
          ),
          Cleaned_Parameter = c(
            "(Intercept)",
            "child",
            "camper",
            "(Intercept)",
            "child",
            "livebait",
            "persons.1",
            "persons.2",
            "persons.3",
            "persons.4",
            "count_Intercept",
            "persons2.1",
            "persons2.2",
            "persons2.3",
            "persons2.4",
            "count2_Intercept",
            "(Intercept)",
            "camper",
            "(Intercept)",
            "child",
            "persons.1",
            "persons.2",
            "persons.3",
            "persons.4",
            "zi_count_Intercept",
            "persons2.1",
            "persons2.2",
            "persons2.3",
            "persons2.4",
            "zi_count2_Intercept"
          )
        ),
        class = c("clean_parameters", "data.frame"),
        row.names = c(NA, -30L)
      ),
      ignore_attr = TRUE
    )
  })
}

Generated by dwww version 1.15 on Sat May 18 06:07:09 CEST 2024.