use crate::{
traits::{AtLeast32BitUnsigned, SaturatedConversion},
Perbill,
};
use core::ops::Sub;
#[derive(PartialEq, Eq, sp_core::RuntimeDebug)]
pub struct PiecewiseLinear<'a> {
pub points: &'a [(Perbill, Perbill)],
pub maximum: Perbill,
}
impl scale_info::TypeInfo for PiecewiseLinear<'static> {
type Identity = Self;
fn type_info() -> ::scale_info::Type {
scale_info::Type::builder()
.path(scale_info::Path::new("PiecewiseLinear", "sp_runtime::curve"))
.type_params(crate::Vec::new())
.docs(&["Piecewise Linear function in [0, 1] -> [0, 1]."])
.composite(
scale_info::build::Fields::named()
.field(|f| {
f.ty::<&'static[(Perbill, Perbill)]>()
.name("points")
.type_name("&'static[(Perbill, Perbill)]")
.docs(&["Array of points. Must be in order from the lowest abscissas to the highest."])
})
.field(|f| {
f.ty::<Perbill>()
.name("maximum")
.type_name("Perbill")
.docs(&["The maximum value that can be returned."])
}),
)
}
}
fn abs_sub<N: Ord + Sub<Output = N> + Clone>(a: N, b: N) -> N where {
a.clone().max(b.clone()) - a.min(b)
}
impl<'a> PiecewiseLinear<'a> {
pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N
where
N: AtLeast32BitUnsigned + Clone,
{
let n = n.min(d.clone());
if self.points.is_empty() {
return N::zero()
}
let next_point_index = self.points.iter().position(|p| n < p.0 * d.clone());
let (prev, next) = if let Some(next_point_index) = next_point_index {
if let Some(previous_point_index) = next_point_index.checked_sub(1) {
(self.points[previous_point_index], self.points[next_point_index])
} else {
return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
}
} else {
return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
};
let delta_y = multiply_by_rational_saturating(
abs_sub(n.clone(), prev.0 * d.clone()),
abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
);
if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
(prev.1 * d).saturating_add(delta_y)
} else {
(prev.1 * d).saturating_sub(delta_y)
}
}
}
fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
where
N: AtLeast32BitUnsigned + Clone,
{
let q = q.max(1);
let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
let result_remainder_part = {
let rem = value % q.into();
let rem_u32 = rem.saturated_into::<u32>();
let rem_part = rem_u32 as u64 * p as u64 / q as u64;
rem_part.saturated_into::<N>()
};
result_divisor_part.saturating_add(result_remainder_part)
}
#[test]
fn test_multiply_by_rational_saturating() {
let div = 100u32;
for value in 0..=div {
for p in 0..=div {
for q in 1..=div {
let value: u64 =
(value as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
let p = (p as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
let q = (q as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
assert_eq!(
multiply_by_rational_saturating(value, p, q),
(value as u128 * p as u128 / q as u128).try_into().unwrap_or(u64::MAX)
);
}
}
}
}
#[test]
fn test_calculate_for_fraction_times_denominator() {
let curve = PiecewiseLinear {
points: &[
(Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
(Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
(Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
],
maximum: Perbill::from_parts(1_000_000_000),
};
pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
if n <= Perbill::from_parts(0_500_000_000) * d {
n + d / 2
} else {
(d as u128 * 2 - n as u128 * 2).try_into().unwrap()
}
}
let div = 100u32;
for d in 0..=div {
for n in 0..=d {
let d: u64 = (d as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
let n: u64 = (n as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
let res = curve.calculate_for_fraction_times_denominator(n, d);
let expected = formal_calculate_for_fraction_times_denominator(n, d);
assert!(abs_sub(res, expected) <= 1);
}
}
}