Skip to content

Commit 860fde8

Browse files
authored
Merge pull request #558 from ScrapeGraphAI/screenshot_scraper
Screenshot scraper integration
2 parents fec3582 + d248646 commit 860fde8

File tree

9 files changed

+336
-3
lines changed

9 files changed

+336
-3
lines changed

examples/openai/screenshot_scraper.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraper
3+
"""
4+
5+
import os
6+
import json
7+
from dotenv import load_dotenv
8+
from scrapegraphai.graphs import ScreenshotScraperGraph
9+
from scrapegraphai.utils import prettify_exec_info
10+
11+
load_dotenv()
12+
13+
# ************************************************
14+
# Define the configuration for the graph
15+
# ************************************************
16+
17+
18+
graph_config = {
19+
"llm": {
20+
"api_key": os.getenv("OPENAI_API_KEY"),
21+
"model": "gpt-4o",
22+
},
23+
"verbose": True,
24+
"headless": False,
25+
}
26+
27+
# ************************************************
28+
# Create the ScreenshotScraperGraph instance and run it
29+
# ************************************************
30+
31+
smart_scraper_graph = ScreenshotScraperGraph(
32+
prompt="List me all the projects",
33+
source="https://perinim.github.io/projects/",
34+
config=graph_config
35+
)
36+
37+
result = smart_scraper_graph.run()
38+
print(json.dumps(result, indent=4))

examples/openai/smart_scraper_openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
import os
66
import json
7+
from dotenv import load_dotenv
78
from scrapegraphai.graphs import SmartScraperGraph
89
from scrapegraphai.utils import prettify_exec_info
9-
from dotenv import load_dotenv
10+
1011
load_dotenv()
1112

1213
# ************************************************
@@ -17,7 +18,7 @@
1718
graph_config = {
1819
"llm": {
1920
"api_key": os.getenv("OPENAI_API_KEY"),
20-
"model": "gpt-3.5-turbo",
21+
"model": "gpt-4o",
2122
},
2223
"verbose": True,
2324
"headless": False,

scrapegraphai/graphs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from .markdown_scraper_graph import MDScraperGraph
2525
from .markdown_scraper_multi_graph import MDScraperMultiGraph
2626
from .search_link_graph import SearchLinkGraph
27+
from .screenshot_scraper_graph import ScreenshotScraperGraph
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
ScreenshotScraperGraph Module
3+
"""
4+
from typing import Optional
5+
import logging
6+
from pydantic import BaseModel
7+
from .base_graph import BaseGraph
8+
from .abstract_graph import AbstractGraph
9+
from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, )
10+
11+
class ScreenshotScraperGraph(AbstractGraph):
12+
"""
13+
A graph instance representing the web scraping workflow for images.
14+
15+
Attributes:
16+
prompt (str): The input text to be scraped.
17+
config (dict): Configuration parameters for the graph.
18+
source (str): The source URL or image link to scrape from.
19+
20+
Methods:
21+
__init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None)
22+
Initializes the ScreenshotScraperGraph instance with the given prompt,
23+
source, and configuration parameters.
24+
25+
_create_graph()
26+
Creates a graph of nodes representing the web scraping workflow for images.
27+
28+
run()
29+
Executes the scraping process and returns the answer to the prompt.
30+
"""
31+
32+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
33+
super().__init__(prompt, config, source, schema)
34+
35+
36+
def _create_graph(self) -> BaseGraph:
37+
"""
38+
Creates the graph of nodes representing the workflow for web scraping with images.
39+
40+
Returns:
41+
BaseGraph: A graph instance representing the web scraping workflow for images.
42+
"""
43+
fetch_screen_node = FetchScreenNode(
44+
input="url",
45+
output=["screenshots"],
46+
node_config={
47+
"link": self.source
48+
}
49+
)
50+
generate_answer_from_image_node = GenerateAnswerFromImageNode(
51+
input="screenshots",
52+
output=["answer"],
53+
node_config={
54+
"config": self.config
55+
}
56+
)
57+
58+
return BaseGraph(
59+
nodes=[
60+
fetch_screen_node,
61+
generate_answer_from_image_node,
62+
],
63+
edges=[
64+
(fetch_screen_node, generate_answer_from_image_node),
65+
],
66+
entry_point=fetch_screen_node,
67+
graph_name=self.__class__.__name__
68+
)
69+
70+
def run(self) -> str:
71+
"""
72+
Executes the scraping process and returns the answer to the prompt.
73+
74+
Returns:
75+
str: The answer to the prompt.
76+
"""
77+
78+
inputs = {"user_prompt": self.prompt}
79+
self.final_state, self.execution_info = self.graph.execute(inputs)
80+
81+
return self.final_state.get("answer", "No answer found.")
82+

