fix(sync): return control from HeaderStage back to pipeline (#609)

* Headerstage now gives control back to pipeline after commit_threshold blocks have been synced

* Updated tests

* Only take required number of headers from the stream

* Simplified stage complete check
This commit is contained in:
Sanket Shanbhag
2022-12-29 14:27:49 +05:30
committed by GitHub
parent 76e76bb651
commit 28669d4aa8

View File

@ -2,7 +2,7 @@ use crate::{
db::Transaction, metrics::HeaderMetrics, DatabaseIntegrityError, ExecInput, ExecOutput, Stage,
StageError, StageId, UnwindInput, UnwindOutput,
};
use futures_util::StreamExt;
use futures_util::{StreamExt, TryStreamExt};
use reth_db::{
cursor::{DbCursorRO, DbCursorRW},
database::Database,
@ -73,63 +73,66 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient, S: Statu
tx: &mut Transaction<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let stage_progress = input.stage_progress.unwrap_or_default();
self.update_head::<DB>(tx, stage_progress).await?;
let current_progress = input.stage_progress.unwrap_or_default();
self.update_head::<DB>(tx, current_progress).await?;
// Lookup the head and tip of the sync range
let (head, tip) = self.get_head_and_tip(tx, stage_progress).await?;
let (head, tip) = self.get_head_and_tip(tx, current_progress).await?;
debug!(target: "sync::stages::headers", ?tip, head = ?head.hash(), "Commencing sync");
let mut current_progress = stage_progress;
let mut stream =
self.downloader.stream(head.clone(), tip).chunks(self.commit_threshold as usize);
// The stage relies on the downloader to return the headers
// in descending order starting from the tip down to
// the local head (latest block in db)
while let Some(headers) = stream.next().await {
match headers.into_iter().collect::<Result<Vec<_>, _>>() {
Ok(res) => {
info!(target: "sync::stages::headers", len = res.len(), "Received headers");
self.metrics.headers_counter.increment(res.len() as u64);
// The downloader returns the headers in descending order starting from the tip
// down to the local head (latest block in db)
let downloaded_headers: Result<Vec<SealedHeader>, DownloadError> = self
.downloader
.stream(head.clone(), tip)
.take(self.commit_threshold as usize) // Only stream [self.commit_threshold] headers
.try_collect()
.await;
// Perform basic response validation
self.validate_header_response(&res)?;
let write_progress =
self.write_headers::<DB>(tx, res).await?.unwrap_or_default();
current_progress = current_progress.max(write_progress);
match downloaded_headers {
Ok(res) => {
info!(target: "sync::stages::headers", len = res.len(), "Received headers");
self.metrics.headers_counter.increment(res.len() as u64);
// Perform basic response validation
self.validate_header_response(&res)?;
// Write the headers to db
self.write_headers::<DB>(tx, res).await?.unwrap_or_default();
if self.is_stage_done(tx, current_progress).await? {
// Update total difficulty values after we have reached fork choice
debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty");
self.write_td::<DB>(tx, &head)?;
let stage_progress = current_progress.max(
tx.cursor::<tables::CanonicalHeaders>()?
.last()?
.map(|(num, _)| num)
.unwrap_or_default(),
);
Ok(ExecOutput { stage_progress, done: true })
} else {
Ok(ExecOutput { stage_progress: current_progress, done: false })
}
Err(e) => {
self.metrics.update_headers_error_metrics(&e);
match e {
DownloadError::Timeout => {
warn!(target: "sync::stages::headers", "No response for header request");
return Err(StageError::Recoverable(DownloadError::Timeout.into()))
}
DownloadError::HeaderValidation { hash, error } => {
error!(target: "sync::stages::headers", ?error, ?hash, "Validation error");
return Err(StageError::Validation { block: stage_progress, error })
}
error => {
error!(target: "sync::stages::headers", ?error, "Unexpected error");
return Err(StageError::Recoverable(error.into()))
}
}
Err(e) => {
self.metrics.update_headers_error_metrics(&e);
match e {
DownloadError::Timeout => {
warn!(target: "sync::stages::headers", "No response for header request");
return Err(StageError::Recoverable(DownloadError::Timeout.into()))
}
DownloadError::HeaderValidation { hash, error } => {
error!(target: "sync::stages::headers", ?error, ?hash, "Validation error");
return Err(StageError::Validation { block: current_progress, error })
}
error => {
error!(target: "sync::stages::headers", ?error, "Unexpected error");
return Err(StageError::Recoverable(error.into()))
}
}
}
}
// Write total difficulty values after all headers have been inserted
debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty");
self.write_td::<DB>(tx, &head)?;
let stage_progress = current_progress.max(
tx.cursor::<tables::CanonicalHeaders>()?
.last()?
.map(|(num, _)| num)
.unwrap_or_default(),
);
Ok(ExecOutput { stage_progress, done: true })
}
/// Unwind the stage.
@ -166,6 +169,19 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient, S: StatusUpdater>
Ok(())
}
async fn is_stage_done<DB: Database>(
&self,
tx: &Transaction<'_, DB>,
stage_progress: u64,
) -> Result<bool, StageError> {
let mut header_cursor = tx.cursor::<tables::CanonicalHeaders>()?;
let (head_num, _) = header_cursor
.seek_exact(stage_progress)?
.ok_or(DatabaseIntegrityError::CanonicalHeader { number: stage_progress })?;
// Check if the next entry is congruent
Ok(header_cursor.next()?.map(|(next_num, _)| head_num + 1 == next_num).unwrap_or_default())
}
/// Get the head and tip of the range we need to sync
async fn get_head_and_tip<DB: Database>(
&self,
@ -207,6 +223,7 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient, S: StatusUpdater>
None => self.next_fork_choice_state(&head.hash()).await.head_block_hash,
_ => return Err(StageError::StageProgress(stage_progress)),
};
Ok((head, tip))
}
@ -261,7 +278,6 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient, S: StatusUpdater>
cursor_header.insert(key, header)?;
cursor_canonical.insert(key.number(), key.hash())?;
}
Ok(latest)
}
@ -388,11 +404,7 @@ mod tests {
runner.consensus.update_tip(tip.hash());
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done: true, stage_progress })
if stage_progress == tip.number
);
assert_matches!(result, Ok(ExecOutput { done: true, stage_progress }) if stage_progress == tip.number);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
@ -507,7 +519,7 @@ mod tests {
client: self.client.clone(),
downloader: self.downloader.clone(),
network_handle: self.network_handle.clone(),
commit_threshold: 100,
commit_threshold: 500,
metrics: HeaderMetrics::default(),
}
}