Skip to content

feat: add basic auth #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Setup Build Environment
run: sudo apt update && sudo apt install -y protobuf-compiler
- name: Install cargo binaries
run: cargo install cargo-sort
run: cargo install cargo-sort --locked
- name: Run Style Check
run: make fmt clippy check-toml

Expand Down
18 changes: 13 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ description = "Rust implementation of HoraeDB client."
readme = "README.md"

[dependencies]
anyhow = "1.0.83"
arrow = "38.0.0"
async-trait = "0.1.72"
base64 = "0.22.1"
dashmap = "5.3.4"
futures = "0.3"
horaedbproto = "1.0.23"
Expand Down
10 changes: 8 additions & 2 deletions examples/read_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use horaedb_client::{
value::Value,
write::{point::PointBuilder, Request as WriteRequest},
},
RpcContext,
Authorization, RpcContext,
};

async fn create_table(client: &Arc<dyn DbClient>, rpc_ctx: &RpcContext) {
Expand Down Expand Up @@ -112,7 +112,13 @@ async fn sql_query(client: &Arc<dyn DbClient>, rpc_ctx: &RpcContext) {
#[tokio::main]
async fn main() {
// you should ensure horaedb is running, and grpc port is set to 8831
let client = Builder::new("127.0.0.1:8831".to_string(), Mode::Direct).build();
let client = Builder::new("127.0.0.1:8831".to_string(), Mode::Direct)
// Set authorization if needed
.authorization(Authorization {
username: "user".to_string(),
password: "pass".to_string(),
})
.build();
let rpc_ctx = RpcContext::default().database("public".to_string());

println!("------------------------------------------------------------------");
Expand Down
6 changes: 6 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ pub struct RpcConfig {
pub connect_timeout: Duration,
}

#[derive(Debug, Clone)]
pub struct Authorization {
pub username: String,
pub password: String,
}

impl Default for RpcConfig {
fn default() -> Self {
Self {
Expand Down
15 changes: 13 additions & 2 deletions src/db_client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use crate::{
db_client::{raw::RawImpl, route_based::RouteBasedImpl, DbClient},
rpc_client::RpcClientImplFactory,
RpcConfig,
Authorization, RpcConfig,
};

/// Access mode to HoraeDB server(s).
Expand All @@ -40,6 +40,7 @@ pub struct Builder {
endpoint: String,
default_database: Option<String>,
rpc_config: RpcConfig,
authorization: Option<Authorization>,
}

impl Builder {
Expand All @@ -50,6 +51,7 @@ impl Builder {
endpoint,
rpc_config: RpcConfig::default(),
default_database: None,
authorization: None,
}
}

Expand All @@ -65,8 +67,17 @@ impl Builder {
self
}

#[inline]
pub fn authorization(mut self, authorization: Authorization) -> Self {
self.authorization = Some(authorization);
self
}

pub fn build(self) -> Arc<dyn DbClient> {
let rpc_client_factory = Arc::new(RpcClientImplFactory::new(self.rpc_config));
let rpc_client_factory = Arc::new(RpcClientImplFactory::new(
self.rpc_config,
self.authorization,
));

match self.mode {
Mode::Direct => Arc::new(RouteBasedImpl::new(
Expand Down
6 changes: 6 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ pub enum Error {

#[error("failed to find a database")]
NoDatabase,

#[error(transparent)]
Other {
#[from]
source: anyhow::Error,
},
}

#[derive(Debug)]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mod util;

#[doc(inline)]
pub use crate::{
config::RpcConfig,
config::{Authorization, RpcConfig},
db_client::{Builder, DbClient, Mode},
errors::{Error, Result},
model::{
Expand Down
42 changes: 36 additions & 6 deletions src/rpc_client/rpc_client_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

use std::{sync::Arc, time::Duration};

use anyhow::Context;
use async_trait::async_trait;
use base64::{prelude::BASE64_STANDARD, Engine};
use horaedbproto::{
common::ResponseHeader,
storage::{
Expand All @@ -24,6 +26,7 @@ use horaedbproto::{
},
};
use tonic::{
metadata::{Ascii, MetadataValue},
transport::{Channel, Endpoint},
Request,
};
Expand All @@ -33,24 +36,28 @@ use crate::{
errors::{Error, Result, ServerError},
rpc_client::{RpcClient, RpcClientFactory, RpcContext},
util::is_ok,
Authorization,
};

struct RpcClientImpl {
channel: Channel,
default_read_timeout: Duration,
default_write_timeout: Duration,
metadata: Option<MetadataValue<Ascii>>,
}

impl RpcClientImpl {
fn new(
channel: Channel,
default_read_timeout: Duration,
default_write_timeout: Duration,
metadata: Option<MetadataValue<Ascii>>,
) -> Self {
Self {
channel,
default_read_timeout,
default_write_timeout,
metadata,
}
}

Expand All @@ -65,19 +72,22 @@ impl RpcClientImpl {
Ok(())
}

fn make_request<T>(ctx: &RpcContext, req: T, default_timeout: Duration) -> Request<T> {
fn make_request<T>(&self, ctx: &RpcContext, req: T, default_timeout: Duration) -> Request<T> {
let timeout = ctx.timeout.unwrap_or(default_timeout);
let mut req = Request::new(req);
req.set_timeout(timeout);
if let Some(md) = &self.metadata {
req.metadata_mut().insert("authorization", md.clone());
}
req
}

fn make_query_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
Self::make_request(ctx, req, self.default_read_timeout)
self.make_request(ctx, req, self.default_read_timeout)
}

fn make_write_request<T>(&self, ctx: &RpcContext, req: T) -> Request<T> {
Self::make_request(ctx, req, self.default_write_timeout)
self.make_request(ctx, req, self.default_write_timeout)
}
}

Expand Down Expand Up @@ -119,7 +129,7 @@ impl RpcClient for RpcClientImpl {
let mut client = StorageServiceClient::<Channel>::new(self.channel.clone());

// use the write timeout for the route request.
let route_req = Self::make_request(ctx, req, self.default_write_timeout);
let route_req = self.make_request(ctx, req, self.default_write_timeout);
let resp = client.route(route_req).await.map_err(Error::Rpc)?;
let mut resp = resp.into_inner();

Expand All @@ -133,11 +143,15 @@ impl RpcClient for RpcClientImpl {

pub struct RpcClientImplFactory {
rpc_config: RpcConfig,
authorization: Option<Authorization>,
}

impl RpcClientImplFactory {
pub fn new(rpc_config: RpcConfig) -> Self {
Self { rpc_config }
pub fn new(rpc_config: RpcConfig, authorization: Option<Authorization>) -> Self {
Self {
rpc_config,
authorization,
}
}

#[inline]
Expand Down Expand Up @@ -174,10 +188,26 @@ impl RpcClientFactory for RpcClientImplFactory {
addr: endpoint,
source: Box::new(e),
})?;

let metadata = if let Some(auth) = &self.authorization {
let mut buf = Vec::with_capacity(auth.username.len() + auth.password.len() + 1);
buf.extend_from_slice(auth.username.as_bytes());
buf.push(b':');
buf.extend_from_slice(auth.password.as_bytes());
let auth = BASE64_STANDARD.encode(&buf);
let metadata: MetadataValue<Ascii> = format!("Basic {}", auth)
.parse()
.context("invalid grpc metadata")?;

Some(metadata)
} else {
None
};
Ok(Arc::new(RpcClientImpl::new(
channel,
self.rpc_config.default_sql_query_timeout,
self.rpc_config.default_write_timeout,
metadata,
)))
}
}
Loading