|
12 | 12 | # limitations under the License. |
13 | 13 | # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= |
14 | 14 | import abc |
| 15 | +import copy |
| 16 | +import json |
15 | 17 | import os |
16 | 18 | import re |
17 | 19 | from abc import ABC, abstractmethod |
@@ -122,6 +124,14 @@ def __init__( |
122 | 124 | ) |
123 | 125 | self._log_dir = os.environ.get("CAMEL_LOG_DIR", "camel_logs") |
124 | 126 |
|
| 127 | + # Snapshot cleaning configuration |
| 128 | + self._snapshot_cleanup_trigger = int( |
| 129 | + os.environ.get("CAMEL_SNAPSHOT_CLEANUP_TRIGGER", "3") |
| 130 | + ) |
| 131 | + self._snapshot_keep_recent = int( |
| 132 | + os.environ.get("CAMEL_SNAPSHOT_KEEP_RECENT", "1") |
| 133 | + ) |
| 134 | + |
125 | 135 | @property |
126 | 136 | @abstractmethod |
127 | 137 | def token_counter(self) -> BaseTokenCounter: |
@@ -258,6 +268,146 @@ def preprocess_messages( |
258 | 268 |
|
259 | 269 | return formatted_messages |
260 | 270 |
|
| 271 | + def _count_snapshots(self, messages: List[OpenAIMessage]) -> int: |
| 272 | + """Count the number of messages containing snapshot key""" |
| 273 | + count = 0 |
| 274 | + for msg in messages: |
| 275 | + if msg.get('role') == 'tool' and msg.get('content'): |
| 276 | + try: |
| 277 | + content_str = msg['content'] |
| 278 | + if ( |
| 279 | + isinstance(content_str, str) |
| 280 | + and "'snapshot':" in content_str |
| 281 | + ): |
| 282 | + count += 1 |
| 283 | + continue |
| 284 | + |
| 285 | + # Try to parse as JSON |
| 286 | + if isinstance(content_str, str): |
| 287 | + try: |
| 288 | + content = json.loads(content_str) |
| 289 | + except: |
| 290 | + import ast |
| 291 | + |
| 292 | + content = ast.literal_eval(content_str) |
| 293 | + else: |
| 294 | + content = content_str |
| 295 | + |
| 296 | + if isinstance(content, dict) and 'snapshot' in content: |
| 297 | + count += 1 |
| 298 | + logger.debug( |
| 299 | + f"Found snapshot #{count} in message (parsed)" |
| 300 | + ) |
| 301 | + except Exception as e: |
| 302 | + logger.debug(f"Error parsing message content: {e}") |
| 303 | + pass |
| 304 | + logger.info(f"Total snapshot count: {count}") |
| 305 | + return count |
| 306 | + |
| 307 | + def _clean_old_snapshots( |
| 308 | + self, messages: List[OpenAIMessage], keep_recent: int = 1 |
| 309 | + ) -> List[OpenAIMessage]: |
| 310 | + """Clean old snapshots, keeping only the most recent ones""" |
| 311 | + logger.info( |
| 312 | + f"Starting snapshot cleanup with keep_recent={keep_recent}" |
| 313 | + ) |
| 314 | + messages_copy = copy.deepcopy(messages) |
| 315 | + |
| 316 | + # Find all messages with snapshots |
| 317 | + snapshot_indices = [] |
| 318 | + for i, msg in enumerate(messages_copy): |
| 319 | + if msg.get('role') == 'tool' and msg.get('content'): |
| 320 | + try: |
| 321 | + content_str = msg['content'] |
| 322 | + |
| 323 | + if ( |
| 324 | + isinstance(content_str, str) |
| 325 | + and "'snapshot':" in content_str |
| 326 | + ): |
| 327 | + snapshot_indices.append(i) |
| 328 | + logger.debug( |
| 329 | + f"Found snapshot " |
| 330 | + f"at message index {i} (string search)" |
| 331 | + ) |
| 332 | + continue |
| 333 | + |
| 334 | + # Try to parse as JSON |
| 335 | + if isinstance(content_str, str): |
| 336 | + try: |
| 337 | + content = json.loads(content_str) |
| 338 | + except: |
| 339 | + import ast |
| 340 | + |
| 341 | + content = ast.literal_eval(content_str) |
| 342 | + else: |
| 343 | + content = content_str |
| 344 | + |
| 345 | + if isinstance(content, dict) and 'snapshot' in content: |
| 346 | + snapshot_indices.append(i) |
| 347 | + logger.debug( |
| 348 | + f"Found snapshot at message index {i} (parsed)" |
| 349 | + ) |
| 350 | + except Exception as e: |
| 351 | + logger.debug(f"Error checking message {i}: {e}") |
| 352 | + pass |
| 353 | + |
| 354 | + logger.info(f"Found {len(snapshot_indices)} snapshots in messages") |
| 355 | + |
| 356 | + # Keep only the last 'keep_recent' snapshots |
| 357 | + if len(snapshot_indices) > keep_recent: |
| 358 | + indices_to_clean = snapshot_indices[:-keep_recent] |
| 359 | + logger.info( |
| 360 | + f"Will clean {len(indices_to_clean)} old snapshots" |
| 361 | + f", keeping the last {keep_recent}" |
| 362 | + ) |
| 363 | + |
| 364 | + for idx in indices_to_clean: |
| 365 | + msg = messages_copy[idx] |
| 366 | + try: |
| 367 | + content_str = msg['content'] |
| 368 | + |
| 369 | + if isinstance(content_str, str): |
| 370 | + # Try to parse the content |
| 371 | + try: |
| 372 | + content = json.loads(content_str) |
| 373 | + is_json = True |
| 374 | + except: |
| 375 | + import ast |
| 376 | + |
| 377 | + content = ast.literal_eval(content_str) |
| 378 | + is_json = False |
| 379 | + |
| 380 | + # Replace snapshot |
| 381 | + content['snapshot'] = ( |
| 382 | + 'snapshot history has been deleted' |
| 383 | + ) |
| 384 | + |
| 385 | + # Convert back to string in the same format |
| 386 | + if is_json: |
| 387 | + msg['content'] = json.dumps(content) |
| 388 | + else: |
| 389 | + # Keep as Python dict string |
| 390 | + msg['content'] = str(content) |
| 391 | + |
| 392 | + logger.debug(f"Cleaned snapshot at index {idx}") |
| 393 | + elif isinstance(msg['content'], dict): |
| 394 | + msg['content']['snapshot'] = ( |
| 395 | + 'snapshot history has been deleted' |
| 396 | + ) |
| 397 | + logger.debug(f"Cleaned snapshot at index {idx}") |
| 398 | + except Exception as e: |
| 399 | + logger.error( |
| 400 | + f"Failed to clean snapshot at index {idx}: {e}" |
| 401 | + ) |
| 402 | + pass |
| 403 | + else: |
| 404 | + logger.info( |
| 405 | + f"No cleaning needed only {len(snapshot_indices)} snapshots" |
| 406 | + f"keep_recent is {keep_recent}" |
| 407 | + ) |
| 408 | + |
| 409 | + return messages_copy |
| 410 | + |
261 | 411 | def _log_request(self, messages: List[OpenAIMessage]) -> Optional[str]: |
262 | 412 | r"""Log the request messages to a JSON file if logging is enabled. |
263 | 413 |
|
@@ -410,6 +560,32 @@ def run( |
410 | 560 | `ChatCompletionStreamManager[BaseModel]` in the structured |
411 | 561 | stream mode. |
412 | 562 | """ |
| 563 | + # Check if we should clean snapshots |
| 564 | + logger.info( |
| 565 | + f"Snapshot cleanup config: trigger={self._snapshot_cleanup_trigger}, keep_recent={self._snapshot_keep_recent}" |
| 566 | + ) |
| 567 | + snapshot_count = self._count_snapshots(messages) |
| 568 | + |
| 569 | + if snapshot_count > 0: |
| 570 | + logger.info( |
| 571 | + f"Checking if {snapshot_count} % {self._snapshot_cleanup_trigger} == 0: {snapshot_count % self._snapshot_cleanup_trigger == 0}" |
| 572 | + ) |
| 573 | + if snapshot_count % self._snapshot_cleanup_trigger == 0: |
| 574 | + logger.info( |
| 575 | + f"Snapshot count ({snapshot_count}) is multiple of {self._snapshot_cleanup_trigger}, " |
| 576 | + f"cleaning old snapshots..." |
| 577 | + ) |
| 578 | + messages = self._clean_old_snapshots( |
| 579 | + messages, keep_recent=self._snapshot_keep_recent |
| 580 | + ) |
| 581 | + logger.info( |
| 582 | + f"Cleaned snapshots, keeping only the {self._snapshot_keep_recent} most recent" |
| 583 | + ) |
| 584 | + else: |
| 585 | + logger.info( |
| 586 | + f"No cleaning needed. {snapshot_count} is not a multiple of {self._snapshot_cleanup_trigger}" |
| 587 | + ) |
| 588 | + |
413 | 589 | # Log the request if logging is enabled |
414 | 590 | log_path = self._log_request(messages) |
415 | 591 |
|
@@ -464,6 +640,32 @@ async def arun( |
464 | 640 | `AsyncChatCompletionStreamManager[BaseModel]` in the structured |
465 | 641 | stream mode. |
466 | 642 | """ |
| 643 | + # Check if we should clean snapshots |
| 644 | + logger.info( |
| 645 | + f"Snapshot cleanup config: trigger={self._snapshot_cleanup_trigger}, keep_recent={self._snapshot_keep_recent}" |
| 646 | + ) |
| 647 | + snapshot_count = self._count_snapshots(messages) |
| 648 | + |
| 649 | + if snapshot_count > 0: |
| 650 | + logger.info( |
| 651 | + f"Checking if {snapshot_count} % {self._snapshot_cleanup_trigger} == 0: {snapshot_count % self._snapshot_cleanup_trigger == 0}" |
| 652 | + ) |
| 653 | + if snapshot_count % self._snapshot_cleanup_trigger == 0: |
| 654 | + logger.info( |
| 655 | + f"Snapshot count ({snapshot_count}) is multiple of {self._snapshot_cleanup_trigger}, " |
| 656 | + f"cleaning old snapshots..." |
| 657 | + ) |
| 658 | + messages = self._clean_old_snapshots( |
| 659 | + messages, keep_recent=self._snapshot_keep_recent |
| 660 | + ) |
| 661 | + logger.info( |
| 662 | + f"Cleaned snapshots, keeping only the {self._snapshot_keep_recent} most recent" |
| 663 | + ) |
| 664 | + else: |
| 665 | + logger.info( |
| 666 | + f"No cleaning needed. {snapshot_count} is not a multiple of {self._snapshot_cleanup_trigger}" |
| 667 | + ) |
| 668 | + |
467 | 669 | # Log the request if logging is enabled |
468 | 670 | log_path = self._log_request(messages) |
469 | 671 |
|
|
0 commit comments