diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 701a8692..777418de 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -428,6 +428,16 @@ mod tests { bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } + #[cfg(feature = "rand")] + #[test] + fn test_sample() { + use rand::distr::Distribution; + + test_almost(vector![1., 2.], 1., 1e-15, |dd| { + dd.sample(&mut ::rand::rng()).sum() + }); + } + #[test] fn test_mean() { let mean = |dd: Dirichlet<_>| dd.mean().unwrap();