9
9
from typing import List , Optional , Dict , Union
10
10
11
11
import botocore .exceptions
12
- import pydantic .version
13
12
from pydantic import BaseModel
14
13
from pydantic import Field
15
14
16
15
from recce import get_version
17
16
from recce .git import current_branch
18
17
from recce .models .types import Run , Check
18
+ from recce .pull_request import fetch_pr_metadata , PullRequestInfo
19
+ from recce .util .pydantic_model import pydantic_model_json_dump , pydantic_model_dump
19
20
20
21
logger = logging .getLogger ('uvicorn' )
21
22
@@ -40,26 +41,6 @@ def check_s3_bucket(bucket_name: str):
40
41
return True , None
41
42
42
43
43
- def pydantic_model_json_dump (model : BaseModel ):
44
- pydantic_version = pydantic .version .VERSION
45
- pydantic_major = pydantic_version .split ("." )[0 ]
46
-
47
- if pydantic_major == "1" :
48
- return model .json (exclude_none = True )
49
- else :
50
- return model .model_dump_json (exclude_none = True )
51
-
52
-
53
- def pydantic_model_dump (model : BaseModel ):
54
- pydantic_version = pydantic .version .VERSION
55
- pydantic_major = pydantic_version .split ("." )[0 ]
56
-
57
- if pydantic_major == "1" :
58
- return model .dict ()
59
- else :
60
- return model .model_dump ()
61
-
62
-
63
44
class GitRepoInfo (BaseModel ):
64
45
branch : Optional [str ] = None
65
46
@@ -75,18 +56,6 @@ def to_dict(self):
75
56
return pydantic_model_dump (self )
76
57
77
58
78
- class PullRequestInfo (BaseModel ):
79
- id : Optional [Union [int , str ]] = None
80
- title : Optional [str ] = None
81
- url : Optional [str ] = None
82
- branch : Optional [str ] = None
83
- base_branch : Optional [str ] = None
84
- repository : Optional [str ] = None
85
-
86
- def to_dict (self ):
87
- return pydantic_model_dump (self )
88
-
89
-
90
59
class RecceStateMetadata (BaseModel ):
91
60
schema_version : str = 'v0'
92
61
recce_version : str = Field (default_factory = lambda : get_version ())
@@ -167,6 +136,13 @@ def __init__(self,
167
136
self .hint_message = None
168
137
self .state : RecceState | None = None
169
138
self .state_lock = threading .Lock ()
139
+ self .pr_info = None
140
+
141
+ if self .cloud_mode :
142
+ if self .cloud_options .get ('token' ):
143
+ self .pr_info = fetch_pr_metadata (github_token = self .cloud_options .get ('token' ))
144
+ else :
145
+ raise Exception ('No GitHub token is provided to access the pull request information.' )
170
146
171
147
# Load the state
172
148
self .load ()
@@ -230,6 +206,44 @@ def refresh(self):
230
206
new_state = self .load (refresh = True )
231
207
return new_state
232
208
209
+ def info (self ):
210
+ if self .state is None :
211
+ self .error_message = 'No state is loaded.'
212
+ return None
213
+
214
+ state_info = {
215
+ 'mode' : 'cloud' if self .cloud_mode else 'local' ,
216
+ 'source' : None ,
217
+ }
218
+ if self .cloud_mode :
219
+ if self .cloud_options .get ('host' , '' ).startswith ('s3://' ):
220
+ state_info ['source' ] = self .cloud_options .get ('host' )
221
+ else :
222
+ state_info ['source' ] = 'Recce Cloud'
223
+ state_info ['pull_request' ] = self .pr_info
224
+ else :
225
+ state_info ['source' ] = self .state_file
226
+ return state_info
227
+
228
+ def purge (self ) -> bool :
229
+ if self .cloud_mode is True :
230
+ # self.error_message = 'Purging the state is not supported in cloud mode.'
231
+ # return False
232
+ if self .cloud_options .get ('host' , '' ).startswith ('s3://' ):
233
+ return self ._purge_state_from_s3_bucket ()
234
+ else :
235
+ return self ._purge_state_from_cloud ()
236
+ else :
237
+ if self .state_file is not None :
238
+ try :
239
+ os .remove (self .state_file )
240
+ except Exception as e :
241
+ self .error_message = f'Failed to remove the state file: { e } '
242
+ return False
243
+ else :
244
+ self .error_message = 'No state file is provided. Skip removing the state file.'
245
+ return False
246
+
233
247
def _get_presigned_url (self , pr_info : PullRequestInfo , artifact_name : str , method : str = 'upload' ) -> str :
234
248
import requests
235
249
# Step 1: Get the token
@@ -254,23 +268,21 @@ def _load_state_from_file(self, file_path: Optional[str] = None) -> RecceState:
254
268
return RecceState .from_file (file_path ) if file_path else None
255
269
256
270
def _load_state_from_cloud (self ) -> RecceState :
257
- from recce .pull_request import fetch_pr_metadata
258
- pr_info = fetch_pr_metadata (github_token = self .cloud_options .get ('token' ))
259
- if (pr_info .id is None ) or (pr_info .repository is None ):
271
+ if (self .pr_info is None ) or (self .pr_info .id is None ) or (self .pr_info .repository is None ):
260
272
raise Exception ('Cannot get the pull request information from GitHub.' )
261
273
262
274
if self .cloud_options .get ('host' , '' ).startswith ('s3://' ):
263
275
logger .debug ('Fetching state from AWS S3 bucket...' )
264
- return self ._load_state_from_s3_bucket (pr_info )
276
+ return self ._load_state_from_s3_bucket ()
265
277
else :
266
278
logger .debug ('Fetching state from Recce Cloud...' )
267
- return self ._load_state_from_recce_cloud (pr_info )
279
+ return self ._load_state_from_recce_cloud ()
268
280
269
- def _load_state_from_recce_cloud (self , pr_info ) -> Union [RecceState , None ]:
281
+ def _load_state_from_recce_cloud (self ) -> Union [RecceState , None ]:
270
282
import tempfile
271
283
import requests
272
284
273
- presigned_url = self ._get_presigned_url (pr_info , RECCE_STATE_COMPRESSED_FILE , method = 'download' )
285
+ presigned_url = self ._get_presigned_url (self . pr_info , RECCE_STATE_COMPRESSED_FILE , method = 'download' )
274
286
275
287
with tempfile .NamedTemporaryFile () as tmp :
276
288
response = requests .get (presigned_url )
@@ -284,12 +296,12 @@ def _load_state_from_recce_cloud(self, pr_info) -> Union[RecceState, None]:
284
296
f .write (response .content )
285
297
return RecceState .from_file (tmp .name , compressed = True )
286
298
287
- def _load_state_from_s3_bucket (self , pr_info ) -> Union [RecceState , None ]:
299
+ def _load_state_from_s3_bucket (self ) -> Union [RecceState , None ]:
288
300
import boto3
289
301
import tempfile
290
302
s3_client = boto3 .client ('s3' )
291
303
s3_bucket_name = self .cloud_options .get ('host' ).replace ('s3://' , '' )
292
- s3_bucket_key = f'github/{ pr_info .repository } /pulls/{ pr_info .id } /{ RECCE_STATE_COMPRESSED_FILE } '
304
+ s3_bucket_key = f'github/{ self . pr_info .repository } /pulls/{ self . pr_info .id } /{ RECCE_STATE_COMPRESSED_FILE } '
293
305
294
306
rc , error_message = check_s3_bucket (s3_bucket_name )
295
307
if rc is False :
@@ -308,23 +320,21 @@ def _load_state_from_s3_bucket(self, pr_info) -> Union[RecceState, None]:
308
320
return RecceState .from_file (tmp .name , compressed = True )
309
321
310
322
def _export_state_to_cloud (self ) -> Union [str , None ]:
311
- from recce .pull_request import fetch_pr_metadata
312
- pr_info = fetch_pr_metadata (github_token = self .cloud_options .get ('token' ))
313
- if (pr_info .id is None ) or (pr_info .repository is None ):
323
+ if (self .pr_info is None ) or (self .pr_info .id is None ) or (self .pr_info .repository is None ):
314
324
raise Exception ('Cannot get the pull request information from GitHub.' )
315
325
316
326
if self .cloud_options .get ('host' , '' ).startswith ('s3://' ):
317
327
logger .info ("Store recce state to AWS S3 bucket" )
318
- return self ._export_state_to_s3_bucket (pr_info )
328
+ return self ._export_state_to_s3_bucket ()
319
329
else :
320
330
logger .info ("Store recce state to Recce Cloud" )
321
- return self ._export_state_to_recce_cloud (pr_info )
331
+ return self ._export_state_to_recce_cloud ()
322
332
323
- def _export_state_to_recce_cloud (self , pr_info ) -> Union [str , None ]:
333
+ def _export_state_to_recce_cloud (self ) -> Union [str , None ]:
324
334
import tempfile
325
335
import requests
326
336
327
- presigned_url = self ._get_presigned_url (pr_info , RECCE_STATE_COMPRESSED_FILE , method = 'upload' )
337
+ presigned_url = self ._get_presigned_url (self . pr_info , RECCE_STATE_COMPRESSED_FILE , method = 'upload' )
328
338
with tempfile .NamedTemporaryFile () as tmp :
329
339
self ._export_state_to_file (tmp .name , compress = True )
330
340
response = requests .put (presigned_url , data = open (tmp .name , 'rb' ).read ())
@@ -333,12 +343,12 @@ def _export_state_to_recce_cloud(self, pr_info) -> Union[str, None]:
333
343
return 'Failed to upload the state file to Recce Cloud.'
334
344
return 'The state file is uploaded to Recce Cloud.'
335
345
336
- def _export_state_to_s3_bucket (self , pr_info ) -> Union [str , None ]:
346
+ def _export_state_to_s3_bucket (self ) -> Union [str , None ]:
337
347
import boto3
338
348
import tempfile
339
349
s3_client = boto3 .client ('s3' )
340
350
s3_bucket_name = self .cloud_options .get ('host' ).replace ('s3://' , '' )
341
- s3_bucket_key = f'github/{ pr_info .repository } /pulls/{ pr_info .id } /{ RECCE_STATE_COMPRESSED_FILE } '
351
+ s3_bucket_key = f'github/{ self . pr_info .repository } /pulls/{ self . pr_info .id } /{ RECCE_STATE_COMPRESSED_FILE } '
342
352
343
353
rc , error_message = check_s3_bucket (s3_bucket_name )
344
354
if rc is False :
@@ -364,3 +374,40 @@ def _export_state_to_file(self, file_path: Optional[str] = None, compress: bool
364
374
with open (file_path , 'w' ) as f :
365
375
f .write (json_data )
366
376
return f'The state file is stored at \' { file_path } \' '
377
+
378
+ def _purge_state_from_cloud (self ) -> bool :
379
+ import requests
380
+ logger .debug ('Purging the state from Recce Cloud...' )
381
+ token = self .cloud_options .get ('token' )
382
+ api_url = f'{ RECCE_CLOUD_API_HOST } /api/v1/{ self .pr_info .repository } /pulls/{ self .pr_info .id } /artifacts'
383
+ headers = {
384
+ 'Authorization' : f'Bearer { token } '
385
+ }
386
+ response = requests .delete (api_url , headers = headers )
387
+ if response .status_code != 204 :
388
+ self .error_message = response .text
389
+ return False
390
+ return True
391
+
392
+ def _purge_state_from_s3_bucket (self ) -> bool :
393
+ import boto3
394
+ from rich .console import Console
395
+ console = Console ()
396
+ delete_objects = []
397
+ logger .debug ('Purging the state from AWS S3 bucket...' )
398
+ s3_client = boto3 .client ('s3' )
399
+ s3_bucket_name = self .cloud_options .get ('host' ).replace ('s3://' , '' )
400
+ s3_key_prefix = f'github/{ self .pr_info .repository } /pulls/{ self .pr_info .id } /'
401
+ list_response = s3_client .list_objects_v2 (Bucket = s3_bucket_name , Prefix = s3_key_prefix )
402
+ if 'Contents' in list_response :
403
+ for obj in list_response ['Contents' ]:
404
+ key = obj ['Key' ]
405
+ delete_objects .append ({'Key' : key })
406
+ console .print (f'[green]Deleted[/green]: { key } ' )
407
+ else :
408
+ return False
409
+
410
+ delete_response = s3_client .delete_objects (Bucket = s3_bucket_name , Delete = {'Objects' : delete_objects })
411
+ if 'Deleted' not in delete_response :
412
+ return False
413
+ return True
0 commit comments