feat(rlp): support deriving optional fields (#1321)

This commit is contained in:
Roman Krasiuk
2023-02-14 00:25:50 +02:00
committed by GitHub
parent 5997103078
commit 2c0557d991
5 changed files with 181 additions and 104 deletions

View File

@ -1,17 +1,35 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, Result};
use crate::utils::has_attribute;
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct};
pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> TokenStream {
let body = if let syn::Data::Struct(s) = &ast.data {
s
} else {
panic!("#[derive(RlpDecodable)] is only defined for structs.");
};
pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpDecodable")?;
let fields = body.fields.iter().enumerate();
let supports_trailing_opt = attributes_include(&ast.attrs, "trailing");
let mut encountered_opt_item = false;
let mut stmts = Vec::with_capacity(body.fields.len());
for (i, field) in fields {
let is_opt = is_optional(field);
if is_opt {
if !supports_trailing_opt {
return Err(Error::new_spanned(field, "Optional fields are disabled. Add `#[rlp(trailing)]` attribute to the struct in order to enable"))
}
encountered_opt_item = true;
} else if encountered_opt_item && !attributes_include(&field.attrs, "default") {
return Err(Error::new_spanned(
field,
"All subsequent fields must be either optional or default.",
))
}
stmts.push(decodable_field(i, field, is_opt));
}
let stmts: Vec<_> =
body.fields.iter().enumerate().map(|(i, field)| decodable_field(i, field)).collect();
let name = &ast.ident;
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
@ -45,20 +63,16 @@ pub(crate) fn impl_decodable(ast: &syn::DeriveInput) -> TokenStream {
}
};
quote! {
Ok(quote! {
const _: () = {
extern crate reth_rlp;
#impl_block
};
}
})
}
pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> TokenStream {
let body = if let syn::Data::Struct(s) = &ast.data {
s
} else {
panic!("#[derive(RlpEncodableWrapper)] is only defined for structs.");
};
pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpEncodableWrapper")?;
assert_eq!(
body.fields.iter().count(),
@ -77,25 +91,28 @@ pub(crate) fn impl_decodable_wrapper(ast: &syn::DeriveInput) -> TokenStream {
}
};
quote! {
Ok(quote! {
const _: () = {
extern crate reth_rlp;
#impl_block
};
}
})
}
fn decodable_field(index: usize, field: &syn::Field) -> TokenStream {
let id = if let Some(ident) = &field.ident {
quote! { #ident }
} else {
let index = syn::Index::from(index);
quote! { #index }
};
fn decodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream {
let ident = field_ident(index, field);
if has_attribute(field, "default") {
quote! { #id: Default::default(), }
if attributes_include(&field.attrs, "default") {
quote! { #ident: Default::default(), }
} else if is_opt {
quote! {
#ident: if started_len - b.len() < rlp_head.payload_length {
Some(reth_rlp::Decodable::decode(b)?)
} else {
None
},
}
} else {
quote! { #id: reth_rlp::Decodable::decode(b)?, }
quote! { #ident: reth_rlp::Decodable::decode(b)?, }
}
}

View File

@ -1,22 +1,37 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, Result};
use crate::utils::has_attribute;
use crate::utils::{attributes_include, field_ident, is_optional, parse_struct};
pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream {
let body = if let syn::Data::Struct(s) = &ast.data {
s
} else {
panic!("#[derive(RlpEncodable)] is only defined for structs.");
};
pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpEncodable")?;
let (length_stmts, stmts): (Vec<_>, Vec<_>) = body
let fields = body
.fields
.iter()
.enumerate()
.filter(|(_, field)| !has_attribute(field, "skip"))
.map(|(i, field)| (encodable_length(i, field), encodable_field(i, field)))
.unzip();
.filter(|(_, field)| !attributes_include(&field.attrs, "skip"));
let supports_trailing_opt = attributes_include(&ast.attrs, "trailing");
let mut encountered_opt_item = false;
let mut length_stmts = Vec::with_capacity(body.fields.len());
let mut stmts = Vec::with_capacity(body.fields.len());
for (i, field) in fields {
let is_opt = is_optional(field);
if is_opt {
if !supports_trailing_opt {
return Err(Error::new_spanned(field, "Optional fields are disabled. Add `#[rlp(trailing)]` attribute to the struct in order to enable"))
}
encountered_opt_item = true;
} else if encountered_opt_item {
return Err(Error::new_spanned(field, "All subsequent fields must be optional."))
}
length_stmts.push(encodable_length(i, field, is_opt));
stmts.push(encodable_field(i, field, is_opt));
}
let name = &ast.ident;
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
@ -46,20 +61,16 @@ pub(crate) fn impl_encodable(ast: &syn::DeriveInput) -> TokenStream {
}
};
quote! {
Ok(quote! {
const _: () = {
extern crate reth_rlp;
#impl_block
};
}
})
}
pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> TokenStream {
let body = if let syn::Data::Struct(s) = &ast.data {
s
} else {
panic!("#[derive(RlpEncodableWrapper)] is only defined for structs.");
};
pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpEncodableWrapper")?;
let ident = {
let fields: Vec<_> = body.fields.iter().collect();
@ -85,26 +96,22 @@ pub(crate) fn impl_encodable_wrapper(ast: &syn::DeriveInput) -> TokenStream {
}
};
quote! {
Ok(quote! {
const _: () = {
extern crate reth_rlp;
#impl_block
};
}
})
}
pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> TokenStream {
let body = if let syn::Data::Struct(s) = &ast.data {
s
} else {
panic!("#[derive(RlpMaxEncodedLen)] is only defined for structs.");
};
pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> Result<TokenStream> {
let body = parse_struct(ast, "RlpMaxEncodedLen")?;
let stmts: Vec<_> = body
.fields
.iter()
.enumerate()
.filter(|(_, field)| !has_attribute(field, "skip"))
.filter(|(_, field)| !attributes_include(&field.attrs, "skip"))
.map(|(index, field)| encodable_max_length(index, field))
.collect();
let name = &ast.ident;
@ -116,27 +123,22 @@ pub(crate) fn impl_max_encoded_len(ast: &syn::DeriveInput) -> TokenStream {
}
};
quote! {
Ok(quote! {
const _: () = {
extern crate reth_rlp;
#impl_block
};
}
})
}
fn field_ident(index: usize, field: &syn::Field) -> TokenStream {
if let Some(ident) = &field.ident {
quote! { #ident }
} else {
let index = syn::Index::from(index);
quote! { #index }
}
}
fn encodable_length(index: usize, field: &syn::Field) -> TokenStream {
fn encodable_length(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream {
let ident = field_ident(index, field);
quote! { rlp_head.payload_length += reth_rlp::Encodable::length(&self.#ident); }
if is_opt {
quote! { rlp_head.payload_length += &self.#ident.as_ref().map(|val| reth_rlp::Encodable::length(val)).unwrap_or_default(); }
} else {
quote! { rlp_head.payload_length += reth_rlp::Encodable::length(&self.#ident); }
}
}
fn encodable_max_length(index: usize, field: &syn::Field) -> TokenStream {
@ -149,10 +151,12 @@ fn encodable_max_length(index: usize, field: &syn::Field) -> TokenStream {
}
}
fn encodable_field(index: usize, field: &syn::Field) -> TokenStream {
fn encodable_field(index: usize, field: &syn::Field, is_opt: bool) -> TokenStream {
let ident = field_ident(index, field);
let id = quote! { self.#ident };
quote! { reth_rlp::Encodable::encode(&#id, out); }
if is_opt {
quote! { self.#ident.as_ref().map(|val| reth_rlp::Encodable::encode(val, out)); }
} else {
quote! { reth_rlp::Encodable::encode(&self.#ident, out); }
}
}

View File

@ -28,34 +28,28 @@ use proc_macro::TokenStream;
/// Derives `Encodable` for the type which encodes the all fields as list: `<rlp-header, fields...>`
#[proc_macro_derive(RlpEncodable, attributes(rlp))]
pub fn encodable(input: TokenStream) -> TokenStream {
let ast = match syn::parse(input) {
Ok(ast) => ast,
Err(err) => return err.to_compile_error().into(),
};
let gen = impl_encodable(&ast);
gen.into()
syn::parse(input)
.and_then(|ast| impl_encodable(&ast))
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
/// Derives `Encodable` for the type which encodes the fields as-is, without a header: `<fields...>`
#[proc_macro_derive(RlpEncodableWrapper, attributes(rlp))]
pub fn encodable_wrapper(input: TokenStream) -> TokenStream {
let ast = match syn::parse(input) {
Ok(ast) => ast,
Err(err) => return err.to_compile_error().into(),
};
let gen = impl_encodable_wrapper(&ast);
gen.into()
syn::parse(input)
.and_then(|ast| impl_encodable_wrapper(&ast))
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
/// Derives `MaxEncodedLen` for types of constant size.
#[proc_macro_derive(RlpMaxEncodedLen, attributes(rlp))]
pub fn max_encoded_len(input: TokenStream) -> TokenStream {
let ast = match syn::parse(input) {
Ok(ast) => ast,
Err(err) => return err.to_compile_error().into(),
};
let gen = impl_max_encoded_len(&ast);
gen.into()
syn::parse(input)
.and_then(|ast| impl_max_encoded_len(&ast))
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
/// Derives `Decodable` for the type whose implementation expects an rlp-list input: `<rlp-header,
@ -64,12 +58,10 @@ pub fn max_encoded_len(input: TokenStream) -> TokenStream {
/// This is the inverse of `RlpEncodable`.
#[proc_macro_derive(RlpDecodable, attributes(rlp))]
pub fn decodable(input: TokenStream) -> TokenStream {
let ast = match syn::parse(input) {
Ok(ast) => ast,
Err(err) => return err.to_compile_error().into(),
};
let gen = impl_decodable(&ast);
gen.into()
syn::parse(input)
.and_then(|ast| impl_decodable(&ast))
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
/// Derives `Decodable` for the type whose implementation expects only the individual fields
@ -78,7 +70,8 @@ pub fn decodable(input: TokenStream) -> TokenStream {
/// This is the inverse of `RlpEncodableWrapper`.
#[proc_macro_derive(RlpDecodableWrapper, attributes(rlp))]
pub fn decodable_wrapper(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
let gen = impl_decodable_wrapper(&ast);
gen.into()
syn::parse(input)
.and_then(|ast| impl_decodable_wrapper(&ast))
.unwrap_or_else(|err| err.to_compile_error())
.into()
}

View File

@ -1,7 +1,23 @@
use syn::{Field, Meta, NestedMeta};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, DataStruct, Error, Field, Meta, NestedMeta, Result, Type, TypePath};
pub(crate) fn has_attribute(field: &Field, attr_name: &str) -> bool {
field.attrs.iter().any(|attr| {
pub(crate) fn parse_struct<'a>(
ast: &'a syn::DeriveInput,
derive_attr: &str,
) -> Result<&'a DataStruct> {
if let syn::Data::Struct(s) = &ast.data {
Ok(s)
} else {
Err(Error::new_spanned(
ast,
format!("#[derive({derive_attr})] is only defined for structs."),
))
}
}
pub(crate) fn attributes_include(attrs: &[Attribute], attr_name: &str) -> bool {
attrs.iter().any(|attr| {
if attr.path.is_ident("rlp") {
if let Ok(Meta::List(meta)) = attr.parse_meta() {
if let Some(NestedMeta::Meta(meta)) = meta.nested.first() {
@ -14,3 +30,23 @@ pub(crate) fn has_attribute(field: &Field, attr_name: &str) -> bool {
false
})
}
pub(crate) fn is_optional(field: &Field) -> bool {
if let Type::Path(TypePath { qself, path }) = &field.ty {
qself.is_none() &&
path.leading_colon.is_none() &&
path.segments.len() == 1 &&
path.segments.first().unwrap().ident == "Option"
} else {
false
}
}
pub(crate) fn field_ident(index: usize, field: &syn::Field) -> TokenStream {
if let Some(ident) = &field.ident {
quote! { #ident }
} else {
let index = syn::Index::from(index);
quote! { #index }
}
}