chore: improve CompactZstd macro (#13277)

This commit is contained in:
Arsenii Kulikov
2024-12-11 15:58:12 +04:00
committed by GitHub
parent f2141925b0
commit 394f973acd
7 changed files with 98 additions and 52 deletions

1
Cargo.lock generated
View File

@ -8872,7 +8872,6 @@ dependencies = [
"alloy-primitives", "alloy-primitives",
"arbitrary", "arbitrary",
"assert_matches", "assert_matches",
"bytes",
"derive_more", "derive_more",
"modular-bitfield", "modular-bitfield",
"proptest", "proptest",

View File

@ -13,8 +13,6 @@ use reth_primitives_traits::receipt::ReceiptExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::TxType; use crate::TxType;
#[cfg(feature = "reth-codec")]
use reth_zstd_compressors::{RECEIPT_COMPRESSOR, RECEIPT_DECOMPRESSOR};
/// Retrieves gas spent by transactions as a vector of tuples (transaction index, gas used). /// Retrieves gas spent by transactions as a vector of tuples (transaction index, gas used).
pub use reth_primitives_traits::receipt::gas_spent_by_transactions; pub use reth_primitives_traits::receipt::gas_spent_by_transactions;
@ -25,6 +23,10 @@ pub use reth_primitives_traits::receipt::gas_spent_by_transactions;
)] )]
#[cfg_attr(any(test, feature = "reth-codec"), derive(reth_codecs::CompactZstd))] #[cfg_attr(any(test, feature = "reth-codec"), derive(reth_codecs::CompactZstd))]
#[cfg_attr(any(test, feature = "reth-codec"), reth_codecs::add_arbitrary_tests)] #[cfg_attr(any(test, feature = "reth-codec"), reth_codecs::add_arbitrary_tests)]
#[cfg_attr(any(test, feature = "reth-codec"), reth_zstd(
compressor = reth_zstd_compressors::RECEIPT_COMPRESSOR,
decompressor = reth_zstd_compressors::RECEIPT_DECOMPRESSOR
))]
#[rlp(trailing)] #[rlp(trailing)]
pub struct Receipt { pub struct Receipt {
/// Receipt type. /// Receipt type.

View File

@ -15,7 +15,6 @@ workspace = true
reth-codecs.workspace = true reth-codecs.workspace = true
alloy-primitives.workspace = true alloy-primitives.workspace = true
bytes.workspace = true
derive_more.workspace = true derive_more.workspace = true
modular-bitfield.workspace = true modular-bitfield.workspace = true
serde.workspace = true serde.workspace = true

View File

@ -1,6 +1,7 @@
//! Code generator for the `Compact` trait. //! Code generator for the `Compact` trait.
use super::*; use super::*;
use crate::ZstdConfig;
use convert_case::{Case, Casing}; use convert_case::{Case, Casing};
use syn::{Attribute, LitStr}; use syn::{Attribute, LitStr};
@ -10,20 +11,20 @@ pub fn generate_from_to(
attrs: &[Attribute], attrs: &[Attribute],
has_lifetime: bool, has_lifetime: bool,
fields: &FieldList, fields: &FieldList,
is_zstd: bool, zstd: Option<ZstdConfig>,
) -> TokenStream2 { ) -> TokenStream2 {
let flags = format_ident!("{ident}Flags"); let flags = format_ident!("{ident}Flags");
let to_compact = generate_to_compact(fields, ident, is_zstd); let reth_codecs = parse_reth_codecs_path(attrs).unwrap();
let from_compact = generate_from_compact(fields, ident, is_zstd);
let to_compact = generate_to_compact(fields, ident, zstd.clone(), &reth_codecs);
let from_compact = generate_from_compact(fields, ident, zstd);
let snake_case_ident = ident.to_string().to_case(Case::Snake); let snake_case_ident = ident.to_string().to_case(Case::Snake);
let fuzz = format_ident!("fuzz_test_{snake_case_ident}"); let fuzz = format_ident!("fuzz_test_{snake_case_ident}");
let test = format_ident!("fuzz_{snake_case_ident}"); let test = format_ident!("fuzz_{snake_case_ident}");
let reth_codecs = parse_reth_codecs_path(attrs).unwrap();
let lifetime = if has_lifetime { let lifetime = if has_lifetime {
quote! { 'a } quote! { 'a }
} else { } else {
@ -77,7 +78,7 @@ pub fn generate_from_to(
#fuzz_tests #fuzz_tests
#impl_compact { #impl_compact {
fn to_compact<B>(&self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]> { fn to_compact<B>(&self, buf: &mut B) -> usize where B: #reth_codecs::__private::bytes::BufMut + AsMut<[u8]> {
let mut flags = #flags::default(); let mut flags = #flags::default();
let mut total_length = 0; let mut total_length = 0;
#(#to_compact)* #(#to_compact)*
@ -92,7 +93,11 @@ pub fn generate_from_to(
} }
/// Generates code to implement the `Compact` trait method `to_compact`. /// Generates code to implement the `Compact` trait method `to_compact`.
fn generate_from_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> TokenStream2 { fn generate_from_compact(
fields: &FieldList,
ident: &Ident,
zstd: Option<ZstdConfig>,
) -> TokenStream2 {
let mut lines = vec![]; let mut lines = vec![];
let mut known_types = let mut known_types =
vec!["B256", "Address", "Bloom", "Vec", "TxHash", "BlockHash", "FixedBytes"]; vec!["B256", "Address", "Bloom", "Vec", "TxHash", "BlockHash", "FixedBytes"];
@ -147,38 +152,41 @@ fn generate_from_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> To
// If the type has compression support, then check the `__zstd` flag. Otherwise, use the default // If the type has compression support, then check the `__zstd` flag. Otherwise, use the default
// code branch. However, even if it's a type with compression support, not all values are // code branch. However, even if it's a type with compression support, not all values are
// to be compressed (thus the zstd flag). Ideally only the bigger ones. // to be compressed (thus the zstd flag). Ideally only the bigger ones.
is_zstd if let Some(zstd) = zstd {
.then(|| { let decompressor = zstd.decompressor;
let decompressor = format_ident!("{}_DECOMPRESSOR", ident.to_string().to_uppercase()); quote! {
quote! { if flags.__zstd() != 0 {
if flags.__zstd() != 0 { #decompressor.with(|decompressor| {
#decompressor.with(|decompressor| { let decompressor = &mut decompressor.borrow_mut();
let decompressor = &mut decompressor.borrow_mut(); let decompressed = decompressor.decompress(buf);
let decompressed = decompressor.decompress(buf); let mut original_buf = buf;
let mut original_buf = buf;
let mut buf: &[u8] = decompressed; let mut buf: &[u8] = decompressed;
#(#lines)*
(obj, original_buf)
})
} else {
#(#lines)* #(#lines)*
(obj, buf) (obj, original_buf)
} })
} } else {
})
.unwrap_or_else(|| {
quote! {
#(#lines)* #(#lines)*
(obj, buf) (obj, buf)
} }
}) }
} else {
quote! {
#(#lines)*
(obj, buf)
}
}
} }
/// Generates code to implement the `Compact` trait method `from_compact`. /// Generates code to implement the `Compact` trait method `from_compact`.
fn generate_to_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> Vec<TokenStream2> { fn generate_to_compact(
fields: &FieldList,
ident: &Ident,
zstd: Option<ZstdConfig>,
reth_codecs: &syn::Path,
) -> Vec<TokenStream2> {
let mut lines = vec![quote! { let mut lines = vec![quote! {
let mut buffer = bytes::BytesMut::new(); let mut buffer = #reth_codecs::__private::bytes::BytesMut::new();
}]; }];
let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_))); let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_)));
@ -198,7 +206,7 @@ fn generate_to_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> Vec<
// Just because a type supports compression, doesn't mean all its values are to be compressed. // Just because a type supports compression, doesn't mean all its values are to be compressed.
// We skip the smaller ones, and thus require a flag` __zstd` to specify if this value is // We skip the smaller ones, and thus require a flag` __zstd` to specify if this value is
// compressed or not. // compressed or not.
if is_zstd { if zstd.is_some() {
lines.push(quote! { lines.push(quote! {
let mut zstd = buffer.len() > 7; let mut zstd = buffer.len() > 7;
if zstd { if zstd {
@ -214,9 +222,8 @@ fn generate_to_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> Vec<
buf.put_slice(&flags); buf.put_slice(&flags);
}); });
if is_zstd { if let Some(zstd) = zstd {
let compressor = format_ident!("{}_COMPRESSOR", ident.to_string().to_uppercase()); let compressor = zstd.compressor;
lines.push(quote! { lines.push(quote! {
if zstd { if zstd {
#compressor.with(|compressor| { #compressor.with(|compressor| {

View File

@ -1,7 +1,7 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2}; use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Generics}; use syn::{Data, DeriveInput, Generics};
mod generator; mod generator;
use generator::*; use generator::*;
@ -15,6 +15,8 @@ use flags::*;
mod structs; mod structs;
use structs::*; use structs::*;
use crate::ZstdConfig;
// Helper Alias type // Helper Alias type
type IsCompact = bool; type IsCompact = bool;
// Helper Alias type // Helper Alias type
@ -40,16 +42,16 @@ pub enum FieldTypes {
} }
/// Derives the `Compact` trait and its from/to implementations. /// Derives the `Compact` trait and its from/to implementations.
pub fn derive(input: TokenStream, is_zstd: bool) -> TokenStream { pub fn derive(input: DeriveInput, zstd: Option<ZstdConfig>) -> TokenStream {
let mut output = quote! {}; let mut output = quote! {};
let DeriveInput { ident, data, generics, attrs, .. } = parse_macro_input!(input); let DeriveInput { ident, data, generics, attrs, .. } = input;
let has_lifetime = has_lifetime(&generics); let has_lifetime = has_lifetime(&generics);
let fields = get_fields(&data); let fields = get_fields(&data);
output.extend(generate_flag_struct(&ident, &attrs, has_lifetime, &fields, is_zstd)); output.extend(generate_flag_struct(&ident, &attrs, has_lifetime, &fields, zstd.is_some()));
output.extend(generate_from_to(&ident, &attrs, has_lifetime, &fields, is_zstd)); output.extend(generate_from_to(&ident, &attrs, has_lifetime, &fields, zstd));
output.into() output.into()
} }
@ -236,7 +238,7 @@ mod tests {
let DeriveInput { ident, data, attrs, .. } = parse2(f_struct).unwrap(); let DeriveInput { ident, data, attrs, .. } = parse2(f_struct).unwrap();
let fields = get_fields(&data); let fields = get_fields(&data);
output.extend(generate_flag_struct(&ident, &attrs, false, &fields, false)); output.extend(generate_flag_struct(&ident, &attrs, false, &fields, false));
output.extend(generate_from_to(&ident, &attrs, false, &fields, false)); output.extend(generate_from_to(&ident, &attrs, false, &fields, None));
// Expected output in a TokenStream format. Commas matter! // Expected output in a TokenStream format. Commas matter!
let should_output = quote! { let should_output = quote! {
@ -298,10 +300,10 @@ mod tests {
fuzz_test_test_struct(TestStruct::default()) fuzz_test_test_struct(TestStruct::default())
} }
impl reth_codecs::Compact for TestStruct { impl reth_codecs::Compact for TestStruct {
fn to_compact<B>(&self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]> { fn to_compact<B>(&self, buf: &mut B) -> usize where B: reth_codecs::__private::bytes::BufMut + AsMut<[u8]> {
let mut flags = TestStructFlags::default(); let mut flags = TestStructFlags::default();
let mut total_length = 0; let mut total_length = 0;
let mut buffer = bytes::BytesMut::new(); let mut buffer = reth_codecs::__private::bytes::BytesMut::new();
let f_u64_len = self.f_u64.to_compact(&mut buffer); let f_u64_len = self.f_u64.to_compact(&mut buffer);
flags.set_f_u64_len(f_u64_len as u8); flags.set_f_u64_len(f_u64_len as u8);
let f_u256_len = self.f_u256.to_compact(&mut buffer); let f_u256_len = self.f_u256.to_compact(&mut buffer);

View File

@ -20,6 +20,12 @@ use syn::{
mod arbitrary; mod arbitrary;
mod compact; mod compact;
#[derive(Clone)]
pub(crate) struct ZstdConfig {
compressor: syn::Path,
decompressor: syn::Path,
}
/// Derives the `Compact` trait for custom structs, optimizing serialization with a possible /// Derives the `Compact` trait for custom structs, optimizing serialization with a possible
/// bitflag struct. /// bitflag struct.
/// ///
@ -51,15 +57,46 @@ mod compact;
/// efficient decoding. /// efficient decoding.
#[proc_macro_derive(Compact, attributes(maybe_zero, reth_codecs))] #[proc_macro_derive(Compact, attributes(maybe_zero, reth_codecs))]
pub fn derive(input: TokenStream) -> TokenStream { pub fn derive(input: TokenStream) -> TokenStream {
let is_zstd = false; compact::derive(parse_macro_input!(input as DeriveInput), None)
compact::derive(input, is_zstd)
} }
/// Adds `zstd` compression to derived [`Compact`]. /// Adds `zstd` compression to derived [`Compact`].
#[proc_macro_derive(CompactZstd, attributes(maybe_zero, reth_codecs))] #[proc_macro_derive(CompactZstd, attributes(maybe_zero, reth_codecs, reth_zstd))]
pub fn derive_zstd(input: TokenStream) -> TokenStream { pub fn derive_zstd(input: TokenStream) -> TokenStream {
let is_zstd = true; let input = parse_macro_input!(input as DeriveInput);
compact::derive(input, is_zstd)
let mut compressor = None;
let mut decompressor = None;
for attr in &input.attrs {
if attr.path().is_ident("reth_zstd") {
if let Err(err) = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("compressor") {
let value = meta.value()?;
let path: syn::Path = value.parse()?;
compressor = Some(path);
} else if meta.path.is_ident("decompressor") {
let value = meta.value()?;
let path: syn::Path = value.parse()?;
decompressor = Some(path);
} else {
return Err(meta.error("unsupported attribute"))
}
Ok(())
}) {
return err.to_compile_error().into()
}
}
}
let (Some(compressor), Some(decompressor)) = (compressor, decompressor) else {
return quote! {
compile_error!("missing compressor or decompressor attribute");
}
.into()
};
compact::derive(input, Some(ZstdConfig { compressor, decompressor }))
} }
/// Generates tests for given type. /// Generates tests for given type.

View File

@ -1,3 +1,3 @@
pub use modular_bitfield; pub use modular_bitfield;
pub use bytes::Buf; pub use bytes::{self, Buf};