diff --git a/src/main.rs b/src/main.rs index 987a75b..4e39adf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use std::{io, sync::{Arc, atomic::{AtomicU32, Ordering, AtomicBool, AtomicUsize}}, time::Duration}; -use futures_util::{StreamExt, future}; +use futures_util::{StreamExt, future::{self, join_all}}; use tui::{backend::CrosstermBackend, Terminal}; use tokio::{sync::{mpsc, Mutex}, task::JoinHandle}; use tokio::join; @@ -21,7 +21,7 @@ pub struct State { num_tasks_errored: Arc, num_bytes_downloaded: Arc, shutting_down: Arc, - url: String, + urls: Vec, log_send: mpsc::Sender, } @@ -35,6 +35,7 @@ fn main() { } async fn async_main() { + let urls: Vec = std::env::args().skip(1).collect(); // Some simple CLI args requirements... let url = match std::env::args().nth(1) { Some(url) => url, @@ -53,7 +54,7 @@ async fn async_main() { num_tasks_errored: Default::default(), num_bytes_downloaded: Default::default(), shutting_down: Default::default(), - url, + urls, log_send }; @@ -154,8 +155,20 @@ async fn download_task_manager(state: State) -> anyhow::Result<()> { async fn download_task(state: State, load_id: usize) -> anyhow::Result { + // one inner task for each url. + let urls = state.urls.clone(); + let inner_tasks = urls.iter().map(|url| download_task_inner(&state, load_id, url.to_string())); + let results = join_all(inner_tasks).await; + for result in results { + result? + } + + Ok::(TaskKind::DownloadTask) +} + +async fn download_task_inner(state: &State, load_id: usize, url: String) -> anyhow::Result<()> { let log_send = state.log_send.clone(); - match reqwest::get(&state.url).await { + match reqwest::get(&url).await { Ok(response) => { state.num_connections_open.fetch_add(1, Ordering::SeqCst); let mut stream = response.bytes_stream(); @@ -165,28 +178,27 @@ async fn download_task(state: State, load_id: usize) -> anyhow::Result state.num_bytes_downloaded.fetch_add(bytes.len().try_into()?, Ordering::Relaxed); }, Some(Err(e)) => { - log_send.send(format!("Task id {} errored while reading: {}", load_id, e)).await?; + log_send.send(format!("Task id {} errored while reading from {}: {}", load_id, &url, e)).await?; // do this second just in case the await fails. if the await fails or we otherwise exit, we'll add this *outside* state.num_tasks_errored.fetch_add(1, Ordering::SeqCst); break; }, None => { state.num_tasks_errored.fetch_add(1, Ordering::SeqCst); - log_send.send(format!("Task id {} ran out of data.", load_id)).await?; + log_send.send(format!("Task id {} ran out of data from {}.", load_id, &url)).await?; break; } } } state.num_connections_open.fetch_sub(1, Ordering::SeqCst); - log_send.send(format!("Task id {} exiting normally.", load_id)).await?; }, Err(e) => { state.num_tasks_errored.fetch_add(1, Ordering::SeqCst); - log_send.send(format!("Task id {} couldn't connect: {}", load_id, e)).await?; + log_send.send(format!("Task id {} couldn't connect to {}: {}", load_id, &url, e)).await?; tokio::time::sleep(Duration::from_secs(1)).await; }, } - Ok::(TaskKind::DownloadTask) + Ok(()) } enum TaskKind {