Skip to content

Instantly share code, notes, and snippets.

@mohe2015
Last active July 29, 2023 20:45
Show Gist options
  • Save mohe2015/84725cc7b60040afe2d44fe03a9d5eef to your computer and use it in GitHub Desktop.
Save mohe2015/84725cc7b60040afe2d44fe03a9d5eef to your computer and use it in GitHub Desktop.
doctests proc macro
[package]
name = "doctests_proc_macro"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
proc-macro = true
[dependencies]
syn = { version = "2", features = [ "full", "extra-traits" ] }
quote = "1"
proc-macro2 = "1"
prettyplease = "0.2"
itertools = "0.11"
extern crate self as my_crate;
use doctests_proc_macro::doctests;
pub use outer::inner::pauli_x;
#[doctests]
mod outer {
#[cfg(test)]
mod tests {
#[test]
pub fn test_pauli_x() {
use matrixcompare::assert_matrix_eq;
use nalgebra::Vector4;
use qukit::complex::r;
use qukit::gates::controlled::controlled;
use qukit::gates::pauli_x::pauli_x;
use qukit::state_vector::StateVector;
let mut state_vector = StateVector::new(2);
let expected = Vector4::new(r(1.), r(0.), r(0.), r(0.));
assert_matrix_eq!(expected, *state_vector, comp = exact);
state_vector.apply_gate(&controlled(&pauli_x()));
let expected = Vector4::new(r(1.), r(0.), r(0.), r(0.));
assert_matrix_eq!(expected, *state_vector, comp = exact);
}
}
pub mod inner {
use nalgebra::Matrix2;
use nalgebra_sparse::CsrMatrix;
use num_complex::Complex;
use qukit::complex::r;
use crate::value::Value;
/// quantum equivalent of NOT Gate
///
/// # Examples
pub fn pauli_x<V: Value>() -> CsrMatrix<Complex<V>> {
CsrMatrix::from(&Matrix2::new(
r(V::zero()),
r(V::one()),
r(V::one()),
r(V::zero()),
))
}
}
}
use core::panic;
use itertools::{Either, Itertools};
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{
spanned::Spanned, Attribute, Expr, ExprLit, File, Item, ItemFn, ItemMod, Lit, LitStr,
MetaNameValue, Path, Signature,
};
#[proc_macro_attribute]
pub fn doctests(
_attr: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = proc_macro2::TokenStream::from(input);
let mut outer_mod_item = match syn::parse2::<ItemMod>(input.clone()) {
Ok(syntax_tree) => syntax_tree,
Err(err) => {
let err = err.to_compile_error();
return proc_macro::TokenStream::from(quote! {
#input
#err
});
}
};
let items = &mut outer_mod_item.content.as_mut().unwrap().1;
// TODO FIXME unwraps
let (doctests_mod, function_mod): (&mut Item, &mut Item) =
items.iter_mut().collect_tuple().unwrap();
let Item::Mod(doctests_mod) = doctests_mod else {
let err = syn::Error::new(doctests_mod.span(), "expected mod").to_compile_error();
return proc_macro::TokenStream::from(quote! {
#input
#err
});
};
let Item::Mod(function_mod) = function_mod else {
let err: proc_macro2::TokenStream =
syn::Error::new(function_mod.span(), "expected mod").to_compile_error();
return proc_macro::TokenStream::from(quote! {
#input
#err
});
};
let doctest_items: &mut Vec<Item> = doctests_mod.content.as_mut().unwrap().1.as_mut();
let function_items: &mut Vec<Item> = function_mod.content.as_mut().unwrap().1.as_mut();
// TODO FIXME check the things before are only imports
let Some((function_item, _function_imports)) = function_items.split_last_mut() else {
panic!()
};
let (doctests, errs): (Vec<Attribute>, proc_macro2::TokenStream) = doctest_items
.iter()
.flat_map(|item| match item {
syn::Item::Fn(item_fn) => {
// inspired by https://github.com/immunant/c2rust/blob/master/c2rust-ast-printer/src/pprust.rs
let code = item_fn.block.stmts.clone();
let ident = syn::Ident::new("main", proc_macro2::Span::call_site());
let generics = syn::Generics {
lt_token: None,
params: Default::default(),
gt_token: None,
where_clause: None,
};
let fakefile = File {
shebang: None,
attrs: Vec::new(),
items: vec![Item::Fn(ItemFn {
attrs: Vec::new(),
vis: syn::Visibility::Inherited,
sig: Signature {
constness: None,
asyncness: None,
unsafety: None,
abi: None,
fn_token: Default::default(),
ident,
generics,
paren_token: Default::default(),
inputs: Default::default(),
variadic: None,
output: syn::ReturnType::Default,
},
block: Box::new(syn::Block {
brace_token: Default::default(),
stmts: code,
}),
})],
};
let unparse = prettyplease::unparse(&fakefile);
let unparse: String = unparse
.trim()
.trim_start_matches("fn main() {")
.trim_end_matches("}")
.trim()
.split_inclusive("\n")
.map(|line| line.trim_start_matches(" "))
.collect();
let attr1 = Attribute {
bracket_token: Default::default(),
style: syn::AttrStyle::Outer,
pound_token: Default::default(),
meta: syn::Meta::NameValue(MetaNameValue {
path: Path::from(Ident::new("doc", Span::call_site())),
eq_token: Default::default(),
value: Expr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Str(LitStr::new("```", Span::call_site())),
}),
}),
};
let attr2 = Attribute {
bracket_token: Default::default(),
style: syn::AttrStyle::Outer,
pound_token: Default::default(),
meta: syn::Meta::NameValue(MetaNameValue {
path: Path::from(Ident::new("doc", Span::call_site())),
eq_token: Default::default(),
value: Expr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Str(LitStr::new(&unparse, Span::call_site())),
}),
}),
};
let attr3 = Attribute {
bracket_token: Default::default(),
style: syn::AttrStyle::Outer,
pound_token: Default::default(),
meta: syn::Meta::NameValue(MetaNameValue {
path: Path::from(Ident::new("doc", Span::call_site())),
eq_token: Default::default(),
value: Expr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Str(LitStr::new("```", Span::call_site())),
}),
}),
};
vec![Ok(attr1), Ok(attr2), Ok(attr3)]
}
_ => {
let err = syn::Error::new(item.span(), "expected function").to_compile_error();
return vec![Err(quote! {
#err
})];
}
})
.partition_result();
let Item::Fn(item_fn) = function_item else {
let err = syn::Error::new(function_item.span(), "expected function").to_compile_error();
return proc_macro::TokenStream::from(quote! {
#input
#err
#errs
});
};
let attrs = item_fn.attrs.clone();
let mut found_examples = false;
let attrs: Vec<_> = attrs
.into_iter()
.flat_map(|attr| match &attr.meta.require_name_value() {
Ok(MetaNameValue {
path,
value:
Expr::Lit(ExprLit {
lit: Lit::Str(comment),
..
}),
..
}) if path.is_ident("doc") && comment.value().contains("# Examples") => {
found_examples = true;
Either::Left(std::iter::once(attr).chain(doctests.iter().cloned()))
}
_ => Either::Right(std::iter::once(attr)),
})
.collect();
if !found_examples {
let err = syn::Error::new(item_fn.span(), "expected a `# Examples` doc comment")
.to_compile_error();
return proc_macro::TokenStream::from(quote! {
#input
#err
#errs
});
}
item_fn.attrs = attrs;
let output = quote! {
#outer_mod_item
#errs
};
proc_macro::TokenStream::from(output)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment