Skip to content

Instantly share code, notes, and snippets.

@djg
Created January 18, 2019 02:34
Show Gist options
  • Save djg/bd1e5756a256554f4cd1f2b3a672304e to your computer and use it in GitHub Desktop.
Save djg/bd1e5756a256554f4cd1f2b3a672304e to your computer and use it in GitHub Desktop.
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
fold::Fold,
parse::{Parse, ParseStream, Result},
parse_quote,
punctuated::Punctuated,
Expr, ExprLit, FieldValue, Ident, ItemFn, LitStr, Local, Member, Pat, Stmt, Token,
};
pub fn gen(tokens: TokenStream) -> TokenStream {
let ast = syn::parse2::<ItemFn>(tokens).unwrap();
let mut cs = RewriteSignals::default();
let ast = cs.fold_item_fn(ast);
// println!("ast = {:#?}", ast);
let output = quote! {
#ast
};
output.into()
}
pub enum SignalAttr {
Expr(syn::Expr),
NamedExpr {
name: syn::Ident,
colon: syn::token::Colon,
expr: Expr,
},
}
impl SignalAttr {
fn is_keyword(path: Ident) -> Result<Ident> {
if path == "shape"
|| path == "name"
|| path == "reset"
|| path == "reset_less"
|| path == "min"
|| path == "max"
{
Ok(path)
} else {
Err(syn::Error::new(
path.span(),
"expected `shape`, `name`, `reset`, `reset_less`, `min`, or `max`.",
))
}
}
}
impl Parse for SignalAttr {
fn parse(input: ParseStream) -> Result<Self> {
if input.peek(Ident) && input.peek2(Token![:]) {
let name: Ident = input.parse()?;
let name = Self::is_keyword(name)?;
let colon: Token![:] = input.parse()?;
let expr: Expr = input.parse()?;
Ok(SignalAttr::NamedExpr { name, colon, expr })
} else {
let e: Expr = input.parse()?;
Ok(SignalAttr::Expr(e))
}
}
}
pub fn is_signal(e: &Expr) -> bool {
match e {
Expr::Macro(m) => m.mac.path.is_ident("signal"),
_ => false,
}
}
pub fn validate_signal_attrs(attrs: &Punctuated<SignalAttr, Token![,]>) -> Result<()> {
use syn::spanned::Spanned;
for attr in attrs.iter().skip(1) {
match attr {
SignalAttr::Expr(e) => Err(syn::Error::new(e.span(), "expected named expression"))?,
_ => {}
}
}
Ok(())
}
pub struct RewriteSignals {
name: Option<Ident>,
}
impl Default for RewriteSignals {
fn default() -> Self {
RewriteSignals { name: None }
}
}
impl RewriteSignals {
fn rewrite_field(&mut self, fv: FieldValue) -> FieldValue {
let FieldValue {
attrs,
member,
expr,
..
} = fv;
assert!(self.name.is_none());
self.name = match member {
Member::Named(ref ident) => Some(ident.clone()),
_ => unreachable!(),
};
let expr = self.fold_expr(expr);
parse_quote! {
#(#attrs),* #member : #expr
}
}
fn rewrite_let(&mut self, local: Local) -> Stmt {
let Local { pats, ty, init, .. } = local;
let pat = &pats[0];
let ty = ty.map(|(colon_token, ty)| quote!(#colon_token #ty));
assert!(self.name.is_none());
self.name = match *pat {
Pat::Ident(ref p) => Some(p.ident.clone()),
_ => unreachable!(),
};
let init = self.fold_expr(*init.unwrap().1);
parse_quote! {
let #pat #ty = #init;
}
}
fn parse_signal(expr: Expr) -> SignalParams {
match expr {
Expr::Macro(syn::ExprMacro { mac, .. }) => {
parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes")
}
_ => unreachable!(),
}
}
fn gen_signal(params: &SignalParams) -> Expr {
assert!(params.name.is_some());
let name = &params.name;
let reset = match params.reset {
Some(ref reset) => quote!(#reset),
None => quote!(0),
};
let reset_less = match params.reset_less {
Some(ref reset_less) => quote!(#reset_less),
None => quote!(false),
};
let min_max = match (&params.shape, &params.min, &params.max) {
(None, None, None) => quote!(0, 1),
(Some((width, _)), _, _) => quote!(0, (1 << #width) - 1),
(_, Some(min), None) => quote!(#min as i64, 1),
(_, None, Some(max)) => quote!(0, #max as i64),
(_, Some(min), Some(max)) => quote!(#min as i64, #max as i64),
};
parse_quote! {
Value::signal(#name, #min_max, #reset, #reset_less)
}
}
fn is_named_member(member: &Member) -> bool {
match member {
Member::Named(_) => true,
_ => false,
}
}
fn gen_string_lit(ident: &str) -> Expr {
ExprLit {
attrs: vec![],
lit: LitStr::new(ident, Span::call_site()).into(),
}
.into()
}
}
impl Fold for RewriteSignals {
fn fold_expr(&mut self, e: Expr) -> Expr {
if !is_signal(&e) {
return syn::fold::fold_expr(self, e);
}
let mut params = Self::parse_signal(e);
// set name from variable name if signal! didn't specify one.
if params.name.is_none() {
params.name = match self.name.take() {
Some(name) => Some(Self::gen_string_lit(&name.to_string())),
None => Some(Self::gen_string_lit("$signal")),
};
}
Self::gen_signal(&params)
}
fn fold_field_value(&mut self, f: FieldValue) -> FieldValue {
if Self::is_named_member(&f.member) && is_signal(&f.expr) {
return self.rewrite_field(f);
}
syn::fold::fold_field_value(self, f)
}
fn fold_stmt(&mut self, s: Stmt) -> Stmt {
match s {
Stmt::Local(s) => {
if s.init.as_ref().map_or(false, |(_, init)| is_signal(init)) {
return self.rewrite_let(s);
}
Stmt::Local(syn::fold::fold_local(self, s))
}
_ => syn::fold::fold_stmt(self, s),
}
}
}
#[derive(Default)]
pub struct SignalParams {
pub shape: Option<(Expr, Expr)>,
pub name: Option<Expr>,
pub reset: Option<Expr>,
pub reset_less: Option<Expr>,
pub min: Option<Expr>,
pub max: Option<Expr>,
}
pub fn parse_signal_attrs(tokens: TokenStream) -> Result<SignalParams> {
let parser = Punctuated::<SignalAttr, Token![,]>::parse_terminated;
let attrs = syn::parse::Parser::parse2(parser, tokens)?;
validate_signal_attrs(&attrs)?;
let mut params = SignalParams::default();
for attr in attrs {
match attr {
SignalAttr::Expr(expr) => {
params.shape = Some((expr, syn::parse_str("false").expect("bug")))
}
SignalAttr::NamedExpr { name, expr, .. } => {
if name == "name" {
params.name = Some(expr)
} else if name == "reset" {
params.reset = Some(expr)
} else if name == "reset_less" {
params.reset_less = Some(expr)
} else if name == "min" {
params.min = Some(expr)
} else if name == "max" {
params.max = Some(expr)
}
}
}
}
Ok(params)
}
#[cfg(test)]
mod test {
use matches::matches;
use syn::{self, fold::Fold};
#[test]
fn test_is_signal() {
let t: syn::Expr = syn::parse_str("signal!()").unwrap();
assert!(super::is_signal(&t));
}
#[test]
fn test_signal_empty_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs = syn::parse::Parser::parse_str(parser, "").expect("Failed to parse");
assert!(attrs.is_empty());
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
fn test_signal_expr_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs = syn::parse::Parser::parse_str(parser, "1").expect("Failed to parse");
assert_eq!(attrs.len(), 1);
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
fn test_signal_named_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs =
syn::parse::Parser::parse_str(parser, "name: \"sig\"").expect("Failed to parse");
assert_eq!(attrs.len(), 1);
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
fn test_signal_multi_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs =
syn::parse::Parser::parse_str(parser, "1, name: \"sig\"").expect("Failed to parse");
assert_eq!(attrs.len(), 2);
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
fn test_signal_all_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs = syn::parse::Parser::parse_str(
parser,
"1, name: \"sig\", reset: 0, reset_less: true, min: -1, max: 1",
)
.expect("Failed to parse");
assert_eq!(attrs.len(), 6);
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
#[should_panic]
fn test_invalid_signal_name_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs = syn::parse::Parser::parse_str(parser, "foo: \"sig\"").expect("Failed to parse");
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
#[should_panic]
fn test_invalid_signal_multi_attrs() {
let parser =
syn::punctuated::Punctuated::<super::SignalAttr, syn::Token![,]>::parse_terminated;
let attrs =
syn::parse::Parser::parse_str(parser, "foo: \"sig\", 1").expect("Failed to parse");
super::validate_signal_attrs(&attrs).expect("Invalid attributes");
}
#[test]
fn test_parse_empty_signal() {
let t: syn::Expr = syn::parse_str("signal!()").unwrap();
match t {
syn::Expr::Macro(syn::ExprMacro { mac, .. }) => {
let params =
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes");
assert_eq!(params.shape, None);
assert_eq!(params.name, None);
assert_eq!(params.reset, None);
assert_eq!(params.reset_less, None);
assert_eq!(params.min, None);
assert_eq!(params.max, None);
}
_ => panic!("Expected a macro"),
}
}
#[test]
fn test_parse_simple_signal() {
use syn::Expr;
let t: Expr = syn::parse_str("signal!(1)").unwrap();
match t {
Expr::Macro(syn::ExprMacro { mac, .. }) => {
let params =
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes");
assert_eq!(
params.shape,
Some((
syn::parse_str("1").unwrap(),
syn::parse_str("false").unwrap()
))
);
assert_eq!(params.name, None);
assert_eq!(params.reset, None);
assert_eq!(params.reset_less, None);
assert_eq!(params.min, None);
assert_eq!(params.max, None);
}
_ => panic!("Expected a macro"),
}
}
#[test]
fn test_parse_complex_signal() {
use syn::Expr;
let t: Expr =
syn::parse_str("signal!(name: \"sig\", reset: 2, reset_less: true, min: 0, max: 10)")
.unwrap();
match t {
Expr::Macro(syn::ExprMacro { mac, .. }) => {
let params =
super::parse_signal_attrs(mac.tts).expect("Failed to parse signal attributes");
assert_eq!(params.shape, None);
assert_eq!(params.name, Some(syn::parse_str("\"sig\"").unwrap()));
assert_eq!(params.reset, Some(syn::parse_str("2").unwrap()));
assert_eq!(params.reset_less, Some(syn::parse_str("true").unwrap()));
assert_eq!(params.min, Some(syn::parse_str("0").unwrap()));
assert_eq!(params.max, Some(syn::parse_str("10").unwrap()));
}
_ => panic!("Expected a macro"),
}
}
#[test]
fn test_let_signal() {
let t: syn::Stmt = syn::parse_str("let o = signal!();").unwrap();
let mut cs = super::RewriteSignals::default();
cs.fold_stmt(t);
}
#[test]
fn test_field_signal() {
let t: syn::ItemFn =
syn::parse_str("pub fn new() -> Self { Foo { a: signal!() } }").unwrap();
let mut cs = super::RewriteSignals::default();
cs.fold_item_fn(t);
}
#[test]
fn test_expr_signal() {
let t: syn::Stmt = syn::parse_str(
"let rst = if reset_less { None } else { Some(signal!(name: \"rst\")) };",
)
.unwrap();
let mut cs = super::RewriteSignals::default();
cs.fold_stmt(t);
}
#[test]
fn test_domain_attribute() {
use crate::domain::Domain;
use proc_macro2::Span;
use syn::Ident;
let parser = syn::Attribute::parse_outer;
let t = syn::parse::Parser::parse_str(parser, "#[comb]").unwrap();
assert_eq!(t.len(), 1);
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Comb);
let t = syn::parse::Parser::parse_str(parser, "#[sync]").unwrap();
assert_eq!(t.len(), 1);
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Sync(Ident::new("sync", Span::call_site())));
let t = syn::parse::Parser::parse_str(parser, "#[sync(\"por\")]").unwrap();
assert_eq!(t.len(), 1);
assert_eq!(Domain::try_from(&t[0]).expect("Expected Domain"), Domain::Sync(Ident::new("por", Span::call_site())));
}
/*
#[test]
fn test_stmt_attr() {
let t: syn::Stmt = syn::parse_str("#[sync] self.v = self.v + 1;").unwrap();
println!("t = {:#?}", t);
match t {
syn::Stmt::Semi(ref e, ..) => {
},
_ => panic!("Expected Stmt::Semi")
}
}
*/
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment