Skip to content
This repository was archived by the owner on Jul 6, 2024. It is now read-only.

Commit 2832de3

Browse files
committed
refactor: Rewrite session policies as request repeater
The behavior is expected for all proton clients, there's no real benefit in leaving this an optional configuration parameter.
1 parent 57e526b commit 2832de3

File tree

8 files changed

+211
-298
lines changed

8 files changed

+211
-298
lines changed

examples/user_id.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use proton_api_rs::SessionType;
2-
use proton_api_rs::{http, ping_async, DefaultSession};
1+
use proton_api_rs::{http, ping_async};
2+
use proton_api_rs::{Session, SessionType};
33
pub use tokio;
44
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
55

@@ -16,7 +16,7 @@ async fn main() {
1616

1717
ping_async(&client).await.unwrap();
1818

19-
let session = match DefaultSession::login_async(&client, &user_email, &user_password)
19+
let session = match Session::login_async(&client, &user_email, &user_password, None)
2020
.await
2121
.unwrap()
2222
{

examples/user_id_sync.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use proton_api_rs::clientv2::{ping, SessionType};
2-
use proton_api_rs::{http, DefaultSession};
2+
use proton_api_rs::{http, Session};
33
use std::io::{BufRead, Write};
44

55
fn main() {
@@ -16,7 +16,7 @@ fn main() {
1616

1717
ping(&client).unwrap();
1818

19-
let session = match DefaultSession::login(&client, &user_email, &user_password).unwrap() {
19+
let session = match Session::login(&client, &user_email, &user_password, None).unwrap() {
2020
SessionType::Authenticated(s) => s,
2121
SessionType::AwaitingTotp(mut t) => {
2222
let mut line_reader = std::io::BufReader::new(std::io::stdin());

src/clientv2/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
mod client;
2+
mod request_repeater;
23
mod session;
3-
mod session_policies;
44
mod totp;
55

66
pub use client::*;
7+
pub use request_repeater::*;
78
pub use session::*;
8-
pub use session_policies::*;
99
pub use totp::*;

src/clientv2/request_repeater.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
//! Automatic request repeater based on the expectations Proton has for their clients.
2+
3+
use crate::domain::UserUid;
4+
use crate::http::{
5+
ClientAsync, ClientSync, DefaultRequestFactory, Method, Request, RequestData, RequestFactory,
6+
};
7+
use crate::requests::{AuthRefreshRequest, UserAuth};
8+
use crate::{http, SessionRefreshData};
9+
use secrecy::ExposeSecret;
10+
11+
pub type OnAuthRefreshedCallback = Box<dyn Fn(&UserUid, &str)>;
12+
13+
pub struct RequestRepeater {
14+
user_auth: parking_lot::RwLock<UserAuth>,
15+
on_auth_refreshed: Option<OnAuthRefreshedCallback>,
16+
}
17+
18+
impl std::fmt::Debug for RequestRepeater {
19+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20+
write!(
21+
f,
22+
"RequestRepeater{{user_auth:{:?} on_auth_refreshed:{}}}",
23+
self.user_auth,
24+
if self.on_auth_refreshed.is_some() {
25+
"Some"
26+
} else {
27+
"None"
28+
}
29+
)
30+
}
31+
}
32+
33+
impl RequestRepeater {
34+
pub fn new(user_auth: UserAuth, on_auth_refreshed: Option<OnAuthRefreshedCallback>) -> Self {
35+
Self {
36+
user_auth: parking_lot::RwLock::new(user_auth),
37+
on_auth_refreshed,
38+
}
39+
}
40+
41+
fn refresh_auth<C: ClientSync>(&self, client: &C) -> http::Result<()> {
42+
let borrow = self.user_auth.read();
43+
match AuthRefreshRequest::new(
44+
borrow.uid.expose_secret(),
45+
borrow.refresh_token.expose_secret(),
46+
)
47+
.execute_sync(client, &DefaultRequestFactory {})
48+
{
49+
Ok(s) => {
50+
let mut borrow = self.user_auth.write();
51+
*borrow = UserAuth::from_auth_refresh_response(&s);
52+
if let Some(cb) = &self.on_auth_refreshed {
53+
(cb)(
54+
borrow.uid.expose_secret(),
55+
borrow.access_token.expose_secret(),
56+
);
57+
}
58+
Ok(())
59+
}
60+
Err(e) => Err(e),
61+
}
62+
}
63+
64+
async fn refresh_auth_async<C: ClientAsync>(&self, client: &C) -> http::Result<()> {
65+
// Have to clone here due to async boundaries.
66+
let user_auth = self.user_auth.read().clone();
67+
match AuthRefreshRequest::new(
68+
user_auth.uid.expose_secret(),
69+
user_auth.refresh_token.expose_secret(),
70+
)
71+
.execute_async(client, &DefaultRequestFactory {})
72+
.await
73+
{
74+
Ok(s) => {
75+
let mut borrow = self.user_auth.write();
76+
*borrow = UserAuth::from_auth_refresh_response(&s);
77+
if let Some(cb) = &self.on_auth_refreshed {
78+
(cb)(
79+
borrow.uid.expose_secret(),
80+
borrow.access_token.expose_secret(),
81+
);
82+
}
83+
Ok(())
84+
}
85+
Err(e) => Err(e),
86+
}
87+
}
88+
89+
pub fn execute<C: ClientSync, R: Request>(
90+
&self,
91+
client: &C,
92+
request: R,
93+
) -> http::Result<R::Output> {
94+
match request.execute_sync(client, self) {
95+
Ok(r) => Ok(r),
96+
Err(original_error) => {
97+
if let http::Error::API(api_err) = &original_error {
98+
if api_err.http_code == 401 {
99+
log::debug!("Account session expired, attempting refresh");
100+
// Session expired/not authorized, try auth refresh.
101+
if let Err(e) = self.refresh_auth(client) {
102+
log::error!("Failed to refresh account {e}");
103+
return Err(original_error);
104+
}
105+
106+
// Execute request again
107+
return request.execute_sync(client, self);
108+
}
109+
}
110+
Err(original_error)
111+
}
112+
}
113+
}
114+
115+
pub async fn execute_async<'a, C: ClientAsync, R: Request + 'a>(
116+
&'a self,
117+
client: &'a C,
118+
request: R,
119+
) -> http::Result<R::Output> {
120+
match request.execute_async(client, self).await {
121+
Ok(r) => Ok(r),
122+
Err(original_error) => {
123+
if let http::Error::API(api_err) = &original_error {
124+
log::debug!("Account session expired, attempting refresh");
125+
if api_err.http_code == 401 {
126+
// Session expired/not authorized, try auth refresh.
127+
if let Err(e) = self.refresh_auth_async(client).await {
128+
log::error!("Failed to refresh account {e}");
129+
return Err(original_error);
130+
}
131+
132+
// Execute request again
133+
return request.execute_async(client, self).await;
134+
}
135+
}
136+
Err(original_error)
137+
}
138+
}
139+
}
140+
141+
pub fn get_refresh_data(&self) -> SessionRefreshData {
142+
let borrow = self.user_auth.read();
143+
SessionRefreshData {
144+
user_uid: borrow.uid.clone(),
145+
token: borrow.refresh_token.clone(),
146+
}
147+
}
148+
}
149+
150+
impl RequestFactory for RequestRepeater {
151+
fn new_request(&self, method: Method, url: &str) -> RequestData {
152+
let accessor = self.user_auth.read();
153+
RequestData::new(method, url)
154+
.header(http::X_PM_UID_HEADER, &accessor.uid.expose_secret().0)
155+
.bearer_token(accessor.access_token.expose_secret())
156+
}
157+
}

0 commit comments

Comments
 (0)