Skip to content

Commit 77a41eb

Browse files
fix: fix lint errors
1 parent f80c4de commit 77a41eb

File tree

8 files changed

+41
-30
lines changed

8 files changed

+41
-30
lines changed

generate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,9 @@
7878
graph_gen.judge(re_judge=False)
7979

8080
graph_gen.traverse()
81-
with open(os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml"), "w", encoding='utf-8') as f:
81+
82+
config_path = os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml")
83+
if not os.path.exists(config_path):
84+
os.makedirs(config_path)
85+
with open(config_path, "w", encoding='utf-8') as f:
8286
yaml.dump(traverse_strategy.to_yaml(), f)

models/llm/openai_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
7171
kwargs["logprobs"] = True
7272
kwargs["top_logprobs"] = self.topk_per_token
7373

74-
# Limit max_tokens to 2 to avoid long completions
75-
kwargs["max_tokens"] = 2
74+
# Limit max_tokens to 1 to avoid long completions
75+
kwargs["max_tokens"] = 1
7676

7777
completion = await self.client.chat.completions.create(
7878
model=self.model_name,

templates/kg_extraction.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=C0301
2+
13
TEMPLATE_EN: str = """You are an NLP expert, skilled at analyzing text to extract named entities and their relationships.
24
35
-Goal-
@@ -174,12 +176,14 @@
174176
输出:
175177
"""
176178

177-
CONTINUE_EN: str = """MANY entities and relationships were missed in the last extraction. Add them below using the same format:
179+
CONTINUE_EN: str = """MANY entities and relationships were missed in the last extraction. \
180+
Add them below using the same format:
178181
"""
179182

180183
CONTINUE_ZH: str = """很多实体和关系在上一次的提取中可能被遗漏了。请在下面使用相同的格式添加它们:"""
181184

182-
IF_LOOP_EN: str = """It appears some entities and relationships may have still been missed. Answer YES | NO if there are still entities and relationships that need to be added.
185+
IF_LOOP_EN: str = """It appears some entities and relationships may have still been missed. \
186+
Answer YES | NO if there are still entities and relationships that need to be added.
183187
"""
184188

185189
IF_LOOP_ZH: str = """看起来可能仍然遗漏了一些实体和关系。如果仍有实体和关系需要添加,请回答YES | NO。"""
@@ -199,7 +203,8 @@
199203
"tuple_delimiter": "<|>",
200204
"record_delimiter": "##",
201205
"completion_delimiter": "<|COMPLETE|>",
202-
"entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene",
206+
"entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
207+
science, technology, mission, gene",
203208
"language": "English",
204209
},
205210
}

templates/search_judgement.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=C0301
2+
13
TEMPLATE: str = """-Goal-
24
Please select the most relevant search result for the given entity.
35
The name and description of the entity are provided. The search results are provided as a list.

utils/detect_lang.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def is_english_char(char):
2626

2727
if chinese_ratio >= 0.5:
2828
return 'zh'
29-
else:
30-
return 'en'
29+
return 'en'
3130

3231
def detect_if_chinese(text):
3332
"""

utils/format.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ async def handle_single_entity_extraction(
4343
entity_type = clean_str(record_attributes[2].upper())
4444
entity_description = clean_str(record_attributes[3])
4545
entity_source_id = chunk_key
46-
return dict(
47-
entity_name=entity_name,
48-
entity_type=entity_type,
49-
description=entity_description,
50-
source_id=entity_source_id,
51-
)
46+
return {
47+
"entity_name": entity_name,
48+
"entity_type": entity_type,
49+
"description": entity_description,
50+
"source_id": entity_source_id,
51+
}
5252

5353
def is_float_regex(value):
5454
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
@@ -65,12 +65,12 @@ async def handle_single_relationship_extraction(
6565
edge_description = clean_str(record_attributes[3])
6666

6767
edge_source_id = chunk_key
68-
return dict(
69-
src_id=source,
70-
tgt_id=target,
71-
description=edge_description,
72-
source_id=edge_source_id,
73-
)
68+
return {
69+
"src_id": source,
70+
"tgt_id": target,
71+
"description": edge_description,
72+
"source_id": edge_source_id,
73+
}
7474

7575
def load_json(file_name):
7676
if not os.path.exists(file_name):

utils/help_nltk.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
2+
from typing import Dict, List, Optional
23
import nltk
34
import jieba
4-
from typing import Dict, List, Optional
55

66
resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources")
77

@@ -30,11 +30,10 @@ def get_stopwords(self, lang: str) -> List[str]:
3030
def word_tokenize(text: str, lang: str) -> List[str]:
3131
if lang == "zh":
3232
return jieba.lcut(text)
33-
else:
34-
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
35-
try:
36-
nltk.data.find("tokenizers/punkt_tab")
37-
except LookupError:
38-
nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
33+
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
34+
try:
35+
nltk.data.find("tokenizers/punkt_tab")
36+
except LookupError:
37+
nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
3938

40-
return nltk.word_tokenize(text)
39+
return nltk.word_tokenize(text)

utils/log.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@ def set_logger(log_file: str, log_level: int = logging.INFO, if_stream: bool = T
1313
file_handler.setLevel(log_level)
1414
file_handler.setFormatter(formatter)
1515

16+
stream_handler = None
17+
1618
if if_stream:
1719
stream_handler = logging.StreamHandler()
1820
stream_handler.setLevel(log_level)
1921
stream_handler.setFormatter(formatter)
2022

2123
if not logger.handlers:
2224
logger.addHandler(file_handler)
23-
if if_stream:
25+
if if_stream and stream_handler:
2426
logger.addHandler(stream_handler)
2527

2628

2729
def parse_log(log_file: str):
28-
with open(log_file, "r") as f:
30+
with open(log_file, "r", encoding='utf-8') as f:
2931
lines = f.readlines()
3032
return lines

0 commit comments

Comments
 (0)