Skip to content

Commit 97cdde3

Browse files
Merge pull request #68 from open-sciencelab/copilot/change-code-style-dataclass
Remove incorrect @DataClass decorators from ABC base classes and all subclasses
2 parents 862e1d4 + b252226 commit 97cdde3

29 files changed

+71
-106
lines changed

graphgen/bases/base_generator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32
from typing import Any
43

54
from graphgen.bases.base_llm_client import BaseLLMClient
65

76

8-
@dataclass
97
class BaseGenerator(ABC):
108
"""
119
Generate QAs based on given prompts.
1210
"""
1311

14-
llm_client: BaseLLMClient
12+
def __init__(self, llm_client: BaseLLMClient):
13+
self.llm_client = llm_client
1514

1615
@staticmethod
1716
@abstractmethod

graphgen/bases/base_kg_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from abc import ABC, abstractmethod
22
from collections import defaultdict
3-
from dataclasses import dataclass, field
43
from typing import Dict, List, Tuple
54

65
from graphgen.bases.base_llm_client import BaseLLMClient
76
from graphgen.bases.base_storage import BaseGraphStorage
87
from graphgen.bases.datatypes import Chunk
98

109

11-
@dataclass
1210
class BaseKGBuilder(ABC):
13-
llm_client: BaseLLMClient
14-
15-
_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
16-
_edges: Dict[Tuple[str, str], List[dict]] = field(
17-
default_factory=lambda: defaultdict(list)
18-
)
11+
def __init__(self, llm_client: BaseLLMClient):
12+
self.llm_client = llm_client
13+
self._nodes: Dict[str, List[dict]] = defaultdict(list)
14+
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
1915

2016
@abstractmethod
2117
async def extract(

graphgen/bases/base_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32
from typing import Any, List
43

54
from graphgen.bases.base_storage import BaseGraphStorage
65
from graphgen.bases.datatypes import Community
76

87

9-
@dataclass
108
class BasePartitioner(ABC):
119
@abstractmethod
1210
async def partition(

graphgen/bases/base_splitter.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
import copy
22
import re
33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import Callable, Iterable, List, Literal, Optional, Union
65

76
from graphgen.bases.datatypes import Chunk
87
from graphgen.utils import logger
98

109

11-
@dataclass
1210
class BaseSplitter(ABC):
1311
"""
1412
Abstract base class for splitting text into smaller chunks.
1513
"""
1614

17-
chunk_size: int = 1024
18-
chunk_overlap: int = 100
19-
length_function: Callable[[str], int] = len
20-
keep_separator: bool = False
21-
add_start_index: bool = False
22-
strip_whitespace: bool = True
15+
def __init__(
16+
self,
17+
chunk_size: int = 1024,
18+
chunk_overlap: int = 100,
19+
length_function: Callable[[str], int] = len,
20+
keep_separator: bool = False,
21+
add_start_index: bool = False,
22+
strip_whitespace: bool = True,
23+
):
24+
self.chunk_size = chunk_size
25+
self.chunk_overlap = chunk_overlap
26+
self.length_function = length_function
27+
self.keep_separator = keep_separator
28+
self.add_start_index = add_start_index
29+
self.strip_whitespace = strip_whitespace
2330

2431
@abstractmethod
2532
def split_text(self, text: str) -> List[str]:

graphgen/bases/base_storage.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ async def query_done_callback(self):
1616
"""commit the storage operations after querying"""
1717

1818

19-
@dataclass
2019
class BaseListStorage(Generic[T], StorageNameSpace):
2120
async def all_items(self) -> list[T]:
2221
raise NotImplementedError
@@ -34,7 +33,6 @@ async def drop(self):
3433
raise NotImplementedError
3534

3635

37-
@dataclass
3836
class BaseKVStorage(Generic[T], StorageNameSpace):
3937
async def all_keys(self) -> list[str]:
4038
raise NotImplementedError
@@ -58,7 +56,6 @@ async def drop(self):
5856
raise NotImplementedError
5957

6058

61-
@dataclass
6259
class BaseGraphStorage(StorageNameSpace):
6360
async def has_node(self, node_id: str) -> bool:
6461
raise NotImplementedError

graphgen/bases/base_tokenizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import List
65

76

8-
@dataclass
97
class BaseTokenizer(ABC):
10-
model_name: str = "cl100k_base"
8+
def __init__(self, model_name: str = "cl100k_base"):
9+
self.model_name = model_name
1110

1211
@abstractmethod
1312
def encode(self, text: str) -> List[int]:

graphgen/models/evaluator/base_evaluator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import asyncio
2-
from dataclasses import dataclass
32

43
from tqdm.asyncio import tqdm as tqdm_async
54

65
from graphgen.bases.datatypes import QAPair
76
from graphgen.utils import create_event_loop
87

98

10-
@dataclass
119
class BaseEvaluator:
12-
max_concurrent: int = 100
13-
results: list[float] = None
10+
def __init__(self, max_concurrent: int = 100):
11+
self.max_concurrent = max_concurrent
12+
self.results: list[float] = None
1413

1514
def evaluate(self, pairs: list[QAPair]) -> list[float]:
1615
"""

graphgen/models/evaluator/length_evaluator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
from dataclasses import dataclass
2-
31
from graphgen.bases.datatypes import QAPair
42
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
53
from graphgen.models.tokenizer import Tokenizer
64
from graphgen.utils import create_event_loop
75

86

9-
@dataclass
107
class LengthEvaluator(BaseEvaluator):
11-
tokenizer_name: str = "cl100k_base"
12-
13-
def __post_init__(self):
8+
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
9+
super().__init__(max_concurrent)
10+
self.tokenizer_name = tokenizer_name
1411
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
1512

1613
async def evaluate_single(self, pair: QAPair) -> float:

graphgen/models/evaluator/mtld_evaluator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass, field
21
from typing import Set
32

43
from graphgen.bases.datatypes import QAPair
@@ -8,18 +7,15 @@
87
nltk_helper = NLTKHelper()
98

109

11-
@dataclass
1210
class MTLDEvaluator(BaseEvaluator):
1311
"""
1412
衡量文本词汇多样性的指标
1513
"""
1614

17-
stopwords_en: Set[str] = field(
18-
default_factory=lambda: set(nltk_helper.get_stopwords("english"))
19-
)
20-
stopwords_zh: Set[str] = field(
21-
default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
22-
)
15+
def __init__(self, max_concurrent: int = 100):
16+
super().__init__(max_concurrent)
17+
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
18+
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
2319

2420
async def evaluate_single(self, pair: QAPair) -> float:
2521
loop = create_event_loop()

graphgen/models/generator/aggregated_generator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from dataclasses import dataclass
21
from typing import Any
32

43
from graphgen.bases import BaseGenerator
54
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
65
from graphgen.utils import compute_content_hash, detect_main_language, logger
76

87

9-
@dataclass
108
class AggregatedGenerator(BaseGenerator):
119
"""
1210
Aggregated Generator follows a TWO-STEP process:

0 commit comments

Comments
 (0)