diff --git a/src/crewai/agents/cache/cache_handler.py b/src/crewai/agents/cache/cache_handler.py index 09dd76f268..0cf2e259bd 100644 --- a/src/crewai/agents/cache/cache_handler.py +++ b/src/crewai/agents/cache/cache_handler.py @@ -1,15 +1,28 @@ from typing import Any, Dict, Optional +import threading +from threading import local from pydantic import BaseModel, PrivateAttr +_thread_local = local() + + class CacheHandler(BaseModel): """Callback handler for tool usage.""" _cache: Dict[str, Any] = PrivateAttr(default_factory=dict) + def _get_lock(self): + """Get a thread-local lock to avoid pickling issues.""" + if not hasattr(_thread_local, "cache_lock"): + _thread_local.cache_lock = threading.Lock() + return _thread_local.cache_lock + def add(self, tool, input, output): - self._cache[f"{tool}-{input}"] = output + with self._get_lock(): + self._cache[f"{tool}-{input}"] = output def read(self, tool, input) -> Optional[str]: - return self._cache.get(f"{tool}-{input}") + with self._get_lock(): + return self._cache.get(f"{tool}-{input}") diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d488783eae..35e6ad4f28 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -88,7 +88,7 @@ class Crew(BaseModel): _rpm_controller: RPMController = PrivateAttr() _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() - _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) + _cache_handler: InstanceOf[CacheHandler] = PrivateAttr() _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() _long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() _entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr() diff --git a/src/crewai/telemetry/telemetry.py b/src/crewai/telemetry/telemetry.py index 984a4938de..c4d15165fd 100644 --- a/src/crewai/telemetry/telemetry.py +++ b/src/crewai/telemetry/telemetry.py @@ -4,11 +4,15 @@ import json import os import platform +import threading import warnings from contextlib import contextmanager from importlib.metadata import version +from threading import local from typing import TYPE_CHECKING, Any, Optional +_thread_local = local() + @contextmanager def suppress_warnings(): @@ -76,12 +80,20 @@ def __init__(self): raise # Re-raise the exception to not interfere with system signals self.ready = False + def _get_lock(self): + """Get a thread-local lock to avoid pickling issues.""" + if not hasattr(_thread_local, "telemetry_lock"): + _thread_local.telemetry_lock = threading.Lock() + return _thread_local.telemetry_lock + def set_tracer(self): if self.ready and not self.trace_set: try: - with suppress_warnings(): - trace.set_tracer_provider(self.provider) - self.trace_set = True + with self._get_lock(): + if not self.trace_set: # Double-check to avoid race condition + with suppress_warnings(): + trace.set_tracer_provider(self.provider) + self.trace_set = True except Exception: self.ready = False self.trace_set = False @@ -90,7 +102,8 @@ def _safe_telemetry_operation(self, operation): if not self.ready: return try: - operation() + with self._get_lock(): + operation() except Exception: pass diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py new file mode 100644 index 0000000000..797af3ea7d --- /dev/null +++ b/tests/concurrency_test.py @@ -0,0 +1,186 @@ +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +import pytest +from unittest.mock import patch + +from crewai import Agent, Crew, Task + + +class MockLLM: + """Mock LLM for testing.""" + def __init__(self, model="gpt-3.5-turbo", **kwargs): + self.model = model + self.stop = None + self.timeout = None + self.temperature = None + self.top_p = None + self.n = None + self.max_completion_tokens = None + self.max_tokens = None + self.presence_penalty = None + self.frequency_penalty = None + self.logit_bias = None + self.response_format = None + self.seed = None + self.logprobs = None + self.top_logprobs = None + self.base_url = None + self.api_version = None + self.api_key = None + self.callbacks = [] + self.context_window_size = 8192 + self.kwargs = {} + + for key, value in kwargs.items(): + setattr(self, key, value) + + def complete(self, prompt, **kwargs): + """Mock completion method.""" + return f"Mock response for: {prompt[:20]}..." + + def chat_completion(self, messages, **kwargs): + """Mock chat completion method.""" + return {"choices": [{"message": {"content": "Mock response"}}]} + + def function_call(self, messages, functions, **kwargs): + """Mock function call method.""" + return { + "choices": [ + { + "message": { + "content": "Mock response", + "function_call": { + "name": "test_function", + "arguments": '{"arg1": "value1"}' + } + } + } + ] + } + + def supports_stop_words(self): + """Mock supports_stop_words method.""" + return False + + def supports_function_calling(self): + """Mock supports_function_calling method.""" + return True + + def get_context_window_size(self): + """Mock get_context_window_size method.""" + return self.context_window_size + + def call(self, messages, callbacks=None): + """Mock call method.""" + return "Mock response from call method" + + def set_callbacks(self, callbacks): + """Mock set_callbacks method.""" + self.callbacks = callbacks + + def set_env_callbacks(self): + """Mock set_env_callbacks method.""" + pass + + +def create_test_crew(): + """Create a simple test crew for concurrency testing.""" + with patch("crewai.agent.LLM", MockLLM): + agent = Agent( + role="Test Agent", + goal="Test concurrent execution", + backstory="I am a test agent for concurrent execution", + ) + + task = Task( + description="Test task for concurrent execution", + expected_output="Test output", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task], + verbose=False, + ) + + return crew + + +def test_threading_concurrency(): + """Test concurrent execution using ThreadPoolExecutor.""" + num_threads = 5 + results = [] + + def generate_response(idx): + try: + crew = create_test_crew() + with patch("crewai.agent.LLM", MockLLM): + output = crew.kickoff(inputs={"test_input": f"input_{idx}"}) + return output + except Exception as e: + pytest.fail(f"Exception in thread {idx}: {e}") + return None + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(generate_response, i) for i in range(num_threads)] + + for future in as_completed(futures): + result = future.result() + assert result is not None + results.append(result) + + assert len(results) == num_threads + + +@pytest.mark.asyncio +async def test_asyncio_concurrency(): + """Test concurrent execution using asyncio.""" + num_tasks = 5 + sem = asyncio.Semaphore(num_tasks) + + async def generate_response_async(idx): + async with sem: + try: + crew = create_test_crew() + with patch("crewai.agent.LLM", MockLLM): + output = await crew.kickoff_async(inputs={"test_input": f"input_{idx}"}) + return output + except Exception as e: + pytest.fail(f"Exception in task {idx}: {e}") + return None + + tasks = [generate_response_async(i) for i in range(num_tasks)] + results = await asyncio.gather(*tasks) + + assert len(results) == num_tasks + assert all(result is not None for result in results) + + +@pytest.mark.asyncio +async def test_extended_asyncio_concurrency(): + """Extended test for asyncio concurrency with more iterations.""" + num_tasks = 5 # Reduced from 10 for faster testing + iterations = 2 # Reduced from 3 for faster testing + sem = asyncio.Semaphore(num_tasks) + + async def generate_response_async(idx): + async with sem: + crew = create_test_crew() + for i in range(iterations): + try: + with patch("crewai.agent.LLM", MockLLM): + output = await crew.kickoff_async( + inputs={"test_input": f"input_{idx}_{i}"} + ) + assert output is not None + except Exception as e: + pytest.fail(f"Exception in task {idx}, iteration {i}: {e}") + return False + return True + + tasks = [generate_response_async(i) for i in range(num_tasks)] + results = await asyncio.gather(*tasks) + + assert all(results)