fix: validate headers in full block downloader (#4034)

This commit is contained in:
Matthias Seitz
2023-08-02 18:36:48 +02:00
committed by GitHub
parent d595834d20
commit 94dfeb3ade
4 changed files with 122 additions and 62 deletions

View File

@ -1,5 +1,6 @@
use super::headers::client::HeadersRequest;
use crate::{
consensus::ConsensusError,
consensus::{Consensus, ConsensusError},
p2p::{
bodies::client::{BodiesClient, SingleBodyRequest},
error::PeerRequestResult,
@ -16,22 +17,28 @@ use std::{
fmt::Debug,
future::Future,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tracing::debug;
use super::headers::client::HeadersRequest;
/// A Client that can fetch full blocks from the network.
#[derive(Debug, Clone)]
pub struct FullBlockClient<Client> {
client: Client,
consensus: Arc<dyn Consensus>,
}
impl<Client> FullBlockClient<Client> {
/// Creates a new instance of `FullBlockClient`.
pub fn new(client: Client) -> Self {
Self { client }
pub fn new(client: Client, consensus: Arc<dyn Consensus>) -> Self {
Self { client, consensus }
}
/// Returns a client with Test consensus
#[cfg(feature = "test-utils")]
pub fn test_client(client: Client) -> Self {
Self::new(client, Arc::new(crate::test_utils::TestConsensus::default()))
}
}
@ -95,6 +102,7 @@ where
headers: None,
pending_headers: VecDeque::new(),
bodies: HashMap::new(),
consensus: Arc::clone(&self.consensus),
}
}
}
@ -186,7 +194,7 @@ where
if let Some(header) = maybe_header {
if header.hash() != this.hash {
debug!(target: "downloaders", expected=?this.hash, received=?header.hash, "Received wrong header");
// received bad header
// received a different header than requested
this.client.report_bad_message(peer)
} else {
this.header = Some(header);
@ -352,6 +360,8 @@ where
{
/// The client used to fetch headers and bodies.
client: Client,
/// The consensus instance used to validate the blocks.
consensus: Arc<dyn Consensus>,
/// The block hash to start fetching from (inclusive).
start_hash: H256,
/// How many blocks to fetch: `len([start_hash, ..]) == count`
@ -381,6 +391,8 @@ where
}
/// Inserts a block body, matching it with the `next_header`.
///
/// Note: this assumes the response matches the next header in the queue.
fn insert_body(&mut self, body_response: BodyResponse) {
if let Some(header) = self.pending_headers.pop_front() {
self.bodies.insert(header, body_response);
@ -388,8 +400,8 @@ where
}
/// Inserts multiple block bodies.
fn insert_bodies(&mut self, bodies: Vec<BodyResponse>) {
for body in bodies {
fn insert_bodies(&mut self, bodies: impl IntoIterator<Item = BodyResponse>) {
for body in bodies.into_iter() {
self.insert_body(body);
}
}
@ -461,6 +473,46 @@ where
Some(response)
}
fn on_headers_response(&mut self, headers: WithPeerId<Vec<Header>>) {
let (peer, mut headers_falling) =
headers.map(|h| h.into_iter().map(|h| h.seal_slow()).collect::<Vec<_>>()).split();
// fill in the response if it's the correct length
if headers_falling.len() == self.count as usize {
// sort headers from highest to lowest block number
headers_falling.sort_unstable_by_key(|h| Reverse(h.number));
// check the starting hash
if headers_falling[0].hash() != self.start_hash {
// received a different header than requested
self.client.report_bad_message(peer);
} else {
let headers_rising = headers_falling.iter().rev().cloned().collect::<Vec<_>>();
// ensure the downloaded headers are valid
if let Err(err) = self.consensus.validate_header_range(&headers_rising) {
debug!(target: "downloaders", %err, ?self.start_hash, "Received bad header response");
self.client.report_bad_message(peer);
return
}
// get the bodies request so it can be polled later
let hashes = headers_falling.iter().map(|h| h.hash()).collect::<Vec<_>>();
// populate the pending headers
self.pending_headers = headers_falling.clone().into();
// set the actual request if it hasn't been started yet
if !self.has_bodies_request_started() {
// request the bodies for the downloaded headers
self.request.bodies = Some(self.client.get_block_bodies(hashes));
}
// set the headers response
self.headers = Some(headers_falling);
}
}
}
/// Returns whether or not a bodies request has been started, returning false if there is no
/// pending request.
fn has_bodies_request_started(&self) -> bool {
@ -500,39 +552,7 @@ where
RangeResponseResult::Header(res) => {
match res {
Ok(headers) => {
let (peer, mut headers) = headers
.map(|h| {
h.iter().map(|h| h.clone().seal_slow()).collect::<Vec<_>>()
})
.split();
// fill in the response if it's the correct length
if headers.len() == this.count as usize {
// sort headers from highest to lowest block number
headers.sort_unstable_by_key(|h| Reverse(h.number));
// check the starting hash
if headers[0].hash() != this.start_hash {
// received bad response
this.client.report_bad_message(peer);
} else {
// get the bodies request so it can be polled later
let hashes =
headers.iter().map(|h| h.hash()).collect::<Vec<_>>();
// populate the pending headers
this.pending_headers = headers.clone().into();
// set the actual request if it hasn't been started yet
if !this.has_bodies_request_started() {
this.request.bodies =
Some(this.client.get_block_bodies(hashes));
}
// set the headers response
this.headers = Some(headers);
}
}
this.on_headers_response(headers);
}
Err(err) => {
debug!(target: "downloaders", %err, ?this.start_hash, "Header range download failed");
@ -561,10 +581,9 @@ where
// first insert the received bodies
this.insert_bodies(
new_bodies
.iter()
.map(|resp| WithPeerId::new(peer, resp.clone()))
.map(BodyResponse::PendingValidation)
.collect::<Vec<_>>(),
.into_iter()
.map(|resp| WithPeerId::new(peer, resp))
.map(BodyResponse::PendingValidation),
);
if !this.is_bodies_complete() {
@ -723,7 +742,7 @@ mod tests {
let header = SealedHeader::default();
let body = BlockBody::default();
client.insert(header.clone(), body.clone());
let client = FullBlockClient::new(client);
let client = FullBlockClient::test_client(client);
let received = client.get_full_block(header.hash()).await;
assert_eq!(received, SealedBlock::new(header, body));
@ -735,7 +754,7 @@ mod tests {
let header = SealedHeader::default();
let body = BlockBody::default();
client.insert(header.clone(), body.clone());
let client = FullBlockClient::new(client);
let client = FullBlockClient::test_client(client);
let received = client.get_full_block_range(header.hash(), 1).await;
let received = received.first().expect("response should include a block");
@ -754,7 +773,7 @@ mod tests {
header = header.header.seal_slow();
client.insert(header.clone(), body.clone());
}
let client = FullBlockClient::new(client);
let client = FullBlockClient::test_client(client);
let received = client.get_full_block_range(header.hash(), 1).await;
let received = received.first().expect("response should include a block");
@ -780,7 +799,7 @@ mod tests {
header = header.header.seal_slow();
client.insert(header.clone(), body.clone());
}
let client = FullBlockClient::new(client);
let client = FullBlockClient::test_client(client);
let future = client.get_full_block_range(header.hash(), 1);
let mut stream = FullBlockRangeStream::from(future);
@ -826,7 +845,7 @@ mod tests {
header = header.header.seal_slow();
client.insert(header.clone(), body.clone());
}
let client = FullBlockClient::new(client);
let client = FullBlockClient::test_client(client);
let received = client.get_full_block_range(header.hash(), 1).await;
let received = received.first().expect("response should include a block");