pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool {
let mut sum = 0.0;
for &elt in arr {
if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() {
return false;
}
sum += elt;
}
sum != 0.0
}
#[cfg(test)]
pub mod test {
use super::is_valid_multinomial;
use crate::consts::ACC;
use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF};
#[cfg(feature = "nightly")]
#[macro_export]
macro_rules! testing_boiler {
($arg:ty, $dist:ty) => {
fn try_create(arg: $arg) -> $dist {
let n = <$dist>::new.call_once(arg);
assert!(n.is_ok());
n.unwrap()
}
fn bad_create_case(arg: $arg) {
let n = <$dist>::new.call(arg);
assert!(n.is_err());
}
fn get_value<F, T>(arg: $arg, eval: F) -> T
where
F: Fn($dist) -> T,
{
let n = try_create(arg);
eval(n)
}
fn test_case<F, T>(arg: $arg, expected: T, eval: F)
where
F: Fn($dist) -> T,
T: ::core::fmt::Debug + ::approx::RelativeEq<Epsilon = f64>,
{
let x = get_value(arg, eval);
assert_relative_eq!(expected, x, max_relative = ACC);
}
fn test_case_special<F, T>(arg: $arg, expected: T, acc: f64, eval: F)
where
F: Fn($dist) -> T,
T: ::core::fmt::Debug + ::approx::AbsDiffEq<Epsilon = f64>,
{
let x = get_value(arg, eval);
assert_abs_diff_eq!(expected, x, epsilon = acc);
}
#[allow(dead_code)] fn test_none<F, T>(arg: $arg, eval: F)
where
F: Fn($dist) -> Option<T>,
T: ::core::cmp::PartialEq + ::core::fmt::Debug,
{
let x = get_value(arg, eval);
assert_eq!(None, x);
}
};
}
fn check_integrate_pdf_is_cdf<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
dist: &D,
x_min: f64,
x_max: f64,
step: f64,
) {
let mut prev_x = x_min;
let mut prev_density = dist.pdf(x_min);
let mut sum = 0.0;
loop {
let x = prev_x + step;
let density = dist.pdf(x);
assert!(density >= 0.0);
let ln_density = dist.ln_pdf(x);
assert_almost_eq!(density.ln(), ln_density, 1e-10);
sum += (prev_density + density) * step / 2.0;
let cdf = dist.cdf(x);
if (sum - cdf).abs() > 1e-3 {
println!("Integral of pdf doesn't equal cdf!");
println!("Integration from {} by {} to {} = {}", x_min, step, x, sum);
println!("cdf = {}", cdf);
panic!();
}
if x >= x_max {
break;
} else {
prev_x = x;
prev_density = density;
}
}
assert!(sum > 0.99);
assert!(sum <= 1.001);
}
fn check_sum_pmf_is_cdf<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(dist: &D, x_max: u64) {
let mut sum = 0.0;
for i in 0..x_max + 3 {
let prob = dist.pmf(i);
assert!(prob >= 0.0);
assert!(prob <= 1.0);
sum += prob;
if i == x_max {
assert!(sum > 0.99);
}
assert_almost_eq!(sum, dist.cdf(i), 1e-10);
}
assert!(sum > 0.99);
assert!(sum <= 1.0 + 1e-10);
}
pub fn check_continuous_distribution<D: ContinuousCDF<f64, f64> + Continuous<f64, f64>>(
dist: &D,
x_min: f64,
x_max: f64,
) {
assert_eq!(dist.pdf(f64::NEG_INFINITY), 0.0);
assert_eq!(dist.pdf(f64::INFINITY), 0.0);
assert_eq!(dist.ln_pdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert_eq!(dist.ln_pdf(f64::INFINITY), f64::NEG_INFINITY);
assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0);
assert_eq!(dist.cdf(f64::INFINITY), 1.0);
check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0);
}
pub fn check_discrete_distribution<D: DiscreteCDF<u64, f64> + Discrete<u64, f64>>(
dist: &D,
x_max: u64,
) {
check_sum_pmf_is_cdf(dist, x_max);
}
#[test]
fn test_is_valid_multinomial() {
use std::f64;
let invalid = [1.0, f64::NAN, 3.0];
assert!(!is_valid_multinomial(&invalid, true));
let invalid2 = [-2.0, 5.0, 1.0, 6.2];
assert!(!is_valid_multinomial(&invalid2, true));
let invalid3 = [0.0, 0.0, 0.0];
assert!(!is_valid_multinomial(&invalid3, true));
let valid = [5.2, 0.0, 1e-15, 1000000.12];
assert!(is_valid_multinomial(&valid, true));
}
#[test]
fn test_is_valid_multinomial_no_zero() {
let invalid = [5.2, 0.0, 1e-15, 1000000.12];
assert!(!is_valid_multinomial(&invalid, false));
}
}