library(sparklyr)
library(testthat)

sc <- spark_connect(master = "local")

df <- data.frame(
  x = seq(8) - 1,
  weight = c(
    rep(4, 1),
    rep(3, 2),
    rep(2, 5)
  )
)
sdf <- copy_to(sc, df, overwrite = TRUE)

sample_size <- 2

# map each possible outcome to an octal value
to_oct <- function(sample)
  sum(8 ^ seq(0, sample_size - 1) * sample$x)

max_possible_outcome <- to_oct(list(x = c(6, 7)))

num_iters <- 1000

expected_dist <- rep(0, max_possible_outcome)
actual_dist <- rep(0, max_possible_outcome)

alpha <- 0.05

for (x in seq(num_iters)) {
  # ensure both dplyr::slice_sample and sdf_weighted_sample are
  # using a different PRNG seed within each iteration
  seed <- x * 97
  set.seed(seed)
  sample <- df %>%
    dplyr::slice_sample(
      n = sample_size,
      weight_by = weight,
      replace = FALSE
    ) %>%
    to_oct()
  expected_dist[[sample]] <- expected_dist[[sample]] + 1

  sample <- sdf %>%
    sdf_weighted_sample(
      k = sample_size,
      weight_col = "weight",
      replacement = FALSE,
      seed = seed
    ) %>%
    collect() %>%
    to_oct()
  actual_dist[[sample]] <- actual_dist[[sample]] + 1
}

res <- ks.test(x = actual_dist, y = expected_dist)

expect_gte(res$p.value, alpha)