diff --git a/Cargo.lock b/Cargo.lock index c05f7027a..d4bebd4c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3337,6 +3337,7 @@ dependencies = [ "reth-primitives", "reth-rlp", "reth-rlp-derive", + "reth-tasks", "reth-tracing", "reth-transaction-pool", "secp256k1", diff --git a/crates/net/network/Cargo.toml b/crates/net/network/Cargo.toml index 2acbcdd38..47989fda9 100644 --- a/crates/net/network/Cargo.toml +++ b/crates/net/network/Cargo.toml @@ -18,6 +18,7 @@ reth-eth-wire = { path = "../eth-wire" } reth-ecies = { path = "../ecies" } reth-rlp = { path = "../../common/rlp" } reth-rlp-derive = { path = "../../common/rlp-derive" } +reth-tasks = { path = "../../tasks" } reth-transaction-pool = { path = "../../transaction-pool" } # async/futures diff --git a/crates/net/network/src/config.rs b/crates/net/network/src/config.rs index 7ddf69877..4c6db50a5 100644 --- a/crates/net/network/src/config.rs +++ b/crates/net/network/src/config.rs @@ -5,6 +5,7 @@ use crate::{ }; use reth_discv4::{Discv4Config, Discv4ConfigBuilder, NodeRecord, DEFAULT_DISCOVERY_PORT}; use reth_primitives::{Chain, ForkId, H256}; +use reth_tasks::TaskExecutor; use secp256k1::SecretKey; use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, @@ -40,6 +41,8 @@ pub struct NetworkConfig { pub block_import: Box, /// The default mode of the network. pub network_mode: NetworkMode, + /// The executor to use for spawning tasks. + pub executor: Option, } // === impl NetworkConfig === @@ -98,6 +101,8 @@ pub struct NetworkConfigBuilder { block_import: Box, /// The default mode of the network. network_mode: NetworkMode, + /// The executor to use for spawning tasks. + executor: Option, } // === impl NetworkConfigBuilder === @@ -119,9 +124,16 @@ impl NetworkConfigBuilder { genesis_hash: Default::default(), block_import: Box::::default(), network_mode: Default::default(), + executor: None, } } + /// Sets the executor to use for spawning tasks. + pub fn executor(mut self, executor: TaskExecutor) -> Self { + self.executor = Some(executor); + self + } + /// Sets a custom config for how sessions are handled. pub fn sessions_config(mut self, config: SessionsConfig) -> Self { self.sessions_config = Some(config); @@ -180,6 +192,7 @@ impl NetworkConfigBuilder { genesis_hash, block_import, network_mode, + executor, } = self; NetworkConfig { client, @@ -199,6 +212,7 @@ impl NetworkConfigBuilder { genesis_hash, block_import, network_mode, + executor, } } } diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index 10e2b4fb3..c5400af97 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -122,6 +122,7 @@ where block_import, network_mode, boot_nodes, + executor, .. } = config; @@ -137,7 +138,7 @@ where // need to retrieve the addr here since provided port could be `0` let local_peer_id = discovery.local_id(); - let sessions = SessionManager::new(secret_key, sessions_config); + let sessions = SessionManager::new(secret_key, sessions_config, executor); let state = NetworkState::new(client, discovery, peers_manger, genesis_hash); let swarm = Swarm::new(incoming, sessions, state); diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 4c0e3d24d..0083300ed 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -33,7 +33,6 @@ use std::{ use tokio::{ net::TcpStream, sync::{mpsc, oneshot}, - task::JoinSet, }; use tokio_stream::wrappers::ReceiverStream; use tracing::{instrument, trace, warn}; @@ -44,6 +43,7 @@ mod handle; use crate::session::config::SessionCounter; pub use config::SessionsConfig; use reth_ecies::util::pk2id; +use reth_tasks::TaskExecutor; /// Internal identifier for active sessions. #[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Eq, Hash)] @@ -68,10 +68,8 @@ pub(crate) struct SessionManager { fork_filter: ForkFilter, /// Size of the command buffer per session. session_command_buffer: usize, - /// All spawned session tasks. - /// - /// Note: If dropped, the session tasks are aborted. - spawned_tasks: JoinSet<()>, + /// The executor for spawned tasks. + executor: Option, /// All pending session that are currently handshaking, exchanging `Hello`s. /// /// Events produced during the authentication phase are reported to this manager. Once the @@ -99,7 +97,11 @@ pub(crate) struct SessionManager { impl SessionManager { /// Creates a new empty [`SessionManager`]. - pub(crate) fn new(secret_key: SecretKey, config: SessionsConfig) -> Self { + pub(crate) fn new( + secret_key: SecretKey, + config: SessionsConfig, + executor: Option, + ) -> Self { let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer); let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer); @@ -121,7 +123,7 @@ impl SessionManager { hello, fork_filter, session_command_buffer: config.session_command_buffer, - spawned_tasks: Default::default(), + executor, pending_sessions: Default::default(), active_sessions: Default::default(), pending_sessions_tx, @@ -139,11 +141,15 @@ impl SessionManager { } /// Spawns the given future onto a new task that is tracked in the `spawned_tasks` [`JoinSet`]. - fn spawn(&mut self, f: F) + fn spawn(&self, f: F) where F: Future + Send + 'static, { - self.spawned_tasks.spawn(async move { f.await }); + if let Some(ref executor) = self.executor { + executor.spawn(async move { f.await }) + } else { + tokio::task::spawn(async move { f.await }); + } } /// Invoked on a received status update diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 44a3935d6..8531f090d 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -73,6 +73,7 @@ impl Stream for TaskManager { } /// A type that can spawn new tokio tasks +#[derive(Debug, Clone)] pub struct TaskExecutor { /// Handle to the tokio runtime this task manager is associated with. ///