Skip to content

Commit 811885f

Browse files
authored
Actually plugins (#421)
* more plugins * clean up * fix tests * fix flakey test
1 parent d5e329f commit 811885f

File tree

11 files changed

+264
-70
lines changed

11 files changed

+264
-70
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ serde_json = "1"
4848

4949
[target.'cfg(not(target_env = "msvc"))'.dependencies]
5050
jemallocator = "0.5.0"
51+

pgcat.toml

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ admin_username = "admin_user"
7777
# Password to access the virtual administrative database
7878
admin_password = "admin_pass"
7979

80-
# Plugins!!
81-
# query_router_plugins = ["pg_table_access", "intercept"]
82-
8380
# pool configs are structured as pool.<pool_name>
8481
# the pool_name is what clients use as database name when connecting.
8582
# For a pool named `sharded_db`, clients access that pool using connection string like
@@ -157,6 +154,45 @@ connect_timeout = 3000
157154
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
158155
# dns_max_ttl = 30
159156

157+
[plugins]
158+
159+
[plugins.query_logger]
160+
enabled = false
161+
162+
[plugins.table_access]
163+
enabled = false
164+
tables = [
165+
"pg_user",
166+
"pg_roles",
167+
"pg_database",
168+
]
169+
170+
[plugins.intercept]
171+
enabled = true
172+
173+
[plugins.intercept.queries.0]
174+
175+
query = "select current_database() as a, current_schemas(false) as b"
176+
schema = [
177+
["a", "text"],
178+
["b", "text"],
179+
]
180+
result = [
181+
["${DATABASE}", "{public}"],
182+
]
183+
184+
[plugins.intercept.queries.1]
185+
186+
query = "select current_database(), current_schema(), current_user"
187+
schema = [
188+
["current_database", "text"],
189+
["current_schema", "text"],
190+
["current_user", "text"],
191+
]
192+
result = [
193+
["${DATABASE}", "public", "${USER}"],
194+
]
195+
160196
# User configs are structured as pool.<pool_name>.users.<user_index>
161197
# This section holds the credentials for users that may connect to this cluster
162198
[pools.sharded_db.users.0]

src/config.rs

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ pub struct General {
302302
pub auth_query: Option<String>,
303303
pub auth_query_user: Option<String>,
304304
pub auth_query_password: Option<String>,
305-
306-
pub query_router_plugins: Option<Vec<String>>,
307305
}
308306

309307
impl General {
@@ -404,7 +402,6 @@ impl Default for General {
404402
auth_query_user: None,
405403
auth_query_password: None,
406404
server_lifetime: 1000 * 3600 * 24, // 24 hours,
407-
query_router_plugins: None,
408405
}
409406
}
410407
}
@@ -682,6 +679,55 @@ impl Default for Shard {
682679
}
683680
}
684681

