Skip to content

Commit 3978d47

Browse files
feat: add arxiv toolkit (#994)
Co-authored-by: dreamflyfreya <dreamfly@sas.upenn.edu>
1 parent cc83c94 commit 3978d47

File tree

9 files changed

+611
-217
lines changed

9 files changed

+611
-217
lines changed

camel/toolkits/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,26 @@
2020
)
2121
from .open_api_specs.security_config import openapi_security_config
2222

23-
from .google_maps_toolkit import GoogleMapsToolkit
23+
2424
from .math_toolkit import MathToolkit, MATH_FUNCS
25-
from .open_api_toolkit import OpenAPIToolkit
26-
from .retrieval_toolkit import RetrievalToolkit
2725
from .search_toolkit import SearchToolkit, SEARCH_FUNCS
28-
from .twitter_toolkit import TwitterToolkit
2926
from .weather_toolkit import WeatherToolkit, WEATHER_FUNCS
30-
from .slack_toolkit import SlackToolkit
3127
from .dalle_toolkit import DalleToolkit, DALLE_FUNCS
32-
from .linkedin_toolkit import LinkedInToolkit
33-
from .reddit_toolkit import RedditToolkit
3428

29+
from .base import BaseToolkit
30+
from .google_maps_toolkit import GoogleMapsToolkit
3531
from .code_execution import CodeExecutionToolkit
3632
from .github_toolkit import GithubToolkit
33+
from .arxiv_toolkit import ArxivToolkit
34+
from .linkedin_toolkit import LinkedInToolkit
35+
from .reddit_toolkit import RedditToolkit
36+
from .slack_toolkit import SlackToolkit
37+
from .twitter_toolkit import TwitterToolkit
38+
from .open_api_toolkit import OpenAPIToolkit
39+
from .retrieval_toolkit import RetrievalToolkit
3740

