diff --git a/crates/storage/codecs/derive/src/arbitrary.rs b/crates/storage/codecs/derive/src/arbitrary.rs index 8176d5d52..4feae63c4 100644 --- a/crates/storage/codecs/derive/src/arbitrary.rs +++ b/crates/storage/codecs/derive/src/arbitrary.rs @@ -1,16 +1,17 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; -use quote::{format_ident, quote}; -use syn::DeriveInput; +use proc_macro2::{Ident, TokenStream as TokenStream2}; +use quote::{quote, ToTokens}; /// If `compact` or `rlp` is passed to `derive_arbitrary`, this function will generate the /// corresponding proptest roundtrip tests. /// /// It accepts an optional integer number for the number of proptest cases. Otherwise, it will set /// it at 1000. -pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream2 { - let type_ident = ast.ident.clone(); - +pub fn maybe_generate_tests( + args: TokenStream, + type_ident: &impl ToTokens, + mod_tests: &Ident, +) -> TokenStream2 { // Same as proptest let mut default_cases = 256; @@ -25,7 +26,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream { let mut buf = vec![]; let len = field.clone().to_compact(&mut buf); - let (decoded, _) = super::#type_ident::from_compact(&buf, len); + let (decoded, _): (super::#type_ident, _) = Compact::from_compact(&buf, len); assert!(field == decoded, "maybe_generate_tests::compact"); } }); @@ -36,7 +37,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream let mut buf = vec![]; let len = field.encode(&mut buf); let mut b = &mut buf.as_slice(); - let decoded = super::#type_ident::decode(b).unwrap(); + let decoded: super::#type_ident = Decodable::decode(b).unwrap(); assert_eq!(field, decoded, "maybe_generate_tests::rlp"); // ensure buffer is fully consumed by decode assert!(b.is_empty(), "buffer was not consumed entirely"); @@ -53,7 +54,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream let mut raw = [0u8; 1024]; rand::thread_rng().fill_bytes(&mut raw); let mut unstructured = arbitrary::Unstructured::new(&raw[..]); - let val = ::arbitrary(&mut unstructured); + let val: Result = arbitrary::Arbitrary::arbitrary(&mut unstructured); if val.is_err() { // this can be flaky sometimes due to not enough data for iterator based types like Vec return @@ -69,7 +70,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream let mut b = Vec::with_capacity(decode_buf.len()); header.encode(&mut b); b.extend_from_slice(decode_buf); - let res = super::#type_ident::decode(&mut b.as_ref()); + let res: Result = Decodable::decode(&mut b.as_ref()); assert!(res.is_err(), "malformed header was decoded"); } }); @@ -80,8 +81,6 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream let mut tests = TokenStream2::default(); if !roundtrips.is_empty() { - let mod_tests = format_ident!("{}Tests", ast.ident); - tests = quote! { #[allow(non_snake_case)] #[cfg(test)] diff --git a/crates/storage/codecs/derive/src/lib.rs b/crates/storage/codecs/derive/src/lib.rs index 71af6619c..edff471f2 100644 --- a/crates/storage/codecs/derive/src/lib.rs +++ b/crates/storage/codecs/derive/src/lib.rs @@ -11,7 +11,11 @@ use proc_macro::{TokenStream, TokenTree}; use quote::{format_ident, quote}; -use syn::{parse_macro_input, DeriveInput}; +use syn::{ + bracketed, + parse::{Parse, ParseStream}, + parse_macro_input, DeriveInput, Result, Token, +}; mod arbitrary; mod compact; @@ -85,7 +89,8 @@ pub fn reth_codec(args: TokenStream, input: TokenStream) -> TokenStream { pub fn derive_arbitrary(args: TokenStream, input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); - let tests = arbitrary::maybe_generate_tests(args, &ast); + let tests = + arbitrary::maybe_generate_tests(args, &ast.ident, &format_ident!("{}Tests", ast.ident)); // Avoid duplicate names let arb_import = format_ident!("{}Arbitrary", ast.ident); @@ -106,10 +111,51 @@ pub fn derive_arbitrary(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn add_arbitrary_tests(args: TokenStream, input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); - let tests = arbitrary::maybe_generate_tests(args, &ast); + + let tests = + arbitrary::maybe_generate_tests(args, &ast.ident, &format_ident!("{}Tests", ast.ident)); quote! { #ast #tests } .into() } + +struct GenerateTestsInput { + args: TokenStream, + ty: syn::Type, + mod_name: syn::Ident, +} + +impl Parse for GenerateTestsInput { + fn parse(input: ParseStream<'_>) -> Result { + input.parse::()?; + + let args; + bracketed!(args in input); + + let args = args.parse::()?; + let ty = input.parse()?; + + input.parse::()?; + let mod_name = input.parse()?; + + Ok(Self { args: args.into(), ty, mod_name }) + } +} + +/// Generates tests for given type based on passed parameters. +/// +/// See `arbitrary::maybe_generate_tests` for more information. +/// +/// Examples: +/// * `generate_tests!(#[rlp] MyType, MyTypeTests)`: will generate rlp roundtrip tests for `MyType` +/// in a module named `MyTypeTests`. +/// * `generate_tests!(#[compact, 10] MyType, MyTypeTests)`: will generate compact roundtrip tests +/// for `MyType` limited to 10 cases. +#[proc_macro] +pub fn generate_tests(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as GenerateTestsInput); + + arbitrary::maybe_generate_tests(input.args, &input.ty, &input.mod_name).into() +}