@@ -18,11 +18,11 @@ use rand::{thread_rng, Rng};
18
18
use std:: collections:: hash_map;
19
19
use std:: collections:: HashMap ;
20
20
use std:: env;
21
- use std:: io:: { Cursor , Read , Write } ;
21
+ use std:: io:: { self , Write } ;
22
22
use std:: path:: PathBuf ;
23
23
use std:: str:: FromStr ;
24
24
use std:: sync:: atomic:: { AtomicBool , Ordering } ;
25
- use std:: sync:: { Arc , Mutex , RwLock } ;
25
+ use std:: sync:: { Arc , Mutex } ;
26
26
use std:: time:: Duration ;
27
27
28
28
macro_rules! expect_event {
@@ -42,25 +42,20 @@ macro_rules! expect_event {
42
42
pub ( crate ) use expect_event;
43
43
44
44
pub ( crate ) struct TestStore {
45
- persisted_bytes : RwLock < HashMap < String , HashMap < String , Arc < RwLock < Vec < u8 > > > > > > ,
45
+ persisted_bytes : Mutex < HashMap < String , HashMap < String , Vec < u8 > > > > ,
46
46
did_persist : Arc < AtomicBool > ,
47
47
}
48
48
49
49
impl TestStore {
50
50
pub fn new ( ) -> Self {
51
- let persisted_bytes = RwLock :: new ( HashMap :: new ( ) ) ;
51
+ let persisted_bytes = Mutex :: new ( HashMap :: new ( ) ) ;
52
52
let did_persist = Arc :: new ( AtomicBool :: new ( false ) ) ;
53
53
Self { persisted_bytes, did_persist }
54
54
}
55
55
56
56
pub fn get_persisted_bytes ( & self , namespace : & str , key : & str ) -> Option < Vec < u8 > > {
57
- if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
58
- if let Some ( inner_ref) = outer_ref. get ( key) {
59
- let locked = inner_ref. read ( ) . unwrap ( ) ;
60
- return Some ( ( * locked) . clone ( ) ) ;
61
- }
62
- }
63
- None
57
+ let persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
58
+ persisted_lock. get ( namespace) . and_then ( |e| e. get ( key) . cloned ( ) )
64
59
}
65
60
66
61
pub fn get_and_clear_did_persist ( & self ) -> bool {
@@ -69,46 +64,45 @@ impl TestStore {
69
64
}
70
65
71
66
impl KVStore for TestStore {
72
- type Reader = TestReader ;
67
+ type Reader = io :: Cursor < Vec < u8 > > ;
73
68
74
- fn read ( & self , namespace : & str , key : & str ) -> std:: io:: Result < Self :: Reader > {
75
- if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
69
+ fn read ( & self , namespace : & str , key : & str ) -> io:: Result < Self :: Reader > {
70
+ let persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
71
+ if let Some ( outer_ref) = persisted_lock. get ( namespace) {
76
72
if let Some ( inner_ref) = outer_ref. get ( key) {
77
- Ok ( TestReader :: new ( Arc :: clone ( inner_ref) ) )
73
+ let bytes = inner_ref. clone ( ) ;
74
+ Ok ( io:: Cursor :: new ( bytes) )
78
75
} else {
79
- let msg = format ! ( "Key not found: {}" , key) ;
80
- Err ( std:: io:: Error :: new ( std:: io:: ErrorKind :: NotFound , msg) )
76
+ Err ( io:: Error :: new ( io:: ErrorKind :: NotFound , "Key not found" ) )
81
77
}
82
78
} else {
83
- let msg = format ! ( "Namespace not found: {}" , namespace) ;
84
- Err ( std:: io:: Error :: new ( std:: io:: ErrorKind :: NotFound , msg) )
79
+ Err ( io:: Error :: new ( io:: ErrorKind :: NotFound , "Namespace not found" ) )
85
80
}
86
81
}
87
82
88
- fn write ( & self , namespace : & str , key : & str , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
89
- let mut guard = self . persisted_bytes . write ( ) . unwrap ( ) ;
90
- let outer_e = guard. entry ( namespace. to_string ( ) ) . or_insert ( HashMap :: new ( ) ) ;
91
- let inner_e = outer_e. entry ( key. to_string ( ) ) . or_insert ( Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ) ;
92
-
93
- let mut guard = inner_e. write ( ) . unwrap ( ) ;
94
- guard. write_all ( buf) ?;
83
+ fn write ( & self , namespace : & str , key : & str , buf : & [ u8 ] ) -> io:: Result < ( ) > {
84
+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
85
+ let outer_e = persisted_lock. entry ( namespace. to_string ( ) ) . or_insert ( HashMap :: new ( ) ) ;
86
+ let mut bytes = Vec :: new ( ) ;
87
+ bytes. write_all ( buf) ?;
88
+ outer_e. insert ( key. to_string ( ) , bytes) ;
95
89
self . did_persist . store ( true , Ordering :: SeqCst ) ;
96
90
Ok ( ( ) )
97
91
}
98
92
99
- fn remove ( & self , namespace : & str , key : & str ) -> std:: io:: Result < ( ) > {
100
- match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
101
- hash_map:: Entry :: Occupied ( mut e) => {
102
- self . did_persist . store ( true , Ordering :: SeqCst ) ;
103
- e. get_mut ( ) . remove ( & key. to_string ( ) ) ;
104
- Ok ( ( ) )
105
- }
106
- hash_map:: Entry :: Vacant ( _) => Ok ( ( ) ) ,
93
+ fn remove ( & self , namespace : & str , key : & str ) -> io:: Result < ( ) > {
94
+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
95
+ if let Some ( outer_ref) = persisted_lock. get_mut ( namespace) {
96
+ outer_ref. remove ( & key. to_string ( ) ) ;
97
+ self . did_persist . store ( true , Ordering :: SeqCst ) ;
107
98
}
99
+
100
+ Ok ( ( ) )
108
101
}
109
102
110
- fn list ( & self , namespace : & str ) -> std:: io:: Result < Vec < String > > {
111
- match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
103
+ fn list ( & self , namespace : & str ) -> io:: Result < Vec < String > > {
104
+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
105
+ match persisted_lock. entry ( namespace. to_string ( ) ) {
112
106
hash_map:: Entry :: Occupied ( e) => Ok ( e. get ( ) . keys ( ) . cloned ( ) . collect ( ) ) ,
113
107
hash_map:: Entry :: Vacant ( _) => Ok ( Vec :: new ( ) ) ,
114
108
}
@@ -139,24 +133,6 @@ impl KVStorePersister for TestStore {
139
133
}
140
134
}
141
135
142
- pub struct TestReader {
143
- entry_ref : Arc < RwLock < Vec < u8 > > > ,
144
- }
145
-
146
- impl TestReader {
147
- pub fn new ( entry_ref : Arc < RwLock < Vec < u8 > > > ) -> Self {
148
- Self { entry_ref }
149
- }
150
- }
151
-
152
- impl Read for TestReader {
153
- fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
154
- let bytes = self . entry_ref . read ( ) . unwrap ( ) . clone ( ) ;
155
- let mut reader = Cursor :: new ( bytes) ;
156
- reader. read ( buf)
157
- }
158
- }
159
-
160
136
// Copied over from upstream LDK
161
137
#[ allow( dead_code) ]
162
138
pub struct TestLogger {
0 commit comments