Last active
December 3, 2024 15:55
-
-
Save bonzini/beae973e906f6076f776eac1d812bac5 to your computer and use it in GitHub Desktop.
recursive descent parser in a procedural macro
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// # Definition entry point | |
/// | |
/// Define a struct with a single field of type $type. Include public constants | |
/// for each element listed in braces. | |
/// | |
/// The unnamed element at the end, if present, can be used to enlarge the set of | |
/// valid bits. Bits that are valid but not listed are treated normally for | |
/// the purpose of arithmetic operations, and are printed with their hexadecimal | |
/// value. | |
/// | |
/// The struct implements the following traits: [`BitAnd`](std::ops::BitAnd), | |
/// [`BitOr`](std::ops::BitOr), [`BitXor`](std::ops::BitXor), | |
/// [`Not`](std::ops::Not), [`Sub`](std::ops::Sub); [`Debug`](std::fmt::Debug), | |
/// [`Display`](std::fmt::Display), [`Binary`](std::fmt::Binary), [`Octal`](std::fmt::Octal), | |
/// [`LowerHex`](std::fmt::LowerHex), [`UpperHex`](std::fmt::UpperHex); | |
/// [`From`]`<type>`/[`Into`]`<type>` where type is the type specified in the | |
/// definition. | |
/// | |
/// ## Example | |
/// | |
/// ``` | |
/// # use bits::bits; | |
/// bits! { | |
/// pub struct Colors(u8) { | |
/// BLACK = 0, | |
/// RED = 1, | |
/// GREEN = 1 << 1, | |
/// BLUE = 1 << 2, | |
/// WHITE = (1 << 0) | (1 << 1) | (1 << 2), | |
/// } | |
/// } | |
/// ``` | |
/// | |
/// ``` | |
/// # use bits::bits; | |
/// # bits! { pub struct Colors(u8) { BLACK = 0, RED = 1, GREEN = 1 << 1, BLUE = 1 << 2, } } | |
/// | |
/// bits! { | |
/// pub struct Colors8(u8) { | |
/// BLACK = 0, | |
/// RED = 1, | |
/// GREEN = 1 << 1, | |
/// BLUE = 1 << 2, | |
/// WHITE = (1 << 0) | (1 << 1) | (1 << 2), | |
/// | |
/// _ = 255, | |
/// } | |
/// } | |
/// | |
/// // The previously defined struct ignores bits not explicitly defined. | |
/// assert_eq!(Colors::from(255).val(), Colors::RED | Colors::GREEN | Colors::BLUE); | |
/// | |
/// // Adding "_ = 255" makes it retain other bits as well. | |
/// assert_eq!(Colors8::from(255).val(), 255); | |
/// ``` | |
/// # Evaluation entry point | |
/// | |
/// Return a constant corresponding to the boolean expression `$expr`. | |
/// Identifiers in the expression correspond to values defined for the | |
/// type `$type`. Supported operators are `!` (unary), `-`, `&`, `^`, `|`. | |
/// | |
/// ## Examples | |
/// | |
/// ``` | |
/// # use bits::bits; | |
/// # bits! { struct Colors(u8) { RED = 1, GREEN = 2, BLUE = 4 } } | |
/// let rgb = bits! { Colors: RED | GREEN | BLUE }; | |
/// assert_eq!(rgb, Colors::WHITE); | |
/// | |
/// pub struct Colors(u8) { | |
/// BLACK = 0, | |
/// RED = 1, | |
/// GREEN = 1 << 1, | |
/// BLUE = 1 << 2, | |
/// // same as "WHITE = 7", | |
/// WHITE = bits!(Self as u8: RED | GREEN | BLUE); | |
/// } | |
/// ``` | |
#[macro_export] | |
macro_rules! bits { | |
{ | |
$(#[$struct_meta:meta])* | |
$struct_vis:vis struct $struct_name:ident($type:ty) { | |
$($const:ident = $val:expr),+ | |
$(,_ = $mask:expr)? | |
$(,)? | |
} | |
} => { | |
$(#[$struct_meta])* | |
#[derive(Clone, Copy, PartialEq, Eq)] | |
#[repr(transparent)] | |
$struct_vis struct $struct_name($type); | |
impl $struct_name { | |
$( #[allow(dead_code)] pub const $const: $struct_name = $struct_name($val); )+ | |
const VALID__: $type = $( Self::$const.0 )|+ $(|$mask)?; | |
#[allow(dead_code)] | |
pub const fn valid_bits() -> Self { | |
Self(Self::VALID__) | |
} | |
#[allow(dead_code)] | |
pub const fn valid(val: $type) -> bool { | |
(val & !Self::VALID__) == 0 | |
} | |
#[allow(dead_code)] | |
pub const fn any_set(self, mask: Self) -> bool { | |
(self.0 & mask.0) != 0 | |
} | |
#[allow(dead_code)] | |
pub const fn all_set(self, mask: Self) -> bool { | |
(self.0 & mask.0) == mask.0 | |
} | |
#[allow(dead_code)] | |
pub const fn none_set(self, mask: Self) -> bool { | |
(self.0 & mask.0) == 0 | |
} | |
#[allow(dead_code)] | |
pub const fn val(self) -> $type { | |
self.0 | |
} | |
#[allow(dead_code)] | |
pub const fn set(&mut self, rhs: Self) { | |
self.0 |= rhs.0; | |
} | |
#[allow(dead_code)] | |
pub const fn clear(&mut self, rhs: Self) { | |
self.0 &= !rhs.0; | |
} | |
#[allow(dead_code)] | |
pub const fn toggle(&mut self, rhs: Self) { | |
self.0 ^= rhs.0; | |
} | |
#[allow(dead_code)] | |
pub const fn intersection(self, rhs: Self) -> Self { | |
$struct_name(self.0 & rhs.0) | |
} | |
#[allow(dead_code)] | |
pub const fn difference(self, rhs: Self) -> Self { | |
$struct_name(self.0 & !rhs.0) | |
} | |
#[allow(dead_code)] | |
pub const fn symmetric_difference(self, rhs: Self) -> Self { | |
$struct_name(self.0 ^ rhs.0) | |
} | |
#[allow(dead_code)] | |
pub const fn union(self, rhs: Self) -> Self { | |
$struct_name(self.0 | rhs.0) | |
} | |
#[allow(dead_code)] | |
pub const fn invert(self) -> Self { | |
$struct_name(self.0 ^ Self::VALID__) | |
} | |
} | |
impl ::std::fmt::Binary for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
// If no width, use the highest valid bit | |
let width = f.width().unwrap_or((Self::VALID__.ilog2() + 1) as usize); | |
write!(f, "{:0>width$.precision$b}", self.0, | |
width = width, | |
precision = f.precision().unwrap_or(width)) | |
} | |
} | |
impl ::std::fmt::LowerHex for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
<$type as ::std::fmt::LowerHex>::fmt(&self.0, f) | |
} | |
} | |
impl ::std::fmt::Octal for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
<$type as ::std::fmt::Octal>::fmt(&self.0, f) | |
} | |
} | |
impl ::std::fmt::UpperHex for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
<$type as ::std::fmt::UpperHex>::fmt(&self.0, f) | |
} | |
} | |
impl ::std::fmt::Debug for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
write!(f, "{}({})", stringify!($struct_name), self) | |
} | |
} | |
impl ::std::fmt::Display for $struct_name { | |
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
use ::std::fmt::Display; | |
let mut first = true; | |
let mut left = self.0; | |
$(if Self::$const.0.is_power_of_two() && (self & Self::$const).0 != 0 { | |
if first { first = false } else { Display::fmt(&'|', f)?; } | |
Display::fmt(stringify!($const), f)?; | |
left -= Self::$const.0; | |
})+ | |
if first { | |
Display::fmt(&'0', f) | |
} else if left != 0 { | |
write!(f, "|{left:#x}") | |
} else { | |
Ok(()) | |
} | |
} | |
} | |
impl ::std::cmp::PartialEq<$type> for $struct_name { | |
fn eq(&self, rhs: &$type) -> bool { | |
self.0 == *rhs | |
} | |
} | |
impl ::std::ops::BitAnd<$struct_name> for &$struct_name { | |
type Output = $struct_name; | |
fn bitand(self, rhs: $struct_name) -> Self::Output { | |
$struct_name(self.0 & rhs.0) | |
} | |
} | |
impl ::std::ops::BitXor<$struct_name> for &$struct_name { | |
type Output = $struct_name; | |
fn bitxor(self, rhs: $struct_name) -> Self::Output { | |
$struct_name(self.0 ^ rhs.0) | |
} | |
} | |
impl ::std::ops::BitOr<$struct_name> for &$struct_name { | |
type Output = $struct_name; | |
fn bitor(self, rhs: $struct_name) -> Self::Output { | |
$struct_name(self.0 | rhs.0) | |
} | |
} | |
impl ::std::ops::Sub<$struct_name> for &$struct_name { | |
type Output = $struct_name; | |
fn sub(self, rhs: $struct_name) -> Self::Output { | |
$struct_name(self.0 & !rhs.0) | |
} | |
} | |
impl ::std::ops::Not for &$struct_name { | |
type Output = $struct_name; | |
fn not(self) -> Self::Output { | |
$struct_name(self.0 ^ $struct_name::VALID__) | |
} | |
} | |
impl ::std::ops::BitAnd<$struct_name> for $struct_name { | |
type Output = Self; | |
fn bitand(self, rhs: Self) -> Self::Output { | |
$struct_name(self.0 & rhs.0) | |
} | |
} | |
impl ::std::ops::BitXor<$struct_name> for $struct_name { | |
type Output = Self; | |
fn bitxor(self, rhs: Self) -> Self::Output { | |
$struct_name(self.0 ^ rhs.0) | |
} | |
} | |
impl ::std::ops::BitOr<$struct_name> for $struct_name { | |
type Output = Self; | |
fn bitor(self, rhs: Self) -> Self::Output { | |
$struct_name(self.0 | rhs.0) | |
} | |
} | |
impl ::std::ops::Sub<$struct_name> for $struct_name { | |
type Output = Self; | |
fn sub(self, rhs: Self) -> Self::Output { | |
$struct_name(self.0 & !rhs.0) | |
} | |
} | |
impl ::std::ops::Not for $struct_name { | |
type Output = Self; | |
fn not(self) -> Self::Output { | |
$struct_name(self.0 ^ Self::VALID__) | |
} | |
} | |
impl From<$struct_name> for $type { | |
fn from(x: $struct_name) -> $type { | |
x.0 | |
} | |
} | |
impl From<$type> for $struct_name { | |
fn from(x: $type) -> Self { | |
$struct_name(x & Self::VALID__) | |
} | |
} | |
}; | |
{ $type:ty: $expr:expr } => { | |
::macros::bits_const_internal! { $type @ ($expr) } | |
}; | |
{ $type:ty as $int_type:ty: $expr:expr } => { | |
(::macros::bits_const_internal! { $type @ ($expr) }.0) as $int_type | |
}; | |
} | |
bits! { | |
pub(crate) struct InterruptMask(u32) { | |
OE = 1 << 10, | |
BE = 1 << 9, | |
PE = 1 << 8, | |
FE = 1 << 7, | |
RT = 1 << 6, | |
TX = 1 << 5, | |
RX = 1 << 4, | |
DSR = 1 << 3, | |
DCD = 1 << 2, | |
CTS = 1 << 1, | |
RI = 1 << 0, | |
E = bits!(Self as u32: OE | BE | PE | FE), | |
MS = bits!(Self as u32: RI | DSR | DCD | CTS), | |
} | |
} | |
fn main() { | |
println!("{:#b}", InterruptMask::MS); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate proc_macro; | |
use proc_macro::{ | |
Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree, TokenTree as TT, | |
}; | |
struct BitsConstInternal { | |
typ: TokenTree, | |
} | |
fn paren(ts: TokenStream) -> TokenTree { | |
TT::Group(Group::new(Delimiter::Parenthesis, ts)) | |
} | |
fn ident(s: &'static str) -> TokenTree { | |
TT::Ident(Ident::new(s, Span::call_site())) | |
} | |
fn punct(ch: char) -> TokenTree { | |
TT::Punct(Punct::new(ch, Spacing::Alone)) | |
} | |
impl BitsConstInternal { | |
fn parse_primary( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
) -> Result<Option<TokenTree>, String> { | |
let next = match tok { | |
TT::Group(ref g) => { | |
if g.delimiter() != Delimiter::Parenthesis && g.delimiter() != Delimiter::None { | |
return Err("expected parenthesis")?; | |
} | |
let mut stream = g.stream().into_iter(); | |
let Some(first_tok) = stream.next() else { | |
return Err("expected operand, found ')'")?; | |
}; | |
let mut output = TokenStream::new(); | |
// start from the lowest precedence | |
let next = self.parse_or(first_tok, &mut stream, &mut output)?; | |
if let Some(tok) = next { | |
Err(format!("unexpected token {}", tok))?; | |
} | |
out.extend(Some(paren(output))); | |
it.next() | |
} | |
TT::Ident(_) => { | |
let mut output = TokenStream::new(); | |
output.extend([ | |
self.typ.clone(), | |
TT::Punct(Punct::new(':', Spacing::Joint)), | |
TT::Punct(Punct::new(':', Spacing::Joint)), | |
tok, | |
]); | |
out.extend(Some(paren(output))); | |
it.next() | |
} | |
TT::Punct(ref p) => { | |
if p.as_char() != '!' { | |
return Err("expected operand")?; | |
} | |
let Some(rhs_tok) = it.next() else { | |
return Err("expected operand at end of input")?; | |
}; | |
let next = self.parse_primary(rhs_tok, it, out)?; | |
out.extend([punct('.'), ident("invert"), paren(TokenStream::new())]); | |
next | |
} | |
_ => Err("unexpected literal")?, | |
}; | |
Ok(next) | |
} | |
fn parse_binop< | |
F: Fn( | |
&Self, | |
TokenTree, | |
&mut dyn Iterator<Item = TokenTree>, | |
&mut TokenStream, | |
) -> Result<Option<TokenTree>, String>, | |
>( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
ch: char, | |
f: F, | |
method: &'static str, | |
) -> Result<Option<TokenTree>, String> { | |
let mut next = f(self, tok, it, out)?; | |
while next.is_some() { | |
let op = next.as_ref().unwrap(); | |
let TT::Punct(ref p) = op else { break }; | |
if p.as_char() != ch { | |
break; | |
} | |
let Some(rhs_tok) = it.next() else { | |
return Err("expected operand at end of input")?; | |
}; | |
let mut rhs = TokenStream::new(); | |
next = f(self, rhs_tok, it, &mut rhs)?; | |
out.extend([punct('.'), ident(method), paren(rhs)]); | |
} | |
Ok(next) | |
} | |
pub fn parse_sub( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
) -> Result<Option<TokenTree>, String> { | |
self.parse_binop(tok, it, out, '-', Self::parse_primary, "difference") | |
} | |
fn parse_and( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
) -> Result<Option<TokenTree>, String> { | |
self.parse_binop(tok, it, out, '&', Self::parse_sub, "intersection") | |
} | |
fn parse_xor( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
) -> Result<Option<TokenTree>, String> { | |
self.parse_binop(tok, it, out, '^', Self::parse_and, "symmetric_difference") | |
} | |
pub fn parse_or( | |
&self, | |
tok: TokenTree, | |
it: &mut dyn Iterator<Item = TokenTree>, | |
out: &mut TokenStream, | |
) -> Result<Option<TokenTree>, String> { | |
self.parse_binop(tok, it, out, '|', Self::parse_xor, "union") | |
} | |
fn parse(it: &mut dyn Iterator<Item = TokenTree>) -> Result<TokenStream, String> { | |
let mut typ = TokenStream::new(); | |
let next = loop { | |
match it.next() { | |
None => break None, | |
Some(TT::Punct(ref p)) if p.as_char() == '@' => break it.next(), | |
Some(x) => typ.extend(Some(x)), | |
} | |
}; | |
let Some(tok) = next else { | |
Err("expected expression, do not call this macro directly")? | |
}; | |
let TT::Group(ref _group) = tok else { | |
Err("expected parenthesis, do not call this macro directly")? | |
}; | |
let mut out = TokenStream::new(); | |
let state = Self { | |
typ: TT::Group(Group::new(Delimiter::None, typ)), | |
}; | |
let next = state.parse_primary(tok, it, &mut out)?; | |
if let Some(tok) = next { | |
Err(format!("unexpected token {}", tok))?; | |
} | |
Ok(out) | |
} | |
} | |
#[proc_macro] | |
pub fn bits_const_internal(ts: TokenStream) -> TokenStream { | |
let mut it = ts.into_iter(); | |
BitsConstInternal::parse(&mut it).unwrap() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment