Skip to content

Instantly share code, notes, and snippets.

@bonzini
Last active December 3, 2024 15:55
Show Gist options
  • Save bonzini/beae973e906f6076f776eac1d812bac5 to your computer and use it in GitHub Desktop.
Save bonzini/beae973e906f6076f776eac1d812bac5 to your computer and use it in GitHub Desktop.
recursive descent parser in a procedural macro
/// # Definition entry point
///
/// Define a struct with a single field of type $type. Include public constants
/// for each element listed in braces.
///
/// The unnamed element at the end, if present, can be used to enlarge the set of
/// valid bits. Bits that are valid but not listed are treated normally for
/// the purpose of arithmetic operations, and are printed with their hexadecimal
/// value.
///
/// The struct implements the following traits: [`BitAnd`](std::ops::BitAnd),
/// [`BitOr`](std::ops::BitOr), [`BitXor`](std::ops::BitXor),
/// [`Not`](std::ops::Not), [`Sub`](std::ops::Sub); [`Debug`](std::fmt::Debug),
/// [`Display`](std::fmt::Display), [`Binary`](std::fmt::Binary), [`Octal`](std::fmt::Octal),
/// [`LowerHex`](std::fmt::LowerHex), [`UpperHex`](std::fmt::UpperHex);
/// [`From`]`<type>`/[`Into`]`<type>` where type is the type specified in the
/// definition.
///
/// ## Example
///
/// ```
/// # use bits::bits;
/// bits! {
/// pub struct Colors(u8) {
/// BLACK = 0,
/// RED = 1,
/// GREEN = 1 << 1,
/// BLUE = 1 << 2,
/// WHITE = (1 << 0) | (1 << 1) | (1 << 2),
/// }
/// }
/// ```
///
/// ```
/// # use bits::bits;
/// # bits! { pub struct Colors(u8) { BLACK = 0, RED = 1, GREEN = 1 << 1, BLUE = 1 << 2, } }
///
/// bits! {
/// pub struct Colors8(u8) {
/// BLACK = 0,
/// RED = 1,
/// GREEN = 1 << 1,
/// BLUE = 1 << 2,
/// WHITE = (1 << 0) | (1 << 1) | (1 << 2),
///
/// _ = 255,
/// }
/// }
///
/// // The previously defined struct ignores bits not explicitly defined.
/// assert_eq!(Colors::from(255).val(), Colors::RED | Colors::GREEN | Colors::BLUE);
///
/// // Adding "_ = 255" makes it retain other bits as well.
/// assert_eq!(Colors8::from(255).val(), 255);
/// ```
/// # Evaluation entry point
///
/// Return a constant corresponding to the boolean expression `$expr`.
/// Identifiers in the expression correspond to values defined for the
/// type `$type`. Supported operators are `!` (unary), `-`, `&`, `^`, `|`.
///
/// ## Examples
///
/// ```
/// # use bits::bits;
/// # bits! { struct Colors(u8) { RED = 1, GREEN = 2, BLUE = 4 } }
/// let rgb = bits! { Colors: RED | GREEN | BLUE };
/// assert_eq!(rgb, Colors::WHITE);
///
/// pub struct Colors(u8) {
/// BLACK = 0,
/// RED = 1,
/// GREEN = 1 << 1,
/// BLUE = 1 << 2,
/// // same as "WHITE = 7",
/// WHITE = bits!(Self as u8: RED | GREEN | BLUE);
/// }
/// ```
#[macro_export]
macro_rules! bits {
{
$(#[$struct_meta:meta])*
$struct_vis:vis struct $struct_name:ident($type:ty) {
$($const:ident = $val:expr),+
$(,_ = $mask:expr)?
$(,)?
}
} => {
$(#[$struct_meta])*
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
$struct_vis struct $struct_name($type);
impl $struct_name {
$( #[allow(dead_code)] pub const $const: $struct_name = $struct_name($val); )+
const VALID__: $type = $( Self::$const.0 )|+ $(|$mask)?;
#[allow(dead_code)]
pub const fn valid_bits() -> Self {
Self(Self::VALID__)
}
#[allow(dead_code)]
pub const fn valid(val: $type) -> bool {
(val & !Self::VALID__) == 0
}
#[allow(dead_code)]
pub const fn any_set(self, mask: Self) -> bool {
(self.0 & mask.0) != 0
}
#[allow(dead_code)]
pub const fn all_set(self, mask: Self) -> bool {
(self.0 & mask.0) == mask.0
}
#[allow(dead_code)]
pub const fn none_set(self, mask: Self) -> bool {
(self.0 & mask.0) == 0
}
#[allow(dead_code)]
pub const fn val(self) -> $type {
self.0
}
#[allow(dead_code)]
pub const fn set(&mut self, rhs: Self) {
self.0 |= rhs.0;
}
#[allow(dead_code)]
pub const fn clear(&mut self, rhs: Self) {
self.0 &= !rhs.0;
}
#[allow(dead_code)]
pub const fn toggle(&mut self, rhs: Self) {
self.0 ^= rhs.0;
}
#[allow(dead_code)]
pub const fn intersection(self, rhs: Self) -> Self {
$struct_name(self.0 & rhs.0)
}
#[allow(dead_code)]
pub const fn difference(self, rhs: Self) -> Self {
$struct_name(self.0 & !rhs.0)
}
#[allow(dead_code)]
pub const fn symmetric_difference(self, rhs: Self) -> Self {
$struct_name(self.0 ^ rhs.0)
}
#[allow(dead_code)]
pub const fn union(self, rhs: Self) -> Self {
$struct_name(self.0 | rhs.0)
}
#[allow(dead_code)]
pub const fn invert(self) -> Self {
$struct_name(self.0 ^ Self::VALID__)
}
}
impl ::std::fmt::Binary for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
// If no width, use the highest valid bit
let width = f.width().unwrap_or((Self::VALID__.ilog2() + 1) as usize);
write!(f, "{:0>width$.precision$b}", self.0,
width = width,
precision = f.precision().unwrap_or(width))
}
}
impl ::std::fmt::LowerHex for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
<$type as ::std::fmt::LowerHex>::fmt(&self.0, f)
}
}
impl ::std::fmt::Octal for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
<$type as ::std::fmt::Octal>::fmt(&self.0, f)
}
}
impl ::std::fmt::UpperHex for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
<$type as ::std::fmt::UpperHex>::fmt(&self.0, f)
}
}
impl ::std::fmt::Debug for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{}({})", stringify!($struct_name), self)
}
}
impl ::std::fmt::Display for $struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
use ::std::fmt::Display;
let mut first = true;
let mut left = self.0;
$(if Self::$const.0.is_power_of_two() && (self & Self::$const).0 != 0 {
if first { first = false } else { Display::fmt(&'|', f)?; }
Display::fmt(stringify!($const), f)?;
left -= Self::$const.0;
})+
if first {
Display::fmt(&'0', f)
} else if left != 0 {
write!(f, "|{left:#x}")
} else {
Ok(())
}
}
}
impl ::std::cmp::PartialEq<$type> for $struct_name {
fn eq(&self, rhs: &$type) -> bool {
self.0 == *rhs
}
}
impl ::std::ops::BitAnd<$struct_name> for &$struct_name {
type Output = $struct_name;
fn bitand(self, rhs: $struct_name) -> Self::Output {
$struct_name(self.0 & rhs.0)
}
}
impl ::std::ops::BitXor<$struct_name> for &$struct_name {
type Output = $struct_name;
fn bitxor(self, rhs: $struct_name) -> Self::Output {
$struct_name(self.0 ^ rhs.0)
}
}
impl ::std::ops::BitOr<$struct_name> for &$struct_name {
type Output = $struct_name;
fn bitor(self, rhs: $struct_name) -> Self::Output {
$struct_name(self.0 | rhs.0)
}
}
impl ::std::ops::Sub<$struct_name> for &$struct_name {
type Output = $struct_name;
fn sub(self, rhs: $struct_name) -> Self::Output {
$struct_name(self.0 & !rhs.0)
}
}
impl ::std::ops::Not for &$struct_name {
type Output = $struct_name;
fn not(self) -> Self::Output {
$struct_name(self.0 ^ $struct_name::VALID__)
}
}
impl ::std::ops::BitAnd<$struct_name> for $struct_name {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
$struct_name(self.0 & rhs.0)
}
}
impl ::std::ops::BitXor<$struct_name> for $struct_name {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
$struct_name(self.0 ^ rhs.0)
}
}
impl ::std::ops::BitOr<$struct_name> for $struct_name {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
$struct_name(self.0 | rhs.0)
}
}
impl ::std::ops::Sub<$struct_name> for $struct_name {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
$struct_name(self.0 & !rhs.0)
}
}
impl ::std::ops::Not for $struct_name {
type Output = Self;
fn not(self) -> Self::Output {
$struct_name(self.0 ^ Self::VALID__)
}
}
impl From<$struct_name> for $type {
fn from(x: $struct_name) -> $type {
x.0
}
}
impl From<$type> for $struct_name {
fn from(x: $type) -> Self {
$struct_name(x & Self::VALID__)
}
}
};
{ $type:ty: $expr:expr } => {
::macros::bits_const_internal! { $type @ ($expr) }
};
{ $type:ty as $int_type:ty: $expr:expr } => {
(::macros::bits_const_internal! { $type @ ($expr) }.0) as $int_type
};
}
bits! {
pub(crate) struct InterruptMask(u32) {
OE = 1 << 10,
BE = 1 << 9,
PE = 1 << 8,
FE = 1 << 7,
RT = 1 << 6,
TX = 1 << 5,
RX = 1 << 4,
DSR = 1 << 3,
DCD = 1 << 2,
CTS = 1 << 1,
RI = 1 << 0,
E = bits!(Self as u32: OE | BE | PE | FE),
MS = bits!(Self as u32: RI | DSR | DCD | CTS),
}
}
fn main() {
println!("{:#b}", InterruptMask::MS);
}
extern crate proc_macro;
use proc_macro::{
Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree, TokenTree as TT,
};
struct BitsConstInternal {
typ: TokenTree,
}
fn paren(ts: TokenStream) -> TokenTree {
TT::Group(Group::new(Delimiter::Parenthesis, ts))
}
fn ident(s: &'static str) -> TokenTree {
TT::Ident(Ident::new(s, Span::call_site()))
}
fn punct(ch: char) -> TokenTree {
TT::Punct(Punct::new(ch, Spacing::Alone))
}
impl BitsConstInternal {
fn parse_primary(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
) -> Result<Option<TokenTree>, String> {
let next = match tok {
TT::Group(ref g) => {
if g.delimiter() != Delimiter::Parenthesis && g.delimiter() != Delimiter::None {
return Err("expected parenthesis")?;
}
let mut stream = g.stream().into_iter();
let Some(first_tok) = stream.next() else {
return Err("expected operand, found ')'")?;
};
let mut output = TokenStream::new();
// start from the lowest precedence
let next = self.parse_or(first_tok, &mut stream, &mut output)?;
if let Some(tok) = next {
Err(format!("unexpected token {}", tok))?;
}
out.extend(Some(paren(output)));
it.next()
}
TT::Ident(_) => {
let mut output = TokenStream::new();
output.extend([
self.typ.clone(),
TT::Punct(Punct::new(':', Spacing::Joint)),
TT::Punct(Punct::new(':', Spacing::Joint)),
tok,
]);
out.extend(Some(paren(output)));
it.next()
}
TT::Punct(ref p) => {
if p.as_char() != '!' {
return Err("expected operand")?;
}
let Some(rhs_tok) = it.next() else {
return Err("expected operand at end of input")?;
};
let next = self.parse_primary(rhs_tok, it, out)?;
out.extend([punct('.'), ident("invert"), paren(TokenStream::new())]);
next
}
_ => Err("unexpected literal")?,
};
Ok(next)
}
fn parse_binop<
F: Fn(
&Self,
TokenTree,
&mut dyn Iterator<Item = TokenTree>,
&mut TokenStream,
) -> Result<Option<TokenTree>, String>,
>(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
ch: char,
f: F,
method: &'static str,
) -> Result<Option<TokenTree>, String> {
let mut next = f(self, tok, it, out)?;
while next.is_some() {
let op = next.as_ref().unwrap();
let TT::Punct(ref p) = op else { break };
if p.as_char() != ch {
break;
}
let Some(rhs_tok) = it.next() else {
return Err("expected operand at end of input")?;
};
let mut rhs = TokenStream::new();
next = f(self, rhs_tok, it, &mut rhs)?;
out.extend([punct('.'), ident(method), paren(rhs)]);
}
Ok(next)
}
pub fn parse_sub(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
) -> Result<Option<TokenTree>, String> {
self.parse_binop(tok, it, out, '-', Self::parse_primary, "difference")
}
fn parse_and(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
) -> Result<Option<TokenTree>, String> {
self.parse_binop(tok, it, out, '&', Self::parse_sub, "intersection")
}
fn parse_xor(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
) -> Result<Option<TokenTree>, String> {
self.parse_binop(tok, it, out, '^', Self::parse_and, "symmetric_difference")
}
pub fn parse_or(
&self,
tok: TokenTree,
it: &mut dyn Iterator<Item = TokenTree>,
out: &mut TokenStream,
) -> Result<Option<TokenTree>, String> {
self.parse_binop(tok, it, out, '|', Self::parse_xor, "union")
}
fn parse(it: &mut dyn Iterator<Item = TokenTree>) -> Result<TokenStream, String> {
let mut typ = TokenStream::new();
let next = loop {
match it.next() {
None => break None,
Some(TT::Punct(ref p)) if p.as_char() == '@' => break it.next(),
Some(x) => typ.extend(Some(x)),
}
};
let Some(tok) = next else {
Err("expected expression, do not call this macro directly")?
};
let TT::Group(ref _group) = tok else {
Err("expected parenthesis, do not call this macro directly")?
};
let mut out = TokenStream::new();
let state = Self {
typ: TT::Group(Group::new(Delimiter::None, typ)),
};
let next = state.parse_primary(tok, it, &mut out)?;
if let Some(tok) = next {
Err(format!("unexpected token {}", tok))?;
}
Ok(out)
}
}
#[proc_macro]
pub fn bits_const_internal(ts: TokenStream) -> TokenStream {
let mut it = ts.into_iter();
BitsConstInternal::parse(&mut it).unwrap()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment