Skip to content

Commit bd36287

Browse files
authored
Merge pull request #60 from wiktor-k/wiktor/add-socket-to-new-session
Expose socket info in `new_session`
2 parents f07a436 + 8c3fb5b commit bd36287

File tree

6 files changed

+183
-99
lines changed

6 files changed

+183
-99
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ use tokio::net::UnixListener as Listener;
1818
#[cfg(windows)]
1919
use ssh_agent_lib::agent::NamedPipeListener as Listener;
2020
use ssh_agent_lib::error::AgentError;
21-
use ssh_agent_lib::agent::{Session, Agent};
21+
use ssh_agent_lib::agent::{Session, listen};
2222
use ssh_agent_lib::proto::{Identity, SignRequest};
2323
use ssh_key::{Algorithm, Signature};
2424
25-
#[derive(Default)]
25+
#[derive(Default, Clone)]
2626
struct MyAgent;
2727
2828
#[ssh_agent_lib::async_trait]
@@ -50,7 +50,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
5050
5151
let _ = std::fs::remove_file(socket); // remove the socket if exists
5252
53-
MyAgent.listen(Listener::bind(socket)?).await?;
53+
listen(Listener::bind(socket)?, MyAgent::default()).await?;
5454
Ok(())
5555
}
5656
```

examples/agent-socket-info.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//! This example shows how to access the underlying socket info.
2+
//! The socket info can be used to implement fine-grained access controls based on UID/GID.
3+
//!
4+
//! Run the example with: `cargo run --example agent-socket-info -- -H unix:///tmp/sock`
5+
//! Then inspect the socket info with: `SSH_AUTH_SOCK=/tmp/sock ssh-add -L` which should display
6+
//! something like this:
7+
//!
8+
//! ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA unix: addr: (unnamed) cred: UCred { pid: Some(68463), uid: 1000, gid: 1000 }
9+
10+
use clap::Parser;
11+
use service_binding::Binding;
12+
use ssh_agent_lib::{
13+
agent::{bind, Agent, Session},
14+
error::AgentError,
15+
proto::Identity,
16+
};
17+
use ssh_key::public::KeyData;
18+
use testresult::TestResult;
19+
20+
#[derive(Debug, Default)]
21+
struct AgentSocketInfo {
22+
comment: String,
23+
}
24+
25+
#[ssh_agent_lib::async_trait]
26+
impl Session for AgentSocketInfo {
27+
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28+
Ok(vec![Identity {
29+
// this is just a dummy key, the comment is important
30+
pubkey: KeyData::Ed25519(ssh_key::public::Ed25519PublicKey([0; 32])),
31+
comment: self.comment.clone(),
32+
}])
33+
}
34+
}
35+
36+
#[cfg(unix)]
37+
impl Agent<tokio::net::UnixListener> for AgentSocketInfo {
38+
fn new_session(&mut self, socket: &tokio::net::UnixStream) -> impl Session {
39+
Self {
40+
comment: format!(
41+
"unix: addr: {:?} cred: {:?}",
42+
socket.peer_addr().unwrap(),
43+
socket.peer_cred().unwrap()
44+
),
45+
}
46+
}
47+
}
48+
49+
impl Agent<tokio::net::TcpListener> for AgentSocketInfo {
50+
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
51+
Self {
52+
comment: "tcp".into(),
53+
}
54+
}
55+
}
56+
57+
#[cfg(windows)]
58+
impl Agent<ssh_agent_lib::agent::NamedPipeListener> for AgentSocketInfo {
59+
fn new_session(
60+
&mut self,
61+
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
62+
) -> impl Session {
63+
Self {
64+
comment: "pipe".into(),
65+
}
66+
}
67+
}
68+
69+
#[derive(Debug, Parser)]
70+
struct Args {
71+
#[clap(short = 'H', long)]
72+
host: Binding,
73+
}
74+
75+
#[tokio::main]
76+
async fn main() -> TestResult {
77+
env_logger::init();
78+
79+
let args = Args::parse();
80+
bind(args.host.try_into()?, AgentSocketInfo::default()).await?;
81+
Ok(())
82+
}