682+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
683+
pub struct Plugins {
684+
pub intercept: Option<Intercept>,
685+
pub table_access: Option<TableAccess>,
686+
pub query_logger: Option<QueryLogger>,
687+
}
688+
689+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
690+
pub struct Intercept {
691+
pub enabled: bool,
692+
pub queries: BTreeMap<String, Query>,
693+
}
694+
695+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
696+
pub struct TableAccess {
697+
pub enabled: bool,
698+
pub tables: Vec<String>,
699+
}
700+
701+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
702+
pub struct QueryLogger {
703+
pub enabled: bool,
704+
}
705+
706+
impl Intercept {
707+
pub fn substitute(&mut self, db: &str, user: &str) {
708+
for (_, query) in self.queries.iter_mut() {
709+
query.substitute(db, user);
710+
}
711+
}
712+
}
713+
714+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
715+
pub struct Query {
716+
pub query: String,
717+
pub schema: Vec<Vec<String>>,
718+
pub result: Vec<Vec<String>>,
719+
}
720+
721+
impl Query {
722+
pub fn substitute(&mut self, db: &str, user: &str) {
723+
for col in self.result.iter_mut() {
724+
for i in 0..col.len() {
725+
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
726+
}
727+
}
728+
}
729+
}
730+
685731
/// Configuration wrapper.
686732
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
687733
pub struct Config {
@@ -700,6 +746,7 @@ pub struct Config {
700746
pub path: String,
701747

702748
pub general: General,
749+
pub plugins: Option<Plugins>,
703750
pub pools: HashMap<String, Pool>,
704751
}
705752

@@ -737,6 +784,7 @@ impl Default for Config {
737784
path: Self::default_path(),
738785
general: General::default(),
739786
pools: HashMap::default(),
787+
plugins: None,
740788
}
741789
}
742790
}
@@ -1128,25 +1176,26 @@ pub async fn parse(path: &str) -> Result<(), Error> {
11281176

11291177
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
11301178
let old_config = get_config();
1179+
11311180
match parse(&old_config.path).await {
11321181
Ok(()) => (),
11331182
Err(err) => {
11341183
error!("Config reload error: {:?}", err);
11351184
return Err(Error::BadConfig);
11361185
}
11371186
};
1187+
11381188
let new_config = get_config();
1189+
11391190
match CachedResolver::from_config().await {
11401191
Ok(_) => (),
11411192
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
11421193
};
11431194

1144-
if old_config.pools != new_config.pools {
1145-
info!("Pool configuration changed");
1195+
if old_config != new_config {
1196+
info!("Config changed, reloading");
11461197
ConnectionPool::from_config(client_server_map).await?;
11471198
Ok(true)
1148-
} else if old_config != new_config {
1149-
Ok(true)
11501199
} else {
11511200
Ok(false)
11521201
}

src/plugins/intercept.rs

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,41 @@ use serde_json::{json, Value};
1111
use sqlparser::ast::Statement;
1212
use std::collections::HashMap;
1313

14-
use log::debug;
14+
use log::{debug, info};
1515
use std::sync::Arc;
1616

1717
use crate::{
18+
config::Intercept as InterceptConfig,
1819
errors::Error,
1920
messages::{command_complete, data_row_nullable, row_description, DataType},
2021
plugins::{Plugin, PluginOutput},
2122
pool::{PoolIdentifier, PoolMap},
2223
query_router::QueryRouter,
2324
};
2425

25-
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, Value>>> =
26+
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
2627
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
2728

28-
/// Configure the intercept plugin.
29-
pub fn configure(pools: &PoolMap) {
29+
/// Check if the interceptor plugin has been enabled.
30+
pub fn enabled() -> bool {
31+
!CONFIG.load().is_empty()
32+
}
33+
34+
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
3035
let mut config = HashMap::new();
3136
for (identifier, _) in pools.iter() {
32-
// TODO: make this configurable from a text config.
33-
let value = fool_datagrip(&identifier.db, &identifier.user);
34-
config.insert(identifier.clone(), value);
37+
let mut intercept_config = intercept_config.clone();
38+
intercept_config.substitute(&identifier.db, &identifier.user);
39+
config.insert(identifier.clone(), intercept_config);
3540
}
3641

3742
CONFIG.store(Arc::new(config));
43+
44+
info!("Intercepting {} queries", intercept_config.queries.len());
45+
}
46+
47+
pub fn disable() {
48+
CONFIG.store(Arc::new(HashMap::new()));
3849
}
3950

4051
// TODO: use these structs for deserialization
@@ -78,19 +89,19 @@ impl Plugin for Intercept {
7889
// Normalization
7990
let q = q.to_string().to_ascii_lowercase();
8091

81-
for target in query_map.as_array().unwrap().iter() {
82-
if target["query"].as_str().unwrap() == q {
83-
debug!("Query matched: {}", q);
92+
for (_, target) in query_map.queries.iter() {
93+
if target.query.as_str() == q {
94+
debug!("Intercepting query: {}", q);
8495

85-
let rd = target["schema"]
86-
.as_array()
87-
.unwrap()
96+
let rd = target
97+
.schema
8898
.iter()
8999
.map(|row| {
90-
let row = row.as_object().unwrap();
100+
let name = &row[0];
101+
let data_type = &row[1];
91102
(
92-
row["name"].as_str().unwrap(),
93-
match row["data_type"].as_str().unwrap() {
103+
name.as_str(),
104+
match data_type.as_str() {
94105
"text" => DataType::Text,
95106
"anyarray" => DataType::AnyArray,
96107
"oid" => DataType::Oid,
@@ -104,13 +115,11 @@ impl Plugin for Intercept {
104115

105116
result.put(row_description(&rd));
106117

107-
target["result"].as_array().unwrap().iter().for_each(|row| {
118+
target.result.iter().for_each(|row| {
108119
let row = row
109-
.as_array()
110-
.unwrap()
111120
.iter()
112121
.map(|s| {
113-
let s = s.as_str().unwrap().to_string();
122+
let s = s.as_str().to_string();
114123

115124
if s == "" {
116125
None
@@ -141,6 +150,7 @@ impl Plugin for Intercept {
141150

142151
/// Make IntelliJ SQL plugin believe it's talking to an actual database
143152
/// instead of PgCat.
153+
#[allow(dead_code)]
144154
fn fool_datagrip(database: &str, user: &str) -> Value {
145155
json!([
146156
{

src/plugins/mod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//!
1010
1111
pub mod intercept;
12+
pub mod query_logger;
1213
pub mod table_access;
1314

1415
use crate::{errors::Error, query_router::QueryRouter};
@@ -17,6 +18,7 @@ use bytes::BytesMut;
1718
use sqlparser::ast::Statement;
1819

1920
pub use intercept::Intercept;
21+
pub use query_logger::QueryLogger;
2022
pub use table_access::TableAccess;
2123

2224
#[derive(Clone, Debug, PartialEq)]
@@ -29,12 +31,13 @@ pub enum PluginOutput {
2931

3032
#[async_trait]
3133
pub trait Plugin {
32-
// Custom output is allowed because we want to extend this system
33-
// to rewriting queries some day. So an output of a plugin could be
34-
// a rewritten AST.
34+
// Run before the query is sent to the server.
3535
async fn run(
3636
&mut self,
3737
query_router: &QueryRouter,
3838
ast: &Vec<Statement>,
3939
) -> Result<PluginOutput, Error>;
40+
41+
// TODO: run after the result is returned
42+
// async fn callback(&mut self, query_router: &QueryRouter);
4043
}

src/plugins/query_logger.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//! Log all queries to stdout (or somewhere else, why not).
2+
3+
use crate::{
4+
errors::Error,
5+
plugins::{Plugin, PluginOutput},
6+
query_router::QueryRouter,
7+
};
8+
use arc_swap::ArcSwap;
9+
use async_trait::async_trait;
10+
use log::info;
11+
use once_cell::sync::Lazy;
12+
use sqlparser::ast::Statement;
13+
use std::sync::Arc;
14+
15+
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
16+
17+
pub struct QueryLogger;
18+
19+
pub fn setup() {
20+
ENABLED.store(Arc::new(true));
21+
22+
info!("Logging queries to stdout");
23+
}
24+
25+
pub fn disable() {
26+
ENABLED.store(Arc::new(false));
27+
}
28+
29+
pub fn enabled() -> bool {
30+
**ENABLED.load()
31+
}
32+
33+
#[async_trait]
34+
impl Plugin for QueryLogger {
35+
async fn run(
36+
&mut self,
37+
_query_router: &QueryRouter,
38+
ast: &Vec<Statement>,
39+
) -> Result<PluginOutput, Error> {
40+
let query = ast
41+
.iter()
42+
.map(|q| q.to_string())
43+
.collect::<Vec<String>>()
44+
.join("; ");
45+
info!("{}", query);
46+
47+
Ok(PluginOutput::Allow)
48+
}
49+
}

0 commit comments

Comments
 (0)