1
1
import csv
2
2
from datetime import datetime
3
- from typing import List , Tuple
3
+ from typing import List , Optional
4
4
5
5
from vnpy .trader .engine import BaseEngine , MainEngine , EventEngine
6
6
from vnpy .trader .constant import Interval , Exchange
7
- from vnpy .trader .object import BarData , HistoryRequest
7
+ from vnpy .trader .object import BarData , TickData , ContractData , HistoryRequest
8
8
from vnpy .trader .database import BaseDatabase , get_database , BarOverview , DB_TZ
9
9
from vnpy .trader .datafeed import BaseDatafeed , get_datafeed
10
10
from vnpy .trader .utility import ZoneInfo
@@ -19,7 +19,7 @@ def __init__(
19
19
self ,
20
20
main_engine : MainEngine ,
21
21
event_engine : EventEngine ,
22
- ):
22
+ ) -> None :
23
23
""""""
24
24
super ().__init__ (main_engine , event_engine , APP_NAME )
25
25
@@ -42,29 +42,29 @@ def import_data_from_csv(
42
42
turnover_head : str ,
43
43
open_interest_head : str ,
44
44
datetime_format : str
45
- ) -> Tuple :
45
+ ) -> tuple :
46
46
""""""
47
47
with open (file_path , "rt" ) as f :
48
- buf = [line .replace ("\0 " , "" ) for line in f ]
48
+ buf : list = [line .replace ("\0 " , "" ) for line in f ]
49
49
50
- reader = csv .DictReader (buf , delimiter = "," )
50
+ reader : csv . DictReader = csv .DictReader (buf , delimiter = "," )
51
51
52
- bars = []
53
- start = None
54
- count = 0
52
+ bars : List [ BarData ] = []
53
+ start : datetime = None
54
+ count : int = 0
55
55
tz = ZoneInfo (tz_name )
56
56
57
57
for item in reader :
58
58
if datetime_format :
59
- dt = datetime .strptime (item [datetime_head ], datetime_format )
59
+ dt : datetime = datetime .strptime (item [datetime_head ], datetime_format )
60
60
else :
61
- dt = datetime .fromisoformat (item [datetime_head ])
61
+ dt : datetime = datetime .fromisoformat (item [datetime_head ])
62
62
dt = dt .replace (tzinfo = tz )
63
63
64
64
turnover = item .get (turnover_head , 0 )
65
65
open_interest = item .get (open_interest_head , 0 )
66
66
67
- bar = BarData (
67
+ bar : BarData = BarData (
68
68
symbol = symbol ,
69
69
exchange = exchange ,
70
70
datetime = dt ,
@@ -86,7 +86,7 @@ def import_data_from_csv(
86
86
if not start :
87
87
start = bar .datetime
88
88
89
- end = bar .datetime
89
+ end : datetime = bar .datetime
90
90
91
91
# insert into database
92
92
self .database .save_bar_data (bars )
@@ -103,9 +103,9 @@ def output_data_to_csv(
103
103
end : datetime
104
104
) -> bool :
105
105
""""""
106
- bars = self .load_bar_data (symbol , exchange , interval , start , end )
106
+ bars : List [ BarData ] = self .load_bar_data (symbol , exchange , interval , start , end )
107
107
108
- fieldnames = [
108
+ fieldnames : list = [
109
109
"symbol" ,
110
110
"exchange" ,
111
111
"datetime" ,
@@ -120,11 +120,11 @@ def output_data_to_csv(
120
120
121
121
try :
122
122
with open (file_path , "w" ) as f :
123
- writer = csv .DictWriter (f , fieldnames = fieldnames , lineterminator = "\n " )
123
+ writer : csv . DictWriter = csv .DictWriter (f , fieldnames = fieldnames , lineterminator = "\n " )
124
124
writer .writeheader ()
125
125
126
126
for bar in bars :
127
- d = {
127
+ d : dict = {
128
128
"symbol" : bar .symbol ,
129
129
"exchange" : bar .exchange .value ,
130
130
"datetime" : bar .datetime .strftime ("%Y-%m-%d %H:%M:%S" ),
@@ -155,7 +155,7 @@ def load_bar_data(
155
155
end : datetime
156
156
) -> List [BarData ]:
157
157
""""""
158
- bars = self .database .load_bar_data (
158
+ bars : List [ BarData ] = self .database .load_bar_data (
159
159
symbol ,
160
160
exchange ,
161
161
interval ,
@@ -172,7 +172,7 @@ def delete_bar_data(
172
172
interval : Interval
173
173
) -> int :
174
174
""""""
175
- count = self .database .delete_bar_data (
175
+ count : int = self .database .delete_bar_data (
176
176
symbol ,
177
177
exchange ,
178
178
interval
@@ -190,25 +190,25 @@ def download_bar_data(
190
190
"""
191
191
Query bar data from datafeed.
192
192
"""
193
- req = HistoryRequest (
193
+ req : HistoryRequest = HistoryRequest (
194
194
symbol = symbol ,
195
195
exchange = exchange ,
196
196
interval = Interval (interval ),
197
197
start = start ,
198
198
end = datetime .now (DB_TZ )
199
199
)
200
200
201
- vt_symbol = f"{ symbol } .{ exchange .value } "
202
- contract = self .main_engine .get_contract (vt_symbol )
201
+ vt_symbol : str = f"{ symbol } .{ exchange .value } "
202
+ contract : Optional [ ContractData ] = self .main_engine .get_contract (vt_symbol )
203
203
204
204
# If history data provided in gateway, then query
205
205
if contract and contract .history_data :
206
- data = self .main_engine .query_history (
206
+ data : List [ BarData ] = self .main_engine .query_history (
207
207
req , contract .gateway_name
208
208
)
209
209
# Otherwise use datafeed to query data
210
210
else :
211
- data = self .datafeed .query_bar_history (req )
211
+ data : List [ BarData ] = self .datafeed .query_bar_history (req )
212
212
213
213
if data :
214
214
self .database .save_bar_data (data )
@@ -225,14 +225,14 @@ def download_tick_data(
225
225
"""
226
226
Query tick data from datafeed.
227
227
"""
228
- req = HistoryRequest (
228
+ req : HistoryRequest = HistoryRequest (
229
229
symbol = symbol ,
230
230
exchange = exchange ,
231
231
start = start ,
232
232
end = datetime .now (DB_TZ )
233
233
)
234
234
235
- data = self .datafeed .query_tick_history (req )
235
+ data : List [ TickData ] = self .datafeed .query_tick_history (req )
236
236
237
237
if data :
238
238
self .database .save_tick_data (data )
0 commit comments