Skip to content

Commit 3af10b9

Browse files
committed
Extract listen from the Agent
Signed-off-by: Wiktor Kwapisiewicz <wiktor@metacode.biz>
1 parent 920b7a2 commit 3af10b9

File tree

5 files changed

+141
-74
lines changed

5 files changed

+141
-74
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use tokio::net::UnixListener as Listener;
1818
#[cfg(windows)]
1919
use ssh_agent_lib::agent::NamedPipeListener as Listener;
2020
use ssh_agent_lib::error::AgentError;
21-
use ssh_agent_lib::agent::{Session, Agent};
21+
use ssh_agent_lib::agent::{Session, listen};
2222
use ssh_agent_lib::proto::{Identity, SignRequest};
2323
use ssh_key::{Algorithm, Signature};
2424
@@ -50,7 +50,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
5050
5151
let _ = std::fs::remove_file(socket); // remove the socket if exists
5252
53-
MyAgent.listen(Listener::bind(socket)?).await?;
53+
listen(Listener::bind(socket)?, MyAgent::default()).await?;
5454
Ok(())
5555
}
5656
```

examples/key_storage.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ use rsa::BigUint;
99
use sha1::Sha1;
1010
#[cfg(windows)]
1111
use ssh_agent_lib::agent::NamedPipeListener as Listener;
12-
use ssh_agent_lib::agent::{ListeningSocket, Session};
12+
use ssh_agent_lib::agent::{listen, Agent, Session};
1313
use ssh_agent_lib::error::AgentError;
1414
use ssh_agent_lib::proto::extension::{QueryResponse, RestrictDestination, SessionBind};
1515
use ssh_agent_lib::proto::{
1616
message, signature, AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained,
1717
Credential, Extension, KeyConstraint, RemoveIdentity, SignRequest, SmartcardKey,
1818
};
19-
use ssh_agent_lib::Agent;
2019
use ssh_key::{
2120
private::{KeypairData, PrivateKey},
2221
public::PublicKey,
@@ -237,11 +236,21 @@ impl KeyStorageAgent {
237236
}
238237
}
239238

240-
impl Agent for KeyStorageAgent {
241-
fn new_session<S>(&mut self, _socket: &S::Stream) -> impl Session
242-
where
243-
S: ListeningSocket + std::fmt::Debug + Send,
244-
{
239+
#[cfg(unix)]
240+
impl Agent<Listener> for KeyStorageAgent {
241+
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
242+
KeyStorage {
243+
identities: Arc::clone(&self.identities),
244+
}
245+
}
246+
}
247+
248+
#[cfg(windows)]
249+
impl Agent<Listener> for KeyStorageAgent {
250+
fn new_session(
251+
&mut self,
252+
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
253+
) -> impl Session {
245254
KeyStorage {
246255
identities: Arc::clone(&self.identities),
247256
}
@@ -263,8 +272,6 @@ async fn main() -> Result<(), AgentError> {
263272
#[cfg(windows)]
264273
std::fs::File::create("server-started")?;
265274

266-
KeyStorageAgent::new()
267-
.listen(Listener::bind(socket)?)
268-
.await?;
275+
listen(Listener::bind(socket)?, KeyStorageAgent::new()).await?;
269276
Ok(())
270277
}

examples/openpgp-card-agent.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
1717
use std::{sync::Arc, time::Duration};
1818

19+
#[cfg(windows)]
20+
use ssh_agent_lib::agent::NamedPipeListener as Listener;
21+
22+
#[cfg(not(windows))]
23+
use tokio::net::UnixListener as Listener;
1924
use card_backend_pcsc::PcscBackend;
2025
use clap::Parser;
2126
use openpgp_card::{
@@ -27,16 +32,16 @@ use retainer::{Cache, CacheExpiration};
2732
use secrecy::{ExposeSecret, SecretString};
2833
use service_binding::Binding;
2934
use ssh_agent_lib::{
30-
agent::Session,
35+
agent::{bind, Session, Agent},
3136
error::AgentError,
3237
proto::{AddSmartcardKeyConstrained, Identity, KeyConstraint, SignRequest, SmartcardKey},
33-
Agent,
3438
};
3539
use ssh_key::{
3640
public::{Ed25519PublicKey, KeyData},
3741
Algorithm, Signature,
3842
};
3943
use testresult::TestResult;
44+
use tokio::net::TcpListener;
4045

4146
struct CardAgent {
4247
pwds: Arc<Cache<String, SecretString>>,
@@ -51,8 +56,30 @@ impl CardAgent {
5156
}
5257
}
5358

54-
impl Agent for CardAgent {
55-
fn new_session(&mut self) -> impl Session {
59+
#[cfg(unix)]
60+
impl Agent<Listener> for CardAgent {
61+
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
62+
CardSession {
63+
pwds: Arc::clone(&self.pwds),
64+
}
65+
}
66+
}
67+
68+
#[cfg(unix)]
69+
impl Agent<TcpListener> for CardAgent {
70+
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
71+
CardSession {
72+
pwds: Arc::clone(&self.pwds),
73+
}
74+
}
75+
}
76+
77+
#[cfg(windows)]
78+
impl Agent<Listener> for CardAgent {
79+
fn new_session(
80+
&mut self,
81+
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
82+
) -> impl Session {
5683
CardSession {
5784
pwds: Arc::clone(&self.pwds),
5885
}
@@ -201,6 +228,6 @@ async fn main() -> TestResult {
201228
env_logger::init();
202229

203230
let args = Args::parse();
204-
CardAgent::new().bind(args.host.try_into()?).await?;
231+
bind(args.host.try_into()?, CardAgent::new()).await?;
205232
Ok(())
206233
}

src/agent.rs

Lines changed: 90 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -249,69 +249,105 @@ where
249249
}
250250
}
251251

252-
/// Type representing an agent listening for incoming connections.
253-
#[async_trait]
254-
pub trait Agent: 'static + Sync + Send + Sized {
255-
/// Create new session object when a new socket is accepted.
256-
fn new_session<S>(&mut self, socket: &S::Stream) -> impl Session
257-
where
258-
S: ListeningSocket + fmt::Debug + Send;
259-
260-
/// Listen on a socket waiting for client connections.
261-
async fn listen<S>(mut self, mut socket: S) -> Result<(), AgentError>
262-
where
263-
S: ListeningSocket + fmt::Debug + Send,
264-
{
265-
log::info!("Listening; socket = {:?}", socket);
266-
loop {
267-
match socket.accept().await {
268-
Ok(socket) => {
269-
let session = self.new_session::<S>(&socket);
270-
tokio::spawn(async move {
271-
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
272-
if let Err(e) = handle_socket::<S>(session, adapter).await {
273-
log::error!("Agent protocol error: {:?}", e);
274-
}
275-
});
276-
}
277-
Err(e) => {
278-
log::error!("Failed to accept socket: {:?}", e);
279-
return Err(AgentError::IO(e));
280-
}
281-
}
282-
}
283-
}
252+
/// Factory of sessions for the given type of sockets.
253+
pub trait Agent<S>: 'static + Send + Sync
254+
where
255+
S: ListeningSocket + fmt::Debug + Send,
256+
{
257+
/// Create a [`Session`] object for a given `socket`.
258+
fn new_session(&mut self, socket: &S::Stream) -> impl Session;
259+
}
284260

285-
/// Bind to a service binding listener.
286-
async fn bind(mut self, listener: service_binding::Listener) -> Result<(), AgentError> {
287-
match listener {
288-
#[cfg(unix)]
289-
service_binding::Listener::Unix(listener) => {
290-
self.listen(UnixListener::from_std(listener)?).await
291-
}
292-
service_binding::Listener::Tcp(listener) => {
293-
self.listen(TcpListener::from_std(listener)?).await
261+
/// Listen for connections on a given socket and use session factory
262+
/// to create new session for each accepted socket.
263+
pub async fn listen<S>(mut socket: S, mut sf: impl Agent<S>) -> Result<(), AgentError>
264+
where
265+
S: ListeningSocket + fmt::Debug + Send,
266+
{
267+
log::info!("Listening; socket = {:?}", socket);
268+
loop {
269+
match socket.accept().await {
270+
Ok(socket) => {
271+
let session = sf.new_session(&socket);
272+
tokio::spawn(async move {
273+
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
274+
if let Err(e) = handle_socket::<S>(session, adapter).await {
275+
log::error!("Agent protocol error: {:?}", e);
276+
}
277+
});
294278
}
295-
#[cfg(windows)]
296-
service_binding::Listener::NamedPipe(pipe) => {
297-
self.listen(NamedPipeListener::bind(pipe)?).await
279+
Err(e) => {
280+
log::error!("Failed to accept socket: {:?}", e);
281+
return Err(AgentError::IO(e));
298282
}
299-
#[cfg(not(windows))]
300-
service_binding::Listener::NamedPipe(_) => Err(AgentError::IO(std::io::Error::other(
301-
"Named pipes supported on Windows only",
302-
))),
303283
}
304284
}
305285
}
306286

307-
impl<T> Agent for T
287+
#[cfg(unix)]
288+
impl<T> Agent<tokio::net::UnixListener> for T
289+
where
290+
T: Default + Send + Sync + Session,
291+
{
292+
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
293+
Self::default()
294+
}
295+
}
296+
297+
impl<T> Agent<tokio::net::TcpListener> for T
308298
where
309-
T: Default + Session,
299+
T: Default + Send + Sync + Session,
310300
{
311-
fn new_session<S>(&mut self, _socket: &S::Stream) -> impl Session
312-
where
313-
S: ListeningSocket + fmt::Debug + Send,
314-
{
301+
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
315302
Self::default()
316303
}
317304
}
305+
306+
#[cfg(windows)]
307+
impl<T> Agent<NamedPipeListener> for T
308+
where
309+
T: Default + Send + Sync + Session,
310+
{
311+
fn new_session(
312+
&mut self,
313+
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
314+
) -> impl Session {
315+
Self::default()
316+
}
317+
}
318+
319+
/// Bind to a service binding listener.
320+
#[cfg(unix)]
321+
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
322+
where
323+
SF: Agent<tokio::net::UnixListener> + Agent<tokio::net::TcpListener>,
324+
{
325+
match listener {
326+
#[cfg(unix)]
327+
service_binding::Listener::Unix(listener) => {
328+
listen(UnixListener::from_std(listener)?, sf).await
329+
}
330+
service_binding::Listener::Tcp(listener) => {
331+
listen(TcpListener::from_std(listener)?, sf).await
332+
}
333+
_ => Err(AgentError::IO(std::io::Error::other(
334+
"Unsupported type of a listener.",
335+
))),
336+
}
337+
}
338+
339+
/// Bind to a service binding listener.
340+
#[cfg(windows)]
341+
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
342+
where
343+
SF: Agent<NamedPipeListener> + Agent<tokio::net::TcpListener>,
344+
{
345+
match listener {
346+
service_binding::Listener::Tcp(listener) => {
347+
listen(TcpListener::from_std(listener)?, sf).await
348+
}
349+
service_binding::Listener::NamedPipe(pipe) => {
350+
listen(NamedPipeListener::bind(pipe)?, sf).await
351+
}
352+
}
353+
}

src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,3 @@ pub mod error;
1515

1616
#[cfg(feature = "agent")]
1717
pub use async_trait::async_trait;
18-
19-
#[cfg(feature = "agent")]
20-
pub use self::agent::Agent;

0 commit comments

Comments
 (0)