3841
__all__ = [
42+
'BaseToolkit',
3943
'FunctionTool',
4044
'OpenAIFunction',
4145
'get_openai_function_schema',
@@ -54,6 +58,7 @@
5458
'LinkedInToolkit',
5559
'RedditToolkit',
5660
'CodeExecutionToolkit',
61+
'ArxivToolkit',
5762
'MATH_FUNCS',
5863
'SEARCH_FUNCS',
5964
'WEATHER_FUNCS',

camel/toolkits/arxiv_toolkit.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
15+
from typing import Dict, Generator, List, Optional
16+
17+
from camel.toolkits.base import BaseToolkit
18+
from camel.toolkits.function_tool import FunctionTool
19+
from camel.utils import dependencies_required
20+
21+
22+
class ArxivToolkit(BaseToolkit):
23+
r"""A toolkit for interacting with the arXiv API to search and download
24+
academic papers.
25+
"""
26+
27+
@dependencies_required('arxiv')
28+
def __init__(self) -> None:
29+
r"""Initializes the ArxivToolkit and sets up the arXiv client."""
30+
import arxiv
31+
32+
self.client = arxiv.Client()
33+
34+
def _get_search_results(
35+
self,
36+
query: str,
37+
paper_ids: Optional[List[str]] = None,
38+
max_results: Optional[int] = 5,
39+
) -> Generator:
40+
r"""Retrieves search results from the arXiv API based on the provided
41+
query and optional paper IDs.
42+
43+
Args:
44+
query (str): The search query string used to search for papers on
45+
arXiv.
46+
paper_ids (List[str], optional): A list of specific arXiv paper
47+
IDs to search for. (default::obj: `None`)
48+
max_results (int, optional): The maximum number of search results
49+
to retrieve. (default::obj: `5`)
50+
51+
Returns:
52+
Generator: A generator that yields results from the arXiv search
53+
query, which includes metadata about each paper matching the
54+
query.
55+
"""
56+
import arxiv
57+
58+
paper_ids = paper_ids or []
59+
search_query = arxiv.Search(
60+
query=query,
61+
id_list=paper_ids,
62+
max_results=max_results,
63+
)
64+
return self.client.results(search_query)
65+
66+
def search_papers(
67+
self,
68+
query: str,
69+
paper_ids: Optional[List[str]] = None,
70+
max_results: Optional[int] = 5,
71+
) -> List[Dict[str, str]]:
72+
r"""Searches for academic papers on arXiv using a query string and
73+
optional paper IDs.
74+
75+
Args:
76+
query (str): The search query string.
77+
paper_ids (List[str], optional): A list of specific arXiv paper
78+
IDs to search for. (default::obj: `None`)
79+
max_results (int, optional): The maximum number of search results
80+
to return. (default::obj: `5`)
81+
82+
Returns:
83+
List[Dict[str, str]]: A list of dictionaries, each containing
84+
information about a paper, including title, published date,
85+
authors, entry ID, summary, and extracted text from the paper.
86+
"""
87+
from arxiv2text import arxiv_to_text
88+
89+
search_results = self._get_search_results(
90+
query, paper_ids, max_results
91+
)
92+
papers_data = []
93+
94+
for paper in search_results:
95+
paper_info = {
96+
"title": paper.title,
97+
"published_date": paper.updated.date().isoformat(),
98+
"authors": [author.name for author in paper.authors],
99+
"entry_id": paper.entry_id,
100+
"summary": paper.summary,
101+
# TODO: Use chunkr instead of atxiv_to_text for better
102+
# performance
103+
"paper_text": arxiv_to_text(paper.pdf_url),
104+
}
105+
papers_data.append(paper_info)
106+
107+
return papers_data
108+
109+
def download_papers(
110+
self,
111+
query: str,
112+
paper_ids: Optional[List[str]] = None,
113+
max_results: Optional[int] = 5,
114+
output_dir: Optional[str] = "./",
115+
) -> str:
116+
r"""Downloads PDFs of academic papers from arXiv based on the provided
117+
query.
118+
119+
Args:
120+
query (str): The search query string.
121+
paper_ids (List[str], optional): A list of specific arXiv paper
122+
IDs to download. (default::obj: `None`)
123+
max_results (int, optional): The maximum number of search results
124+
to download. (default::obj: `5`)
125+
output_dir (str, optional): The directory to save the downloaded
126+
PDFs. Defaults to the current directory.
127+
128+
Returns:
129+
str: Status message indicating success or failure.
130+
"""
131+
try:
132+
search_results = self._get_search_results(
133+
query, paper_ids, max_results
134+
)
135+
136+
for paper in search_results:
137+
paper.download_pdf(
138+
dirpath=output_dir, filename=f"{paper.title}" + ".pdf"
139+
)
140+
return "papers downloaded successfully"
141+
except Exception as e:
142+
return f"An error occurred: {e}"
143+
144+
def get_tools(self) -> List[FunctionTool]:
145+
return [
146+
FunctionTool(self.search_papers),
147+
FunctionTool(self.download_papers),
148+
]

camel/toolkits/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
from typing import List
1616

17+
from camel.toolkits import FunctionTool
1718
from camel.utils import AgentOpsMeta
1819

19-
from .function_tool import FunctionTool
20-
2120

2221
class BaseToolkit(metaclass=AgentOpsMeta):
2322
def get_tools(self) -> List[FunctionTool]:

camel/toolkits/code_execution.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
from camel.interpreters import InternalPythonInterpreter
1717
from camel.toolkits import FunctionTool
18-
19-
from .base import BaseToolkit
18+
from camel.toolkits.base import BaseToolkit
2019

2120

2221
class CodeExecutionToolkit(BaseToolkit):

camel/toolkits/github_toolkit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818

1919
from pydantic import BaseModel
2020

21+
from camel.toolkits import FunctionTool
22+
from camel.toolkits.base import BaseToolkit
2123
from camel.utils import dependencies_required
2224

23-
from .base import BaseToolkit
24-
from .function_tool import FunctionTool
25-
2625

2726
class GithubIssue(BaseModel):
2827
r"""Represents a GitHub issue.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
15+
from camel.agents import ChatAgent
16+
from camel.configs.openai_config import ChatGPTConfig
17+
from camel.messages import BaseMessage
18+
from camel.models import ModelFactory
19+
from camel.toolkits import ArxivToolkit
20+
from camel.types import ModelPlatformType, ModelType
21+
22+
# Define system message
23+
sys_msg = BaseMessage.make_assistant_message(
24+
role_name='Tools calling opertor', content='You are a helpful assistant'
25+
)
26+
27+
# Set model config
28+
tools = ArxivToolkit().get_tools()
29+
model_config_dict = ChatGPTConfig(
30+
temperature=0.0,
31+
).as_dict()
32+
33+
model = ModelFactory.create(
34+
model_platform=ModelPlatformType.OPENAI,
35+
model_type=ModelType.GPT_4O_MINI,
36+
model_config_dict=model_config_dict,
37+
)
38+
39+
# Set agent
40+
camel_agent = ChatAgent(
41+
system_message=sys_msg,
42+
model=model,
43+
tools=tools,
44+
)
45+
camel_agent.reset()
46+
47+
# Define a user message
48+
usr_msg = BaseMessage.make_user_message(
49+
role_name="CAMEL User",
50+
content="""Search paper 'attention is all you need' for me""",
51+
)
52+
53+
# Get response information
54+
response = camel_agent.step(usr_msg)
55+
print(str(response.info['tool_calls'])[:1000])
56+
'''
57+
===============================================================================
58+
[FunctionCallingRecord(func_name='search_papers', args={'query': 'attention is
59+
all you need'}, result=[{'title': "Attention Is All You Need But You Don't
60+
Need All Of It For Inference of Large Language Models", 'published_date':
61+
'2024-07-22', 'authors': ['Georgy Tyukin', 'Gbetondji J-S Dovonon', 'Jean
62+
Kaddour', 'Pasquale Minervini'], 'entry_id': 'http://arxiv.org/abs/2407.
63+
15516v1', 'summary': 'The inference demand for LLMs has skyrocketed in recent
64+
months, and serving\nmodels with low latencies remains challenging due to the
65+
quadratic input length\ncomplexity of the attention layers. In this work, we
66+
investigate the effect of\ndropping MLP and attention layers at inference time
67+
on the performance of\nLlama-v2 models. We find that dropping dreeper
68+
attention layers only marginally\ndecreases performance but leads to the best
69+
speedups alongside dropping entire\nlayers. For example, removing 33\\% of
70+
attention layers in a 13B Llama2 model\nresults in a 1.8\\% drop in average
71+
performance ove...
72+
===============================================================================
73+
'''
74+
75+
76+
# Define a user message
77+
usr_msg = BaseMessage.make_user_message(
78+
role_name="CAMEL User",
79+
content="""Download paper "attention is all you need" for me to my
80+
local path '/Users/enrei/Desktop/camel0826/camel/examples/tool_call'""",
81+
)
82+
83+
# Get response information
84+
response = camel_agent.step(usr_msg)
85+
print(str(response.info['tool_calls'])[:1000])
86+
'''
87+
===============================================================================
88+
[FunctionCallingRecord(func_name='download_papers', args={'query': 'attention
89+
is all you need', 'output_dir': '/Users/enrei/Desktop/camel0826/camel/examples/
90+
tool_call', 'paper_ids': ['2407.15516v1', '2107.08000v1', '2306.01926v1',
91+
'2112.05993v1', '1912.11959v2']}, result='papers downloaded successfully')]
92+
===============================================================================
93+
'''

0 commit comments

Comments
 (0)