From 5a3022373b75d3f748697c2c45c6707604481048 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Thu, 6 Jul 2023 18:13:00 +0200 Subject: [PATCH] refactor: Sync and Async Code Re-use Attempt to abstract sync and async requests by writing an abstraction over the flow of data that encompasses a request. The requests fill out a `RequestData` instance which then gets used as the source of future transformations. The login chain is a prime example of this. We define a series of steps that are expressed as a series of sequences which will finally produce the expected outcome. The Session now produces a `Sequence` implementation which needs to be driven by a client. Unfortunately, due to lack of Async Trait support, the async implementation is not as efficient as it could be. Attempt was made to test the code on nightly, but ran into this bug https://github.com/rust-lang/rust/pull/113108 --- Cargo.toml | 6 +- examples/user_id.rs | 22 +- examples/user_id_sync.rs | 19 +- go-gpa-server/build.rs | 1 + go-gpa-server/go/lib.go | 13 ++ go-gpa-server/src/lib.rs | 10 + go-srp/build.rs | 1 + src/clientv2/client.rs | 30 +-- src/clientv2/mod.rs | 2 - src/clientv2/request_repeater.rs | 223 ------------------ src/clientv2/session.rs | 372 +++++++++++++++++++------------ src/clientv2/totp.rs | 39 +--- src/domain/human_verification.rs | 2 +- src/http/client.rs | 157 +++++++++++++ src/http/mod.rs | 330 ++------------------------- src/http/proxy.rs | 39 ++++ src/http/request.rs | 112 ++++++++++ src/http/reqwest_client.rs | 101 ++++++--- src/http/response.rs | 92 ++++++++ src/http/sequence.rs | 275 +++++++++++++++++++++++ src/http/ureq_client.rs | 49 ++-- src/lib.rs | 26 ++- src/requests/auth.rs | 139 ++++++------ src/requests/event.rs | 16 +- src/requests/tests.rs | 7 +- src/requests/user.rs | 8 +- tests/session/login.rs | 104 ++++++++- tests/session/utils.rs | 6 +- 28 files changed, 1277 insertions(+), 924 deletions(-) delete mode 100644 src/clientv2/request_repeater.rs create mode 100644 src/http/client.rs create mode 100644 src/http/proxy.rs create mode 100644 src/http/request.rs create mode 100644 src/http/response.rs create mode 100644 src/http/sequence.rs diff --git a/Cargo.toml b/Cargo.toml index 0503179..08d9bbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "proton-api-rs" authors = ["Leander Beernaert "] -version = "0.10.2" +version = "0.11.0" edition = "2021" license = "AGPL-3.0-only" description = "Unofficial implemention of proton REST API in rust" @@ -30,6 +30,7 @@ ureq = {version="2.6", optional=true, features=["socks-proxy", "socks"]} default = [] http-ureq = ["dep:ureq"] http-reqwest = ["dep:reqwest"] +async-traits =[] [dependencies.reqwest] version = "0.11" @@ -40,7 +41,6 @@ optional = true [dev-dependencies] env_logger = "0.10" tokio = {version ="1", features = ["full"]} -httpmock = "0.6" go-gpa-server = {path= "go-gpa-server"} [[example]] @@ -53,5 +53,5 @@ required-features = ["http-ureq"] [[test]] name = "session" -required-features = ["http-ureq"] +required-features = ["http-ureq", "http-reqwest"] diff --git a/examples/user_id.rs b/examples/user_id.rs index bd78dd1..92877bd 100644 --- a/examples/user_id.rs +++ b/examples/user_id.rs @@ -1,4 +1,6 @@ -use proton_api_rs::{http, ping_async}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; +use proton_api_rs::{http, ping}; use proton_api_rs::{Session, SessionType}; pub use tokio; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; @@ -6,7 +8,7 @@ use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; #[tokio::main(worker_threads = 1)] async fn main() { let user_email = std::env::var("PAPI_USER_EMAIL").unwrap(); - let user_password = std::env::var("PAPI_USER_PASSWORD").unwrap(); + let user_password = SecretString::new(std::env::var("PAPI_USER_PASSWORD").unwrap()); let app_version = std::env::var("PAPI_APP_VERSION").unwrap(); let client = http::ClientBuilder::new() @@ -14,15 +16,16 @@ async fn main() { .build::() .unwrap(); - ping_async(&client).await.unwrap(); + ping().do_async(&client).await.unwrap(); - let session = match Session::login_async(&client, &user_email, &user_password, None, None) + let session = match Session::login(&user_email, &user_password, None) + .do_async(&client) .await .unwrap() { SessionType::Authenticated(c) => c, - SessionType::AwaitingTotp(mut t) => { + SessionType::AwaitingTotp(t) => { let mut stdout = tokio::io::stdout(); let mut line_reader = tokio::io::BufReader::new(tokio::io::stdin()).lines(); let session = { @@ -41,13 +44,12 @@ async fn main() { let totp = line.trim_end_matches('\n'); - match t.submit_totp_async(&client, totp).await { + match t.submit_totp(totp).do_async(&client).await { Ok(ac) => { session = Some(ac); break; } - Err((et, e)) => { - t = et; + Err(e) => { eprintln!("Failed to submit totp: {e}"); continue; } @@ -65,8 +67,8 @@ async fn main() { } }; - let user = session.get_user_async(&client).await.unwrap(); + let user = session.get_user().do_async(&client).await.unwrap(); println!("User ID is {}", user.id); - session.logout_async(&client).await.unwrap(); + session.logout().do_async(&client).await.unwrap(); } diff --git a/examples/user_id_sync.rs b/examples/user_id_sync.rs index 7de976f..c256519 100644 --- a/examples/user_id_sync.rs +++ b/examples/user_id_sync.rs @@ -1,4 +1,6 @@ use proton_api_rs::clientv2::{ping, SessionType}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; use proton_api_rs::{http, Session}; use std::io::{BufRead, Write}; @@ -6,7 +8,7 @@ fn main() { env_logger::init(); let user_email = std::env::var("PAPI_USER_EMAIL").unwrap(); - let user_password = std::env::var("PAPI_USER_PASSWORD").unwrap(); + let user_password = SecretString::new(std::env::var("PAPI_USER_PASSWORD").unwrap()); let app_version = std::env::var("PAPI_APP_VERSION").unwrap(); let client = http::ClientBuilder::new() @@ -15,12 +17,12 @@ fn main() { .build::() .unwrap(); - ping(&client).unwrap(); + ping().do_sync(&client).unwrap(); - let login_result = Session::login(&client, &user_email, &user_password, None, None); + let login_result = Session::login(&user_email, &user_password, None).do_sync(&client); let session = match login_result.unwrap() { SessionType::Authenticated(s) => s, - SessionType::AwaitingTotp(mut t) => { + SessionType::AwaitingTotp(t) => { let mut line_reader = std::io::BufReader::new(std::io::stdin()); let session = { let mut session = None; @@ -38,13 +40,12 @@ fn main() { let totp = line.trim_end_matches('\n'); - match t.submit_totp(&client, totp) { + match t.submit_totp(totp).do_sync(&client) { Ok(ac) => { session = Some(ac); break; } - Err((et, e)) => { - t = et; + Err(e) => { eprintln!("Failed to submit totp: {e}"); continue; } @@ -62,8 +63,8 @@ fn main() { } }; - let user = session.get_user(&client).unwrap(); + let user = session.get_user().do_sync(&client).unwrap(); println!("User ID is {}", user.id); - session.logout(&client).unwrap(); + session.logout().do_sync(&client).unwrap(); } diff --git a/go-gpa-server/build.rs b/go-gpa-server/build.rs index 7dae73f..2c8e6b2 100644 --- a/go-gpa-server/build.rs +++ b/go-gpa-server/build.rs @@ -27,6 +27,7 @@ fn target_path_for_go_lib() -> (PathBuf, PathBuf) { fn build_go_lib(lib_path: &Path) { let mut command = Command::new("go"); + #[cfg(any(target_os= "linux",target_os = "android"))] command.env("CGO_LDFLAGS", "-Wl,--build-id=none"); command.arg("build"); command.arg("-ldflags=-buildid="); diff --git a/go-gpa-server/go/lib.go b/go-gpa-server/go/lib.go index ac7176c..e4f55e5 100644 --- a/go-gpa-server/go/lib.go +++ b/go-gpa-server/go/lib.go @@ -9,6 +9,7 @@ typedef const char cchar_t; import "C" import ( "sync" + "time" "unsafe" "github.com/ProtonMail/go-proton-api/server" @@ -104,6 +105,18 @@ func gpaCreateUser(h int, cuser *C.cchar_t, cpassword *C.cchar_t, outUserID **C. return 0 } +//export gpaSetAuthLife +func gpaSetAuthLife(h int, seconds int) int { + srv := alloc.resolve(h) + if srv == nil { + return -1 + } + + srv.SetAuthLife(time.Duration(seconds) * time.Second) + + return 0 +} + //export CStrFree func CStrFree(ptr *C.char) { C.free(unsafe.Pointer(ptr)) diff --git a/go-gpa-server/src/lib.rs b/go-gpa-server/src/lib.rs index c4ab7f5..22ded0b 100644 --- a/go-gpa-server/src/lib.rs +++ b/go-gpa-server/src/lib.rs @@ -68,6 +68,16 @@ impl Server { )) } } + + pub fn set_auth_timeout(&self, duration: std::time::Duration) -> Result<()> { + unsafe { + if go::gpaSetAuthLife(self.0, duration.as_secs() as i64) < 0 { + return Err("Failed to set auth timeout".to_string()); + } + + Ok(()) + } + } } impl Drop for Server { diff --git a/go-srp/build.rs b/go-srp/build.rs index 3b04330..fd1af67 100644 --- a/go-srp/build.rs +++ b/go-srp/build.rs @@ -74,6 +74,7 @@ fn target_path_for_go_lib(platform: Platform) -> (PathBuf, PathBuf) { fn build_go_lib(lib_path: &Path, platform: Platform) { let mut command = Command::new("go"); + #[cfg(any(target_os= "linux",target_os = "android"))] command.env("CGO_LDFLAGS", "-Wl,--build-id=none"); match platform { Platform::Desktop => {} diff --git a/src/clientv2/client.rs b/src/clientv2/client.rs index a72dc9d..883f913 100644 --- a/src/clientv2/client.rs +++ b/src/clientv2/client.rs @@ -1,30 +1,10 @@ -use crate::http; -use crate::http::Request; +use crate::http::{Request, RequestDesc}; use crate::requests::{CaptchaRequest, Ping}; -pub fn ping(client: &T) -> Result<(), http::Error> { - Ping.execute_sync::(client, &http::DefaultRequestFactory {}) +pub fn ping() -> impl Request { + Ping.to_request() } -pub async fn ping_async(client: &T) -> Result<(), http::Error> { - Ping.execute_async::(client, &http::DefaultRequestFactory {}) - .await -} - -pub fn captcha_get( - client: &T, - token: &str, - force_web: bool, -) -> Result { - CaptchaRequest::new(token, force_web).execute_sync(client, &http::DefaultRequestFactory {}) -} - -pub async fn captcha_get_async( - client: &T, - token: &str, - force_web: bool, -) -> Result { - CaptchaRequest::new(token, force_web) - .execute_async(client, &http::DefaultRequestFactory {}) - .await +pub fn captcha_get(token: &str, force_web: bool) -> impl Request { + CaptchaRequest::new(token, force_web).to_request() } diff --git a/src/clientv2/mod.rs b/src/clientv2/mod.rs index 29456b2..1f3e450 100644 --- a/src/clientv2/mod.rs +++ b/src/clientv2/mod.rs @@ -1,9 +1,7 @@ mod client; -mod request_repeater; mod session; mod totp; pub use client::*; -pub use request_repeater::*; pub use session::*; pub use totp::*; diff --git a/src/clientv2/request_repeater.rs b/src/clientv2/request_repeater.rs deleted file mode 100644 index 3a1e6d3..0000000 --- a/src/clientv2/request_repeater.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! Automatic request repeater based on the expectations Proton has for their clients. - -use crate::domain::{SecretString, UserUid}; -use crate::http::{ - ClientAsync, ClientSync, DefaultRequestFactory, Method, Request, RequestData, RequestFactory, -}; -use crate::requests::{AuthRefreshRequest, UserAuth}; -use crate::{http, SessionRefreshData}; -use secrecy::{ExposeSecret, Secret}; - -pub trait OnAuthRefreshed: Send + Sync { - fn on_auth_refreshed(&self, user: &Secret, token: &SecretString); -} - -pub struct RequestRepeater { - user_auth: parking_lot::RwLock, - on_auth_refreshed: Option>, -} - -impl std::fmt::Debug for RequestRepeater { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "RequestRepeater{{user_auth:{:?} on_auth_refreshed:{}}}", - self.user_auth, - if self.on_auth_refreshed.is_some() { - "Some" - } else { - "None" - } - ) - } -} - -impl RequestRepeater { - pub fn new(user_auth: UserAuth, on_auth_refreshed: Option>) -> Self { - Self { - user_auth: parking_lot::RwLock::new(user_auth), - on_auth_refreshed, - } - } - - fn refresh_auth(&self, client: &C) -> http::Result<()> { - let mut borrow = self.user_auth.write(); - match AuthRefreshRequest::new( - borrow.uid.expose_secret(), - borrow.refresh_token.expose_secret(), - ) - .execute_sync(client, &DefaultRequestFactory {}) - { - Ok(s) => { - *borrow = UserAuth::from_auth_refresh_response(&s); - if let Some(cb) = &self.on_auth_refreshed { - cb.on_auth_refreshed(&borrow.uid, &borrow.access_token); - } - Ok(()) - } - Err(e) => Err(e), - } - } - - async fn refresh_auth_async(&self, client: &C) -> http::Result<()> { - // Have to clone here due to async boundaries. - let user_auth = { self.user_auth.read().clone() }; - match AuthRefreshRequest::new( - user_auth.uid.expose_secret(), - user_auth.refresh_token.expose_secret(), - ) - .execute_async(client, &DefaultRequestFactory {}) - .await - { - Ok(s) => { - let mut borrow = self.user_auth.write(); - *borrow = UserAuth::from_auth_refresh_response(&s); - if let Some(cb) = &self.on_auth_refreshed { - cb.on_auth_refreshed(&borrow.uid, &borrow.access_token); - } - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn execute( - &self, - client: &C, - request: R, - ) -> http::Result { - match request.execute_sync(client, self) { - Ok(r) => Ok(r), - Err(original_error) => { - if let http::Error::API(api_err) = &original_error { - if api_err.http_code == 401 { - log::debug!("Account session expired, attempting refresh"); - // Session expired/not authorized, try auth refresh. - if let Err(e) = self.refresh_auth(client) { - log::error!("Failed to refresh account {e}"); - return Err(original_error); - } - - // Execute request again - return request.execute_sync(client, self); - } - } - Err(original_error) - } - } - } - - pub async fn execute_async<'a, C: ClientAsync, R: Request + 'a>( - &'a self, - client: &'a C, - request: R, - ) -> http::Result { - match request.execute_async(client, self).await { - Ok(r) => Ok(r), - Err(original_error) => { - if let http::Error::API(api_err) = &original_error { - log::debug!("Account session expired, attempting refresh"); - if api_err.http_code == 401 { - // Session expired/not authorized, try auth refresh. - if let Err(e) = self.refresh_auth_async(client).await { - log::error!("Failed to refresh account {e}"); - return Err(original_error); - } - - // Execute request again - return request.execute_async(client, self).await; - } - } - Err(original_error) - } - } - } - - pub fn get_refresh_data(&self) -> SessionRefreshData { - let borrow = self.user_auth.read(); - SessionRefreshData { - user_uid: borrow.uid.clone(), - token: borrow.refresh_token.clone(), - } - } -} - -impl RequestFactory for RequestRepeater { - fn new_request(&self, method: Method, url: &str) -> RequestData { - let accessor = self.user_auth.read(); - RequestData::new(method, url) - .header(http::X_PM_UID_HEADER, &accessor.uid.expose_secret().0) - .bearer_token(accessor.access_token.expose_secret()) - } -} - -#[cfg(test)] -mod test { - - #[test] - #[cfg(feature = "http-ureq")] - fn request_repeats_with_401() { - use crate::domain::{EventId, SecretString, UserUid}; - use crate::http::X_PM_UID_HEADER; - use crate::requests::{GetLatestEventRequest, UserAuth}; - use crate::RequestRepeater; - use httpmock::prelude::*; - use secrecy::Secret; - - let server = MockServer::start(); - let url = server.base_url(); - - let client = crate::http::ClientBuilder::new() - .allow_http() - .base_url(&url) - .build::() - .unwrap(); - - let repeater = RequestRepeater::new( - UserAuth { - uid: Secret::new(UserUid("test-uid".to_string())), - access_token: SecretString::new("secret-token".to_string()), - refresh_token: SecretString::new("refresh-token".to_string()), - }, - None, - ); - - let expected_latest_event_id = EventId("My_Event_Id".to_string()); - - let latest_event_first_call = server.mock(|when, then| { - when.method(GET) - .path("/core/v4/events/latest") - .header(X_PM_UID_HEADER, "test-uid"); - then.status(401); - }); - - let latest_event_second_call = server.mock(|when, then| { - when.method(GET) - .path("/core/v4/events/latest") - .header(X_PM_UID_HEADER, "User_UID"); - then.status(200) - .body(format!(r#"{{"EventID":"{}"}}"#, expected_latest_event_id.0)); - }); - - let refresh_mock = server.mock(|when, then| { - when.method(POST).path("/auth/v4/refresh"); - - let response = r#"{ - "UID": "User_UID", - "TokenType": "type", - "AccessToken": "access-token", - "RefreshToken": "refresh-token", - "Scope": "Scope" -}"#; - - then.status(200).body(response); - }); - - let latest_event = repeater.execute(&client, GetLatestEventRequest {}).unwrap(); - assert_eq!(latest_event.event_id, expected_latest_event_id); - - latest_event_first_call.assert(); - refresh_mock.assert(); - latest_event_second_call.assert(); - } -} diff --git a/src/clientv2/session.rs b/src/clientv2/session.rs index 9cc32d0..bfebbbf 100644 --- a/src/clientv2/session.rs +++ b/src/clientv2/session.rs @@ -1,16 +1,25 @@ -use crate::clientv2::request_repeater::RequestRepeater; use crate::clientv2::TotpSession; use crate::domain::{ - Event, EventId, HumanVerification, HumanVerificationLoginData, TwoFactorAuth, User, UserUid, + EventId, HumanVerification, HumanVerificationLoginData, SecretString, TwoFactorAuth, User, + UserUid, +}; +use crate::http; +use crate::http::{ + ClientAsync, ClientRequest, ClientRequestBuilder, ClientSync, FromResponse, Request, + RequestDesc, Sequence, StateProducerSequence, X_PM_UID_HEADER, }; -use crate::http::{DefaultRequestFactory, Request}; use crate::requests::{ AuthInfoRequest, AuthInfoResponse, AuthRefreshRequest, AuthRequest, AuthResponse, - GetEventRequest, GetLatestEventRequest, LogoutRequest, TFAStatus, UserAuth, UserInfoRequest, + GetEventRequest, GetLatestEventRequest, LogoutRequest, TFAStatus, TOTPRequest, UserAuth, + UserInfoRequest, }; -use crate::{http, OnAuthRefreshed}; use go_srp::SRPAuth; -use secrecy::Secret; +use secrecy::{ExposeSecret, Secret}; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; +use std::sync::Arc; #[derive(Debug, thiserror::Error)] pub enum LoginError { @@ -46,174 +55,80 @@ pub enum SessionType { /// users. #[derive(Debug)] pub struct Session { - pub(super) repeater: RequestRepeater, + pub(super) user_auth: Arc>, } impl Session { - fn new(user: UserAuth, on_auth_refreshed_cb: Option>) -> Self { + fn new(user: UserAuth) -> Self { Self { - repeater: RequestRepeater::new(user, on_auth_refreshed_cb), + user_auth: Arc::new(parking_lot::RwLock::new(user)), } } - pub fn login( - client: &T, - username: &str, - password: &str, + pub fn login<'a>( + username: &'a str, + password: &'a SecretString, human_verification: Option, - on_auth_refreshed: Option>, - ) -> Result { - let auth_info_response = - AuthInfoRequest { username }.execute_sync::(client, &DefaultRequestFactory {})?; - - let proof = generate_session_proof(username, password, &auth_info_response)?; - - let auth_response = AuthRequest { + ) -> impl Sequence<'a, Output = SessionType, Error = LoginError> + 'a { + let state = State { username, - client_ephemeral: &proof.client_ephemeral, - client_proof: &proof.client_proof, - srp_session: auth_info_response.srp_session.as_ref(), - human_verification, - } - .execute_sync::(client, &DefaultRequestFactory {}) - .map_err(map_human_verification_err)?; - - validate_server_proof(&proof, &auth_response, on_auth_refreshed) - } - - pub async fn login_async( - client: &T, - username: &str, - password: &str, - human_verification: Option, - on_auth_refreshed: Option>, - ) -> Result { - let auth_info_response = AuthInfoRequest { username } - .execute_async::(client, &DefaultRequestFactory {}) - .await?; - - let proof = generate_session_proof(username, password, &auth_info_response)?; - - let auth_response = AuthRequest { - username, - client_ephemeral: &proof.client_ephemeral, - client_proof: &proof.client_proof, - srp_session: auth_info_response.srp_session.as_ref(), - human_verification, - } - .execute_async::(client, &DefaultRequestFactory {}) - .await - .map_err(map_human_verification_err)?; - - validate_server_proof(&proof, &auth_response, on_auth_refreshed) - } - - pub async fn refresh_async( - client: &T, - user_uid: &UserUid, - token: &str, - on_auth_refreshed: Option>, - ) -> http::Result { - let refresh_response = AuthRefreshRequest::new(user_uid, token) - .execute_async(client, &DefaultRequestFactory {}) - .await?; - let user = UserAuth::from_auth_refresh_response(&refresh_response); - Ok(Session::new(user, on_auth_refreshed)) - } - - pub fn refresh( - client: &T, - user_uid: &UserUid, - token: &str, - on_auth_refreshed: Option>, - ) -> http::Result { - let refresh_response = AuthRefreshRequest::new(user_uid, token) - .execute_sync(client, &DefaultRequestFactory {})?; - let user = UserAuth::from_auth_refresh_response(&refresh_response); - Ok(Session::new(user, on_auth_refreshed)) - } - - pub fn get_user(&self, client: &T) -> Result { - let user = self.repeater.execute(client, UserInfoRequest {})?; - Ok(user.user) - } + password, + hv: human_verification, + }; - pub async fn get_user_async( - &self, - client: &T, - ) -> Result { - let user = self - .repeater - .execute_async(client, UserInfoRequest {}) - .await?; - Ok(user.user) + StateProducerSequence::new(state, login_sequence_1) } - pub fn logout(&self, client: &T) -> Result<(), http::Error> { - LogoutRequest {}.execute_sync::(client, &self.repeater) + pub fn submit_totp(&self, code: &str) -> impl Sequence { + self.wrap_request(TOTPRequest::new(code).to_request()) } - pub async fn logout_async(&self, client: &T) -> Result<(), http::Error> { - LogoutRequest {} - .execute_async::(client, &self.repeater) - .await + pub fn refresh<'a>( + user_uid: &'a UserUid, + token: &'a str, + ) -> impl Sequence<'a, Output = Self, Error = http::Error> + 'a { + AuthRefreshRequest::new(user_uid, token) + .to_request() + .map(|r| { + let user = UserAuth::from_auth_refresh_response(r); + Ok(Session::new(user)) + }) } - pub fn get_latest_event(&self, client: &T) -> http::Result { - let r = self.repeater.execute(client, GetLatestEventRequest {})?; - Ok(r.event_id) + pub fn get_user(&self) -> impl Sequence { + self.wrap_request(UserInfoRequest {}.to_request()) + .map(|r| -> Result { Ok(r.user) }) } - pub async fn get_latest_event_async( - &self, - client: &T, - ) -> http::Result { - let r = self - .repeater - .execute_async(client, GetLatestEventRequest {}) - .await?; - Ok(r.event_id) + pub fn logout(&self) -> impl Sequence { + self.wrap_request(LogoutRequest {}.to_request()) } - pub fn get_event(&self, client: &T, id: &EventId) -> http::Result { - self.repeater.execute(client, GetEventRequest::new(id)) + pub fn get_latest_event(&self) -> impl Request { + self.wrap_request(GetLatestEventRequest {}.to_request()) } - pub async fn get_event_async( - &self, - client: &T, - id: &EventId, - ) -> http::Result { - self.repeater - .execute_async(client, GetEventRequest::new(id)) - .await + pub fn get_event(&self, id: &EventId) -> impl Request { + self.wrap_request(GetEventRequest::new(id).to_request()) } pub fn get_refresh_data(&self) -> SessionRefreshData { - self.repeater.get_refresh_data() + let reader = self.user_auth.read(); + SessionRefreshData { + user_uid: reader.uid.clone(), + token: reader.refresh_token.clone(), + } } -} -fn generate_session_proof( - username: &str, - password: &str, - auth_info_response: &AuthInfoResponse, -) -> Result { - SRPAuth::generate( - username, - password, - auth_info_response.version, - &auth_info_response.salt, - &auth_info_response.modulus, - &auth_info_response.server_ephemeral, - ) - .map_err(LoginError::ServerProof) + #[inline(always)] + fn wrap_request(&self, r: R) -> SessionRequest { + SessionRequest(r, self.user_auth.clone()) + } } fn validate_server_proof( proof: &SRPAuth, - auth_response: &AuthResponse, - on_auth_refreshed: Option>, + auth_response: AuthResponse, ) -> Result { if proof.expected_server_proof != auth_response.server_proof { return Err(LoginError::ServerProof( @@ -221,11 +136,12 @@ fn validate_server_proof( )); } + let tfa_enabled = auth_response.tfa.enabled; let user = UserAuth::from_auth_response(auth_response); - let session = Session::new(user, on_auth_refreshed); + let session = Session::new(user); - match auth_response.tfa.enabled { + match tfa_enabled { TFAStatus::None => Ok(SessionType::Authenticated(session)), TFAStatus::Totp => Ok(SessionType::AwaitingTotp(TotpSession(session))), TFAStatus::FIDO2 => Err(LoginError::Unsupported2FA(TwoFactorAuth::FIDO2)), @@ -233,12 +149,172 @@ fn validate_server_proof( } } -fn map_human_verification_err(e: http::Error) -> LoginError { - if let http::Error::API(e) = &e { +fn map_human_verification_err(e: LoginError) -> LoginError { + if let LoginError::Request(http::Error::API(e)) = &e { if let Ok(hv) = e.try_get_human_verification_details() { return LoginError::HumanVerificationRequired(hv); } } - LoginError::from(e) + e +} + +pub struct SessionRequest(R, Arc>); + +impl SessionRequest { + fn refresh_auth(&self) -> impl Sequence<'_, Output = (), Error = http::Error> + '_ { + let reader = self.1.read(); + AuthRefreshRequest::new( + reader.uid.expose_secret(), + reader.refresh_token.expose_secret(), + ) + .to_request() + .map(|resp| { + let mut writer = self.1.write(); + *writer = UserAuth::from_auth_refresh_response(resp); + Ok(()) + }) + } + + async fn exec_async_impl<'a, C: ClientAsync, F: FromResponse>( + &'a self, + client: &'a C, + ) -> Result { + let v = self.build(client); + match client.execute_async::(v).await { + Ok(r) => Ok(r), + Err(original_error) => { + if let http::Error::API(api_err) = &original_error { + if api_err.http_code == 401 { + log::debug!("Account session expired, attempting refresh"); + // Session expired/not authorized, try auth refresh. + if let Err(e) = self.refresh_auth().do_async(client).await { + log::error!("Failed to refresh account {e}"); + return Err(original_error); + } + + // Execute request again + return client.execute_async::(self.build(client)).await; + } + } + Err(original_error) + } + } + } +} + +impl Request for SessionRequest { + type Response = R::Response; + + fn build(&self, builder: &C) -> C::Request { + let r = self.0.build(builder); + let borrow = self.1.read(); + r.header(X_PM_UID_HEADER, borrow.uid.expose_secret().as_str()) + .bearer_token(borrow.access_token.expose_secret()) + } + + fn exec_sync( + &self, + client: &T, + ) -> Result<::Output, http::Error> { + match client.execute::(self.build(client)) { + Ok(r) => Ok(r), + Err(original_error) => { + if let http::Error::API(api_err) = &original_error { + if api_err.http_code == 401 { + log::debug!("Account session expired, attempting refresh"); + // Session expired/not authorized, try auth refresh. + if let Err(e) = self.refresh_auth().do_sync(client) { + log::error!("Failed to refresh account {e}"); + return Err(original_error); + } + + // Execute request again + return client.execute::(self.build(client)); + } + } + Err(original_error) + } + } + } + + #[cfg(not(feature = "async-traits"))] + fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Pin< + Box< + dyn Future::Output, http::Error>> + 'a, + >, + > { + Box::pin(async move { self.exec_async_impl::(client).await }) + } + + #[cfg(feature = "async-traits")] + async fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Result<::Output, http::Error> { + self.exec_async_impl::(client).await + } +} + +struct State<'a> { + username: &'a str, + password: &'a SecretString, + hv: Option, +} + +struct LoginState<'a> { + username: &'a str, + proof: SRPAuth, + session: String, + hv: Option, +} + +fn generate_login_state( + state: State, + auth_info_response: AuthInfoResponse, +) -> Result { + let proof = SRPAuth::generate( + state.username, + state.password.expose_secret(), + auth_info_response.version, + &auth_info_response.salt, + &auth_info_response.modulus, + &auth_info_response.server_ephemeral, + ) + .map_err(LoginError::ServerProof)?; + + Ok(LoginState { + username: state.username, + proof, + session: auth_info_response.srp_session, + hv: state.hv, + }) +} + +fn login_sequence_2( + login_state: LoginState, +) -> impl Sequence<'_, Output = SessionType, Error = LoginError> + '_ { + AuthRequest { + username: login_state.username, + client_ephemeral: &login_state.proof.client_ephemeral, + client_proof: &login_state.proof.client_proof, + srp_session: &login_state.session, + human_verification: &login_state.hv, + } + .to_request() + .map(move |auth_response| { + validate_server_proof(&login_state.proof, auth_response).map_err(map_human_verification_err) + }) +} + +fn login_sequence_1(st: State) -> impl Sequence<'_, Output = SessionType, Error = LoginError> + '_ { + AuthInfoRequest { + username: st.username, + } + .to_request() + .map(move |auth_info_response| generate_login_state(st, auth_info_response)) + .state(login_sequence_2) } diff --git a/src/clientv2/totp.rs b/src/clientv2/totp.rs index e77567f..187a034 100644 --- a/src/clientv2/totp.rs +++ b/src/clientv2/totp.rs @@ -1,42 +1,19 @@ use crate::clientv2::Session; use crate::http; -use crate::http::Request; -use crate::requests::TOTPRequest; +use crate::http::Sequence; #[derive(Debug)] pub struct TotpSession(pub(super) Session); impl TotpSession { - pub fn submit_totp( - self, - client: &T, - code: &str, - ) -> Result { - match TOTPRequest::new(code).execute_sync(client, &self.0.repeater) { - Err(e) => Err((self, e)), - Ok(_) => Ok(self.0), - } + pub fn submit_totp(&self, code: &str) -> impl Sequence { + let auth = self.0.user_auth.clone(); + self.0 + .submit_totp(code) + .map(move |_| Ok(Session { user_auth: auth })) } - pub async fn submit_totp_async( - self, - client: &T, - code: &str, - ) -> Result { - match TOTPRequest::new(code) - .execute_async(client, &self.0.repeater) - .await - { - Err(e) => Err((self, e)), - Ok(_) => Ok(self.0), - } - } - - pub fn logout(&self, client: &T) -> http::Result<()> { - self.0.logout(client) - } - - pub async fn logout_async(&self, client: &T) -> http::Result<()> { - self.0.logout_async(client).await + pub fn logout(&self) -> impl Sequence { + self.0.logout() } } diff --git a/src/domain/human_verification.rs b/src/domain/human_verification.rs index cc9866c..995e473 100644 --- a/src/domain/human_verification.rs +++ b/src/domain/human_verification.rs @@ -36,7 +36,7 @@ impl std::fmt::Display for HumanVerificationType { } /// Human Verification data required for Login. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct HumanVerificationLoginData { /// Type of human verification where the code originated from. pub hv_type: HumanVerificationType, diff --git a/src/http/client.rs b/src/http/client.rs new file mode 100644 index 0000000..7ca8d98 --- /dev/null +++ b/src/http/client.rs @@ -0,0 +1,157 @@ +use crate::http::{Proxy, RequestData, Result, DEFAULT_APP_VERSION, DEFAULT_HOST_URL}; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; +use std::time::Duration; + +/// Builder for an http client +#[derive(Debug, Clone)] +pub struct ClientBuilder { + pub(super) app_version: String, + pub(super) base_url: String, + pub(super) request_timeout: Option, + pub(super) connect_timeout: Option, + pub(super) user_agent: String, + pub(super) proxy_url: Option, + pub(super) debug: bool, + pub(super) allow_http: bool, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + pub fn new() -> Self { + Self { + app_version: DEFAULT_APP_VERSION.to_string(), + user_agent: "NoClient/0.1.0".to_string(), + base_url: DEFAULT_HOST_URL.to_string(), + request_timeout: None, + connect_timeout: None, + proxy_url: None, + debug: false, + allow_http: false, + } + } + + /// Set the app version for this client e.g.: my-client@1.4.0+beta. + /// Note: The default app version is not guaranteed to be accepted by the proton servers. + pub fn app_version(mut self, version: &str) -> Self { + self.app_version = version.to_string(); + self + } + + /// Set the user agent to be submitted with every request. + pub fn user_agent(mut self, agent: &str) -> Self { + self.user_agent = agent.to_string(); + self + } + + /// Set server's base url. By default the proton API server url is used. + pub fn base_url(mut self, url: &str) -> Self { + self.base_url = url.to_string(); + self + } + + /// Set the full request timeout. By default there is no timeout. + pub fn request_timeout(mut self, duration: Duration) -> Self { + self.request_timeout = Some(duration); + self + } + + /// Set the connection timeout. By default there is no timeout. + pub fn connect_timeout(mut self, duration: Duration) -> Self { + self.connect_timeout = Some(duration); + self + } + + /// Specify proxy URL for the builder. + pub fn with_proxy(mut self, proxy: Proxy) -> Self { + self.proxy_url = Some(proxy); + self + } + + /// Allow http request + pub fn allow_http(mut self) -> Self { + self.allow_http = true; + self + } + + /// Enable request debugging. + pub fn debug(mut self) -> Self { + self.debug = true; + self + } + + pub fn build>( + self, + ) -> std::result::Result { + T::try_from(self) + } +} + +pub trait ClientRequest: Sized { + fn header(self, key: impl AsRef, value: impl AsRef) -> Self; + + fn bearer_token(self, token: impl AsRef) -> Self { + self.header("authorization", format!("Bearer {}", token.as_ref())) + } +} + +pub trait ClientRequestBuilder { + type Request: ClientRequest; + fn new_request(&self, data: &RequestData) -> Self::Request; +} + +/// HTTP Client abstraction Sync. +pub trait ClientSync: ClientRequestBuilder + TryFrom { + fn execute(&self, request: Self::Request) -> Result; +} + +/// HTTP Client abstraction Async. +pub trait ClientAsync: + ClientRequestBuilder + TryFrom +{ + #[cfg(not(feature = "async-traits"))] + fn execute_async( + &self, + request: Self::Request, + ) -> Pin> + '_>>; + + #[cfg(feature = "async-traits")] + async fn execute_async(&self, request: Self::Request) -> Result; +} + +pub trait ResponseBodySync { + type Body: AsRef<[u8]>; + fn get_body(self) -> Result; +} + +pub trait ResponseBodyAsync { + type Body: AsRef<[u8]>; + + #[cfg(not(feature = "async-traits"))] + fn get_body_async(self) -> Pin>>>; + + #[cfg(feature = "async-traits")] + async fn get_body_async(self) -> Result; +} + +pub trait FromResponse { + type Output; + fn from_response_sync(response: T) -> Result; + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: T, + ) -> Pin>>>; + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: T, + ) -> Result; +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 8971d15..36ec496 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,15 +1,7 @@ //! Basic HTTP Protocol abstraction for the Proton API. -use crate::domain::SecretString; use anyhow; -use secrecy::ExposeSecret; -use serde::de::DeserializeOwned; -use serde::Serialize; -use std::collections::HashMap; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::time::Duration; +use std::fmt::Debug; use thiserror::Error; #[cfg(feature = "http-ureq")] @@ -18,6 +10,18 @@ pub mod ureq_client; #[cfg(feature = "http-reqwest")] pub mod reqwest_client; +mod client; +mod proxy; +mod request; +mod response; +mod sequence; + +pub use client::*; +pub use proxy::*; +pub use request::*; +pub use response::*; +pub use sequence::*; + pub(crate) const DEFAULT_HOST_URL: &str = "https://mail.proton.me/api"; pub(crate) const DEFAULT_APP_VERSION: &str = "proton-api-rs"; #[allow(unused)] // it is used by the http implementations @@ -36,52 +40,6 @@ pub enum Method { Patch, } -/// HTTP Request representation. -#[derive(Debug)] -pub struct RequestData { - #[allow(unused)] // Only used by http implementations. - pub(super) method: Method, - #[allow(unused)] // Only used by http implementations. - pub(super) url: String, - pub(super) headers: HashMap, - pub(super) body: Option>, -} - -impl RequestData { - pub fn new(method: Method, url: impl Into) -> Self { - Self { - method, - url: url.into(), - headers: HashMap::new(), - body: None, - } - } - - pub fn header(mut self, key: impl Into, value: impl Into) -> Self { - self.headers.insert(key.into(), value.into()); - self - } - - pub fn bearer_token(self, token: &str) -> Self { - self.header("authorization", format!("Bearer {token}")) - } - - pub fn bytes(mut self, bytes: Vec) -> Self { - self.body = Some(bytes); - self - } - - pub fn json(self, value: impl Serialize) -> Self { - let bytes = serde_json::to_vec(&value).expect("Failed to serialize json"); - self.json_bytes(bytes) - } - - pub fn json_bytes(mut self, bytes: Vec) -> Self { - self.body = Some(bytes); - self.header("Content-Type", "application/json") - } -} - /// Errors that may occur during an HTTP request, mostly related to network. #[derive(Debug, Error)] pub enum Error { @@ -108,265 +66,3 @@ impl From for Error { } pub type Result = std::result::Result; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum ProxyProtocol { - Https, - Socks5, -} - -#[derive(Debug, Clone)] -pub struct ProxyAuth { - pub username: String, - pub password: SecretString, -} - -#[derive(Debug, Clone)] -pub struct Proxy { - pub protocol: ProxyProtocol, - pub auth: Option, - pub url: String, - pub port: u16, -} - -impl Proxy { - pub fn as_url(&self) -> String { - let protocol = match self.protocol { - ProxyProtocol::Https => "https", - ProxyProtocol::Socks5 => "socks5", - }; - - let auth = if let Some(auth) = &self.auth { - format!("{}:{}@", auth.username, auth.password.expose_secret()) - } else { - String::new() - }; - - format!("{protocol}://{auth}{}:{}", self.url, self.port) - } -} - -/// Builder for an http client -#[derive(Debug, Clone)] -pub struct ClientBuilder { - app_version: String, - base_url: String, - request_timeout: Option, - connect_timeout: Option, - user_agent: String, - proxy_url: Option, - debug: bool, - allow_http: bool, -} - -impl Default for ClientBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ClientBuilder { - pub fn new() -> Self { - Self { - app_version: DEFAULT_APP_VERSION.to_string(), - user_agent: "NoClient/0.1.0".to_string(), - base_url: DEFAULT_HOST_URL.to_string(), - request_timeout: None, - connect_timeout: None, - proxy_url: None, - debug: false, - allow_http: false, - } - } - - /// Set the app version for this client e.g.: my-client@1.4.0+beta. - /// Note: The default app version is not guaranteed to be accepted by the proton servers. - pub fn app_version(mut self, version: &str) -> Self { - self.app_version = version.to_string(); - self - } - - /// Set the user agent to be submitted with every request. - pub fn user_agent(mut self, agent: &str) -> Self { - self.user_agent = agent.to_string(); - self - } - - /// Set server's base url. By default the proton API server url is used. - pub fn base_url(mut self, url: &str) -> Self { - self.base_url = url.to_string(); - self - } - - /// Set the full request timeout. By default there is no timeout. - pub fn request_timeout(mut self, duration: Duration) -> Self { - self.request_timeout = Some(duration); - self - } - - /// Set the connection timeout. By default there is no timeout. - pub fn connect_timeout(mut self, duration: Duration) -> Self { - self.connect_timeout = Some(duration); - self - } - - /// Specify proxy URL for the builder. - pub fn with_proxy(mut self, proxy: Proxy) -> Self { - self.proxy_url = Some(proxy); - self - } - - /// Allow http request - pub fn allow_http(mut self) -> Self { - self.allow_http = true; - self - } - - /// Enable request debugging. - pub fn debug(mut self) -> Self { - self.debug = true; - self - } - - pub fn build>( - self, - ) -> std::result::Result { - T::try_from(self) - } -} - -/// Abstraction for request creation, this can enable wrapping of request creations to add -/// session token or other headers. -pub trait RequestFactory { - fn new_request(&self, method: Method, url: &str) -> RequestData; -} - -/// Default request factory, creates basic requests. -#[derive(Copy, Clone)] -pub struct DefaultRequestFactory {} - -impl RequestFactory for DefaultRequestFactory { - fn new_request(&self, method: Method, url: &str) -> RequestData { - RequestData::new(method, url) - } -} - -pub trait ResponseBodySync { - type Body: AsRef<[u8]>; - fn get_body(self) -> Result; -} - -pub trait ResponseBodyAsync { - type Body: AsRef<[u8]>; - fn get_body_async(self) -> Pin>>>; -} - -pub trait FromResponse { - type Output; - fn from_response_sync(response: T) -> Result; - - fn from_response_async( - response: T, - ) -> Pin>>>; -} - -#[derive(Copy, Clone)] -pub struct NoResponse {} - -impl FromResponse for NoResponse { - type Output = (); - - fn from_response_sync(_: T) -> Result { - Ok(()) - } - - fn from_response_async( - _: T, - ) -> Pin>>> { - Box::pin(async { Ok(()) }) - } -} - -pub struct JsonResponse(PhantomData); - -impl FromResponse for JsonResponse { - type Output = T; - - fn from_response_sync(response: R) -> Result { - let body = response.get_body()?; - let r = serde_json::from_slice(body.as_ref())?; - Ok(r) - } - - fn from_response_async( - response: R, - ) -> Pin>>> { - Box::pin(async move { - let body = response.get_body_async().await?; - let r = serde_json::from_slice(body.as_ref())?; - Ok(r) - }) - } -} - -#[derive(Copy, Clone)] -pub struct StringResponse {} - -impl FromResponse for StringResponse { - type Output = String; - - fn from_response_sync(response: R) -> Result { - let body = response.get_body()?; - Ok(String::from_utf8_lossy(body.as_ref()).to_string()) - } - - fn from_response_async( - response: R, - ) -> Pin>>> { - Box::pin(async move { - let body = response.get_body_async().await?; - Ok(String::from_utf8_lossy(body.as_ref()).to_string()) - }) - } -} - -pub trait Request { - type Output: Sized; - type Response: FromResponse; - - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData; - - fn execute_sync( - &self, - client: &T, - factory: &dyn RequestFactory, - ) -> Result { - client.execute(self, factory) - } - - fn execute_async( - &self, - client: &T, - factory: &dyn RequestFactory, - ) -> Pin>>> { - client.execute_async(self, factory) - } -} - -/// HTTP Client abstraction Sync. -pub trait ClientSync: TryFrom { - fn execute( - &self, - request: &R, - factory: &dyn RequestFactory, - ) -> Result; -} - -/// HTTP Client abstraction Async. -pub trait ClientAsync: TryFrom { - fn execute_async( - &self, - request: &R, - factory: &dyn RequestFactory, - ) -> Pin>>>; -} diff --git a/src/http/proxy.rs b/src/http/proxy.rs new file mode 100644 index 0000000..1804d31 --- /dev/null +++ b/src/http/proxy.rs @@ -0,0 +1,39 @@ +use crate::domain::SecretString; +use secrecy::ExposeSecret; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ProxyProtocol { + Https, + Socks5, +} + +#[derive(Debug, Clone)] +pub struct ProxyAuth { + pub username: String, + pub password: SecretString, +} + +#[derive(Debug, Clone)] +pub struct Proxy { + pub protocol: ProxyProtocol, + pub auth: Option, + pub url: String, + pub port: u16, +} + +impl Proxy { + pub fn as_url(&self) -> String { + let protocol = match self.protocol { + ProxyProtocol::Https => "https", + ProxyProtocol::Socks5 => "socks5", + }; + + let auth = if let Some(auth) = &self.auth { + format!("{}:{}@", auth.username, auth.password.expose_secret()) + } else { + String::new() + }; + + format!("{protocol}://{auth}{}:{}", self.url, self.port) + } +} diff --git a/src/http/request.rs b/src/http/request.rs new file mode 100644 index 0000000..7226f16 --- /dev/null +++ b/src/http/request.rs @@ -0,0 +1,112 @@ +use crate::http::{ClientAsync, ClientRequestBuilder, ClientSync, Error, FromResponse, Method}; +use bytes::Bytes; +use serde::Serialize; +use std::collections::HashMap; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +use std::marker::PhantomData; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +/// HTTP Request representation. +#[derive(Debug)] +pub struct RequestData { + #[allow(unused)] // Only used by http implementations. + pub(super) method: Method, + #[allow(unused)] // Only used by http implementations. + pub(super) url: String, + pub(super) headers: HashMap, + pub(super) body: Option, +} + +impl RequestData { + pub fn new(method: Method, url: impl Into) -> Self { + Self { + method, + url: url.into(), + headers: HashMap::new(), + body: None, + } + } + + pub fn header(mut self, key: impl Into, value: impl Into) -> Self { + self.headers.insert(key.into(), value.into()); + self + } + + pub fn bearer_token(self, token: &str) -> Self { + self.header("authorization", format!("Bearer {token}")) + } + + pub fn bytes(mut self, bytes: impl Into) -> Self { + self.body = Some(bytes.into()); + self + } + + pub fn json(self, value: impl Serialize) -> Self { + let bytes = serde_json::to_vec(&value).expect("Failed to serialize json"); + self.json_bytes(bytes) + } + + pub fn json_bytes(mut self, bytes: impl Into) -> Self { + self.body = Some(bytes.into()); + self.header("Content-Type", "application/json") + } +} + +pub trait RequestDesc { + type Output: Sized; + type Response: FromResponse; + + fn build(&self) -> RequestData; + + fn to_request(&self) -> RequestWrapper { + let data = self.build(); + RequestWrapper(data, PhantomData) + } +} + +pub struct RequestWrapper(RequestData, PhantomData); + +impl Request for RequestWrapper { + type Response = F; + + fn build(&self, builder: &C) -> C::Request { + builder.new_request(&self.0) + } +} + +#[cfg(not(feature = "async-traits"))] +type RequestFuture<'a, F> = + Pin::Output, Error>> + 'a>>; + +pub trait Request { + type Response: FromResponse; + + fn build(&self, builder: &C) -> C::Request; + + fn exec_sync( + &self, + client: &T, + ) -> Result<::Output, Error> { + client.execute::(self.build(client)) + } + + #[cfg(not(feature = "async-traits"))] + fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> RequestFuture<'a, Self::Response> { + let v = self.build(client); + Box::pin(async move { client.execute_async::(v).await }) + } + + #[cfg(feature = "async-traits")] + async fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Result<::Output, Error> { + let v = self.build(client); + client.execute_async::(v).await + } +} diff --git a/src/http/reqwest_client.rs b/src/http/reqwest_client.rs index fb6d150..7383b2b 100644 --- a/src/http/reqwest_client.rs +++ b/src/http/reqwest_client.rs @@ -1,11 +1,14 @@ use crate::http::{ - ClientAsync, ClientBuilder, Error, FromResponse, Method, Request, RequestFactory, - ResponseBodyAsync, X_PM_APP_VERSION_HEADER, + ClientAsync, ClientBuilder, ClientRequest, ClientRequestBuilder, Error, FromResponse, Method, + RequestData, ResponseBodyAsync, X_PM_APP_VERSION_HEADER, }; use crate::requests::APIError; use bytes::Bytes; use reqwest; + +#[cfg(not(feature = "async-traits"))] use std::future::Future; +#[cfg(not(feature = "async-traits"))] use std::pin::Pin; #[derive(Debug)] @@ -43,7 +46,7 @@ impl TryFrom for ReqwestClient { builder = builder .min_tls_version(Version::TLS_1_2) - .https_only(true) + .https_only(!value.allow_http) .cookie_store(true) .user_agent(value.user_agent) .default_headers(header_map); @@ -87,28 +90,39 @@ impl From for Error { struct ReqwestResponse(reqwest::Response); +pub struct ReqwestRequest(reqwest::RequestBuilder); + +impl ClientRequest for ReqwestRequest { + fn header(self, key: impl AsRef, value: impl AsRef) -> Self { + Self(self.0.header(key.as_ref(), value.as_ref())) + } +} + impl ResponseBodyAsync for ReqwestResponse { type Body = Bytes; + #[cfg(not(feature = "async-traits"))] fn get_body_async(self) -> Pin>>> { Box::pin(async { let bytes = self.0.bytes().await?; Ok(bytes) }) } + + #[cfg(feature = "async-traits")] + async fn get_body_async(self) -> crate::http::Result { + let bytes = self.0.bytes().await?; + Ok(bytes) + } } -impl ClientAsync for ReqwestClient { - fn execute_async( - &self, - r: &R, - factory: &dyn RequestFactory, - ) -> Pin>>> { - let request = r.build_request(factory); +impl ClientRequestBuilder for ReqwestClient { + type Request = ReqwestRequest; - let final_url = format!("{}/{}", self.base_url, request.url); + fn new_request(&self, data: &RequestData) -> Self::Request { + let final_url = format!("{}/{}", self.base_url, data.url); - let mut rrequest = match request.method { + let mut request = match data.method { Method::Delete => self.client.delete(&final_url), Method::Get => self.client.get(&final_url), Method::Put => self.client.put(&final_url), @@ -117,32 +131,57 @@ impl ClientAsync for ReqwestClient { }; // Set headers. - for (header, value) in &request.headers { - rrequest = rrequest.header(header, value); + for (header, value) in &data.headers { + request = request.header(header, value); } - if let Some(body) = &request.body { - rrequest = rrequest.body(body.to_vec()) + if let Some(body) = &data.body { + request = request.body(body.clone()) } - Box::pin(async move { - let response = rrequest.send().await?; + ReqwestRequest(request) + } +} - let status = response.status().as_u16(); +impl ReqwestClient { + pub async fn direct_exec( + &self, + r: ReqwestRequest, + ) -> crate::http::Result { + let response = r.0.send().await?; + + let status = response.status().as_u16(); + + if status >= 400 { + let body = response + .bytes() + .await + .map_err(|_| Error::API(APIError::new(status)))?; + + return Err(Error::API(APIError::with_status_and_body( + status, + body.as_ref(), + ))); + } - if status >= 400 { - let body = response - .bytes() - .await - .map_err(|_| Error::API(APIError::new(status)))?; + R::from_response_async(ReqwestResponse(response)).await + } +} - return Err(Error::API(APIError::with_status_and_body( - status, - body.as_ref(), - ))); - } +impl ClientAsync for ReqwestClient { + #[cfg(not(feature = "async-traits"))] + fn execute_async( + &self, + r: Self::Request, + ) -> Pin> + '_>> { + Box::pin(async move { self.direct_exec::(r).await }) + } - R::Response::from_response_async(ReqwestResponse(response)).await - }) + #[cfg(feature = "async-traits")] + async fn execute_async( + &self, + request: Self::Request, + ) -> crate::http::Result { + self.direct_exec::(request).await } } diff --git a/src/http/response.rs b/src/http/response.rs new file mode 100644 index 0000000..ecc9650 --- /dev/null +++ b/src/http/response.rs @@ -0,0 +1,92 @@ +use crate::http::{FromResponse, ResponseBodyAsync, ResponseBodySync, Result}; +use serde::de::DeserializeOwned; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +use std::marker::PhantomData; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +#[derive(Copy, Clone)] +pub struct NoResponse {} + +impl FromResponse for NoResponse { + type Output = (); + + fn from_response_sync(_: T) -> Result { + Ok(()) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + _: T, + ) -> Pin>>> { + Box::pin(async { Ok(()) }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async(_: T) -> Result { + Ok(()) + } +} + +pub struct JsonResponse(PhantomData); + +impl FromResponse for JsonResponse { + type Output = T; + + fn from_response_sync(response: R) -> Result { + let body = response.get_body()?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: R, + ) -> Pin>>> { + Box::pin(async move { + let body = response.get_body_async().await?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: R, + ) -> Result { + let body = response.get_body_async().await?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + } +} + +#[derive(Copy, Clone)] +pub struct StringResponse {} + +impl FromResponse for StringResponse { + type Output = String; + + fn from_response_sync(response: R) -> Result { + let body = response.get_body()?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: R, + ) -> Pin>>> { + Box::pin(async move { + let body = response.get_body_async().await?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: R, + ) -> Result { + let body = response.get_body_async().await?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + } +} diff --git a/src/http/sequence.rs b/src/http/sequence.rs new file mode 100644 index 0000000..f441ac7 --- /dev/null +++ b/src/http/sequence.rs @@ -0,0 +1,275 @@ +use crate::http::{ClientAsync, ClientSync, Error, FromResponse, Request}; +use std::fmt::Debug; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +#[cfg(not(feature = "async-traits"))] +type SequenceFuture<'a, O, E> = Pin> + 'a>>; + +/// Trait which can be use to link a sequence of request operations. +pub trait Sequence<'a> { + type Output: 'a; + type Error: From + Debug; + + fn do_sync(self, client: &T) -> Result; + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> SequenceFuture<'a, Self::Output, Self::Error>; + + #[cfg(feature = "async-traits")] + async fn do_async(self, client: &'a T) -> Result; + + fn map Result>(self, f: F) -> MapSequence + where + Self: Sized, + E: From + From + Debug, + { + MapSequence { c: self, f } + } + + fn state(self, f: F) -> StateSequence + where + Self: Sized, + SS: Sequence<'a>, + F: FnOnce(Self::Output) -> SS, + >::Error: From + From + Debug, + { + StateSequence { seq: self, f } + } + + fn chain(self, f: F) -> SequenceChain + where + SS: Sequence<'a>, + F: FnOnce(Self::Output) -> Result, + E: From + Debug, + >::Error: From + From + Debug, + Self: Sized, + { + SequenceChain { s: self, f } + } +} + +impl<'a, R: Request + 'a> Sequence<'a> for R +where + ::Output: 'a, +{ + type Output = ::Output; + type Error = Error; + + fn do_sync(self, client: &T) -> Result { + self.exec_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { self.exec_async(client).await }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result<>::Output, >::Error> { + self.exec_async(client).await + } +} + +#[doc(hidden)] +pub struct MapSequence { + c: C, + f: F, +} + +impl<'a, C, O, E, F> Sequence<'a> for MapSequence +where + O: 'a, + C: Sequence<'a> + 'a, + F: FnOnce(C::Output) -> Result + 'a, + E: From + Debug + From, +{ + type Output = O; + type Error = E; + + fn do_sync(self, client: &T) -> Result { + let v = self.c.do_sync(client)?; + let r = (self.f)(v)?; + Ok(r) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let v = self.c.do_async(client).await?; + let r = (self.f)(v)?; + Ok(r) + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let v = self.c.do_async(client).await?; + let r = (self.f)(v)?; + Ok(r) + } +} + +#[doc(hidden)] +pub struct StateSequence { + seq: S, + f: F, +} + +impl<'a, S, SS, F> Sequence<'a> for StateSequence +where + S: Sequence<'a> + 'a, + SS: Sequence<'a>, + >::Error: From<>::Error> + From + Debug, + F: FnOnce(S::Output) -> SS + 'a, +{ + type Output = SS::Output; + type Error = SS::Error; + + fn do_sync(self, client: &T) -> Result { + let state = self.seq.do_sync(client)?; + let ss = (self.f)(state); + ss.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let state = self.seq.do_async(client).await?; + let ss = (self.f)(state); + ss.do_async(client).await + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let state = self.seq.do_async(client).await?; + let ss = (self.f)(state); + ss.do_async(client).await + } +} + +#[doc(hidden)] +pub struct StateProducerSequence { + s: S, + f: F, +} + +impl StateProducerSequence { + pub fn new(s: S, f: F) -> Self { + Self { s, f } + } +} + +impl<'a, Seq, S, F> Sequence<'a> for StateProducerSequence +where + Seq: Sequence<'a>, + F: FnOnce(S) -> Seq, +{ + type Output = Seq::Output; + type Error = Seq::Error; + + fn do_sync(self, client: &T) -> Result { + let seq = (self.f)(self.s); + seq.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + let seq = (self.f)(self.s); + seq.do_async(client) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let seq = (self.f)(self.s); + seq.do_async(client).await + } +} + +#[doc(hidden)] +pub struct SequenceChain { + s: S, + f: F, +} + +impl<'a, SS, S, E, F> Sequence<'a> for SequenceChain +where + SS: Sequence<'a>, + S: Sequence<'a> + 'a, + F: FnOnce(S::Output) -> Result + 'a, + E: From + Debug, + >::Error: From + From + Debug, +{ + type Output = SS::Output; + type Error = SS::Error; + + fn do_sync(self, client: &T) -> Result { + let v = self.s.do_sync(client)?; + let ss = (self.f)(v)?; + ss.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let v = self.s.do_async(client).await?; + let ss = (self.f)(v)?; + ss.do_async(client).await + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let v = self.s.do_async(client).await?; + let ss = (self.f)(v)?; + ss.do_async(client).await + } +} diff --git a/src/http/ureq_client.rs b/src/http/ureq_client.rs index 6d5873e..48efc8b 100644 --- a/src/http/ureq_client.rs +++ b/src/http/ureq_client.rs @@ -1,7 +1,10 @@ //! UReq HTTP client implementation. -use crate::http::{ClientBuilder, ClientSync, Error, FromResponse, Method, ResponseBodySync}; -use crate::http::{Request, RequestFactory, X_PM_APP_VERSION_HEADER}; +use crate::http::X_PM_APP_VERSION_HEADER; +use crate::http::{ + ClientBuilder, ClientRequest, ClientRequestBuilder, ClientSync, Error, FromResponse, Method, + RequestData, ResponseBodySync, +}; use crate::requests::APIError; use log::debug; use std::io; @@ -116,13 +119,22 @@ impl ResponseBodySync for UReqDebugResponse { } } -impl ClientSync for UReqClient { - fn execute( - &self, - r: &R, - factory: &dyn RequestFactory, - ) -> Result { - let request = r.build_request(factory); +pub struct UReqRequest { + request: ureq::Request, + body: Option, +} + +impl ClientRequest for UReqRequest { + fn header(mut self, key: impl AsRef, value: impl AsRef) -> Self { + self.request = self.request.set(key.as_ref(), value.as_ref()); + self + } +} + +impl ClientRequestBuilder for UReqClient { + type Request = UReqRequest; + + fn new_request(&self, request: &RequestData) -> Self::Request { let final_url = format!("{}/{}", self.base_url, request.url); let mut ureq_request = match request.method { Method::Delete => self.agent.delete(&final_url), @@ -140,16 +152,25 @@ impl ClientSync for UReqClient { ureq_request = ureq_request.set(header, value); } - let ureq_response = if let Some(body) = &request.body { - ureq_request.send_bytes(body)? + Self::Request { + request: ureq_request, + body: request.body.clone(), + } + } +} + +impl ClientSync for UReqClient { + fn execute(&self, request: Self::Request) -> Result { + let ureq_response = if let Some(body) = request.body { + request.request.send_bytes(body.as_ref())? } else { - ureq_request.call()? + request.request.call()? }; if !self.debug { - R::Response::from_response_sync(UReqResponse(ureq_response)) + R::from_response_sync(UReqResponse(ureq_response)) } else { - R::Response::from_response_sync(UReqDebugResponse(ureq_response)) + R::from_response_sync(UReqDebugResponse(ureq_response)) } } } diff --git a/src/lib.rs b/src/lib.rs index 7c2d9b6..2097069 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "async-traits", allow(incomplete_features))] +#![cfg_attr(feature = "async-traits", feature(async_fn_in_trait))] // Enable clippy if our Cargo.toml file asked us to do so. #![cfg_attr(feature = "clippy", feature(plugin))] #![cfg_attr(feature = "clippy", plugin(clippy))] @@ -38,7 +40,8 @@ //! //! Login into a new session async: //! ``` -//! use proton_api_rs::{http, Session, SessionType}; +//! use proton_api_rs::{http, Session, SessionType, http::Sequence}; +//! use proton_api_rs::domain::SecretString; //! async fn example() { //! let client = http::ClientBuilder::new() //! .user_agent("MyUserAgent/0.0.0") @@ -46,25 +49,26 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = match Session::login_async(&client, "my_address@proton.me", "my_proton_password", None, None).await.unwrap(){ +//! let session = match Session::login(&"my_address@proton.me", &SecretString::new("my_proton_password".into()), None).do_async(&client).await.unwrap(){ //! // Session is authenticated, no 2FA verifications necessary. //! SessionType::Authenticated(c) => c, //! // Session needs 2FA TOTP auth. //! SessionType::AwaitingTotp(t) => { -//! t.submit_totp_async(&client, "000000").await.unwrap() +//! t.submit_totp("000000").do_async(&client).await.unwrap() //! } //! }; //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout_async(&client).await.unwrap(); +//! session.logout().do_async(&client).await.unwrap(); //! } //! ``` //! //! Login into a new session sync: //! ``` -//! use proton_api_rs::{Session, http, SessionType}; +//! use proton_api_rs::{Session, http, SessionType, http::Sequence}; +//! use proton_api_rs::domain::SecretString; //! fn example() { //! let client = http::ClientBuilder::new() //! .user_agent("MyUserAgent/0.0.0") @@ -72,25 +76,25 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = match Session::login(&client, "my_address@proton.me", "my_proton_password", None, None).unwrap(){ +//! let session = match Session::login("my_address@proton.me", &SecretString::new("my_proton_password".into()), None).do_sync(&client).unwrap(){ //! // Session is authenticated, no 2FA verifications necessary. //! SessionType::Authenticated(c) => c, //! // Session needs 2FA TOTP auth. //! SessionType::AwaitingTotp(t) => { -//! t.submit_totp(&client, "000000").unwrap() +//! t.submit_totp("000000").do_sync(&client).unwrap() //! } //! }; //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout(&client).unwrap(); +//! session.logout().do_sync(&client).unwrap(); //! } //! ``` //! //! Login using a previous sessions token. //! ``` -//! use proton_api_rs::{http, Session, SessionType}; +//! use proton_api_rs::{http, Session, SessionType, http::Sequence}; //! use proton_api_rs::domain::UserUid; //! //! async fn example() { @@ -102,12 +106,12 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = Session::refresh_async(&client, &user_uid, &user_refresh_token, None).await.unwrap(); +//! let session = Session::refresh(&user_uid, &user_refresh_token).do_async(&client).await.unwrap(); //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout_async(&client).await.unwrap(); +//! session.logout().do_async(&client).await.unwrap(); //! } //! ``` diff --git a/src/requests/auth.rs b/src/requests/auth.rs index 8db3e8c..060c234 100644 --- a/src/requests/auth.rs +++ b/src/requests/auth.rs @@ -1,8 +1,6 @@ use crate::domain::{HumanVerificationLoginData, SecretString, UserUid}; use crate::http; -use crate::http::{ - RequestData, RequestFactory, X_PM_HUMAN_VERIFICATION_TOKEN, X_PM_HUMAN_VERIFICATION_TOKEN_TYPE, -}; +use crate::http::{RequestData, X_PM_HUMAN_VERIFICATION_TOKEN, X_PM_HUMAN_VERIFICATION_TOKEN_TYPE}; use secrecy::Secret; use serde::{Deserialize, Serialize}; use serde_repr::Deserialize_repr; @@ -15,27 +13,25 @@ pub struct AuthInfoRequest<'a> { pub username: &'a str, } -impl<'a> http::Request for AuthInfoRequest<'a> { - type Output = AuthInfoResponse<'a>; +impl<'a> http::RequestDesc for AuthInfoRequest<'a> { + type Output = AuthInfoResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/info") - .json(self) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/info").json(self) } } #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthInfoResponse<'a> { +pub struct AuthInfoResponse { pub version: i64, - pub modulus: Cow<'a, str>, - pub server_ephemeral: Cow<'a, str>, - pub salt: Cow<'a, str>, + pub modulus: String, + pub server_ephemeral: String, + pub salt: String, #[serde(rename = "SRPSession")] - pub srp_session: Cow<'a, str>, + pub srp_session: String, } #[doc(hidden)] @@ -48,17 +44,15 @@ pub struct AuthRequest<'a> { #[serde(rename = "SRPSession")] pub srp_session: &'a str, #[serde(skip)] - pub human_verification: Option, + pub human_verification: &'a Option, } -impl<'a> http::Request for AuthRequest<'a> { - type Output = AuthResponse<'a>; +impl<'a> http::RequestDesc for AuthRequest<'a> { + type Output = AuthResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - let mut request = factory - .new_request(http::Method::Post, "auth/v4") - .json(self); + fn build(&self) -> RequestData { + let mut request = RequestData::new(http::Method::Post, "auth/v4").json(self); if let Some(hv) = &self.human_verification { // repeat submission with x-pm-human-verification-token and x-pm-human-verification-token-type @@ -74,18 +68,18 @@ impl<'a> http::Request for AuthRequest<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthResponse<'a> { +pub struct AuthResponse { #[serde(rename = "UserID")] - pub user_id: Cow<'a, str>, + pub user_id: String, #[serde(rename = "UID")] - pub uid: Cow<'a, str>, - pub token_type: Option>, - pub access_token: Cow<'a, str>, - pub refresh_token: Cow<'a, str>, - pub server_proof: Cow<'a, str>, - pub scope: Cow<'a, str>, + pub uid: String, + pub token_type: Option, + pub access_token: String, + pub refresh_token: String, + pub server_proof: String, + pub scope: String, #[serde(rename = "2FA")] - pub tfa: TFAInfo<'a>, + pub tfa: TFAInfo, pub password_mode: PasswordMode, } @@ -110,10 +104,10 @@ pub enum TFAStatus { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct TFAInfo<'a> { +pub struct TFAInfo { pub enabled: TFAStatus, #[serde(rename = "FIDO2")] - pub fido2_info: FIDO2Info<'a>, + pub fido2_info: FIDO2Info, } #[doc(hidden)] @@ -129,9 +123,9 @@ pub struct FIDOKey<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct FIDO2Info<'a> { +pub struct FIDO2Info { pub authentication_options: serde_json::Value, - pub registered_keys: Option>>, + pub registered_keys: Option, } #[doc(hidden)] @@ -177,17 +171,15 @@ impl<'a> TOTPRequest<'a> { } } -impl<'a> http::Request for TOTPRequest<'a> { +impl<'a> http::RequestDesc for TOTPRequest<'a> { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/2fa") - .json(TFAAuth { - two_factor_code: self.code, - fido2: FIDO2Auth::empty(), - }) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/2fa").json(TFAAuth { + two_factor_code: self.code, + fido2: FIDO2Auth::empty(), + }) } } @@ -200,19 +192,19 @@ pub struct UserAuth { } impl UserAuth { - pub fn from_auth_response(auth: &AuthResponse) -> Self { + pub fn from_auth_response(auth: AuthResponse) -> Self { Self { - uid: Secret::new(UserUid(auth.uid.to_string())), - access_token: SecretString::new(auth.access_token.to_string()), - refresh_token: SecretString::new(auth.refresh_token.to_string()), + uid: Secret::new(UserUid(auth.uid)), + access_token: SecretString::new(auth.access_token), + refresh_token: SecretString::new(auth.refresh_token), } } - pub fn from_auth_refresh_response(auth: &AuthRefreshResponse) -> Self { + pub fn from_auth_refresh_response(auth: AuthRefreshResponse) -> Self { Self { - uid: Secret::new(UserUid(auth.uid.to_string())), - access_token: SecretString::new(auth.access_token.to_string()), - refresh_token: SecretString::new(auth.refresh_token.to_string()), + uid: Secret::new(UserUid(auth.uid)), + access_token: SecretString::new(auth.access_token), + refresh_token: SecretString::new(auth.refresh_token), } } } @@ -233,13 +225,13 @@ pub struct AuthRefresh<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthRefreshResponse<'a> { +pub struct AuthRefreshResponse { #[serde(rename = "UID")] - pub uid: Cow<'a, str>, - pub token_type: Cow<'a, str>, - pub access_token: Cow<'a, str>, - pub refresh_token: Cow<'a, str>, - pub scope: Cow<'a, str>, + pub uid: String, + pub token_type: Option, + pub access_token: String, + pub refresh_token: String, + pub scope: String, } pub struct AuthRefreshRequest<'a> { @@ -253,31 +245,29 @@ impl<'a> AuthRefreshRequest<'a> { } } -impl<'a> http::Request for AuthRefreshRequest<'a> { - type Output = AuthRefreshResponse<'a>; +impl<'a> http::RequestDesc for AuthRefreshRequest<'a> { + type Output = AuthRefreshResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/refresh") - .json(AuthRefresh { - uid: &self.uid.0, - refresh_token: self.token, - grant_type: "refresh_token", - response_type: "token", - redirect_uri: "https://protonmail.ch/", - }) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/refresh").json(AuthRefresh { + uid: &self.uid.0, + refresh_token: self.token, + grant_type: "refresh_token", + response_type: "token", + redirect_uri: "https://protonmail.ch/", + }) } } pub struct LogoutRequest {} -impl http::Request for LogoutRequest { +impl http::RequestDesc for LogoutRequest { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request(http::Method::Delete, "auth/v4") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Delete, "auth/v4") } } @@ -292,16 +282,17 @@ impl<'a> CaptchaRequest<'a> { } } -impl<'a> http::Request for CaptchaRequest<'a> { +impl<'a> http::RequestDesc for CaptchaRequest<'a> { type Output = String; type Response = http::StringResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { + fn build(&self) -> RequestData { let url = if self.force_web { format!("core/v4/captcha?ForceWebMessaging=1&Token={}", self.token) } else { format!("core/v4/captcha?Token={}", self.token) }; - factory.new_request(http::Method::Get, &url) + + RequestData::new(http::Method::Get, url) } } diff --git a/src/requests/event.rs b/src/requests/event.rs index 7a61f6e..d281f8a 100644 --- a/src/requests/event.rs +++ b/src/requests/event.rs @@ -1,5 +1,5 @@ use crate::http; -use crate::http::{RequestData, RequestFactory}; +use crate::http::RequestData; use serde::Deserialize; #[doc(hidden)] @@ -11,12 +11,12 @@ pub struct LatestEventResponse { pub struct GetLatestEventRequest; -impl http::Request for GetLatestEventRequest { +impl http::RequestDesc for GetLatestEventRequest { type Output = LatestEventResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request(http::Method::Get, "core/v4/events/latest") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "core/v4/events/latest") } } @@ -30,14 +30,14 @@ impl<'a> GetEventRequest<'a> { } } -impl<'a> http::Request for GetEventRequest<'a> { +impl<'a> http::RequestDesc for GetEventRequest<'a> { type Output = crate::domain::Event; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request( + fn build(&self) -> RequestData { + RequestData::new( http::Method::Get, - &format!("core/v4/events/{}", self.event_id), + format!("core/v4/events/{}", self.event_id), ) } } diff --git a/src/requests/tests.rs b/src/requests/tests.rs index 031bae5..806a3ca 100644 --- a/src/requests/tests.rs +++ b/src/requests/tests.rs @@ -1,12 +1,13 @@ use crate::http; +use crate::http::RequestData; pub struct Ping; -impl http::Request for Ping { +impl http::RequestDesc for Ping { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn http::RequestFactory) -> http::RequestData { - factory.new_request(http::Method::Get, "tests/ping") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "tests/ping") } } diff --git a/src/requests/user.rs b/src/requests/user.rs index fa22295..45a9101 100644 --- a/src/requests/user.rs +++ b/src/requests/user.rs @@ -1,6 +1,6 @@ use crate::domain::User; use crate::http; -use crate::http::{JsonResponse, RequestFactory}; +use crate::http::{JsonResponse, RequestData}; use serde::Deserialize; #[derive(Deserialize)] @@ -11,11 +11,11 @@ pub struct UserInfoResponse { pub struct UserInfoRequest {} -impl http::Request for UserInfoRequest { +impl http::RequestDesc for UserInfoRequest { type Output = UserInfoResponse; type Response = JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> http::RequestData { - factory.new_request(http::Method::Get, "core/v4/users") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "core/v4/users") } } diff --git a/tests/session/login.rs b/tests/session/login.rs index 5b37338..92adcce 100644 --- a/tests/session/login.rs +++ b/tests/session/login.rs @@ -1,38 +1,126 @@ -use crate::utils::create_session_and_server; +use crate::utils::{create_session_and_server, ClientASync, ClientSync}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; use proton_api_rs::{http, LoginError, Session, SessionType}; +use secrecy::{ExposeSecret, Secret}; +use tokio; const DEFAULT_USER_EMAIL: &str = "foo@bar.com"; const DEFAULT_USER_PASSWORD: &str = "12345"; #[test] fn session_login() { - let (client, server) = create_session_and_server(); + let (client, server) = create_session_and_server::(); + let (user_id, _) = server .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) .expect("failed to create default user"); let auth_result = Session::login( - &client, DEFAULT_USER_EMAIL, - DEFAULT_USER_PASSWORD, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), None, + ) + .do_sync(&client) + .expect("Failed to login"); + + assert!(matches!(auth_result, SessionType::Authenticated(_))); + + if let SessionType::Authenticated(s) = auth_result { + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + s.logout().do_sync(&client).expect("Failed to logout") + } +} + +#[test] +fn session_login_auto_refresh() { + let (client, server) = create_session_and_server::(); + + let (user_id, _) = server + .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) + .expect("failed to create default user"); + let auth_result = Session::login( + DEFAULT_USER_EMAIL, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), + None, + ) + .do_sync(&client) + .expect("Failed to login"); + + assert!(matches!(auth_result, SessionType::Authenticated(_))); + + if let SessionType::Authenticated(s) = auth_result { + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + let rs = s.get_refresh_data(); + server + .set_auth_timeout(std::time::Duration::from_secs(1)) + .expect("Failed to set timeout"); + std::thread::sleep(std::time::Duration::from_secs(1)); + + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + let rs_post_refresh = s.get_refresh_data(); + + assert_eq!( + rs.user_uid.expose_secret(), + rs_post_refresh.user_uid.expose_secret() + ); + + assert_ne!( + rs.token.expose_secret(), + rs_post_refresh.token.expose_secret() + ); + + s.logout().do_sync(&client).expect("Failed to logout") + } +} + +#[tokio::test()] +async fn session_login_async() { + let (client, server) = create_session_and_server::(); + + let (user_id, _) = server + .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) + .expect("failed to create default user"); + let auth_result = Session::login( + DEFAULT_USER_EMAIL, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), None, ) + .do_async(&client) + .await .expect("Failed to login"); assert!(matches!(auth_result, SessionType::Authenticated(_))); if let SessionType::Authenticated(s) = auth_result { - let user = s.get_user(&client).expect("Failed to get user"); + let user = s + .get_user() + .do_async(&client) + .await + .expect("Failed to get user"); assert_eq!(user.id.as_ref(), user_id.as_ref()); - s.logout(&client).expect("Failed to logout") + s.logout() + .do_async(&client) + .await + .expect("Failed to logout") } } #[test] fn session_login_invalid_user() { - let (client, _server) = create_session_and_server(); - let auth_result = Session::login(&client, "bar", DEFAULT_USER_PASSWORD, None, None); + let (client, _server) = create_session_and_server::(); + let auth_result = Session::login( + "bar", + &SecretString::new(DEFAULT_USER_PASSWORD.into()), + None, + ) + .do_sync(&client); assert!(matches!( auth_result, diff --git a/tests/session/utils.rs b/tests/session/utils.rs index 5263622..f9c521b 100644 --- a/tests/session/utils.rs +++ b/tests/session/utils.rs @@ -4,11 +4,13 @@ use proton_api_rs::http; use proton_api_rs::http::ClientBuilder; use std::sync::OnceLock; -type Client = http::ureq_client::UReqClient; +pub type ClientSync = http::ureq_client::UReqClient; +pub type ClientASync = http::reqwest_client::ReqwestClient; static LOG_CELL: OnceLock<()> = OnceLock::new(); -pub fn create_session_and_server() -> (Client, Server) { +pub fn create_session_and_server>( +) -> (Client, Server) { let debug = if let Ok(v) = std::env::var("RUST_LOG") { if v.eq_ignore_ascii_case("debug") { true