4
4
from functools import partial
5
5
from itertools import count
6
6
from types import TracebackType
7
- from typing import Any
7
+ from typing import Any , TypeAlias
8
8
9
9
from github import Auth , Github
10
10
14
14
15
15
__all__ = ["GithubClient" ]
16
16
17
+ GraphQLQuery : TypeAlias = dict [str , Any ]
18
+ GraphQLError : TypeAlias = dict [str , Any ]
19
+ GraphQLResponseData : TypeAlias = dict [str , Any ]
20
+ GraphQLResponse : TypeAlias = tuple [GraphQLResponseData , list [GraphQLError ]]
21
+
22
+
23
+ class GithubClientError (Exception ):
24
+ """An error occurred with the Github Client."""
25
+
26
+
27
+ class GithubRequestError (GithubClientError ):
28
+ """An error occurred while making a request with the Github Client."""
29
+
30
+
31
+ class GithubGraphQLRequestError (GithubRequestError ):
32
+ """An error occurred while making a GraphQL request with the Github Client."""
33
+
17
34
18
35
class GithubClient :
19
36
DEFAULT_DAYS_OF_HISTORY : int = 30
@@ -70,6 +87,35 @@ async def close(self) -> None:
70
87
if self ._session is not None :
71
88
await self ._session .close ()
72
89
90
+ @property
91
+ def session (self ) -> ClientSession :
92
+ if self ._session is None :
93
+ raise GithubClientError ("Session not initialized" )
94
+ return self ._session
95
+
96
+ @staticmethod
97
+ def _parse_datetime (datetime_str : str ) -> datetime :
98
+ return datetime .strptime (datetime_str , "%Y-%m-%dT%H:%M:%SZ" )
99
+
100
+ @staticmethod
101
+ def safe_get (data : dict [str , Any ], * path : str , default : Any = None , raise_on_missing : bool = False ) -> Any :
102
+ current : Any = data
103
+ for key in path :
104
+ if not isinstance (current , dict ) or current is None :
105
+ if raise_on_missing :
106
+ raise KeyError (f"Key { '.' .join (path )} not found in data" )
107
+ return default
108
+ current = current .get (key )
109
+ return current
110
+
111
+ @staticmethod
112
+ def build_repo_id (organization : str , repo : str ) -> str :
113
+ return f"{ organization } /{ repo } "
114
+
115
+ @staticmethod
116
+ def is_error_not_found_organization (data : dict [str , Any ]) -> bool :
117
+ return data .get ("type" ) == "NOT_FOUND" and data .get ("path" ) == ["organization" ]
118
+
73
119
@property
74
120
def days_of_history (self ) -> int :
75
121
if self ._days_of_history is None :
@@ -116,9 +162,6 @@ def page_size(self) -> int:
116
162
def ignore_old_repos (self ) -> bool :
117
163
return to_bool (self .config .integration .properties .get ("ignoreOld" , True ))
118
164
119
- def build_repo_id (self , organization : str , repo : str ) -> str :
120
- return f"{ organization } /{ repo } "
121
-
122
165
async def _make_request (
123
166
self ,
124
167
method : str ,
@@ -131,7 +174,7 @@ async def _make_request(
131
174
) -> Any :
132
175
try :
133
176
return await make_request (
134
- self ._session ,
177
+ self .session ,
135
178
method ,
136
179
url ,
137
180
** kwargs ,
@@ -142,24 +185,32 @@ async def _make_request(
142
185
)
143
186
except RequestError as e :
144
187
if raise_on_error :
145
- raise
188
+ raise GithubRequestError ( str ( e )) from e
146
189
self .logger .error (f"Error while making request, defaulting to empty response. ({ e } )" )
147
190
return None
148
191
149
- async def _make_graphql_request (self , query : dict [str , Any ], * , ignore_file_not_found_errors : bool = True ) -> Any :
192
+ async def _make_graphql_request (
193
+ self , query : GraphQLQuery , * , raise_on_organization_not_found : bool = True
194
+ ) -> GraphQLResponse :
150
195
response = await self ._make_request ("POST" , f"{ self .base_url } /graphql" , json = query , raise_on_error = True )
151
- if response .get ("errors" ):
152
- if not ignore_file_not_found_errors :
153
- errors = response ["errors" ]
154
- else :
155
- errors = [e for e in response ["errors" ] if isinstance (e , dict ) and e .get ("type" ) != "NOT_FOUND" ]
156
196
157
- if errors :
158
- self .logger .warning ("GraphQL error: %r" , errors )
159
- return response
197
+ if response is None :
198
+ raise GithubGraphQLRequestError ("Failed to make GraphQL request" )
160
199
161
- def _parse_datetime (self , datetime_str : str ) -> datetime :
162
- return datetime .strptime (datetime_str , "%Y-%m-%dT%H:%M:%SZ" )
200
+ if not isinstance (response , dict ):
201
+ raise GithubGraphQLRequestError ("Failed to parse GraphQL response" )
202
+
203
+ data = response .get ("data" )
204
+ if data is None :
205
+ raise GithubGraphQLRequestError ("Failed to parse GraphQL data" )
206
+
207
+ errors = response .get ("errors" ) or []
208
+ if errors :
209
+ if raise_on_organization_not_found and any (self .is_error_not_found_organization (error ) for error in errors ):
210
+ raise GithubGraphQLRequestError ("Organization not found" )
211
+ self .logger .warning ("GraphQL error: %r" , errors )
212
+
213
+ return data , errors
163
214
164
215
async def get_repos (
165
216
self , * , limit : int | None = None , ignore_archived : bool = True , ignore_old : bool = True , page_size : int = 50
@@ -190,8 +241,11 @@ async def get_repos(
190
241
191
242
async def get_repo (self , organization : str , repo : str ) -> dict [str , Any ]:
192
243
query = build_graphql_query (query_type = QueryType .REPOSITORY , repo_id = self .build_repo_id (organization , repo ))
193
- response = await self ._make_graphql_request (query )
194
- return response ["data" ]["repository" ]
244
+ data , errors = await self ._make_graphql_request (query )
245
+ repository = self .safe_get (data , "repository" )
246
+ if repository is None :
247
+ raise GithubClientError (f"Failed to get repository { organization } /{ repo } : { errors } " )
248
+ return repository
195
249
196
250
async def get_pull_requests (self , organization : str , repo : str , states : list [str ]) -> list [dict [str , Any ]]:
197
251
all_pull_requests = []
@@ -205,17 +259,17 @@ async def get_pull_requests(self, organization: str, repo: str, states: list[str
205
259
after = cursor ,
206
260
page_size = self .page_size ,
207
261
)
208
- response = await self ._make_graphql_request (query )
262
+ data , _ = await self ._make_graphql_request (query )
209
263
210
- edges = response [ " data" ][ " repository"][ "pullRequests" ][ "edges" ]
264
+ edges = self . safe_get ( data , " repository", "pullRequests" , "edges" )
211
265
if not edges :
212
266
break
213
- all_pull_requests .extend (edge ["node" ] for edge in edges )
214
267
215
- page_info = response ["data" ]["repository" ]["pullRequests" ]["pageInfo" ]
268
+ all_pull_requests .extend (edge ["node" ] for edge in edges )
269
+ page_info = self .safe_get (data , "repository" , "pullRequests" , "pageInfo" )
216
270
if (
217
271
not page_info ["hasNextPage" ]
218
- or self ._parse_datetime (edges [- 1 ][ "node" ][ "createdAt" ] ) < self .history_limit_timestamp
272
+ or self ._parse_datetime (self . safe_get ( edges [- 1 ], "node" , "createdAt" ) ) < self .history_limit_timestamp
219
273
):
220
274
break
221
275
@@ -235,19 +289,20 @@ async def get_issues(self, organization: str, repo: str, state: str) -> list[dic
235
289
state = state ,
236
290
page_size = self .page_size ,
237
291
)
238
- response = await self ._make_graphql_request (query )
292
+ data , _ = await self ._make_graphql_request (query )
239
293
240
- edges = response [ " data" ][ " search"][ "edges" ]
294
+ edges = self . safe_get ( data , " search", "edges" )
241
295
if not edges :
242
296
break
243
- all_issues .extend (edge ["node" ] for edge in edges )
244
297
245
- page_info = response ["data" ]["search" ]["pageInfo" ]
298
+ all_issues .extend (edge ["node" ] for edge in edges )
299
+ page_info = self .safe_get (data , "search" , "pageInfo" )
246
300
if (
247
301
not page_info ["hasNextPage" ]
248
- or self ._parse_datetime (edges [- 1 ][ "node" ][ "createdAt" ] ) < self .history_limit_timestamp
302
+ or self ._parse_datetime (self . safe_get ( edges [- 1 ], "node" , "createdAt" ) ) < self .history_limit_timestamp
249
303
):
250
304
break
305
+
251
306
cursor = page_info ["endCursor" ]
252
307
253
308
return all_issues
@@ -315,14 +370,14 @@ async def get_members(self, organization: str) -> list[dict[str, Any]]:
315
370
query = build_graphql_query (
316
371
query_type = QueryType .MEMBERS , owner = organization , after = cursor , page_size = self .page_size
317
372
)
318
- response = await self ._make_graphql_request (query )
373
+ data , _ = await self ._make_graphql_request (query )
319
374
320
- edges = response [ " data" ][ " organization"][ "membersWithRole" ][ "edges" ]
375
+ edges = self . safe_get ( data , " organization", "membersWithRole" , "edges" )
321
376
if not edges :
322
377
break
323
378
all_members .extend (edges )
324
379
325
- page_info = response [ " data" ][ " organization"][ "membersWithRole" ][ "pageInfo" ]
380
+ page_info = self . safe_get ( data , " organization", "membersWithRole" , "pageInfo" )
326
381
if not page_info ["hasNextPage" ]:
327
382
break
328
383
cursor = page_info ["endCursor" ]
@@ -336,8 +391,8 @@ async def get_teams(self, organization: str) -> list[dict]:
336
391
query = build_graphql_query (
337
392
query_type = QueryType .TEAMS , owner = organization , after = cursor , page_size = self .page_size
338
393
)
339
- response = await self ._make_graphql_request (query )
340
- teams = response . get ( " data" , {}). get ( "organization" , {}). get ( "teams" , {}). get ( "nodes" , [] )
394
+ data , _ = await self ._make_graphql_request (query )
395
+ teams = self . safe_get ( data , "organization" , "teams" , "nodes" )
341
396
if not teams :
342
397
break
343
398
@@ -350,9 +405,10 @@ async def get_teams(self, organization: str) -> list[dict]:
350
405
351
406
all_teams .extend (teams )
352
407
353
- page_info = response [ " data" ][ " organization"][ "teams" ][ "pageInfo" ]
408
+ page_info = self . safe_get ( data , " organization", "teams" , "pageInfo" )
354
409
if not page_info ["hasNextPage" ]:
355
410
break
411
+
356
412
cursor = page_info ["endCursor" ]
357
413
358
414
return all_teams
@@ -368,17 +424,17 @@ async def get_team_members(self, organization: str, team_id: str) -> list[dict]:
368
424
after = cursor ,
369
425
page_size = self .page_size ,
370
426
)
371
- response = await self ._make_graphql_request (query )
427
+ data , _ = await self ._make_graphql_request (query )
372
428
373
- team = response [ " data" ][ " organization"][ "team" ]
429
+ team = self . safe_get ( data , " organization", "team" )
374
430
if not team :
375
431
break
376
432
377
- all_members .extend (team .get ("members" , {}).get ("nodes" , []))
378
-
379
- page_info = team ["members" ]["pageInfo" ]
433
+ all_members .extend (self .safe_get (team , "members" , "nodes" ))
434
+ page_info = self .safe_get (team , "members" , "pageInfo" )
380
435
if not page_info ["hasNextPage" ]:
381
436
break
437
+
382
438
cursor = page_info ["endCursor" ]
383
439
384
440
return all_members
@@ -394,17 +450,17 @@ async def get_team_repositories(self, organization: str, team_id: str) -> list[d
394
450
after = cursor ,
395
451
page_size = self .page_size ,
396
452
)
397
- response = await self ._make_graphql_request (query )
453
+ data , _ = await self ._make_graphql_request (query )
398
454
399
- team = response [ " data" ][ " organization"][ "team" ]
455
+ team = self . safe_get ( data , " organization", "team" )
400
456
if not team :
401
457
break
402
458
403
- all_repositories .extend (team .get ("repositories" , {}).get ("nodes" , []))
404
-
405
- page_info = team ["repositories" ]["pageInfo" ]
459
+ all_repositories .extend (self .safe_get (team , "repositories" , "nodes" ))
460
+ page_info = self .safe_get (team , "repositories" , "pageInfo" )
406
461
if not page_info ["hasNextPage" ]:
407
462
break
463
+
408
464
cursor = page_info ["endCursor" ]
409
465
410
466
return all_repositories
@@ -438,16 +494,17 @@ async def get_deployments(
438
494
after = cursor ,
439
495
page_size = self .page_size ,
440
496
)
441
- response = await self ._make_graphql_request (query )
497
+ data , _ = await self ._make_graphql_request (query )
442
498
443
- edges = response [ " data" ][ " repository"][ "deployments" ][ "edges" ]
499
+ edges = self . safe_get ( data , " repository", "deployments" , "edges" )
444
500
if not edges :
445
501
break
446
- all_deployments .extend (edge ["node" ] for edge in edges )
447
502
448
- page_info = response ["data" ]["repository" ]["deployments" ]["pageInfo" ]
503
+ all_deployments .extend (edge ["node" ] for edge in edges )
504
+ page_info = self .safe_get (data , "repository" , "deployments" , "pageInfo" )
449
505
if not page_info ["hasNextPage" ]:
450
506
break
507
+
451
508
cursor = page_info ["endCursor" ]
452
509
453
510
return all_deployments
@@ -473,10 +530,13 @@ async def get_commits(
473
530
cursor = None
474
531
while True :
475
532
query = query_builder (after = cursor )
476
- response = await self ._make_graphql_request (query )
533
+ data , _ = await self ._make_graphql_request (query )
477
534
478
535
try :
479
- data = response ["data" ]["repository" ]["commits" ]["history" ] or {}
536
+ data = self .safe_get (data , "repository" , "commits" , "history" ) or {}
537
+ if not data :
538
+ break
539
+
480
540
commits = data ["nodes" ] or []
481
541
all_commits .extend (
482
542
[commit for commit in commits if not exclude_merge_commits or not self ._is_merge_commit (commit )]
@@ -492,5 +552,6 @@ async def get_commits(
492
552
493
553
return all_commits
494
554
495
- def _is_merge_commit (self , commit : dict ) -> bool :
496
- return ((commit .get ("parents" ) or {}).get ("totalCount" ) or 0 ) > 1
555
+ @classmethod
556
+ def _is_merge_commit (cls , commit : dict ) -> bool :
557
+ return (cls .safe_get (commit , "parents" , "totalCount" ) or 0 ) > 1
0 commit comments