use super::HashMap;
use crate::frontend::FunctionBuilder;
use alloc::vec::Vec;
use core::convert::TryFrom;
use cranelift_codegen::ir::condcodes::IntCC;
use cranelift_codegen::ir::*;
type EntryIndex = u128;
#[derive(Debug, Default)]
pub struct Switch {
cases: HashMap<EntryIndex, Block>,
}
impl Switch {
pub fn new() -> Self {
Self {
cases: HashMap::new(),
}
}
pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
let prev = self.cases.insert(index, block);
assert!(
prev.is_none(),
"Tried to set the same entry {} twice",
index
);
}
pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
&self.cases
}
fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
cases.sort_by_key(|&(index, _)| index);
let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
let mut last_index = None;
for (index, block) in cases {
match last_index {
None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
Some(last_index) => {
if index > last_index + 1 {
contiguous_case_ranges.push(ContiguousCaseRange::new(index));
}
}
}
contiguous_case_ranges
.last_mut()
.unwrap()
.blocks
.push(block);
last_index = Some(index);
}
log::trace!(
"build_contiguous_case_ranges after: {:#?}",
contiguous_case_ranges
);
contiguous_case_ranges
}
fn build_search_tree(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
) -> Vec<(EntryIndex, Block, Vec<Block>)> {
let mut cases_and_jt_blocks = Vec::new();
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
return cases_and_jt_blocks;
}
let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
stack.push((None, contiguous_case_ranges));
while let Some((block, contiguous_case_ranges)) = stack.pop() {
if let Some(block) = block {
bx.switch_to_block(block);
}
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
} else {
let split_point = contiguous_case_ranges.len() / 2;
let mut left = contiguous_case_ranges;
let right = left.split_off(split_point);
let left_block = bx.create_block();
let right_block = bx.create_block();
let first_index = right[0].first_index;
let should_take_right_side =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(should_take_right_side, right_block, &[]);
bx.ins().jump(left_block, &[]);
bx.seal_block(left_block);
bx.seal_block(right_block);
stack.push((Some(left_block), left));
stack.push((Some(right_block), right));
}
}
cases_and_jt_blocks
}
fn build_search_branches(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
) {
let mut was_branch = false;
let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
if was_branch {
let block = bx.create_block();
bx.ins().jump(block, &[]);
bx.seal_block(block);
bx.switch_to_block(block);
}
};
for ContiguousCaseRange {
first_index,
blocks,
} in contiguous_case_ranges.into_iter().rev()
{
match (blocks.len(), first_index) {
(1, 0) => {
ins_fallthrough_jump(was_branch, bx);
bx.ins().brz(val, blocks[0], &[]);
}
(1, _) => {
ins_fallthrough_jump(was_branch, bx);
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index);
bx.ins().brnz(is_good_val, blocks[0], &[]);
}
(_, 0) => {
let jt_block = bx.create_block();
bx.ins().jump(jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
return;
}
(_, _) => {
ins_fallthrough_jump(was_branch, bx);
let jt_block = bx.create_block();
let is_good_val =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(is_good_val, jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
}
}
was_branch = true;
}
bx.ins().jump(otherwise, &[]);
}
fn build_jump_tables(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
) {
for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
assert!(
u32::try_from(blocks.len()).is_ok(),
"Jump tables bigger than 2^32-1 are not yet supported"
);
let mut jt_data = JumpTableData::new();
for block in blocks {
jt_data.push_entry(block);
}
let jump_table = bx.create_jump_table(jt_data);
bx.switch_to_block(jt_block);
let discr = if first_index == 0 {
val
} else {
if let Ok(first_index) = u64::try_from(first_index) {
bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
} else {
let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().isub(val, index)
}
};
let discr = match bx.func.dfg.value_type(discr).bits() {
bits if bits > 32 => {
let new_block = bx.create_block();
let bigger_than_u32 =
bx.ins()
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
bx.ins().brnz(bigger_than_u32, otherwise, &[]);
bx.ins().jump(new_block, &[]);
bx.seal_block(new_block);
bx.switch_to_block(new_block);
bx.ins().ireduce(types::I32, discr)
}
bits if bits < 32 => bx.ins().uextend(types::I32, discr),
_ => discr,
};
bx.ins().br_table(discr, otherwise, jump_table);
}
}
pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
let max = self.cases.keys().max().copied().unwrap_or(0);
let val_ty = bx.func.dfg.value_type(val);
let val_ty_max = val_ty.bounds(false).1;
if max > val_ty_max {
panic!(
"The index type {} does not fit the maximum switch entry of {}",
val_ty, max
);
}
let contiguous_case_ranges = self.collect_contiguous_case_ranges();
let cases_and_jt_blocks =
Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
}
}
fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
if let Ok(index) = u64::try_from(y) {
bx.ins().icmp_imm(cond, x, index as i64)
} else {
let (lsb, msb) = (y as u64, (y >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().icmp(cond, x, index)
}
}
#[derive(Debug)]
struct ContiguousCaseRange {
first_index: EntryIndex,
blocks: Vec<Block>,
}
impl ContiguousCaseRange {
fn new(first_index: EntryIndex) -> Self {
Self {
first_index,
blocks: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frontend::FunctionBuilderContext;
use alloc::string::ToString;
use cranelift_codegen::ir::Function;
macro_rules! setup {
($default:expr, [$($index:expr,)*]) => {{
let mut func = Function::new();
let mut func_ctx = FunctionBuilderContext::new();
{
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
let block = bx.create_block();
bx.switch_to_block(block);
let val = bx.ins().iconst(types::I8, 0);
let mut switch = Switch::new();
$(
let block = bx.create_block();
switch.set_entry($index, block);
)*
switch.emit(&mut bx, val, Block::with_number($default).unwrap());
}
func
.to_string()
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string()
}};
}
#[test]
fn switch_zero() {
let func = setup!(0, [0,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
brz v0, block1 ; v0 = 0
jump block0"
);
}
#[test]
fn switch_single() {
let func = setup!(0, [1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 1 ; v0 = 0
brnz v1, block1
jump block0"
);
}
#[test]
fn switch_bool() {
let func = setup!(0, [0, 1,]);
assert_eq!(
func,
" jt0 = jump_table [block1, block2]
block0:
v0 = iconst.i8 0
jump block3
block3:
v1 = uextend.i32 v0 ; v0 = 0
br_table v1, block0, jt0"
);
}
#[test]
fn switch_two_gap() {
let func = setup!(0, [0, 2,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 2 ; v0 = 0
brnz v1, block2
jump block3
block3:
brz.i8 v0, block1 ; v0 = 0
jump block0"
);
}
#[test]
fn switch_many() {
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
assert_eq!(
func,
" jt0 = jump_table [block1, block2]
jt1 = jump_table [block5, block6, block7]
block0:
v0 = iconst.i8 0
v1 = icmp_imm uge v0, 7 ; v0 = 0
brnz v1, block9
jump block8
block9:
v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
brnz v2, block10
jump block11
block11:
v3 = icmp_imm.i8 eq v0, 7 ; v0 = 0
brnz v3, block4
jump block0
block8:
v4 = icmp_imm.i8 eq v0, 5 ; v0 = 0
brnz v4, block3
jump block12
block12:
v5 = uextend.i32 v0 ; v0 = 0
br_table v5, block0, jt0
block10:
v6 = iadd_imm.i8 v0, -10 ; v0 = 0
v7 = uextend.i32 v6
br_table v7, block0, jt1"
);
}
#[test]
fn switch_min_index_value() {
let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 128 ; v0 = 0
brnz v1, block1
jump block3
block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2
jump block0"
);
}
#[test]
fn switch_max_index_value() {
let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 127 ; v0 = 0
brnz v1, block1
jump block3
block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2
jump block0"
)
}
#[test]
fn switch_optimal_codegen() {
let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
assert_eq!(
func,
" jt0 = jump_table [block2, block3]
block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 255 ; v0 = 0
brnz v1, block1
jump block4
block4:
v2 = uextend.i32 v0 ; v0 = 0
br_table v2, block0, jt0"
);
}
#[test]
#[should_panic(
expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
)]
fn switch_rejects_small_inputs() {
setup!(1, [0x4100_0000_00bf_d470,]);
}
#[test]
fn switch_seal_generated_blocks() {
let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
for case in cases {
for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
eprintln!("Testing {:?} with keys: {:?}", typ, case);
do_case(case, *typ);
}
}
fn do_case(keys: &[u128], typ: Type) {
let mut func = Function::new();
let mut builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
let root_block = builder.create_block();
let default_block = builder.create_block();
let mut switch = Switch::new();
let case_blocks = keys
.iter()
.map(|key| {
let block = builder.create_block();
switch.set_entry(*key, block);
block
})
.collect::<Vec<_>>();
builder.seal_block(root_block);
builder.switch_to_block(root_block);
let val = builder.ins().iconst(typ, 1);
switch.emit(&mut builder, val, default_block);
for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
builder.seal_block(block);
builder.switch_to_block(block);
builder.ins().return_(&[]);
}
builder.finalize(); }
}
#[test]
fn switch_64bit() {
let mut func = Function::new();
let mut func_ctx = FunctionBuilderContext::new();
{
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
let block0 = bx.create_block();
bx.switch_to_block(block0);
let val = bx.ins().iconst(types::I64, 0);
let mut switch = Switch::new();
let block1 = bx.create_block();
switch.set_entry(1, block1);
let block2 = bx.create_block();
switch.set_entry(0, block2);
let block3 = bx.create_block();
switch.emit(&mut bx, val, block3);
}
let func = func
.to_string()
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string();
assert_eq!(
func,
" jt0 = jump_table [block2, block1]
block0:
v0 = iconst.i64 0
jump block4
block4:
v1 = icmp_imm.i64 ugt v0, 0xffff_ffff ; v0 = 0
brnz v1, block3
jump block5
block5:
v2 = ireduce.i32 v0 ; v0 = 0
br_table v2, block3, jt0"
);
}
#[test]
fn switch_128bit() {
let mut func = Function::new();
let mut func_ctx = FunctionBuilderContext::new();
{
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
let block0 = bx.create_block();
bx.switch_to_block(block0);
let val = bx.ins().iconst(types::I128, 0);
let mut switch = Switch::new();
let block1 = bx.create_block();
switch.set_entry(1, block1);
let block2 = bx.create_block();
switch.set_entry(0, block2);
let block3 = bx.create_block();
switch.emit(&mut bx, val, block3);
}
let func = func
.to_string()
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string();
assert_eq!(
func,
" jt0 = jump_table [block2, block1]
block0:
v0 = iconst.i128 0
jump block4
block4:
v1 = icmp_imm.i128 ugt v0, 0xffff_ffff ; v0 = 0
brnz v1, block3
jump block5
block5:
v2 = ireduce.i32 v0 ; v0 = 0
br_table v2, block3, jt0"
);
}
}