examples/key_storage.rs

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ use rsa::BigUint;
99
use sha1::Sha1;
1010
#[cfg(windows)]
1111
use ssh_agent_lib::agent::NamedPipeListener as Listener;
12-
use ssh_agent_lib::agent::Session;
12+
use ssh_agent_lib::agent::{listen, Session};
1313
use ssh_agent_lib::error::AgentError;
1414
use ssh_agent_lib::proto::extension::{QueryResponse, RestrictDestination, SessionBind};
1515
use ssh_agent_lib::proto::{
1616
message, signature, AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained,
1717
Credential, Extension, KeyConstraint, RemoveIdentity, SignRequest, SmartcardKey,
1818
};
19-
use ssh_agent_lib::Agent;
2019
use ssh_key::{
2120
private::{KeypairData, PrivateKey},
2221
public::PublicKey,
@@ -32,6 +31,7 @@ struct Identity {
3231
comment: String,
3332
}
3433

34+
#[derive(Default, Clone)]
3535
struct KeyStorage {
3636
identities: Arc<Mutex<Vec<Identity>>>,
3737
}
@@ -225,26 +225,6 @@ impl Session for KeyStorage {
225225
}
226226
}
227227

228-
struct KeyStorageAgent {
229-
identities: Arc<Mutex<Vec<Identity>>>,
230-
}
231-
232-
impl KeyStorageAgent {
233-
fn new() -> Self {
234-
Self {
235-
identities: Arc::new(Mutex::new(vec![])),
236-
}
237-
}
238-
}
239-
240-
impl Agent for KeyStorageAgent {
241-
fn new_session(&mut self) -> impl Session {
242-
KeyStorage {
243-
identities: Arc::clone(&self.identities),
244-
}
245-
}
246-
}
247-
248228
#[tokio::main]
249229
async fn main() -> Result<(), AgentError> {
250230
env_logger::init();
@@ -260,8 +240,6 @@ async fn main() -> Result<(), AgentError> {
260240
#[cfg(windows)]
261241
std::fs::File::create("server-started")?;
262242

263-
KeyStorageAgent::new()
264-
.listen(Listener::bind(socket)?)
265-
.await?;
243+
listen(Listener::bind(socket)?, KeyStorage::default()).await?;
266244
Ok(())
267245
}

examples/openpgp-card-agent.rs

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,29 @@ use retainer::{Cache, CacheExpiration};
2727
use secrecy::{ExposeSecret, SecretString};
2828
use service_binding::Binding;
2929
use ssh_agent_lib::{
30-
agent::Session,
30+
agent::{bind, Session},
3131
error::AgentError,
3232
proto::{AddSmartcardKeyConstrained, Identity, KeyConstraint, SignRequest, SmartcardKey},
33-
Agent,
3433
};
3534
use ssh_key::{
3635
public::{Ed25519PublicKey, KeyData},
3736
Algorithm, Signature,
3837
};
3938
use testresult::TestResult;
4039

41-
struct CardAgent {
40+
#[derive(Clone)]
41+
struct CardSession {
4242
pwds: Arc<Cache<String, SecretString>>,
4343
}
4444

45-
impl CardAgent {
45+
impl CardSession {
4646
pub fn new() -> Self {
4747
let pwds: Arc<Cache<String, SecretString>> = Arc::new(Default::default());
4848
let clone = Arc::clone(&pwds);
4949
tokio::spawn(async move { clone.monitor(4, 0.25, Duration::from_secs(3)).await });
5050
Self { pwds }
5151
}
52-
}
53-
54-
impl Agent for CardAgent {
55-
fn new_session(&mut self) -> impl Session {
56-
CardSession {
57-
pwds: Arc::clone(&self.pwds),
58-
}
59-
}
60-
}
61-
62-
struct CardSession {
63-
pwds: Arc<Cache<String, SecretString>>,
64-
}
6552

66-
impl CardSession {
6753
async fn handle_sign(
6854
&self,
6955
request: SignRequest,
@@ -201,6 +187,6 @@ async fn main() -> TestResult {
201187
env_logger::init();
202188

203189
let args = Args::parse();
204-
CardAgent::new().bind(args.host.try_into()?).await?;
190+
bind(args.host.try_into()?, CardSession::new()).await?;
205191
Ok(())
206192
}

src/agent.rs

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -249,64 +249,105 @@ where
249249
}
250250
}
251251

252-
/// Type representing an agent listening for incoming connections.
253-
#[async_trait]
254-
pub trait Agent: 'static + Sync + Send + Sized {
255-
/// Create new session object when a new socket is accepted.
256-
fn new_session(&mut self) -> impl Session;
257-
258-
/// Listen on a socket waiting for client connections.
259-
async fn listen<S>(mut self, mut socket: S) -> Result<(), AgentError>
260-
where
261-
S: ListeningSocket + fmt::Debug + Send,
262-
{
263-
log::info!("Listening; socket = {:?}", socket);
264-
loop {
265-
match socket.accept().await {
266-
Ok(socket) => {
267-
let session = self.new_session();
268-
tokio::spawn(async move {
269-
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
270-
if let Err(e) = handle_socket::<S>(session, adapter).await {
271-
log::error!("Agent protocol error: {:?}", e);
272-
}
273-
});
274-
}
275-
Err(e) => {
276-
log::error!("Failed to accept socket: {:?}", e);
277-
return Err(AgentError::IO(e));
278-
}
252+
/// Factory of sessions for the given type of sockets.
253+
pub trait Agent<S>: 'static + Send + Sync
254+
where
255+
S: ListeningSocket + fmt::Debug + Send,
256+
{
257+
/// Create a [`Session`] object for a given `socket`.
258+
fn new_session(&mut self, socket: &S::Stream) -> impl Session;
259+
}
260+
261+
/// Listen for connections on a given socket and use session factory
262+
/// to create new session for each accepted socket.
263+
pub async fn listen<S>(mut socket: S, mut sf: impl Agent<S>) -> Result<(), AgentError>
264+
where
265+
S: ListeningSocket + fmt::Debug + Send,
266+
{
267+
log::info!("Listening; socket = {:?}", socket);
268+
loop {
269+
match socket.accept().await {
270+
Ok(socket) => {
271+
let session = sf.new_session(&socket);
272+
tokio::spawn(async move {
273+
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
274+
if let Err(e) = handle_socket::<S>(session, adapter).await {
275+
log::error!("Agent protocol error: {:?}", e);
276+
}
277+
});
278+
}
279+
Err(e) => {
280+
log::error!("Failed to accept socket: {:?}", e);
281+
return Err(AgentError::IO(e));
279282
}
280283
}
281284
}
285+
}
282286

283-
/// Bind to a service binding listener.
284-
async fn bind(mut self, listener: service_binding::Listener) -> Result<(), AgentError> {
285-
match listener {
286-
#[cfg(unix)]
287-
service_binding::Listener::Unix(listener) => {
288-
self.listen(UnixListener::from_std(listener)?).await
289-
}
290-
service_binding::Listener::Tcp(listener) => {
291-
self.listen(TcpListener::from_std(listener)?).await
292-
}
293-
#[cfg(windows)]
294-
service_binding::Listener::NamedPipe(pipe) => {
295-
self.listen(NamedPipeListener::bind(pipe)?).await
296-
}
297-
#[cfg(not(windows))]
298-
service_binding::Listener::NamedPipe(_) => Err(AgentError::IO(std::io::Error::other(
299-
"Named pipes supported on Windows only",
300-
))),
287+
#[cfg(unix)]
288+
impl<T> Agent<tokio::net::UnixListener> for T
289+
where
290+
T: Clone + Send + Sync + Session,
291+
{
292+
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
293+
Self::clone(self)
294+
}
295+
}
296+
297+
impl<T> Agent<tokio::net::TcpListener> for T
298+
where
299+
T: Clone + Send + Sync + Session,
300+
{
301+
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
302+
Self::clone(self)
303+
}
304+
}
305+
306+
#[cfg(windows)]
307+
impl<T> Agent<NamedPipeListener> for T
308+
where
309+
T: Clone + Send + Sync + Session,
310+
{
311+
fn new_session(
312+
&mut self,
313+
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
314+
) -> impl Session {
315+
Self::clone(self)
316+
}
317+
}
318+
319+
/// Bind to a service binding listener.
320+
#[cfg(unix)]
321+
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
322+
where
323+
SF: Agent<tokio::net::UnixListener> + Agent<tokio::net::TcpListener>,
324+
{
325+
match listener {
326+
#[cfg(unix)]
327+
service_binding::Listener::Unix(listener) => {
328+
listen(UnixListener::from_std(listener)?, sf).await
329+
}
330+
service_binding::Listener::Tcp(listener) => {
331+
listen(TcpListener::from_std(listener)?, sf).await
301332
}
333+
_ => Err(AgentError::IO(std::io::Error::other(
334+
"Unsupported type of a listener.",
335+
))),
302336
}
303337
}
304338

305-
impl<T> Agent for T
339+
/// Bind to a service binding listener.
340+
#[cfg(windows)]
341+
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
306342
where
307-
T: Default + Session,
343+
SF: Agent<NamedPipeListener> + Agent<tokio::net::TcpListener>,
308344
{
309-
fn new_session(&mut self) -> impl Session {
310-
Self::default()
345+
match listener {
346+
service_binding::Listener::Tcp(listener) => {
347+
listen(TcpListener::from_std(listener)?, sf).await
348+
}
349+
service_binding::Listener::NamedPipe(pipe) => {
350+
listen(NamedPipeListener::bind(pipe)?, sf).await
351+
}
311352
}
312353
}

src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,3 @@ pub mod error;
1515

1616
#[cfg(feature = "agent")]
1717
pub use async_trait::async_trait;
18-
19-
#[cfg(feature = "agent")]
20-
pub use self::agent::Agent;

0 commit comments

Comments
 (0)