Created
January 18, 2019 02:34
-
-
Save djg/bd1e5756a256554f4cd1f2b3a672304e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use proc_macro2::{Span, TokenStream}; | |
use quote::quote; | |
use syn::{ | |
fold::Fold, | |
parse::{Parse, ParseStream, Result}, | |
parse_quote, | |
punctuated::Punctuated, | |
Expr, ExprLit, FieldValue, Ident, ItemFn, LitStr, Local, Member, Pat, Stmt, Token, | |
}; | |
pub fn gen(tokens: TokenStream) -> TokenStream { | |
let ast = syn::parse2::<ItemFn>(tokens).unwrap(); | |
let mut cs = RewriteSignals::default(); | |
let ast = cs.fold_item_fn(ast); | |
// println!("ast = {:#?}", ast); | |
let output = quote! { | |
#ast | |
}; | |
output.into() | |
} | |
pub enum SignalAttr { | |
Expr(syn::Expr), | |
NamedExpr { | |
name: syn::Ident, | |
colon: syn::token::Colon, | |
expr: Expr, | |
}, | |
} | |
impl SignalAttr { | |
fn is_keyword(path: Ident) -> Result<Ident> { | |
if path == "shape" | |
|| path == "name" | |
|| path == "reset" | |
|| path == "reset_less" | |
|| path == "min" | |
|| path == "max" | |
{ | |
Ok(path) | |
} else { | |
Err(syn::Error::new( | |
path.span(), | |
"expected `shape`, `name`, `reset`, `reset_less`, `min`, or `max`.", | |
)) | |
} | |
} | |
} | |
impl Parse for SignalAttr { | |
fn parse(input: ParseStream) -> Result<Self> { | |
if input.peek(Ident) && input.peek2(Token![:]) { | |
let name: Ident = input.parse()?; | |
let name = Self::is_keyword(name)?; | |
let colon: Token![:] = input.parse()?; | |
let expr: Expr = input.parse()?; | |
Ok(SignalAttr::NamedExpr { name, colon, expr }) | |
} else { | |
let e: Expr = input.parse()?; | |
Ok(SignalAttr::Expr(e)) | |
} | |
} | |
} | |
pub fn is_signal(e: &Expr) -> bool { | |
match e { | |
Expr::Macro(m) => m.mac.path.is_ident("signal"), | |
_ => false, | |
} | |
} | |
pub fn validate_signal_attrs(attrs: &Punctuated<SignalAttr, Token![,]>) -> Result<()> { | |
use syn::spanned::Spanned; | |
for attr in attrs.iter().skip(1) { | |
match attr { | |
SignalAttr::Expr(e) => Err(syn::Error::new(e.span(), "expected named expression"))?, | |
_ => {} | |
} | |
} | |
Ok(()) | |
} | |
pub struct RewriteSignals { | |
name: Option<Ident>, | |
} | |
impl Default for RewriteSignals { | |
fn default() -> Self { | |
RewriteSignals { name: None } | |
} | |
} | |
impl RewriteSignals { | |
fn rewrite_field(&mut self, fv: FieldValue) -> FieldValue { | |
let FieldValue { | |
attrs, | |
member, | |
expr, | |
.. | |
} = fv; | |
assert!(self.name.is_none()); | |
self.name = match member { | |
Member::Named(ref ident) => Some(ident.clone()), | |
_ => unreachable!(), | |
}; | |
let expr = self.fold_expr(expr); | |
parse_quote! { | |
#(#attrs),* #member : #expr | |
} | |
} | |
fn rewrite_let(&mut self, local: Local) -> Stmt { | |
let Local { pats, ty, init, .. } = local; | |
let pat = &pats[0]; | |
let ty = ty.map(|(colon_token, ty)| quote!(#colon_token #ty)); | |
assert!(self.name.is_none()); | |
self.name = match *pat { | |
Pat::Ident(ref p) => Some(p.ident.clone()), | |
_ => unreachable!(), | |
}; | |
let init = self.fold_expr(*init.unwrap().1); | |
parse_quote! { | |
let #pat #ty = #init; | |
} | |
} | |
fn parse_signal(expr: Expr) -> SignalParams { | |
match expr { | |
Expr::Macro(syn::ExprMacro { mac, .. }) => { | |
parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes") | |
} | |
_ => unreachable!(), | |
} | |
} | |
fn gen_signal(params: &SignalParams) -> Expr { | |
assert!(params.name.is_some()); | |
let name = ¶ms.name; | |
let reset = match params.reset { | |
Some(ref reset) => quote!(#reset), | |
None => quote!(0), | |
}; | |
let reset_less = match params.reset_less { | |
Some(ref reset_less) => quote!(#reset_less), | |
None => quote!(false), | |
}; | |
let min_max = match (¶ms.shape, ¶ms.min, ¶ms.max) { | |
(None, None, None) => quote!(0, 1), | |
(Some((width, _)), _, _) => quote!(0, (1 << #width) - 1), | |
(_, Some(min), None) => quote!(#min as i64, 1), | |
(_, None, Some(max)) => quote!(0, #max as i64), | |
(_, Some(min), Some(max)) => quote!(#min as i64, #max as i64), | |
}; | |
parse_quote! { | |
Value::signal(#name, #min_max, #reset, #reset_less) | |
} | |
} | |
fn is_named_member(member: &Member) -> bool { | |
match member { | |
Member::Named(_) => true, | |
_ => false, | |
} | |
} | |
fn gen_string_lit(ident: &str) -> Expr { | |
ExprLit { | |
attrs: vec![], | |
lit: LitStr::new(ident, Span::call_site()).into(), | |
} | |
.into() | |
} | |
} | |
impl Fold for RewriteSignals { | |
fn fold_expr(&mut self, e: Expr) -> Expr { | |
if !is_signal(&e) { | |
return syn::fold::fold_expr(self, e); | |
} | |
let mut params = Self::parse_signal(e); | |
// set name from variable name if signal! didn't specify one. | |
if params.name.is_none() { | |
params.name = match self.name.take() { | |
Some(name) => Some(Self::gen_string_lit(&name.to_string())), | |
None => Some(Self::gen_string_lit("$signal")), | |
}; | |
} | |
Self::gen_signal(¶ms) | |
} | |
fn fold_field_value(&mut self, f: FieldValue) -> FieldValue { | |
if Self::is_named_member(&f.member) && is_signal(&f.expr) { | |
return self.rewrite_field(f); | |
} | |
syn::fold::fold_field_value(self, f) | |
} | |
fn fold_stmt(&mut self, s: Stmt) -> Stmt { | |
match s { | |
Stmt::Local(s) => { | |
if s.init.as_ref().map_or(false, |(_, init)| is_signal(init)) { | |
return self.rewrite_let(s); | |
} | |
Stmt::Local(syn::fold::fold_local(self, s)) | |
} | |
_ => syn::fold::fold_stmt(self, s), | |
} | |
} | |
} | |
#[derive(Default)] | |
pub struct SignalParams { | |
pub shape: Option<(Expr, Expr)>, | |
pub name: Option<Expr>, | |
pub reset: Option<Expr>, | |
pub reset_less: Option<Expr>, | |
pub min: Option<Expr>, | |
pub max: Option<Expr>, | |
} | |
pub fn parse_signal_attrs(tokens: TokenStream) -> Result<SignalParams> { | |
let parser = Punctuated::<SignalAttr, Token![,]>::parse_terminated; | |
let attrs = syn::parse::Parser::parse2(parser, tokens)?; | |
validate_signal_attrs(&attrs)?; | |
let mut params = SignalParams::default(); | |
for attr in attrs { | |
match attr { | |
SignalAttr::Expr(expr) => { | |
params.shape = Some((expr, syn::parse_str("false").expect("bug"))) | |
} | |
SignalAttr::NamedExpr { name, expr, .. } => { | |
if name == "name" { | |
params.name = Some(expr) | |
} else if name == "reset" { | |
params.reset = Some(expr) | |
} else if name == "reset_less" { | |
params.reset_less = Some(expr) | |
} else if name == "min" { | |
params.min = Some(expr) | |
} else if name == "max" { | |
params.max = Some(expr) | |
} | |
} | |
} | |
} | |
Ok(params) | |
} | |
#[cfg(test)] | |
mod test { | |
use matches::matches; | |
use syn::{self, fold::Fold}; | |
#[test] | |
fn test_is_signal() { | |
let t: syn::Expr = syn::parse_str("signal!()").unwrap(); | |
assert!(super::is_signal(&t)); | |
} | |
#[test] | |
fn test_signal_empty_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = syn::parse::Parser::parse_str(parser, "").expect("Failed to parse"); | |
assert!(attrs.is_empty()); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
fn test_signal_expr_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = syn::parse::Parser::parse_str(parser, "1").expect("Failed to parse"); | |
assert_eq!(attrs.len(), 1); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
fn test_signal_named_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = | |
syn::parse::Parser::parse_str(parser, "name: \"sig\"").expect("Failed to parse"); | |
assert_eq!(attrs.len(), 1); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
fn test_signal_multi_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = | |
syn::parse::Parser::parse_str(parser, "1, name: \"sig\"").expect("Failed to parse"); | |
assert_eq!(attrs.len(), 2); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
fn test_signal_all_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = syn::parse::Parser::parse_str( | |
parser, | |
"1, name: \"sig\", reset: 0, reset_less: true, min: -1, max: 1", | |
) | |
.expect("Failed to parse"); | |
assert_eq!(attrs.len(), 6); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
#[should_panic] | |
fn test_invalid_signal_name_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = syn::parse::Parser::parse_str(parser, "foo: \"sig\"").expect("Failed to parse"); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
#[should_panic] | |
fn test_invalid_signal_multi_attrs() { | |
let parser = | |
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated; | |
let attrs = | |
syn::parse::Parser::parse_str(parser, "foo: \"sig\", 1").expect("Failed to parse"); | |
super::validate_signal_attrs(&attrs).expect("Invalid attributes"); | |
} | |
#[test] | |
fn test_parse_empty_signal() { | |
let t: syn::Expr = syn::parse_str("signal!()").unwrap(); | |
match t { | |
syn::Expr::Macro(syn::ExprMacro { mac, .. }) => { | |
let params = | |
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes"); | |
assert_eq!(params.shape, None); | |
assert_eq!(params.name, None); | |
assert_eq!(params.reset, None); | |
assert_eq!(params.reset_less, None); | |
assert_eq!(params.min, None); | |
assert_eq!(params.max, None); | |
} | |
_ => panic!("Expected a macro"), | |
} | |
} | |
#[test] | |
fn test_parse_simple_signal() { | |
use syn::Expr; | |
let t: Expr = syn::parse_str("signal!(1)").unwrap(); | |
match t { | |
Expr::Macro(syn::ExprMacro { mac, .. }) => { | |
let params = | |
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes"); | |
assert_eq!( | |
params.shape, | |
Some(( | |
syn::parse_str("1").unwrap(), | |
syn::parse_str("false").unwrap() | |
)) | |
); | |
assert_eq!(params.name, None); | |
assert_eq!(params.reset, None); | |
assert_eq!(params.reset_less, None); | |
assert_eq!(params.min, None); | |
assert_eq!(params.max, None); | |
} | |
_ => panic!("Expected a macro"), | |
} | |
} | |
#[test] | |
fn test_parse_complex_signal() { | |
use syn::Expr; | |
let t: Expr = | |
syn::parse_str("signal!(name: \"sig\", reset: 2, reset_less: true, min: 0, max: 10)") | |
.unwrap(); | |
match t { | |
Expr::Macro(syn::ExprMacro { mac, .. }) => { | |
let params = | |
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes"); | |
assert_eq!(params.shape, None); | |
assert_eq!(params.name, Some(syn::parse_str("\"sig\"").unwrap())); | |
assert_eq!(params.reset, Some(syn::parse_str("2").unwrap())); | |
assert_eq!(params.reset_less, Some(syn::parse_str("true").unwrap())); | |
assert_eq!(params.min, Some(syn::parse_str("0").unwrap())); | |
assert_eq!(params.max, Some(syn::parse_str("10").unwrap())); | |
} | |
_ => panic!("Expected a macro"), | |
} | |
} | |
#[test] | |
fn test_let_signal() { | |
let t: syn::Stmt = syn::parse_str("let o = signal!();").unwrap(); | |
let mut cs = super::RewriteSignals::default(); | |
cs.fold_stmt(t); | |
} | |
#[test] | |
fn test_field_signal() { | |
let t: syn::ItemFn = | |
syn::parse_str("pub fn new() -> Self { Foo { a: signal!() } }").unwrap(); | |
let mut cs = super::RewriteSignals::default(); | |
cs.fold_item_fn(t); | |
} | |
#[test] | |
fn test_expr_signal() { | |
let t: syn::Stmt = syn::parse_str( | |
"let rst = if reset_less { None } else { Some(signal!(name: \"rst\")) };", | |
) | |
.unwrap(); | |
let mut cs = super::RewriteSignals::default(); | |
cs.fold_stmt(t); | |
} | |
#[test] | |
fn test_domain_attribute() { | |
use crate::domain::Domain; | |
use proc_macro2::Span; | |
use syn::Ident; | |
let parser = syn::Attribute::parse_outer; | |
let t = syn::parse::Parser::parse_str(parser, "#[comb]").unwrap(); | |
assert_eq!(t.len(), 1); | |
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Comb); | |
let t = syn::parse::Parser::parse_str(parser, "#[sync]").unwrap(); | |
assert_eq!(t.len(), 1); | |
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Sync(Ident::new("sync", Span::call_site()))); | |
let t = syn::parse::Parser::parse_str(parser, "#[sync(\"por\")]").unwrap(); | |
assert_eq!(t.len(), 1); | |
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Sync(Ident::new("por", Span::call_site()))); | |
} | |
/* | |
#[test] | |
fn test_stmt_attr() { | |
let t: syn::Stmt = syn::parse_str("#[sync] self.v = self.v + 1;").unwrap(); | |
println!("t = {:#?}", t); | |
match t { | |
syn::Stmt::Semi(ref e, ..) => { | |
}, | |
_ => panic!("Expected Stmt::Semi") | |
} | |
} | |
*/ | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment