Skip to content

Commit 2a8f365

Browse files
authored
Fix COPY FROM and add tests (#522)
* Fix COPY FROM and add tests * E * fmt
1 parent 19cb8a3 commit 2a8f365

File tree

3 files changed

+132
-3
lines changed

3 files changed

+132
-3
lines changed

src/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ where
12601260

12611261
// Release server back to the pool if we are in transaction mode.
12621262
// If we are in session mode, we keep the server until the client disconnects.
1263-
if self.transaction_mode {
1263+
if self.transaction_mode && !server.in_copy_mode() {
12641264
self.stats.idle();
12651265

12661266
break;
@@ -1410,7 +1410,7 @@ where
14101410

14111411
// Release server back to the pool if we are in transaction mode.
14121412
// If we are in session mode, we keep the server until the client disconnects.
1413-
if self.transaction_mode {
1413+
if self.transaction_mode && !server.in_copy_mode() {
14141414
break;
14151415
}
14161416
}

src/server.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ pub struct Server {
170170
/// Is there more data for the client to read.
171171
data_available: bool,
172172

173+
/// Is the server in copy-in or copy-out modes
174+
in_copy_mode: bool,
175+
173176
/// Is the server broken? We'll remote it from the pool if so.
174177
bad: bool,
175178

@@ -677,6 +680,7 @@ impl Server {
677680
process_id,
678681
secret_key,
679682
in_transaction: false,
683+
in_copy_mode: false,
680684
data_available: false,
681685
bad: false,
682686
cleanup_state: CleanupState::new(),
@@ -828,8 +832,19 @@ impl Server {
828832
break;
829833
}
830834

835+
// ErrorResponse
836+
'E' => {
837+
if self.in_copy_mode {
838+
self.in_copy_mode = false;
839+
}
840+
}
841+
831842
// CommandComplete
832843
'C' => {
844+
if self.in_copy_mode {
845+
self.in_copy_mode = false;
846+
}
847+
833848
let mut command_tag = String::new();
834849
match message.reader().read_to_string(&mut command_tag) {
835850
Ok(_) => {
@@ -873,10 +888,14 @@ impl Server {
873888
}
874889

875890
// CopyInResponse: copy is starting from client to server.
876-
'G' => break,
891+
'G' => {
892+
self.in_copy_mode = true;
893+
break;
894+
}
877895

878896
// CopyOutResponse: copy is starting from the server to the client.
879897
'H' => {
898+
self.in_copy_mode = true;
880899
self.data_available = true;
881900
break;
882901
}
@@ -1030,6 +1049,10 @@ impl Server {
10301049
self.in_transaction
10311050
}
10321051

1052+
pub fn in_copy_mode(&self) -> bool {
1053+
self.in_copy_mode
1054+
}
1055+
10331056
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
10341057
/// The client is responsible to call `self.recv()` while this method returns true.
10351058
pub fn is_data_available(&self) -> bool {
@@ -1129,6 +1152,10 @@ impl Server {
11291152
self.cleanup_state.reset();
11301153
}
11311154

1155+
if self.in_copy_mode() {
1156+
warn!("Server returned while still in copy-mode");
1157+
}
1158+
11321159
Ok(())
11331160
}
11341161

tests/ruby/copy_spec.rb

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# frozen_string_literal: true
2+
require_relative 'spec_helper'
3+
4+
5+
describe "COPY Handling" do
6+
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
7+
before do
8+
new_configs = processes.pgcat.current_config
9+
10+
# Allow connections in the pool to expire faster
11+
new_configs["general"]["idle_timeout"] = 5
12+
processes.pgcat.update_config(new_configs)
13+
# We need to kill the old process that was using the default configs
14+
processes.pgcat.stop
15+
processes.pgcat.start
16+
processes.pgcat.wait_until_ready
17+
end
18+
19+
before do
20+
processes.all_databases.first.with_connection do |conn|
21+
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
22+
end
23+
end
24+
25+
after do
26+
processes.all_databases.first.with_connection do |conn|
27+
conn.async_exec "DROP TABLE copy_test_table;"
28+
end
29+
end
30+
31+
after do
32+
processes.all_databases.map(&:reset)
33+
processes.pgcat.shutdown
34+
end
35+
36+
describe "COPY FROM" do
37+
context "within transaction" do
38+
it "finishes within alloted time" do
39+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
40+
Timeout.timeout(3) do
41+
conn.async_exec("BEGIN")
42+
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
43+
sleep 0.5
44+
conn.put_copy_data "some,data,to,copy\n"
45+
conn.put_copy_data "more,data,to,copy\n"
46+
end
47+
conn.async_exec("COMMIT")
48+
end
49+
50+
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
51+
expect(res).to eq([
52+
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
53+
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
54+
])
55+
end
56+
end
57+
58+
context "outside transaction" do
59+
it "finishes within alloted time" do
60+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
61+
Timeout.timeout(3) do
62+
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
63+
sleep 0.5
64+
conn.put_copy_data "some,data,to,copy\n"
65+
conn.put_copy_data "more,data,to,copy\n"
66+
end
67+
end
68+
69+
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
70+
expect(res).to eq([
71+
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
72+
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
73+
])
74+
end
75+
end
76+
end
77+
78+
describe "COPY TO" do
79+
before do
80+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
81+
conn.async_exec("BEGIN")
82+
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
83+
conn.put_copy_data "some,data,to,copy\n"
84+
conn.put_copy_data "more,data,to,copy\n"
85+
end
86+
conn.async_exec("COMMIT")
87+
conn.close
88+
end
89+
90+
it "works" do
91+
res = []
92+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
93+
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
94+
while row=conn.get_copy_data
95+
res << row
96+
end
97+
end
98+
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
99+
end
100+
end
101+
102+
end

0 commit comments

Comments
 (0)