use crate::archparam;
pub(crate) trait GemmKernel {
type Elem: Element;
const MR: usize = Self::MRTy::VALUE;
const NR: usize = Self::NRTy::VALUE;
type MRTy: ConstNum;
type NRTy: ConstNum;
fn align_to() -> usize;
fn always_masked() -> bool;
#[inline(always)]
fn nc() -> usize { archparam::S_NC }
#[inline(always)]
fn kc() -> usize { archparam::S_KC }
#[inline(always)]
fn mc() -> usize { archparam::S_MC }
unsafe fn kernel(
k: usize,
alpha: Self::Elem,
a: *const Self::Elem,
b: *const Self::Elem,
beta: Self::Elem,
c: *mut Self::Elem, rsc: isize, csc: isize);
}
pub(crate) trait Element : Copy + Send + Sync {
fn zero() -> Self;
fn one() -> Self;
fn test_value() -> Self;
fn is_zero(&self) -> bool;
fn add_assign(&mut self, rhs: Self);
fn mul_assign(&mut self, rhs: Self);
}
impl Element for f32 {
fn zero() -> Self { 0. }
fn one() -> Self { 1. }
fn test_value() -> Self { 1. }
fn is_zero(&self) -> bool { *self == 0. }
fn add_assign(&mut self, rhs: Self) { *self += rhs; }
fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
}
impl Element for f64 {
fn zero() -> Self { 0. }
fn one() -> Self { 1. }
fn test_value() -> Self { 1. }
fn is_zero(&self) -> bool { *self == 0. }
fn add_assign(&mut self, rhs: Self) { *self += rhs; }
fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
}
pub(crate) trait GemmSelect<T> {
fn select<K>(self, kernel: K)
where K: GemmKernel<Elem=T>,
T: Element;
}
#[cfg(feature = "cgemm")]
#[allow(non_camel_case_types)]
pub(crate) type c32 = [f32; 2];
#[cfg(feature = "cgemm")]
#[allow(non_camel_case_types)]
pub(crate) type c64 = [f64; 2];
#[cfg(feature = "cgemm")]
impl Element for c32 {
fn zero() -> Self { [0., 0.] }
fn one() -> Self { [1., 0.] }
fn test_value() -> Self { [1., 0.5] }
fn is_zero(&self) -> bool { *self == [0., 0.] }
#[inline(always)]
fn add_assign(&mut self, y: Self) {
self[0] += y[0];
self[1] += y[1];
}
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = c32_mul(*self, rhs);
}
}
#[cfg(feature = "cgemm")]
impl Element for c64 {
fn zero() -> Self { [0., 0.] }
fn one() -> Self { [1., 0.] }
fn test_value() -> Self { [1., 0.5] }
fn is_zero(&self) -> bool { *self == [0., 0.] }
#[inline(always)]
fn add_assign(&mut self, y: Self) {
self[0] += y[0];
self[1] += y[1];
}
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = c64_mul(*self, rhs);
}
}
#[cfg(feature = "cgemm")]
#[inline(always)]
pub(crate) fn c32_mul(x: c32, y: c32) -> c32 {
let [a, b] = x;
let [c, d] = y;
[a * c - b * d, b * c + a * d]
}
#[cfg(feature = "cgemm")]
#[inline(always)]
pub(crate) fn c64_mul(x: c64, y: c64) -> c64 {
let [a, b] = x;
let [c, d] = y;
[a * c - b * d, b * c + a * d]
}
pub(crate) trait ConstNum {
const VALUE: usize;
}
#[cfg(feature = "cgemm")]
pub(crate) struct U2;
pub(crate) struct U4;
pub(crate) struct U8;
#[cfg(feature = "cgemm")]
impl ConstNum for U2 { const VALUE: usize = 2; }
impl ConstNum for U4 { const VALUE: usize = 4; }
impl ConstNum for U8 { const VALUE: usize = 8; }
#[cfg(test)]
pub(crate) mod test {
use std::fmt;
use super::GemmKernel;
use super::Element;
use crate::aligned_alloc::Alloc;
pub(crate) fn aligned_alloc<K>(elt: K::Elem, n: usize) -> Alloc<K::Elem>
where K: GemmKernel,
K::Elem: Copy,
{
unsafe {
Alloc::new(n, K::align_to()).init_with(elt)
}
}
pub(crate) fn test_a_kernel<K, T>(_name: &str)
where
K: GemmKernel<Elem = T>,
T: Element + fmt::Debug + PartialEq,
{
const K: usize = 16;
let mr = K::MR;
let nr = K::NR;
let mut a = aligned_alloc::<K>(T::zero(), mr * K);
let mut b = aligned_alloc::<K>(T::zero(), nr * K);
let mut count = 1;
for i in 0..mr {
for j in 0..K {
for _ in 0..count {
a[i * K + j].add_assign(T::test_value());
}
count += 1;
}
}
for i in 0..Ord::min(K, nr) {
b[i + i * nr] = T::one();
}
let mut c = vec![T::zero(); mr * nr];
unsafe {
K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), 1, mr as isize);
}
let common_len = Ord::min(a.len(), c.len());
assert_eq!(&a[..common_len], &c[..common_len]);
let mut a = aligned_alloc::<K>(T::zero(), mr * K);
let mut b = aligned_alloc::<K>(T::zero(), nr * K);
for i in 0..Ord::min(K, mr) {
a[i + i * mr] = T::one();
}
let mut count = 1;
for i in 0..K {
for j in 0..nr {
for _ in 0..count {
b[i * nr + j].add_assign(T::test_value());
}
count += 1;
}
}
let mut c = vec![T::zero(); mr * nr];
unsafe {
K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), nr as isize, 1);
}
let common_len = Ord::min(b.len(), c.len());
assert_eq!(&b[..common_len], &c[..common_len]);
}
}