rework typed_ir to have flat expressions

This commit is contained in:
NotAFile 2022-04-05 23:31:20 +02:00
parent b71f9f09ae
commit a2cca95dbd
10 changed files with 227 additions and 160 deletions

View File

@ -1,7 +1,7 @@
comb comparator (
a: Logic<8>,
b: Logic<8>
) -> Logic
) -> Logic<1>
{
~reduce_or(a ^ b)
}

View File

@ -1,6 +1,6 @@
comb identity (
a: Logic
) -> Logic
a: Logic<5>
) -> Logic<5>
{
a
}

View File

@ -1,5 +1,5 @@
use std::cell::Cell;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
use super::parser;
use super::parser::block_expression::BlockExpr;
@ -122,20 +122,32 @@ impl Context {
}
}
fn type_expression(
fn intern_expression(
&self,
exprs: &mut HashMap<typed_ir::ExprId, typed_ir::Expr>,
expr: typed_ir::Expr,
) -> typed_ir::ExprId {
let expr_id = expr.id;
exprs.insert(expr.id, expr);
expr_id
}
fn type_expression(
&mut self,
exprs: &mut HashMap<typed_ir::ExprId, typed_ir::Expr>,
expr: &parser::expression::Expression,
) -> Result<typed_ir::Expr, CompileError> {
) -> Result<typed_ir::ExprId, CompileError> {
use parser::expression::Expression;
let id = typed_ir::ExprId(self.ids.next() as u32);
let t_expr = match expr {
Expression::Path(name) => {
let signal = self.try_get_signal(name)?;
typed_ir::Expr {
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Path(signal.id),
typ: signal.typ,
}
};
self.intern_expression(exprs, this_expr)
}
Expression::Literal(lit) => {
let data = match lit.kind {
@ -144,15 +156,16 @@ impl Context {
self.types.make_const_u32(width, val as u32)
}
};
typed_ir::Expr {
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Literal(data),
typ: self.types.primitives.infer,
}
};
self.intern_expression(exprs, this_expr)
}
Expression::UnOp(op) => {
let a = self.type_expression(&op.a)?;
typed_ir::Expr {
let a = self.type_expression(exprs, &op.a)?;
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Call(typed_ir::Call {
called: self.callables.builtins.bitnot,
@ -160,11 +173,15 @@ impl Context {
genargs: vec![],
}),
typ: self.types.primitives.infer,
}
};
self.intern_expression(exprs, this_expr)
}
Expression::BinOp(op) => {
let (a, b) = (self.type_expression(&op.a)?, self.type_expression(&op.b)?);
typed_ir::Expr {
let (a, b) = (
self.type_expression(exprs, &op.a)?,
self.type_expression(exprs, &op.b)?,
);
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Call(typed_ir::Call {
called: self.callables.builtins.xor,
@ -172,13 +189,14 @@ impl Context {
genargs: vec![],
}),
typ: self.types.primitives.infer,
}
};
self.intern_expression(exprs, this_expr)
}
Expression::Call(call) => {
let args_resolved = call
.args
.iter()
.map(|expr| self.type_expression(expr))
.map(|expr| self.type_expression(exprs, expr))
.collect::<Result<Vec<_>, _>>()?;
let called = self.try_get_callable(call.name.fragment())?;
let called_callable = self.callables.get(called);
@ -187,13 +205,13 @@ impl Context {
received: args_resolved.len(),
expected: called_callable.argcount(),
}));
}
};
let genargs_resolved = called_callable
.genargs
.iter()
.map(|genarg| genarg.1)
.collect();
typed_ir::Expr {
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Call(typed_ir::Call {
called,
@ -201,28 +219,43 @@ impl Context {
genargs: genargs_resolved,
}),
typ: self.types.primitives.infer,
}
};
self.intern_expression(exprs, this_expr)
}
Expression::BlockExpr(block) => match &**block {
BlockExpr::IfElse(_) => todo!(),
BlockExpr::Match(match_) => {
let expr = self.type_expression(&match_.expr)?;
let expr = self.type_expression(exprs, &match_.expr)?;
let arms = match_
.arms
.iter()
.map(|(cond, val)| {
Ok((self.type_expression(cond)?, self.type_expression(val)?))
Ok((
self.type_expression(exprs, cond)?,
self.type_expression(exprs, val)?,
))
})
.collect::<Result<_, _>>()?;
let typed = typed_ir::Match { expr, arms };
typed_ir::Expr {
let this_expr = typed_ir::Expr {
id,
kind: typed_ir::ExprKind::Match(Box::new(typed)),
typ: self.types.primitives.infer,
}
};
self.intern_expression(exprs, this_expr)
}
BlockExpr::Block(block) => {
todo!("expression blocks not representable in typed ir yet")
// TODO: we need to find some way of resolving a name to an expression
todo!("can not convert blocks to typed_ir yet");
for (name, expr) in &block.assignments {
let signal = typed_ir::Signal {
id: typed_ir::DefId(self.ids.next() as u32),
typ: self.types.primitives.infer,
};
// TODO: need to add this signal to the block from here, somehow
self.signals.insert(name.span().to_string(), signal);
}
self.type_expression(exprs, &block.value)?
}
},
Expression::StructInit(_) => todo!("structure initialization"),
@ -247,8 +280,9 @@ impl Context {
let sig_typename = &port.net.typ;
let mut sig_type = self.try_get_type(sig_typename.name.fragment())?;
if let Some(arg) = &sig_typename.generics {
let elab_expr = self.type_expression(arg)?;
let elab_val = self.eval_expression(&elab_expr)?;
let mut exprs = Default::default();
let elab_expr = self.type_expression(&mut exprs, arg)?;
let elab_val = self.eval_expression(exprs.get(&elab_expr).unwrap())?;
sig_type = self
.types
.parameterize(sig_type, &[types::GenericArg::Elab(elab_val)])
@ -294,11 +328,14 @@ impl Context {
}
}
let root_expr = self.type_expression(&comb.expr)?;
let mut exprs = Default::default();
let root_expr = self.type_expression(&mut exprs, &comb.expr)?;
Ok(typed_ir::Body {
signature: callable_id,
signals,
exprs,
expr: root_expr,
})
}

