diff --git a/Cargo.toml b/Cargo.toml index 6bb0676..87a33cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ tokio = { version = "1.15", features = ["full"], optional = true } futures-timer = "3.0" thiserror = "1.0" robotstxt = "0.3" +reqwest-middleware = "0.1.6" [[example]] name = "reddit" diff --git a/src/domain.rs b/src/domain.rs index f26f204..800eda5 100644 --- a/src/domain.rs +++ b/src/domain.rs @@ -6,6 +6,7 @@ use std::task::{Context, Poll}; use anyhow::Result; use futures::stream::Stream; use futures::{Future, FutureExt}; +use reqwest::header::HeaderValue; use crate::error::{CrawlError, DisallowReason}; use crate::requests::{ @@ -190,7 +191,7 @@ type CrawlRequest = Pin>>>>; type RobotsTxtRequest = Pin>>>; pub struct AllowedDomain { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, /// Futures that eventually return a http response that is passed to the /// scraper in_progress_crawl_requests: Vec>, @@ -332,7 +333,7 @@ where { // respect robots.txt let mut fut = Box::pin(get_response( - &pin.client, + pin.client.clone(), req, pin.skip_non_successful_responses, )); @@ -381,14 +382,14 @@ where pub struct AllowListConfig { pub delay: Option, pub respect_robots_txt: bool, - pub client: reqwest::Client, + pub client: reqwest_middleware::ClientWithMiddleware, pub skip_non_successful_responses: bool, pub max_depth: usize, pub max_requests: usize, } pub struct BlockList { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, /// list of domains that are blocked blocked_domains: HashSet, /// Futures that eventually return a http response that is passed to the @@ -416,7 +417,7 @@ pub struct BlockList { impl BlockList { pub fn new( blocked_domains: HashSet, - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, respect_robots_txt: bool, skip_non_successful_responses: bool, max_depth: usize, @@ -534,7 +535,7 @@ where if let Some(robots) = pin.robots_map.get(host) { if robots.is_not_disallowed(&req.request) { let fut = Box::pin(get_response( - &pin.client, + pin.client.clone(), req, pin.skip_non_successful_responses, )); @@ -572,7 +573,7 @@ where } } else { let fut = Box::pin(get_response( - &pin.client, + pin.client.clone(), req, pin.skip_non_successful_responses, )); @@ -603,7 +604,7 @@ where } fn get_response( - client: &reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, request: QueuedRequest, skip_non_successful_responses: bool, ) -> CrawlRequest @@ -618,9 +619,8 @@ where let request_url = request.url().clone(); let skip_http_error_response = skip_non_successful_responses; - let request = client.execute(request); - Box::pin(async move { + let request = client.execute(request); let mut resp = request.await?; if !resp.status().is_success() && skip_http_error_response { diff --git a/src/lib.rs b/src/lib.rs index c6145d0..c493ca8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,6 +147,7 @@ use anyhow::Result; use futures::stream::Stream; use futures::FutureExt; use reqwest::IntoUrl; +use reqwest::header::HeaderValue; use std::collections::{HashMap, HashSet, VecDeque}; use std::fmt; use std::future::Future; @@ -269,7 +270,7 @@ pub struct Crawler { in_progress_crawl_requests: Vec>, queued_results: VecDeque>, /// The client that issues all the requests - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, /// used to track the depth of submitted requests current_depth: usize, /// Either a list that only allows a set of domains or disallows a set of @@ -289,7 +290,9 @@ pub struct Crawler { impl Crawler { /// Create a new crawler following the config pub fn new(config: CrawlerConfig) -> Self { - let client = config.client.unwrap_or_default(); + let client = config + .client + .unwrap_or(reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build()); let list = if config.allowed_domains.is_empty() { let block_list = BlockList::new( @@ -367,7 +370,7 @@ where /// URLs. pub fn crawl(&mut self, fun: TCrawlFunction) where - TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture, + TCrawlFunction: FnOnce(&reqwest_middleware::ClientWithMiddleware) -> TCrawlFuture, TCrawlFuture: Future)>> + 'static, { let depth = self.current_depth + 1; @@ -375,8 +378,9 @@ where let fut = Box::pin(async move { let (mut resp, state) = fut.await?; let (status, url, headers) = response_info(&mut resp); + let text = resp.text().await?; - + Ok(Response { depth, // Note: There is no way to determine the original url since only the response is @@ -397,7 +401,7 @@ where /// returned once finished. pub fn complete(&mut self, fun: TCrawlFunction) where - TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture, + TCrawlFunction: FnOnce(&reqwest_middleware::ClientWithMiddleware) -> TCrawlFuture, TCrawlFuture: Future>> + 'static, { let fut = (fun)(&self.client); @@ -415,16 +419,20 @@ where } /// This queues in a whole request with no state attached - pub fn request(&mut self, req: reqwest::RequestBuilder) { + pub fn request(&mut self, req: reqwest_middleware::RequestBuilder) { self.queue_request(req, None) } /// This queues in a whole request with a state attached - pub fn request_with_state(&mut self, req: reqwest::RequestBuilder, state: T::State) { + pub fn request_with_state(&mut self, req: reqwest_middleware::RequestBuilder, state: T::State) { self.queue_request(req, Some(state)) } - fn queue_request(&mut self, request: reqwest::RequestBuilder, state: Option) { + fn queue_request( + &mut self, + request: reqwest_middleware::RequestBuilder, + state: Option, + ) { let req = QueuedRequestBuilder { request, state, @@ -436,8 +444,8 @@ where } } - /// Returns the client that performs all requests - pub fn client(&self) -> &reqwest::Client { + /// The client that performs all request + pub fn client(&self) -> &reqwest_middleware::ClientWithMiddleware { &self.client } @@ -573,7 +581,7 @@ pub struct CrawlerConfig { // /// Delay a request // request_delay: Option, /// The client that will be used to send the requests - client: Option, + client: Option, } impl Default for CrawlerConfig { @@ -608,18 +616,21 @@ impl CrawlerConfig { self } - pub fn set_client(mut self, client: reqwest::Client) -> Self { + pub fn set_client(mut self, client: reqwest_middleware::ClientWithMiddleware) -> Self { self.client = Some(client); self } - /// *NOTE* [`reqwest::Client`] already uses Arc under the hood, so - /// it's preferable to just `clone` it and pass via [`Self::set_client()`] + /// *NOTE* [`reqwest_middleware::ClientWithMiddleware`] already uses Arc under the hood, so + /// it's preferable to just `clone` it and pass via [`Self::set_client`] #[deprecated( since = "0.2.0", note = "You do not have to wrap the Client it in a `Arc` to reuse it, because it already uses an `Arc` internally. Users should use `set_client` instead." )] - pub fn with_shared_client(mut self, client: std::sync::Arc) -> Self { + pub fn with_shared_client( + mut self, + client: std::sync::Arc, + ) -> Self { self.client = Some(client.as_ref().clone()); self } diff --git a/src/requests.rs b/src/requests.rs index 41f72ea..d8a1855 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -59,6 +59,7 @@ impl RequestQueue { /// Set a delay to be applied between requests. pub fn set_delay(&mut self, mut delay: RequestDelay) -> Option { if let Some((_, d)) = self.delay.as_mut() { + //TODO: take a look. decide if this swap is necessary std::mem::swap(&mut delay, d); Some(delay) } else { @@ -162,7 +163,5 @@ impl RequestDelay { /// This helps callers avoid accidentally moving the [Response](reqwest::Response) /// when reading its sub-fields. pub(crate) fn response_info(resp: &mut reqwest::Response) -> (StatusCode, Url, HeaderMap) { - let mut headers = HeaderMap::new(); - std::mem::swap(&mut headers, resp.headers_mut()); - (resp.status(), resp.url().clone(), headers) + (resp.status(), resp.url().clone(), resp.headers_mut().clone()) } diff --git a/src/robots.rs b/src/robots.rs index 68f6642..9835d4b 100644 --- a/src/robots.rs +++ b/src/robots.rs @@ -273,7 +273,7 @@ Disallow: /", let data = handler.finish(); assert_eq!(data.groups.len(), 1); - let client = reqwest::Client::new(); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build(); let request = client .request(reqwest::Method::GET, "https://old.reddit.com/r/rust") .build() @@ -287,7 +287,7 @@ Disallow: /", parse_robotstxt("", &mut handler); let data = handler.finish(); - let client = reqwest::Client::new(); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build(); let request = client .request(reqwest::Method::GET, "https://old.reddit.com/r/rust") .build() @@ -305,7 +305,7 @@ Disallow: /r/rust", ); let data = handler.finish(); - let client = reqwest::Client::new(); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build(); let request = client .request(reqwest::Method::GET, "https://old.reddit.com/r/crust") .build()