From b890c0594d41dbf7bb030c597a09ddd4c38aae19 Mon Sep 17 00:00:00 2001 From: NotAFile Date: Sun, 20 Feb 2022 17:49:28 +0100 Subject: [PATCH] add first steps toward generic type inference --- src/frontend.rs | 73 ++++++++++++++++++++++++++++++++++++++-- src/frontend/callable.rs | 22 ++++++++---- src/frontend/typed_ir.rs | 13 ++++++- src/frontend/types.rs | 34 ++++++++++++++++++- src/main.rs | 8 ++++- 5 files changed, 137 insertions(+), 13 deletions(-) diff --git a/src/frontend.rs b/src/frontend.rs index 07d0ce5..170ca2c 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -133,7 +133,11 @@ impl Context { match &expr.kind { typed_ir::ExprKind::Literal(lit) => Ok(lit.clone()), typed_ir::ExprKind::Path(_) => todo!("evaluate path"), - typed_ir::ExprKind::Call { called, args } => todo!("evaluate call"), + typed_ir::ExprKind::Call { + called, + args, + genargs, + } => todo!("evaluate call"), } } @@ -174,6 +178,7 @@ impl Context { kind: typed_ir::ExprKind::Call { called: self.callables.builtins.bitnot, args: vec![a], + genargs: vec![], }, typ: self.types.primitives.infer, } @@ -185,6 +190,7 @@ impl Context { kind: typed_ir::ExprKind::Call { called: self.callables.builtins.xor, args: vec![a, b], + genargs: vec![], }, typ: self.types.primitives.infer, } @@ -203,11 +209,17 @@ impl Context { expected: called_callable.argcount(), })); } + let genargs_resolved = called_callable + .genargs + .iter() + .map(|genarg| genarg.1) + .collect(); typed_ir::Expr { id, kind: typed_ir::ExprKind::Call { called, args: args_resolved, + genargs: genargs_resolved, }, typ: self.types.primitives.infer, } @@ -270,6 +282,44 @@ impl Context { ))) } + pub fn infer_expr_types(&self, expr: &typed_ir::Expr) -> typed_ir::Expr { + if self.types.is_fully_typed(expr.typ) { + // there is nothing more to infer + return expr.clone(); + } + match &expr.kind { + typed_ir::ExprKind::Literal(_) => todo!(), + typed_ir::ExprKind::Path(_) => todo!(), + typed_ir::ExprKind::Call { + called, + args, + genargs, + } => { + let callee_def = self.callables.get(*called); + if self.types.is_fully_typed(callee_def.ret_type) { + expr.clone().with_type(callee_def.ret_type) + } else { + let args_typed: Vec<_> = + args.iter().map(|ex| self.infer_expr_types(ex)).collect(); + let param_types: Vec<_> = callee_def.args.iter().map(|param| param.1).collect(); + let mut genargs = callee_def.genargs.clone(); + let inferred_args: Vec<_> = param_types + .iter() + .zip(args_typed) + .map(|(param, arg)| {}) + .collect(); + expr.clone().with_type(callee_def.ret_type) + } + } + } + } + + pub fn infer_types(&self, mut block: typed_ir::Block) -> typed_ir::Block { + let new_root = self.infer_expr_types(&block.expr); + block.expr = new_root; + block + } + pub fn pretty_typed_block( &self, w: &mut dyn std::fmt::Write, @@ -292,7 +342,11 @@ impl Context { let expr_pretty = match &expr.kind { typed_ir::ExprKind::Literal(_) => todo!(), typed_ir::ExprKind::Path(path) => format!("sig_{}", path.0), - typed_ir::ExprKind::Call { called, args } => { + typed_ir::ExprKind::Call { + called, + args, + genargs, + } => { let args = args .iter() .map(|arg| { @@ -301,7 +355,20 @@ impl Context { }) .collect::, std::fmt::Error>>()?; let callable = self.callables.get(*called); - format!("{}({})", callable.name(), args.join(", ")) + let genargs = genargs + .iter() + .map(|param| { + let mut type_str = String::new(); + self.types.pretty_type(&mut type_str, *param)?; + Ok(type_str) + }) + .collect::, std::fmt::Error>>()?; + format!( + "{}<{}>({})", + callable.name(), + genargs.join(", "), + args.join(", ") + ) } }; let mut type_pretty = String::new(); diff --git a/src/frontend/callable.rs b/src/frontend/callable.rs index 68be24f..0764153 100644 --- a/src/frontend/callable.rs +++ b/src/frontend/callable.rs @@ -1,4 +1,4 @@ -use super::types::{GenericArg, Type, TypingContext}; +use super::types::{Type, TypingContext}; #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Eq, Ord)] pub struct CallableId(pub usize); @@ -6,7 +6,8 @@ pub struct CallableId(pub usize); pub struct Callable { pub name: String, pub args: Vec<(Option, Type)>, - pub ret_type: Option, + pub genargs: Vec<(Option, Type)>, + pub ret_type: Type, } impl<'ty> Callable { @@ -38,22 +39,29 @@ impl CallableContext { reduce_or: CallableId(2), }; let logic1 = typectx.make_logic_size(1); + let logic_tvar0 = typectx.make_typevar(0, 0); Self { callables: vec![ Callable { name: "builtin::xor".to_string(), - args: vec![], - ret_type: Some(typectx.primitives.logic), + args: vec![ + (Some("a".to_string()), logic_tvar0), + (Some("b".to_string()), logic_tvar0), + ], + genargs: vec![(Some("T".to_string()), typectx.primitives.logic)], + ret_type: logic_tvar0, }, Callable { name: "builtin::bitnot".to_string(), - args: vec![], - ret_type: Some(typectx.primitives.logic), + args: vec![(Some("a".to_string()), logic_tvar0)], + genargs: vec![(Some("T".to_string()), typectx.primitives.logic)], + ret_type: logic_tvar0, }, Callable { name: "builtin::reduce_or".to_string(), args: vec![(Some("a".to_string()), typectx.primitives.logic)], - ret_type: Some(logic1), + genargs: vec![], + ret_type: logic1, }, ], builtins, diff --git a/src/frontend/typed_ir.rs b/src/frontend/typed_ir.rs index 77b0d88..3f7683f 100644 --- a/src/frontend/typed_ir.rs +++ b/src/frontend/typed_ir.rs @@ -32,7 +32,11 @@ pub struct Expr { pub enum ExprKind { Literal(ElabData), Path(DefId), - Call { called: CallableId, args: Vec }, + Call { + called: CallableId, + args: Vec, + genargs: Vec, + }, } #[derive(Debug, Clone)] @@ -47,3 +51,10 @@ pub struct Block { pub signals: Vec, pub expr: Expr, } + +impl Expr { + pub fn with_type(mut self, typ: Type) -> Self { + self.typ = typ; + self + } +} diff --git a/src/frontend/types.rs b/src/frontend/types.rs index 367e031..9a794a2 100644 --- a/src/frontend/types.rs +++ b/src/frontend/types.rs @@ -30,6 +30,9 @@ enum TypeKind { Callable(FnSig), /// A type that was not given and needs to be inferred Infer, + /// A reference to a type variable as DeBruijn index + /// (scope, param) + TypeVar(u32, u32), } #[derive(Debug, Clone)] @@ -131,6 +134,12 @@ impl TypingContext { .unwrap() } + pub fn make_typevar(&mut self, dbi: u32, tvar: u32) -> Type { + self.add(TypeStruct { + kind: TypeKind::TypeVar(dbi, tvar), + }) + } + pub fn parameterize(&mut self, typ: Type, params: &[GenericArg]) -> Option { // TODO: return proper error type here match &self.get(typ).kind { @@ -159,7 +168,29 @@ impl TypingContext { TypeKind::UInt(_) => todo!(), TypeKind::Callable(_sig) => todo!("callable generic params"), // need to know what the type is to parameterize it - TypeKind::Infer => None, + TypeKind::Infer | &TypeKind::TypeVar(_, _) => None, + } + } + + /// return whether the type has no unfilled parameters + pub fn is_fully_typed(&self, typ: Type) -> bool { + match &self.get(typ).kind { + TypeKind::ElabType(_) => todo!(), + TypeKind::Logic(data) => { + if let ElabValue::Concrete(_) = data.value { + true + } else { + false + } + } + TypeKind::UInt(_) => todo!(), + TypeKind::Callable(_) => todo!(), + TypeKind::Infer => false, + TypeKind::TypeVar(dbi, _tvar) => { + // if the DeBruijn index is 0, there is no further information to gain + // from a surrounding scope + *dbi != 0 + } } } @@ -185,6 +216,7 @@ impl TypingContext { TypeKind::Infer => write!(w, "?"), TypeKind::UInt(_) => todo!("print uint"), TypeKind::Callable(_sig) => todo!("print callable"), + TypeKind::TypeVar(_, tvar) => write!(w, "T{}", tvar), } } } diff --git a/src/main.rs b/src/main.rs index 523f106..01cdd85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -67,8 +67,14 @@ fn main() { frontendcontext .pretty_typed_block(&mut pretty_block, &block) .unwrap(); + println!("{}", &pretty_block); + let typed_inferred = frontendcontext.infer_types(block); + let mut pretty_block = String::new(); + frontendcontext + .pretty_typed_block(&mut pretty_block, &typed_inferred) + .unwrap(); + println!("{}", &pretty_block); } - println!("{}", &pretty_block); } /* match lowered {