View File

@ -4,13 +4,18 @@ use super::{make_pubid, CompileError, Context};
use crate::rtlil;
use crate::rtlil::RtlilWrite;
fn wire_for_expr(expr: typed_ir::ExprId) -> rtlil::SigSpec {
rtlil::SigSpec::Wire(format!("$_expr_{}", expr.0))
}
fn lower_expression(
ctx: &Context,
module: &mut rtlil::Module,
body: &typed_ir::Body,
expr: &typed_ir::Expr,
) -> Result<rtlil::SigSpec, CompileError> {
let expr_width = ctx.types.get_width(expr.typ).expect("signal needs width");
let expr_wire_name = format!("$_sig_{}", expr.id.0);
let expr_wire_name = format!("$_expr_{}", expr.id.0);
let expr_wire = rtlil::Wire::new(expr_wire_name.clone(), expr_width, None);
module.add_wire(expr_wire);
match &expr.kind {
@ -19,15 +24,21 @@ fn lower_expression(
let args_resolved = call
.args
.iter()
.map(|expr| lower_expression(ctx, module, expr))
.collect::<Result<Vec<_>, _>>()?;
.map(|expr| wire_for_expr(*expr))
.collect::<Vec<_>>();
let args: Vec<_> = call
.args
.iter()
.map(|expr_id| body.exprs.get(expr_id).unwrap())
.collect();
let callable = ctx.callables.get(call.called);
let cell_id = module.make_genid(callable.name());
if call.called == ctx.callables.builtins.xor {
let a_width = ctx.types.get_width(call.args[0].typ).unwrap();
let b_width = ctx.types.get_width(call.args[1].typ).unwrap();
let a_width = ctx.types.get_width(args[0].typ).unwrap();
let b_width = ctx.types.get_width(args[1].typ).unwrap();
let y_width = ctx.types.get_width(expr.typ).unwrap();
let mut cell = rtlil::Cell::new(&cell_id, "$xor");
cell.add_param("\\A_SIGNED", "0");
@ -40,7 +51,7 @@ fn lower_expression(
cell.add_connection("\\Y", &rtlil::SigSpec::Wire(expr_wire_name.clone()));
module.add_cell(cell);
} else if call.called == ctx.callables.builtins.reduce_or {
let a_width = ctx.types.get_width(call.args[0].typ).unwrap();
let a_width = ctx.types.get_width(args[0].typ).unwrap();
let y_width = ctx.types.get_width(expr.typ).unwrap();
let mut cell = rtlil::Cell::new(&cell_id, "$reduce_or");
cell.add_param("\\A_SIGNED", "0");
@ -50,7 +61,7 @@ fn lower_expression(
cell.add_connection("\\Y", &rtlil::SigSpec::Wire(expr_wire_name.clone()));
module.add_cell(cell);
} else if call.called == ctx.callables.builtins.bitnot {
let a_width = ctx.types.get_width(call.args[0].typ).unwrap();
let a_width = ctx.types.get_width(args[0].typ).unwrap();
let y_width = ctx.types.get_width(expr.typ).unwrap();
let mut cell = rtlil::Cell::new(&cell_id, "$not");
cell.add_param("\\A_SIGNED", "0");
@ -75,11 +86,11 @@ fn lower_expression(
.iter()
.map(|(pat, val)| {
Ok((
lower_expression(ctx, module, pat)?,
wire_for_expr(*pat),
rtlil::CaseRule {
assign: vec![(
rtlil::SigSpec::Wire(expr_wire_name.clone()),
lower_expression(ctx, module, val)?,
wire_for_expr(*val),
)],
switches: vec![],
},
@ -88,7 +99,7 @@ fn lower_expression(
.collect::<Result<Vec<_>, CompileError>>()
.unwrap();
let root_switch = rtlil::SwitchRule {
signal: lower_expression(ctx, module, &match_.expr)?,
signal: wire_for_expr(match_.expr),
cases,
};
let root_case = rtlil::CaseRule {
@ -125,11 +136,15 @@ fn lower_comb(
module.add_wire(rtlil::Wire::new(
ret_id.clone(),
ctx.types
.get_width(block.expr.typ)
.get_width(block.exprs.get(&block.expr).unwrap().typ)
.expect("signal has no size"),
Some(rtlil::PortOption::Output(block.signals.len() as i32)),
));
let out_sig = lower_expression(ctx, module, &block.expr)?;
for (_, expr) in &block.exprs {
let expr_wire = lower_expression(ctx, module, block, &expr)?;
module.add_connection(&wire_for_expr(expr.id), &expr_wire);
}
let out_sig = wire_for_expr(block.expr);
module.add_connection(&rtlil::SigSpec::Wire(ret_id), &out_sig);
Ok(())
}

View File

@ -5,9 +5,9 @@ impl Context {
pub fn pretty_typed_block(
&self,
w: &mut dyn std::fmt::Write,
block: &typed_ir::Body,
body: &typed_ir::Body,
) -> std::fmt::Result {
let callsig = self.callables.get(block.signature);
let callsig = self.callables.get(body.signature);
{
// TODO: ugly copy paste job
let args = callsig
@ -36,12 +36,15 @@ impl Context {
args.join(", ")
)?;
}
for sig in &block.signals {
for sig in &body.signals {
let mut typ_pretty = String::new();
self.types.pretty_type(&mut typ_pretty, sig.typ)?;
writeln!(w, "sig_{}: {}", sig.id.0, typ_pretty)?
}
self.pretty_typed_expr(w, &block.expr)?;
for (_, expr) in &body.exprs {
self.pretty_typed_expr(w, &expr)?;
}
writeln!(w, "return _{}", body.expr.0)?;
Ok(())
}
@ -61,10 +64,7 @@ impl Context {
let args = call
.args
.iter()
.map(|arg| {
self.pretty_typed_expr(w, arg)?;
Ok(format!("_{}", arg.id.0))
})
.map(|arg| Ok(format!("_{}", arg.0)))
.collect::<Result<Vec<_>, std::fmt::Error>>()?;
let callable = self.callables.get(call.called);
let genargs = call
@ -84,21 +84,12 @@ impl Context {
)
}
typed_ir::ExprKind::Match(match_) => {
self.pretty_typed_expr(w, &match_.expr)?;
let arms = match_
.arms
.iter()
.map(|(pat, val)| {
self.pretty_typed_expr(w, pat)?;
self.pretty_typed_expr(w, val)?;
Ok(format!(" _{} => _{}", pat.id.0, val.id.0))
})
.map(|(pat, val)| Ok(format!(" _{} => _{}", pat.0, val.0)))
.collect::<Result<Vec<_>, _>>()?;
format!(
"match (_{}) {{\n{}\n}}",
&match_.expr.id.0,
arms.join(",\n")
)
format!("match (_{}) {{\n{}\n}}", &match_.expr.0, arms.join(",\n"))
}
};
let mut type_pretty = String::new();

View File

@ -1,87 +1,104 @@
use super::typed_ir;
use super::typed_ir::{Expr, ExprId, ExprKind};
use super::types;
use super::Context;
use super::{CompileError, CompileErrorKind, Context};
use std::collections::HashMap;
impl Context {
pub fn infer_types(&mut self, mut block: typed_ir::Body) -> typed_ir::Body {
let new_root = self.infer_expr_types(&block.expr);
block.expr = new_root;
block
pub fn infer_types(&mut self, block: typed_ir::Body) -> typed_ir::Body {
// TODO: ugly ugly hack
let try_1 = self
.infer_body_types(&block)
.expect("could not infer types");
let try_2 = self
.infer_body_types(&try_1)
.expect("could not infer types");
self.infer_body_types(&try_2)
.expect("could not infer types")
}
pub fn infer_expr_types(&mut self, expr: &typed_ir::Expr) -> typed_ir::Expr {
pub fn infer_body_types(
&mut self,
body: &typed_ir::Body,
) -> Result<typed_ir::Body, CompileError> {
let mut new_exprs = HashMap::new();
for (expr_id, expr) in &body.exprs {
if self.types.is_fully_typed(expr.typ) {
// there is nothing more to infer
return expr.clone();
new_exprs.insert(*expr_id, expr.clone());
continue;
}
match &expr.kind {
typed_ir::ExprKind::Literal(lit) => expr.clone().with_type(lit.typ),
// we can not see beyond this expression right now
typed_ir::ExprKind::Path(_) => expr.clone(),
typed_ir::ExprKind::Call(call) => {
let args_typed: Vec<_> = call
.args
.iter()
.map(|ex| self.infer_expr_types(ex))
.collect();
let callee_def = self.callables.get(call.called);
ExprKind::Literal(lit) => {
// TODO: don't try to overwrite the type of a literal
let infres = self.types.infer_type(expr.typ, lit.typ);
new_exprs.insert(*expr_id, expr.clone().with_type(lit.typ));
}
ExprKind::Path(_) => {
new_exprs.insert(*expr_id, expr.clone());
}
ExprKind::Call(call) => {
let called_def = self.callables.get(call.called);
let param_types = called_def.args.iter().map(|param| param.1);
let param_types: Vec<_> = callee_def.args.iter().map(|param| param.1).collect();
let mut genargs: Vec<_> = called_def.genargs.iter().map(|a| a.1).collect();
let inferred_args: Vec<_> = param_types
.iter()
.zip(&args_typed)
.map(|(param, arg)| self.types.infer_type(*param, arg.typ))
.zip(&call.args)
.map(|(param, arg)| {
self.types
.infer_type(param, body.exprs.get(arg).unwrap().typ)
})
.collect();
let mut genargs: Vec<_> = callee_def.genargs.iter().map(|a| a.1).collect();
let mut new_type = callee_def.ret_type;
if !genargs.is_empty() {
// need to infer generic arguments
for inf_res in inferred_args {
match inf_res {
types::InferenceResult::First(_) => todo!(),
types::InferenceResult::Second(_) => todo!(),
types::InferenceResult::TypeVar(dbi, tvar, typ) => {
assert_eq!(dbi, 0);
// TODO: type check argument instead of just using it
genargs[tvar as usize] = typ;
}
types::InferenceResult::Incompatible => todo!(),
types::InferenceResult::Ambigous => todo!(),
_ => todo!(),
}
}
// TODO: HACKY HACKY HACK
new_type = genargs[0];
}
let mut new_expr = expr.clone();
new_expr.typ = new_type;
new_expr.kind = typed_ir::ExprKind::Call(typed_ir::Call {
called: call.called,
args: args_typed,
genargs,
});
new_expr
let ret_type = match self.types.infer_type(expr.typ, called_def.ret_type) {
types::InferenceResult::TypeVar(dbi, tvar, typ) => {
assert_eq!(dbi, 0);
genargs[tvar as usize]
}
typed_ir::ExprKind::Match(match_) => {
let new_arms: Vec<_> = match_
.arms
.iter()
.map(|(pat, val)| (self.infer_expr_types(pat), self.infer_expr_types(val)))
.collect();
// TODO: hacky hacky hacky
let res_type = new_arms.first().unwrap().1.typ;
let new_match = typed_ir::Match {
expr: self.infer_expr_types(&match_.expr),
arms: new_arms,
types::InferenceResult::First(typ) => typ,
x => todo!("{x:?}"),
};
let mut new_expr = expr.clone().with_type(res_type);
new_expr.kind = typed_ir::ExprKind::Match(Box::new(new_match));
new_expr
let new_expr = typed_ir::Expr {
kind: typed_ir::ExprKind::Call(typed_ir::Call {
genargs,
..call.clone()
}),
typ: ret_type,
..expr.clone()
};
new_exprs.insert(*expr_id, new_expr);
}
ExprKind::Match(match_) => {
// TODO: hacky hacky hacky
let res_type = body.exprs.get(&match_.arms.first().unwrap().1).unwrap().typ;
let new_expr = expr.clone().with_type(res_type);
new_exprs.insert(*expr_id, new_expr);
}
}
}
/*
for (expr_id, expr) in &new_exprs {
if !self.types.is_fully_typed(expr.typ) {
return Err(CompileError::new(CompileErrorKind::TodoError("fail".to_owned())))
}
}
*/
let new_body = typed_ir::Body {
exprs: new_exprs,
..body.clone()
};
Ok(new_body)
}
}

View File

@ -1,11 +1,12 @@
use super::callable::CallableId;
use super::types::{ElabData, Type};
use std::collections::HashMap;
use std::fmt::Debug;
/// ID of a definition (e.g. variable, block, function)
#[derive(Clone, Copy)]
pub struct DefId(pub u32);
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Hash, PartialEq, Eq)]
pub struct ExprId(pub u32);
// more compact Debug impl
@ -31,14 +32,14 @@ pub struct Expr {
#[derive(Debug, Clone)]
pub struct Call {
pub called: CallableId,
pub args: Vec<Expr>,
pub args: Vec<ExprId>,
pub genargs: Vec<Type>,
}
#[derive(Debug, Clone)]
pub struct Match {
pub expr: Expr,
pub arms: Vec<(Expr, Expr)>,
pub expr: ExprId,
pub arms: Vec<(ExprId, ExprId)>,
}
#[derive(Debug, Clone)]
@ -60,7 +61,8 @@ pub struct Signal {
pub struct Body {
pub signature: CallableId,
pub signals: Vec<Signal>,
pub expr: Expr,
pub exprs: HashMap<ExprId, Expr>,
pub expr: ExprId,
}
impl Expr {

View File

@ -5,7 +5,7 @@ use std::fmt::Debug;
/// easier
pub type Type = InternedType;
#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct InternedType(usize);
impl Debug for InternedType {
@ -14,12 +14,12 @@ impl Debug for InternedType {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypeStruct {
kind: TypeKind,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
enum TypeKind {
/// Elaboration-time types
ElabType(ElabKind),
@ -34,23 +34,23 @@ enum TypeKind {
TypeVar(u32, u32),
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Adt {
Struct(Struct),
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Struct {
members: Vec<Type>,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ElabData {
pub typ: Type,
value: ElabValue,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
enum ElabValue {
/// the value is not given and has to be inferred
Infer,
@ -58,14 +58,14 @@ enum ElabValue {
Concrete(ElabValueData),
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
enum ElabValueData {
U32(u32),
Bytes(Vec<u8>),
}
/// Types that are only valid during Elaboration
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
enum ElabKind {
/// general, unsized number type
Num,
@ -85,6 +85,8 @@ pub enum InferenceResult {
Second(Type),
/// A typevar was inferred
TypeVar(u32, u32, Type),
/// The types are equivalent
Equivalent,
/// The types were incompatible
Incompatible,
/// Neither of the types were complete
@ -244,7 +246,7 @@ impl TypingContext {
fn is_fully_typed_kind(&self, kind: &TypeKind) -> bool {
match kind {
TypeKind::ElabType(_) => todo!(),
TypeKind::ElabType(_) => true,
TypeKind::Logic(data) => {
matches!(data.value, ElabValue::Concrete(_))
}

View File

@ -15,8 +15,8 @@ use crate::parser::{
/// a block that is a single expression
#[derive(Debug, Clone)]
pub struct ExpressionBlock<'a> {
assignments: Vec<(Token<'a>, Expression<'a>)>,
value: Expression<'a>,
pub assignments: Vec<(Token<'a>, Expression<'a>)>,
pub value: Expression<'a>,
}
/// an expression that contains a block

View File

@ -8,11 +8,11 @@ use super::{
use nom::{
branch::alt,
bytes::complete::{is_not, tag, take_until},
character::complete::{anychar, digit1, line_ending},
combinator::{consumed, map, recognize},
character::complete::{alpha1, anychar, digit1, line_ending},
combinator::{consumed, map, not, peek, recognize},
error::ParseError,
multi::many0,
sequence::tuple,
sequence::{terminated, tuple},
InputTake,
};
use std::fmt;
@ -227,6 +227,7 @@ fn lex_punctuation(input: Span) -> IResult<Span, Token> {
fn lex_keywords(input: Span) -> IResult<Span, Token> {
map(
terminated(
consumed(alt((
map(tag("module"), |_| TokenKind::Module),
map(tag("assign"), |_| TokenKind::Assign),
@ -237,6 +238,8 @@ fn lex_keywords(input: Span) -> IResult<Span, Token> {
map(tag("let"), |_| TokenKind::Let),
map(tag("struct"), |_| TokenKind::Struct),
))),
peek(not(alt((alpha1, tag("_"))))),
),
|(span, kind)| Token::new(span, kind),
)(input)
}