| 
 | 1 | +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========  | 
 | 2 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 3 | +# you may not use this file except in compliance with the License.  | 
 | 4 | +# You may obtain a copy of the License at  | 
 | 5 | +#  | 
 | 6 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 7 | +#  | 
 | 8 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 9 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 11 | +# See the License for the specific language governing permissions and  | 
 | 12 | +# limitations under the License.  | 
 | 13 | +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========  | 
 | 14 | +import argparse  | 
 | 15 | +import json  | 
 | 16 | +import logging  | 
 | 17 | +import os  | 
 | 18 | +import shutil  | 
 | 19 | +import subprocess  | 
 | 20 | +import sys  | 
 | 21 | +import time  | 
 | 22 | +from logging.handlers import RotatingFileHandler  | 
 | 23 | +from queue import Empty, Queue  | 
 | 24 | +from threading import Thread  | 
 | 25 | + | 
 | 26 | +import requests  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +def setup_dispatcher_logging():  | 
 | 30 | +    log_formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')  | 
 | 31 | +    log_file = 'dispatcher.log'  | 
 | 32 | + | 
 | 33 | +    logger = logging.getLogger('dispatcher')  | 
 | 34 | +    logger.setLevel(logging.INFO)  | 
 | 35 | + | 
 | 36 | +    # File Handler  | 
 | 37 | +    file_handler = RotatingFileHandler(  | 
 | 38 | +        log_file, maxBytes=1024 * 1024 * 5, backupCount=2  | 
 | 39 | +    )  | 
 | 40 | +    file_handler.setFormatter(log_formatter)  | 
 | 41 | +    logger.addHandler(file_handler)  | 
 | 42 | + | 
 | 43 | +    # Stream Handler (for console output)  | 
 | 44 | +    stream_handler = logging.StreamHandler(sys.stdout)  | 
 | 45 | +    stream_handler.setFormatter(log_formatter)  | 
 | 46 | +    logger.addHandler(stream_handler)  | 
 | 47 | + | 
 | 48 | +    return logger  | 
 | 49 | + | 
 | 50 | + | 
 | 51 | +log = setup_dispatcher_logging()  | 
 | 52 | + | 
 | 53 | + | 
 | 54 | +def download_website_assets(project_name: str):  | 
 | 55 | +    log.info(f"Checking for assets for project: {project_name}")  | 
 | 56 | + | 
 | 57 | +    # Check if assets already exist  | 
 | 58 | +    project_dir = project_name  | 
 | 59 | +    templates_path = os.path.join(project_dir, 'templates')  | 
 | 60 | +    static_path = os.path.join(project_dir, 'static')  | 
 | 61 | +    if os.path.exists(templates_path) and os.path.exists(static_path):  | 
 | 62 | +        log.info(  | 
 | 63 | +            "Templates and static folders already exist. Skipping download."  | 
 | 64 | +        )  | 
 | 65 | +        return True  | 
 | 66 | + | 
 | 67 | +    log.info(f"Attempting to download assets for project: {project_name}")  | 
 | 68 | +    try:  | 
 | 69 | +        from huggingface_hub import snapshot_download  | 
 | 70 | +    except ImportError:  | 
 | 71 | +        log.error(  | 
 | 72 | +            "huggingface_hub not installed. "  | 
 | 73 | +            "Please install it with `pip install huggingface-hub`"  | 
 | 74 | +        )  | 
 | 75 | +        return False  | 
 | 76 | + | 
 | 77 | +    repo_id = "camel-ai/mock_websites"  | 
 | 78 | +    local_dir_root = "hf_mock_website"  | 
 | 79 | +    project_pattern = f"{project_name}/*"  | 
 | 80 | + | 
 | 81 | +    try:  | 
 | 82 | +        # Download only the project's folder  | 
 | 83 | +        snapshot_path = snapshot_download(  | 
 | 84 | +            repo_id=repo_id,  | 
 | 85 | +            repo_type="dataset",  | 
 | 86 | +            allow_patterns=project_pattern,  | 
 | 87 | +            local_dir=local_dir_root,  | 
 | 88 | +            local_dir_use_symlinks=False,  | 
 | 89 | +            # Use False for Windows compatibility  | 
 | 90 | +        )  | 
 | 91 | +        log.info(f"Snapshot downloaded to: {snapshot_path}")  | 
 | 92 | + | 
 | 93 | +        # The downloaded content for the project is in a subdirectory  | 
 | 94 | +        project_assets_path = os.path.join(snapshot_path, project_name)  | 
 | 95 | + | 
 | 96 | +        if not os.path.isdir(project_assets_path):  | 
 | 97 | +            log.error(  | 
 | 98 | +                f"Project folder '{project_name}' not found in downloaded "  | 
 | 99 | +                f"assets at '{project_assets_path}'"  | 
 | 100 | +            )  | 
 | 101 | +            return False  | 
 | 102 | + | 
 | 103 | +        # Copy templates and static folders into project root  | 
 | 104 | +        for folder in ['templates', 'static']:  | 
 | 105 | +            src = os.path.join(project_assets_path, folder)  | 
 | 106 | +            # Destination is now inside the project folder  | 
 | 107 | +            dst = os.path.join(".", project_name, folder)  | 
 | 108 | +            if not os.path.exists(src):  | 
 | 109 | +                log.warning(f"'{src}' not found in downloaded assets.")  | 
 | 110 | +                continue  | 
 | 111 | + | 
 | 112 | +            if os.path.exists(dst):  | 
 | 113 | +                log.info(f"Removing existing '{dst}' directory.")  | 
 | 114 | +                shutil.rmtree(dst)  | 
 | 115 | +            log.info(f"Copying '{src}' to '{dst}'.")  | 
 | 116 | +            shutil.copytree(src, dst)  | 
 | 117 | +        log.info(f"Assets for '{project_name}' are set up.")  | 
 | 118 | +        return True  | 
 | 119 | +    except Exception as e:  | 
 | 120 | +        log.error(f"Failed to download or set up assets: {e}")  | 
 | 121 | +        return False  | 
 | 122 | + | 
 | 123 | + | 
 | 124 | +def enqueue_output(stream, queue):  | 
 | 125 | +    # If `stream` is from a subprocess opened with `text=True`,  | 
 | 126 | +    # then `readline()` returns strings, and `''` is the sentinel for EOF.  | 
 | 127 | +    for line in iter(stream.readline, ''):  | 
 | 128 | +        queue.put(line)  | 
 | 129 | +    stream.close()  | 
 | 130 | + | 
 | 131 | + | 
 | 132 | +def run_project(project_name: str, port: int):  | 
 | 133 | +    # 1. Prepare environment  | 
 | 134 | +    log.info(f"Setting up project: {project_name}")  | 
 | 135 | +    if not download_website_assets(project_name):  | 
 | 136 | +        log.error("Failed to download assets. Aborting.")  | 
 | 137 | +        return  | 
 | 138 | + | 
 | 139 | +    # Load task configuration  | 
 | 140 | +    task_file = 'task.json'  | 
 | 141 | +    if not os.path.exists(task_file):  | 
 | 142 | +        log.error(f"'{task_file}' not found. Aborting.")  | 
 | 143 | +        return  | 
 | 144 | +    with open(task_file, 'r') as f:  | 
 | 145 | +        task_data = json.load(f)  | 
 | 146 | +    products = task_data.get("products", [])  | 
 | 147 | +    ground_truth = task_data.get("ground_truth_cart", [])  | 
 | 148 | + | 
 | 149 | +    # Prepare project-specific files  | 
 | 150 | +    project_dir = project_name  | 
 | 151 | +    app_path = os.path.join(project_dir, 'app.py')  | 
 | 152 | +    if not os.path.exists(app_path):  | 
 | 153 | +        log.error(f"Application file not found: {app_path}. Aborting.")  | 
 | 154 | +        return  | 
 | 155 | + | 
 | 156 | +    # Write the products to a file inside the project directory  | 
 | 157 | +    with open(os.path.join(project_dir, 'products.json'), 'w') as f:  | 
 | 158 | +        json.dump(products, f)  | 
 | 159 | +    log.info(f"Wrote products.json to '{project_dir}'.")  | 
 | 160 | + | 
 | 161 | +    # Start the web server app  | 
 | 162 | +    # Use sys.executable to ensure we use the same python interpreter  | 
 | 163 | +    cmd = [sys.executable, app_path, '--port', str(port)]  | 
 | 164 | +    process = subprocess.Popen(  | 
 | 165 | +        cmd,  | 
 | 166 | +        stdout=subprocess.PIPE,  | 
 | 167 | +        stderr=subprocess.PIPE,  | 
 | 168 | +        text=True,  | 
 | 169 | +        bufsize=1,  | 
 | 170 | +        universal_newlines=True,  | 
 | 171 | +    )  | 
 | 172 | +    log.info(f"Started {app_path} on port {port} with PID: {process.pid}")  | 
 | 173 | + | 
 | 174 | +    # Non-blocking stream reading for both stdout and stderr  | 
 | 175 | +    q_out: Queue[str] = Queue()  | 
 | 176 | +    t_out = Thread(target=enqueue_output, args=(process.stdout, q_out))  | 
 | 177 | +    t_out.daemon = True  | 
 | 178 | +    t_out.start()  | 
 | 179 | + | 
 | 180 | +    q_err: Queue[str] = Queue()  | 
 | 181 | +    t_err = Thread(target=enqueue_output, args=(process.stderr, q_err))  | 
 | 182 | +    t_err.daemon = True  | 
 | 183 | +    t_err.start()  | 
 | 184 | + | 
 | 185 | +    time.sleep(5)  # Wait for server to start  | 
 | 186 | + | 
 | 187 | +    # 5. Start the task and then wait for user to terminate  | 
 | 188 | +    base_url = f"http://127.0.0.1:{port}"  | 
 | 189 | +    try:  | 
 | 190 | +        # Start task  | 
 | 191 | +        log.info(f"Starting task with ground truth: {ground_truth}")  | 
 | 192 | +        r = requests.post(  | 
 | 193 | +            f"{base_url}/task/start", json={"ground_truth_cart": ground_truth}  | 
 | 194 | +        )  | 
 | 195 | +        r.raise_for_status()  | 
 | 196 | +        log.info(f"Task started on server: {r.json()['message']}")  | 
 | 197 | +        print(  | 
 | 198 | +            "Server is running. Interact with the website at "  | 
 | 199 | +            f"http://127.0.0.1:{port}"  | 
 | 200 | +        )  | 
 | 201 | +        print("Dispatcher is now polling for task completion...")  | 
 | 202 | +        print("Press Ctrl+C to stop the dispatcher early and get results.")  | 
 | 203 | + | 
 | 204 | +        # Poll for task completion  | 
 | 205 | +        while True:  | 
 | 206 | +            # Check if the subprocess has terminated unexpectedly  | 
 | 207 | +            if process.poll() is not None:  | 
 | 208 | +                log.error("App process terminated unexpectedly.")  | 
 | 209 | +                break  | 
 | 210 | + | 
 | 211 | +            # Check for completion via API  | 
 | 212 | +            try:  | 
 | 213 | +                r_check = requests.get(f"{base_url}/task/check")  | 
 | 214 | +                r_check.raise_for_status()  | 
 | 215 | +                status = r_check.json()  | 
 | 216 | +                if status.get('completed', False):  | 
 | 217 | +                    log.info(  | 
 | 218 | +                        "Task completion reported by API. "  | 
 | 219 | +                        "Proceeding to final report and shutdown."  | 
 | 220 | +                    )  | 
 | 221 | +                    break  # Exit the loop on completion  | 
 | 222 | +            except requests.exceptions.RequestException as e:  | 
 | 223 | +                log.error(f"Could not poll task status: {e}")  | 
 | 224 | +                break  # Exit if the server becomes unresponsive  | 
 | 225 | + | 
 | 226 | +            # Log any error output from the app  | 
 | 227 | +            try:  | 
 | 228 | +                err_line = q_err.get_nowait()  | 
 | 229 | +                log.error(f"APP STDERR: {err_line.strip()}")  | 
 | 230 | +            except Empty:  | 
 | 231 | +                pass  # No new error output  | 
 | 232 | +            time.sleep(2)  # Wait for 2 seconds before polling again  | 
 | 233 | + | 
 | 234 | +    except KeyboardInterrupt:  | 
 | 235 | +        log.info(  | 
 | 236 | +            "\nCtrl+C detected. Shutting down and checking final task status."  | 
 | 237 | +        )  | 
 | 238 | + | 
 | 239 | +    except requests.exceptions.RequestException as e:  | 
 | 240 | +        log.error(f"Failed to communicate with the web app: {e}")  | 
 | 241 | +    finally:  | 
 | 242 | +        # 6. Check task completion  | 
 | 243 | +        log.info("Checking final task completion status.")  | 
 | 244 | +        op_steps = 0  | 
 | 245 | +        try:  | 
 | 246 | +            r = requests.get(f"{base_url}/task/check")  | 
 | 247 | +            r.raise_for_status()  | 
 | 248 | +            result = r.json()  | 
 | 249 | +            log.info(f"Final task check result: {result}")  | 
 | 250 | +            success = result.get('completed', False)  | 
 | 251 | +            op_steps = result.get('steps', 0)  | 
 | 252 | +        except requests.exceptions.RequestException as e:  | 
 | 253 | +            log.error(f"Could not get final task status: {e}")  | 
 | 254 | +            success = False  | 
 | 255 | + | 
 | 256 | +        log.info("--- FINAL FEEDBACK ---")  | 
 | 257 | +        log.info(f"Project: {project_name}")  | 
 | 258 | +        log.info(f"Success: {success}")  | 
 | 259 | +        log.info(f"Total Operation Steps: {op_steps}")  | 
 | 260 | +        log.info("----------------------")  | 
 | 261 | + | 
 | 262 | +        # 7. Shutdown server  | 
 | 263 | +        log.info("Shutting down web server.")  | 
 | 264 | +        process.terminate()  | 
 | 265 | +        try:  | 
 | 266 | +            process.wait(timeout=5)  | 
 | 267 | +        except subprocess.TimeoutExpired:  | 
 | 268 | +            log.warning("Server did not terminate gracefully. Killing.")  | 
 | 269 | +            process.kill()  | 
 | 270 | +        log.info("Web server process stopped.")  | 
 | 271 | +        # Log any remaining stderr for debugging  | 
 | 272 | +        # Drain the queue first  | 
 | 273 | +        while not q_err.empty():  | 
 | 274 | +            log.error(f"APP STDERR: {q_err.get_nowait().strip()}")  | 
 | 275 | + | 
 | 276 | + | 
 | 277 | +def main():  | 
 | 278 | +    parser = argparse.ArgumentParser(  | 
 | 279 | +        description="Dispatcher for running mock website benchmarks."  | 
 | 280 | +    )  | 
 | 281 | +    parser.add_argument(  | 
 | 282 | +        '--project',  | 
 | 283 | +        type=str,  | 
 | 284 | +        default='shopping_mall',  | 
 | 285 | +        help='The name of the project to run (e.g., shopping_mall).',  | 
 | 286 | +    )  | 
 | 287 | +    parser.add_argument(  | 
 | 288 | +        '--port',  | 
 | 289 | +        type=int,  | 
 | 290 | +        default=5001,  | 
 | 291 | +        help='Port to run the project web server on.',  | 
 | 292 | +    )  | 
 | 293 | +    args = parser.parse_args()  | 
 | 294 | + | 
 | 295 | +    run_project(args.project, args.port)  | 
 | 296 | + | 
 | 297 | + | 
 | 298 | +if __name__ == "__main__":  | 
 | 299 | +    main()  | 
0 commit comments