Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Reqwest Middleware #8

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ categories = ["web-programming"]
[dependencies]
html5ever = "0.25"
scraper = "0.12"
reqwest = "0.11"
reqwest = "0.11.10"
Zizico2 marked this conversation as resolved.
Show resolved Hide resolved
futures = "0.3"
rand = "0.8"
anyhow = "1.0"
tokio = { version = "1.15", features = ["full"], optional = true }
futures-timer = "3.0"
thiserror = "1.0"
robotstxt = "0.3"
reqwest-middleware = "0.1.6"

[[example]]
name = "reddit"
Expand Down
20 changes: 10 additions & 10 deletions src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::task::{Context, Poll};
use anyhow::Result;
use futures::stream::Stream;
use futures::{Future, FutureExt};
use reqwest::header::HeaderValue;

use crate::error::{CrawlError, DisallowReason};
use crate::requests::{
Expand Down Expand Up @@ -190,7 +191,7 @@ type CrawlRequest<T> = Pin<Box<dyn Future<Output = Result<Response<T>>>>>;
type RobotsTxtRequest = Pin<Box<dyn Future<Output = Result<RobotsData>>>>;

pub struct AllowedDomain<T> {
client: reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,
/// Futures that eventually return a http response that is passed to the
/// scraper
in_progress_crawl_requests: Vec<CrawlRequest<T>>,
Expand Down Expand Up @@ -332,7 +333,7 @@ where
{
// respect robots.txt
let mut fut = Box::pin(get_response(
&pin.client,
pin.client.clone(),
req,
pin.skip_non_successful_responses,
));
Expand Down Expand Up @@ -381,14 +382,14 @@ where
pub struct AllowListConfig {
pub delay: Option<RequestDelay>,
pub respect_robots_txt: bool,
pub client: reqwest::Client,
pub client: reqwest_middleware::ClientWithMiddleware,
pub skip_non_successful_responses: bool,
pub max_depth: usize,
pub max_requests: usize,
}

pub struct BlockList<T> {
client: reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,
/// list of domains that are blocked
blocked_domains: HashSet<String>,
/// Futures that eventually return a http response that is passed to the
Expand Down Expand Up @@ -416,7 +417,7 @@ pub struct BlockList<T> {
impl<T> BlockList<T> {
pub fn new(
blocked_domains: HashSet<String>,
client: reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,
respect_robots_txt: bool,
skip_non_successful_responses: bool,
max_depth: usize,
Expand Down Expand Up @@ -534,7 +535,7 @@ where
if let Some(robots) = pin.robots_map.get(host) {
if robots.is_not_disallowed(&req.request) {
let fut = Box::pin(get_response(
&pin.client,
pin.client.clone(),
req,
pin.skip_non_successful_responses,
));
Expand Down Expand Up @@ -572,7 +573,7 @@ where
}
} else {
let fut = Box::pin(get_response(
&pin.client,
pin.client.clone(),
req,
pin.skip_non_successful_responses,
));
Expand Down Expand Up @@ -603,7 +604,7 @@ where
}

fn get_response<T>(
client: &reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,
request: QueuedRequest<T>,
skip_non_successful_responses: bool,
) -> CrawlRequest<T>
Expand All @@ -618,9 +619,8 @@ where
let request_url = request.url().clone();
let skip_http_error_response = skip_non_successful_responses;

let request = client.execute(request);

Box::pin(async move {
let request = client.execute(request);
let mut resp = request.await?;

if !resp.status().is_success() && skip_http_error_response {
Expand Down
37 changes: 24 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ use anyhow::Result;
use futures::stream::Stream;
use futures::FutureExt;
use reqwest::IntoUrl;
use reqwest::header::HeaderValue;
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use std::future::Future;
Expand Down Expand Up @@ -269,7 +270,7 @@ pub struct Crawler<T: Scraper> {
in_progress_crawl_requests: Vec<CrawlRequest<T::State>>,
queued_results: VecDeque<CrawlResult<T>>,
/// The client that issues all the requests
client: reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,
/// used to track the depth of submitted requests
current_depth: usize,
/// Either a list that only allows a set of domains or disallows a set of
Expand All @@ -289,7 +290,9 @@ pub struct Crawler<T: Scraper> {
impl<T: Scraper> Crawler<T> {
/// Create a new crawler following the config
pub fn new(config: CrawlerConfig) -> Self {
let client = config.client.unwrap_or_default();
let client = config
.client
.unwrap_or(reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build());

let list = if config.allowed_domains.is_empty() {
let block_list = BlockList::new(
Expand Down Expand Up @@ -364,16 +367,17 @@ where
/// the scraper again
pub fn crawl<TCrawlFunction, TCrawlFuture>(&mut self, fun: TCrawlFunction)
where
TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture,
TCrawlFunction: FnOnce(&reqwest_middleware::ClientWithMiddleware) -> TCrawlFuture,
TCrawlFuture: Future<Output = Result<(reqwest::Response, Option<T::State>)>> + 'static,
{
let depth = self.current_depth + 1;
let fut = (fun)(&self.client);
let fut = Box::pin(async move {
let (mut resp, state) = fut.await?;
let (status, url, headers) = response_info(&mut resp);

let text = resp.text().await?;

Ok(Response {
depth,
// Note: There is no way to determine the original url since only the response is
Expand All @@ -394,7 +398,7 @@ where
/// returned once finished.
pub fn complete<TCrawlFunction, TCrawlFuture>(&mut self, fun: TCrawlFunction)
where
TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture,
TCrawlFunction: FnOnce(&reqwest_middleware::ClientWithMiddleware) -> TCrawlFuture,
TCrawlFuture: Future<Output = Result<Option<T::Output>>> + 'static,
{
let fut = (fun)(&self.client);
Expand All @@ -412,16 +416,20 @@ where
}

/// This queues in a whole request with no state attached
pub fn request(&mut self, req: reqwest::RequestBuilder) {
pub fn request(&mut self, req: reqwest_middleware::RequestBuilder) {
self.queue_request(req, None)
}

/// This queues in a whole request with a state attached
pub fn request_with_state(&mut self, req: reqwest::RequestBuilder, state: T::State) {
pub fn request_with_state(&mut self, req: reqwest_middleware::RequestBuilder, state: T::State) {
self.queue_request(req, Some(state))
}

fn queue_request(&mut self, request: reqwest::RequestBuilder, state: Option<T::State>) {
fn queue_request(
&mut self,
request: reqwest_middleware::RequestBuilder,
state: Option<T::State>,
) {
let req = QueuedRequestBuilder {
request,
state,
Expand All @@ -434,7 +442,7 @@ where
}

/// The client that performs all request
pub fn client(&self) -> &reqwest::Client {
pub fn client(&self) -> &reqwest_middleware::ClientWithMiddleware {
&self.client
}

Expand Down Expand Up @@ -570,7 +578,7 @@ pub struct CrawlerConfig {
// /// Delay a request
// request_delay: Option<RequestDelay>,
/// The client that will be used to send the requests
client: Option<reqwest::Client>,
client: Option<reqwest_middleware::ClientWithMiddleware>,
}

impl Default for CrawlerConfig {
Expand Down Expand Up @@ -605,18 +613,21 @@ impl CrawlerConfig {
self
}

pub fn set_client(mut self, client: reqwest::Client) -> Self {
pub fn set_client(mut self, client: reqwest_middleware::ClientWithMiddleware) -> Self {
self.client = Some(client);
self
}

/// *NOTE* [`reqwest::Client`] already uses Arc under the hood, so
/// *NOTE* [`reqwest_middleware::ClientWithMiddleware`] already uses Arc under the hood, so
/// it's preferable to just `clone` it and pass via [`Self::set_client`]
#[deprecated(
since = "0.2.0",
note = "You do not have to wrap the Client it in a `Arc` to reuse it, because it already uses an `Arc` internally. Users should use `set_client` instead."
)]
pub fn with_shared_client(mut self, client: std::sync::Arc<reqwest::Client>) -> Self {
pub fn with_shared_client(
mut self,
client: std::sync::Arc<reqwest_middleware::ClientWithMiddleware>,
) -> Self {
self.client = Some(client.as_ref().clone());
self
}
Expand Down
7 changes: 3 additions & 4 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct QueuedRequest<T> {

/// A request that is waiting to be build
pub struct QueuedRequestBuilder<T> {
pub request: reqwest::RequestBuilder,
pub request: reqwest_middleware::RequestBuilder,
pub state: Option<T>,
pub depth: usize,
}
Expand Down Expand Up @@ -47,6 +47,7 @@ impl<T> RequestQueue<T> {
/// Set a delay to be applied between requests
pub fn set_delay(&mut self, mut delay: RequestDelay) -> Option<RequestDelay> {
if let Some((_, d)) = self.delay.as_mut() {
//TODO: take a look. decide if this swap is necessary
std::mem::swap(&mut delay, d);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My I ask why this swap? Just curious

Some(delay)
} else {
Expand Down Expand Up @@ -141,7 +142,5 @@ impl RequestDelay {
}

pub(crate) fn response_info(resp: &mut reqwest::Response) -> (StatusCode, Url, HeaderMap) {
let mut headers = HeaderMap::new();
std::mem::swap(&mut headers, resp.headers_mut());
(resp.status(), resp.url().clone(), headers)
(resp.status(), resp.url().clone(), resp.headers_mut().clone())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this swap just for performance reasons? It needs to return headers_mut() response's directly since it will be mutated and isn't just a view "stuck in time".

}
6 changes: 3 additions & 3 deletions src/robots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Disallow: /",
let data = handler.finish();
assert_eq!(data.groups.len(), 1);

let client = reqwest::Client::new();
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
let request = client
.request(reqwest::Method::GET, "https://old.reddit.com/r/rust")
.build()
Expand All @@ -276,7 +276,7 @@ Disallow: /",
parse_robotstxt("", &mut handler);
let data = handler.finish();

let client = reqwest::Client::new();
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
let request = client
.request(reqwest::Method::GET, "https://old.reddit.com/r/rust")
.build()
Expand All @@ -294,7 +294,7 @@ Disallow: /r/rust",
);
let data = handler.finish();

let client = reqwest::Client::new();
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
let request = client
.request(reqwest::Method::GET, "https://old.reddit.com/r/crust")
.build()
Expand Down