scrapegraphai/nodes/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@
1919
from .graph_iterator_node import GraphIteratorNode
2020
from .merge_answers_node import MergeAnswersNode
2121
from .generate_answer_omni_node import GenerateAnswerOmniNode
22-
from .merge_generated_scripts import MergeGeneratedScriptsNode
22+
from .merge_generated_scripts import MergeGeneratedScriptsNode
23+
from .fetch_screen_node import FetchScreenNode
24+
from .generate_answer_from_image_node import GenerateAnswerFromImageNode
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
fetch_screen_node module
3+
"""
4+
from typing import List, Optional
5+
from playwright.sync_api import sync_playwright
6+
from .base_node import BaseNode
7+
from ..utils.logging import get_logger
8+
9+
class FetchScreenNode(BaseNode):
10+
"""
11+
FetchScreenNode captures screenshots from a given URL and stores the image data as bytes.
12+
"""
13+
14+
def __init__(
15+
self,
16+
input: str,
17+
output: List[str],
18+
node_config: Optional[dict] = None,
19+
node_name: str = "FetchScreenNode",
20+
):
21+
super().__init__(node_name, "node", input, output, 2, node_config)
22+
self.url = node_config.get("link")
23+
24+
def execute(self, state: dict) -> dict:
25+
"""
26+
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
27+
"""
28+
self.logger.info(f"--- Executing {self.node_name} Node ---")
29+
30+
with sync_playwright() as p:
31+
browser = p.chromium.launch()
32+
page = browser.new_page()
33+
page.goto(self.url)
34+
35+
viewport_height = page.viewport_size["height"]
36+
37+
screenshot_counter = 1
38+
39+
screenshot_data_list = []
40+
41+
def capture_screenshot(scroll_position, counter):
42+
page.evaluate(f"window.scrollTo(0, {scroll_position});")
43+
screenshot_data = page.screenshot()
44+
screenshot_data_list.append(screenshot_data)
45+
46+
capture_screenshot(0, screenshot_counter)
47+
screenshot_counter += 1
48+
capture_screenshot(viewport_height, screenshot_counter)
49+
50+
browser.close()
51+
52+
state["link"] = self.url
53+
state['screenshots'] = screenshot_data_list
54+
55+
return state
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
GenerateAnswerFromImageNode Module
3+
"""
4+
import base64
5+
import asyncio
6+
from typing import List, Optional
7+
import aiohttp
8+
from .base_node import BaseNode
9+
from ..utils.logging import get_logger
10+
11+
class GenerateAnswerFromImageNode(BaseNode):
12+
"""
13+
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
14+
and updates the state with the consolidated answers.
15+
"""
16+
17+
def __init__(
18+
self,
19+
input: str,
20+
output: List[str],
21+
node_config: Optional[dict] = None,
22+
node_name: str = "GenerateAnswerFromImageNode",
23+
):
24+
super().__init__(node_name, "node", input, output, 2, node_config)
25+
26+
async def process_image(self, session, api_key, image_data, user_prompt):
27+
"""
28+
async process image
29+
"""
30+
base64_image = base64.b64encode(image_data).decode('utf-8')
31+
32+
headers = {
33+
"Content-Type": "application/json",
34+
"Authorization": f"Bearer {api_key}"
35+
}
36+
37+
payload = {
38+
"model": self.node_config["config"]["llm"]["model"],
39+
"messages": [
40+
{
41+
"role": "user",
42+
"content": [
43+
{
44+
"type": "text",
45+
"text": user_prompt
46+
},
47+
{
48+
"type": "image_url",
49+
"image_url": {
50+
"url": f"data:image/jpeg;base64,{base64_image}"
51+
}
52+
}
53+
]
54+
}
55+
],
56+
"max_tokens": 300
57+
}
58+
59+
async with session.post("https://api.openai.com/v1/chat/completions",
60+
headers=headers, json=payload) as response:
61+
result = await response.json()
62+
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
63+
64+
async def execute_async(self, state: dict) -> dict:
65+
"""
66+
Processes images from the state, generates answers,
67+
consolidates the results, and updates the state asynchronously.
68+
"""
69+
self.logger.info(f"--- Executing {self.node_name} Node ---")
70+
71+
images = state.get('screenshots', [])
72+
analyses = []
73+
74+
supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo")
75+
76+
if self.node_config["config"]["llm"]["model"] not in supported_models:
77+
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
78+
is not supported. Supported models are:
79+
{', '.join(supported_models)}.""")
80+
81+
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
82+
83+
async with aiohttp.ClientSession() as session:
84+
tasks = [
85+
self.process_image(session, api_key, image_data,
86+
state.get("user_prompt", "Extract information from the image"))
87+
for image_data in images
88+
]
89+
90+
analyses = await asyncio.gather(*tasks)
91+
92+
consolidated_analysis = " ".join(analyses)
93+
94+
state['answer'] = {
95+
"consolidated_analysis": consolidated_analysis
96+
}
97+
98+
return state
99+
100+
def execute(self, state: dict) -> dict:
101+
"""
102+
Wrapper to run the asynchronous execute_async function in a synchronous context.
103+
"""
104+
try:
105+
eventloop = asyncio.get_event_loop()
106+
except RuntimeError:
107+
eventloop = None
108+
109+
if eventloop and eventloop.is_running():
110+
task = eventloop.create_task(self.execute_async(state))
111+
state = eventloop.run_until_complete(asyncio.gather(task))[0]
112+
else:
113+
state = asyncio.run(self.execute_async(state))
114+
115+
return state
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import pytest
3+
import json
4+
from scrapegraphai.graphs import ScreenshotScraperGraph
5+
from dotenv import load_dotenv
6+
7+
# Load environment variables
8+
load_dotenv()
9+
10+
# Define a fixture for the graph configuration
11+
@pytest.fixture
12+
def graph_config():
13+
"""
14+
Creation of the graph
15+
"""
16+
return {
17+
"llm": {
18+
"api_key": os.getenv("OPENAI_API_KEY"),
19+
"model": "gpt-4o",
20+
},
21+
"verbose": True,
22+
"headless": False,
23+
}
24+
25+
def test_screenshot_scraper_graph(graph_config):
26+
"""
27+
test
28+
"""
29+
smart_scraper_graph = ScreenshotScraperGraph(
30+
prompt="List me all the projects",
31+
source="https://perinim.github.io/projects/",
32+
config=graph_config
33+
)
34+
35+
result = smart_scraper_graph.run()
36+
37+
assert result is not None, "The result should not be None"
38+
39+
print(json.dumps(result, indent=4))

0 commit comments

Comments
 (0)