fix(discv5): fix bug flip byte lookup tgt (#7764)

Co-authored-by: DaniPopes <57450786+DaniPopes@users.noreply.github.com>
This commit is contained in:
Emilia Hane
2024-04-20 22:39:55 +02:00
committed by GitHub
parent 3750edd905
commit d81cf8aa5c

View File

@ -22,6 +22,7 @@ use discv5::ListenConfig;
use enr::{discv4_id_to_discv5_id, EnrCombinedKeyWrapper};
use futures::future::join_all;
use itertools::Itertools;
use rand::{Rng, RngCore};
use reth_primitives::{bytes::Bytes, ForkId, NodeRecord, PeerId};
use secp256k1::SecretKey;
use tokio::{sync::mpsc, task};
@ -541,19 +542,24 @@ pub fn get_lookup_target(
let mut target = local_node_id.raw();
// make sure target has a 'log2distance'-long suffix that differs from local node id
if kbucket_index != 0 {
let suffix_bit_offset = MAX_KBUCKET_INDEX.saturating_sub(kbucket_index);
let suffix_byte_offset = suffix_bit_offset / 8;
// todo: flip the precise bit
// let rel_suffix_bit_offset = suffix_bit_offset % 8;
target[suffix_byte_offset] = !target[suffix_byte_offset];
let bit_offset = MAX_KBUCKET_INDEX.saturating_sub(kbucket_index);
let (byte, bit) = (bit_offset / 8, bit_offset % 8);
// Flip the target bit.
target[byte] ^= 1 << (7 - bit);
if suffix_byte_offset != 31 {
for b in target.iter_mut().take(31).skip(suffix_byte_offset + 1) {
*b = rand::random::<u8>();
}
}
// Randomize the bits after the target.
let mut rng = rand::thread_rng();
// Randomize remaining bits in the byte we modified.
if bit < 7 {
// Compute the mask of the bits that need to be randomized.
let bits_to_randomize = 0xff >> (bit + 1);
// Clear.
target[byte] &= !bits_to_randomize;
// Randomize.
target[byte] |= rng.gen::<u8>() & bits_to_randomize;
}
// Randomize remaining bytes.
rng.fill_bytes(&mut target[byte + 1..]);
target.into()
}
@ -595,13 +601,11 @@ pub async fn lookup(
#[cfg(test)]
mod tests {
use super::*;
use ::enr::{CombinedKey, EnrKey};
use rand::Rng;
use secp256k1::rand::thread_rng;
use tracing::trace;
use super::*;
fn discv5_noop() -> Discv5 {
let sk = CombinedKey::generate_secp256k1();
Discv5 {
@ -786,11 +790,7 @@ mod tests {
pub fn log2_distance<U>(&self, other: &Key<U>) -> Option<u64> {
let xor_dist = self.distance(other);
let log_dist = (256 - xor_dist.0.leading_zeros() as u64);
if log_dist == 0 {
None
} else {
Some(log_dist)
}
(log_dist != 0).then_some(log_dist)
}
}
@ -807,15 +807,7 @@ mod tests {
#[test]
fn select_lookup_target() {
// bucket index ceiled to the next multiple of 4
const fn expected_bucket_index(kbucket_index: usize) -> u64 {
let log2distance = kbucket_index + 1;
let log2distance = log2distance / 8;
((log2distance + 1) * 8) as u64
}
let bucket_index = rand::thread_rng().gen_range(0..=MAX_KBUCKET_INDEX);
for bucket_index in 0..=MAX_KBUCKET_INDEX {
let sk = CombinedKey::generate_secp256k1();
let local_node_id = discv5::enr::NodeId::from(sk.public());
let target = get_lookup_target(bucket_index, local_node_id);
@ -823,14 +815,7 @@ mod tests {
let local_node_id = sigp::Key::from(local_node_id);
let target = sigp::Key::from(target);
if bucket_index == 0 {
// log2distance undef (inf)
assert!(local_node_id.log2_distance(&target).is_none())
} else {
assert_eq!(
expected_bucket_index(bucket_index),
local_node_id.log2_distance(&target).unwrap()
);
assert_eq!(local_node_id.log2_distance(&target), Some(bucket_index as u64 + 1));
}
}
}