diff --git a/src/exercises/concurrency/link-checker.md b/src/exercises/concurrency/link-checker.md index 71e42191..a5ba6ab4 100644 --- a/src/exercises/concurrency/link-checker.md +++ b/src/exercises/concurrency/link-checker.md @@ -57,12 +57,13 @@ Your `src/main.rs` file should look something like this: ```rust,compile_fail {{#include link-checker.rs:setup}} -{{#include link-checker.rs:extract_links}} +{{#include link-checker.rs:visit_page}} fn main() { + let client = Client::new(); let start_url = Url::parse("https://www.google.org").unwrap(); - let response = get(start_url).unwrap(); - match extract_links(response) { + let crawl_command = CrawlCommand{ url: start_url, extract_links: true }; + match visit_page(&client, &crawl_command) { Ok(links) => println!("Links: {links:#?}"), Err(err) => println!("Could not extract links: {err:#}"), } diff --git a/src/exercises/concurrency/link-checker.rs b/src/exercises/concurrency/link-checker.rs index 114187ed..17ef60cf 100644 --- a/src/exercises/concurrency/link-checker.rs +++ b/src/exercises/concurrency/link-checker.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{sync::Arc, sync::Mutex, sync::mpsc, thread}; + // ANCHOR: setup -use reqwest::blocking::{get, Response}; -use reqwest::Url; +use reqwest::{blocking::Client, Url}; use scraper::{Html, Selector}; use thiserror::Error; @@ -22,68 +23,159 @@ use thiserror::Error; enum Error { #[error("request error: {0}")] ReqwestError(#[from] reqwest::Error), + #[error("bad http response: {0}")] + BadResponse(String), } // ANCHOR_END: setup -// ANCHOR: extract_links -fn extract_links(response: Response) -> Result, Error> { - let base_url = response.url().to_owned(); - let document = response.text()?; - let html = Html::parse_document(&document); - let selector = Selector::parse("a").unwrap(); +// ANCHOR: visit_page +#[derive(Debug)] +struct CrawlCommand { + url: Url, + extract_links: bool, +} - let mut valid_urls = Vec::new(); - for element in html.select(&selector) { - if let Some(href) = element.value().attr("href") { - match base_url.join(href) { - Ok(url) => valid_urls.push(url), - Err(err) => { - println!("On {base_url}: could not parse {href:?}: {err} (ignored)",); +fn visit_page(client: &Client, command: &CrawlCommand) -> Result, Error> { + println!("Checking {:#}", command.url); + let response = client.get(command.url.clone()).send()?; + if !response.status().is_success() { + return Err(Error::BadResponse(response.status().to_string())); + } + + let mut link_urls = Vec::new(); + if !command.extract_links { + return Ok(link_urls); + } + + let base_url = response.url().to_owned(); + let body_text = response.text()?; + let document = Html::parse_document(&body_text); + + let selector = Selector::parse("a").unwrap(); + let href_values = document + .select(&selector) + .filter_map(|element| element.value().attr("href")); + for href in href_values { + match base_url.join(href) { + Ok(link_url) => { + link_urls.push(link_url); + } + Err(err) => { + println!("On {base_url:#}: ignored unparsable {href:?}: {err}"); + } + } + } + Ok(link_urls) +} +// ANCHOR_END: visit_page + +struct CrawlState { + domain: String, + visited_pages: std::collections::HashSet, +} + +impl CrawlState { + fn new(start_url: &Url) -> CrawlState { + let mut visited_pages = std::collections::HashSet::new(); + visited_pages.insert(start_url.as_str().to_string()); + CrawlState { + domain: start_url.domain().unwrap().to_string(), + visited_pages, + } + } + + /// Determine whether links within the given page should be extracted. + fn should_extract_links(&self, url: &Url) -> bool { + let Some(url_domain) = url.domain() else { + return false; + }; + url_domain == self.domain + } + + /// Mark the given page as visited, returning true if it had already + /// been visited. + fn mark_visited(&mut self, url: &Url) -> bool { + self.visited_pages.insert(url.as_str().to_string()) + } +} + +type CrawlResult = Result, (Url, Error)>; +fn spawn_crawler_threads( + command_receiver: mpsc::Receiver, + result_sender: mpsc::Sender, + thread_count: u32, +) { + let command_receiver = Arc::new(Mutex::new(command_receiver)); + + for _ in 0..thread_count { + let result_sender = result_sender.clone(); + let command_receiver = command_receiver.clone(); + thread::spawn(move || { + let client = Client::new(); + loop { + let command_result = { + let receiver_guard = command_receiver.lock().unwrap(); + receiver_guard.recv() + }; + let Ok(crawl_command) = command_result else { + // The sender got dropped. No more commands coming in. + break; + }; + let crawl_result = match visit_page(&client, &crawl_command) { + Ok(link_urls) => Ok(link_urls), + Err(error) => Err((crawl_command.url, error)), + }; + result_sender.send(crawl_result).unwrap(); + } + }); + } +} + +fn control_crawl( + start_url: Url, + command_sender: mpsc::Sender, + result_receiver: mpsc::Receiver, +) -> Vec { + let mut crawl_state = CrawlState::new(&start_url); + let start_command = CrawlCommand { url: start_url, extract_links: true }; + command_sender.send(start_command).unwrap(); + let mut pending_urls = 1; + + let mut bad_urls = Vec::new(); + while pending_urls > 0 { + let crawl_result = result_receiver.recv().unwrap(); + pending_urls -= 1; + + match crawl_result { + Ok(link_urls) => { + for url in link_urls { + if crawl_state.mark_visited(&url) { + let extract_links = crawl_state.should_extract_links(&url); + let crawl_command = CrawlCommand { url, extract_links }; + command_sender.send(crawl_command).unwrap(); + pending_urls += 1; + } } } - } - } - - Ok(valid_urls) -} -// ANCHOR_END: extract_links - -fn check_links(url: Url) -> Result, Error> { - println!("Checking {url}"); - - let response = get(url.to_owned())?; - - if !response.status().is_success() { - return Ok(vec![url.to_owned()]); - } - - let links = extract_links(response)?; - for link in &links { - println!("{link}, {:?}", link.domain()); - } - - let mut failed_links = Vec::new(); - for link in links { - if link.domain() != url.domain() { - println!("Checking external link: {link}"); - let response = get(link.clone())?; - if !response.status().is_success() { - println!("Error on {url}: {link} failed: {}", response.status()); - failed_links.push(link); + Err((url, error)) => { + bad_urls.push(url); + println!("Got crawling error: {:#}", error); + continue; } - } else { - println!("Checking link in same domain: {link}"); - failed_links.extend(check_links(link)?) } } + bad_urls +} - Ok(failed_links) +fn check_links(start_url: Url) -> Vec { + let (result_sender, result_receiver) = mpsc::channel::(); + let (command_sender, command_receiver) = mpsc::channel::(); + spawn_crawler_threads(command_receiver, result_sender, 16); + control_crawl(start_url, command_sender, result_receiver) } fn main() { - let start_url = Url::parse("https://www.google.org").unwrap(); - match check_links(start_url) { - Ok(links) => println!("Links: {links:#?}"), - Err(err) => println!("Could not extract links: {err:#}"), - } + let start_url = reqwest::Url::parse("https://www.google.org").unwrap(); + let bad_urls = check_links(start_url); + println!("Bad URLs: {:#?}", bad_urls); } diff --git a/src/exercises/concurrency/solutions-morning.md b/src/exercises/concurrency/solutions-morning.md index ba74e3c8..40defe91 100644 --- a/src/exercises/concurrency/solutions-morning.md +++ b/src/exercises/concurrency/solutions-morning.md @@ -8,3 +8,10 @@ {{#include dining-philosophers.rs}} ``` +## Link Checker + +([back to exercise](link-checker.md)) + +```rust,compile_fail +{{#include link-checker.rs}} +```