diff --git a/Cargo.lock b/Cargo.lock index 38d1f8f61..256b13e19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3221,6 +3221,7 @@ dependencies = [ "hex-literal", "reth-rlp", "reth-rlp-derive", + "smol_str", ] [[package]] @@ -3846,6 +3847,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +[[package]] +name = "smol_str" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7475118a28b7e3a2e157ce0131ba8c5526ea96e90ee601d9f6bb2e286a35ab44" + [[package]] name = "socket2" version = "0.4.7" diff --git a/crates/common/rlp/Cargo.toml b/crates/common/rlp/Cargo.toml index 5f3d46869..a14282d28 100644 --- a/crates/common/rlp/Cargo.toml +++ b/crates/common/rlp/Cargo.toml @@ -11,6 +11,7 @@ arrayvec = { version = "0.7", default-features = false } auto_impl = "1" bytes = { version = "1", default-features = false } ethnum = { version = "1", default-features = false, optional = true } +smol_str = { version = "0.1", default-features = false, optional = true } ethereum-types = { version = "0.13", features = ["codec"], optional = true } reth-rlp-derive = { version = "0.1", path = "../rlp-derive", optional = true } @@ -20,6 +21,7 @@ reth-rlp-test = { path = ".", package = "reth-rlp", features = [ "std", "ethnum", "ethereum-types", + "smol_str" ] } criterion = "0.4.0" hex-literal = "0.3" diff --git a/crates/common/rlp/src/decode.rs b/crates/common/rlp/src/decode.rs index 068ce6a54..0df96aebf 100644 --- a/crates/common/rlp/src/decode.rs +++ b/crates/common/rlp/src/decode.rs @@ -79,6 +79,9 @@ impl core::fmt::Display for DecodeError { } impl Header { + /// Returns the decoded header. + /// + /// Returns an error if the given `buf`'s len is less than the expected payload. pub fn decode(buf: &mut &[u8]) -> Result { if !buf.has_remaining() { return Err(DecodeError::InputTooShort) @@ -352,21 +355,39 @@ where } } +#[cfg(feature = "smol_str")] +impl Decodable for smol_str::SmolStr { + fn decode(from: &mut &[u8]) -> Result { + let h = Header::decode(from)?; + if h.list { + return Err(DecodeError::UnexpectedList) + } + let data = &from[..h.payload_length]; + let s = match core::str::from_utf8(data) { + Ok(s) => Ok(smol_str::SmolStr::from(s)), + Err(_) => Err(DecodeError::Custom("invalid string")), + }; + from.advance(h.payload_length); + s + } +} + #[cfg(test)] mod tests { extern crate alloc; use super::*; + use crate::Encodable; use alloc::vec; use core::fmt::Debug; use ethereum_types::{U128, U256, U512, U64}; use ethnum::AsU256; use hex_literal::hex; - fn check_decode(fixtures: IT) + fn check_decode<'a, T, IT>(fixtures: IT) where T: Decodable + PartialEq + Debug, - IT: IntoIterator, &'static [u8])>, + IT: IntoIterator, &'a [u8])>, { for (expected, mut input) in fixtures { assert_eq!(T::decode(&mut input), expected); @@ -557,4 +578,16 @@ mod tests { (Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]), ]) } + + #[cfg(feature = "smol_str")] + #[test] + fn rlp_smol_str() { + use smol_str::SmolStr; + let mut b = BytesMut::new(); + "test smol str".to_string().encode(&mut b); + check_decode::(vec![ + (Ok(SmolStr::new("test smol str")), b.as_ref()), + (Err(DecodeError::UnexpectedList), &hex!("C0")[..]), + ]) + } } diff --git a/crates/common/rlp/src/encode.rs b/crates/common/rlp/src/encode.rs index eea18a567..a13feb00f 100644 --- a/crates/common/rlp/src/encode.rs +++ b/crates/common/rlp/src/encode.rs @@ -175,6 +175,16 @@ impl Encodable for bool { impl_max_encoded_len!(bool, { ::LEN }); +#[cfg(feature = "smol_str")] +impl Encodable for smol_str::SmolStr { + fn encode(&self, out: &mut dyn BufMut) { + self.as_bytes().encode(out); + } + fn length(&self) -> usize { + self.as_bytes().length() + } +} + #[cfg(feature = "ethnum")] mod ethnum_support { use super::*; @@ -506,4 +516,17 @@ mod tests { assert_eq!(encoded_list::(&[]), &hex!("c0")[..]); assert_eq!(encoded_list(&[0xFFCCB5_u64, 0xFFC0B5_u64]), &hex!("c883ffccb583ffc0b5")[..]); } + + #[cfg(feature = "smol_str")] + #[test] + fn rlp_smol_str() { + use smol_str::SmolStr; + assert_eq!(encoded(SmolStr::new(""))[..], hex!("80")[..]); + let mut b = BytesMut::new(); + "test smol str".to_string().encode(&mut b); + assert_eq!(&encoded(SmolStr::new("test smol str"))[..], b.as_ref()); + let mut b = BytesMut::new(); + "abcdefgh".to_string().encode(&mut b); + assert_eq!(&encoded(SmolStr::new("abcdefgh"))[..], b.as_ref()); + } }