diff --git a/redis/cluster.py b/redis/cluster.py index dbcf5cc2b7..93e4188492 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1403,12 +1403,12 @@ class LoadBalancer: """ def __init__(self, start_index: int = 0) -> None: - self.primary_to_idx = {} + self.slot_to_idx = {} self.start_index = start_index def get_server_index( self, - primary: str, + slot: int, list_size: int, load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, ) -> int: @@ -1416,26 +1416,26 @@ def get_server_index( return self._get_random_replica_index(list_size) else: return self._get_round_robin_index( - primary, + slot, list_size, load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, ) def reset(self) -> None: - self.primary_to_idx.clear() + self.slot_to_idx.clear() def _get_random_replica_index(self, list_size: int) -> int: return random.randint(1, list_size - 1) def _get_round_robin_index( - self, primary: str, list_size: int, replicas_only: bool + self, slot: int, list_size: int, replicas_only: bool ) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) + server_index = self.slot_to_idx.setdefault(slot, self.start_index) if replicas_only and server_index == 0: # skip the primary node index server_index = 1 # Update the index for the next round - self.primary_to_idx[primary] = (server_index + 1) % list_size + self.slot_to_idx[slot] = (server_index + 1) % list_size return server_index @@ -1575,9 +1575,8 @@ def get_node_from_slot( if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: # get the server index using the strategy defined in load_balancing_strategy - primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]), load_balancing_strategy + slot, len(self.slots_cache[slot]), load_balancing_strategy ) elif ( server_type is None @@ -1835,11 +1834,7 @@ def close(self) -> None: node.redis_connection.close() def reset(self): - try: - self.read_load_balancer.reset() - except TypeError: - # The read_load_balancer is None, do nothing - pass + pass def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: """ diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 25f487fe4c..fad8c797b2 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2445,33 +2445,31 @@ async def test_load_balancer(self, r: RedisCluster) -> None: slot_1: [node_1, node_2, node_3], slot_2: [node_4, node_5], } - primary1_name = n_manager.slots_cache[slot_1][0].name - primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary1_name, list1_size) == 1 - assert lb.get_server_index(primary1_name, list1_size) == 2 - assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 1 + assert lb.get_server_index(slot_1, list1_size) == 2 + assert lb.get_server_index(slot_1, list1_size) == 0 # slot 2 - assert lb.get_server_index(primary2_name, list2_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 1 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 1 + assert lb.get_server_index(slot_2, list2_size) == 0 lb.reset() - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 0 # reset the indexes before load balancing strategy test lb.reset() # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS for i in [1, 2, 1]: srv_index = lb.get_server_index( - primary1_name, + slot_1, list1_size, load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, ) @@ -2482,7 +2480,7 @@ async def test_load_balancer(self, r: RedisCluster) -> None: # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA for i in range(5): srv_index = lb.get_server_index( - primary1_name, + slot_1, list1_size, load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, ) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d360ab07f7..29a412caea 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2549,32 +2549,30 @@ def test_load_balancer(self, r): slot_1: [node_1, node_2, node_3], slot_2: [node_4, node_5], } - primary1_name = n_manager.slots_cache[slot_1][0].name - primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary1_name, list1_size) == 1 - assert lb.get_server_index(primary1_name, list1_size) == 2 - assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 1 + assert lb.get_server_index(slot_1, list1_size) == 2 + assert lb.get_server_index(slot_1, list1_size) == 0 # slot 2 - assert lb.get_server_index(primary2_name, list2_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 1 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 1 + assert lb.get_server_index(slot_2, list2_size) == 0 lb.reset() - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(slot_1, list1_size) == 0 + assert lb.get_server_index(slot_2, list2_size) == 0 # reset the indexes before load balancing strategy test lb.reset() # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS for i in [1, 2, 1]: srv_index = lb.get_server_index( - primary1_name, + slot_1, list1_size, load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, ) @@ -2585,7 +2583,7 @@ def test_load_balancer(self, r): # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA for i in range(5): srv_index = lb.get_server_index( - primary1_name, + slot_1, list1_size, load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, )