From ce194e33e08e454318c58483e5bee2b1ffdb9421 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Aug 2024 17:09:27 -0400 Subject: [PATCH 01/14] Update (base update) [ghstack-poisoned] --- test/test_recipe.py | 13 - test/test_specs.py | 47 +++ test/test_storage_map.py | 249 +++++++++++++ torchrl/data/__init__.py | 12 + torchrl/data/map/__init__.py | 9 + torchrl/data/map/hash.py | 183 ++++++++++ torchrl/data/map/query.py | 197 +++++++++++ torchrl/data/map/tdstorage.py | 323 +++++++++++++++++ torchrl/data/map/tree.py | 447 ++++++++++++++++++++++++ torchrl/data/replay_buffers/storages.py | 71 +++- torchrl/data/tensor_specs.py | 120 ++++++- 11 files changed, 1641 insertions(+), 30 deletions(-) delete mode 100644 test/test_recipe.py create mode 100644 test/test_storage_map.py create mode 100644 torchrl/data/map/__init__.py create mode 100644 torchrl/data/map/hash.py create mode 100644 torchrl/data/map/query.py create mode 100644 torchrl/data/map/tdstorage.py create mode 100644 torchrl/data/map/tree.py diff --git a/test/test_recipe.py b/test/test_recipe.py deleted file mode 100644 index 5387a23f503..00000000000 --- a/test/test_recipe.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import argparse - -import pytest - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 2d597d770f0..2a47d2680b9 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3740,6 +3740,53 @@ def test_device_ordinal(): assert spec.device == torch.device("cuda:0") +class TestSpecEnumerate: + def test_discrete(self): + spec = DiscreteTensorSpec(n=5, shape=(3,)) + assert ( + spec.enumerate() + == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) + ).all() + + def test_one_hot(self): + spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) + assert ( + spec.enumerate() + == torch.tensor( + [ + [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + [[0, 1, 0, 0, 0], [0, 1, 0, 0, 0]], + [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], + [[0, 0, 0, 1, 0], [0, 0, 0, 1, 0]], + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 1]], + ], + dtype=torch.bool, + ) + ).all() + + def test_multi_discrete(self): + spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) + enum = spec.enumerate() + assert enum.shape == torch.Size([60, 2, 3]) + + def test_multi_onehot(self): + spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) + enum = spec.enumerate() + assert enum.shape == torch.Size([60, 2, 12]) + + def test_composite(self): + c = CompositeSpec( + { + "a": OneHotDiscreteTensorSpec(n=5, shape=(3, 5)), + ("b", "c"): DiscreteTensorSpec(n=4, shape=(3,)), + }, + shape=[3], + ) + c_enum = c.enumerate() + assert c_enum.shape == torch.Size((20, 3)) + assert c_enum["b"].shape == torch.Size((20, 3)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_storage_map.py b/test/test_storage_map.py new file mode 100644 index 00000000000..90aaa5c7dda --- /dev/null +++ b/test/test_storage_map.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import functools +import importlib.util + +import pytest + +import torch + +from tensordict import TensorDict +from torchrl.data import LazyTensorStorage, ListStorage +from torchrl.data.map import ( + BinaryToDecimal, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, +) +from torchrl.envs import GymEnv + +_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( + "gym", None +) + + +class TestHash: + def test_binary_to_decimal(self): + binary_to_decimal = BinaryToDecimal( + num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True + ) + binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) + decimal = binary_to_decimal(binary) + + assert decimal.shape == (2,) + assert (decimal == torch.Tensor([3, 2])).all() + + def test_sip_hash(self): + a = torch.rand((3, 2)) + b = a.clone() + hash_module = SipHash(as_tensor=True) + hash_a = torch.tensor(hash_module(a)) + hash_b = torch.tensor(hash_module(b)) + assert (hash_a == hash_b).all() + + @pytest.mark.parametrize("n_components", [None, 14]) + @pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000]) + def test_randomprojection_hash(self, n_components, scale): + torch.manual_seed(0) + r = RandomProjectionHash(n_components=n_components) + x = torch.randn(10000, 100).mul_(scale) + y = r(x) + if n_components is None: + assert r.n_components == r._N_COMPONENTS_DEFAULT + else: + assert r.n_components == n_components + + assert y.shape == (10000,) + assert y.unique().numel() == y.numel() + + +class TestQuery: + def test_query_construct(self): + query_module = QueryModule( + in_keys=[(("key1",),), (("another",), "key2")], + index_key=("some", ("_index",)), + hash_module=SipHash(), + clone=False, + ) + assert not query_module.clone + assert query_module.in_keys == ["key1", ("another", "key2")] + assert query_module.index_key == ("some", "_index") + assert isinstance(query_module.hash_module, dict) + assert isinstance( + query_module.aggregator, + type(query_module.hash_module[query_module.in_keys[0]]), + ) + query_module = QueryModule( + in_keys=[(("key1",),), (("another",), "key2")], + index_key=("some", ("_index",)), + hash_module=SipHash(), + clone=False, + aggregator=SipHash(), + ) + # assert not isinstance(query_module.aggregator is not query_module.hash_module[0] + assert isinstance(query_module.aggregator, SipHash) + query_module = QueryModule( + in_keys=[(("key1",),), (("another",), "key2")], + index_key=("some", ("_index",)), + hash_module=[SipHash(), SipHash()], + clone=False, + ) + # assert query_module.aggregator is not query_module.hash_module[0] + assert isinstance(query_module.aggregator, SipHash) + + @pytest.mark.parametrize("index_key", ["index", ("another", "index")]) + @pytest.mark.parametrize("clone", [True, False]) + def test_query(self, clone, index_key): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key=index_key, + hash_module=SipHash(), + clone=clone, + ) + + query = TensorDict( + { + "key1": torch.Tensor([[1], [1], [1], [2]]), + "key2": torch.Tensor([[3], [3], [2], [3]]), + }, + batch_size=(4,), + ) + res = query_module(query) + if clone: + assert res is not query + else: + assert res is query + assert index_key in res + + assert res[index_key][0] == res[index_key][1] + for i in range(1, 3): + assert res[index_key][i].item() != res[index_key][i + 1].item() + + def test_query_module(self): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=SipHash(), + ) + + embedding_storage = LazyTensorStorage(23) + + tensor_dict_storage = TensorDictMap( + query_module=query_module, + storage=embedding_storage, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + + tensor_dict_storage[index] = value + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() + + +class TesttTensorDictMap: + @pytest.mark.parametrize( + "storage_type", + [ + functools.partial(ListStorage, 1000), + functools.partial(LazyTensorStorage, 1000), + ], + ) + def test_map(self, storage_type): + query_module = QueryModule( + in_keys=["key1", "key2"], + index_key="index", + hash_module=SipHash(), + ) + + embedding_storage = storage_type() + + tensor_dict_storage = TensorDictMap( + query_module=query_module, + storage=embedding_storage, + ) + + index = TensorDict( + { + "key1": torch.Tensor([[-1], [1], [3], [-3]]), + "key2": torch.Tensor([[0], [2], [4], [-4]]), + }, + batch_size=(4,), + ) + + value = TensorDict( + {"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ) + assert not hasattr(tensor_dict_storage, "out_keys") + + tensor_dict_storage[index] = value + if isinstance(embedding_storage, LazyTensorStorage): + assert hasattr(tensor_dict_storage, "out_keys") + else: + assert not hasattr(tensor_dict_storage, "out_keys") + assert tensor_dict_storage._has_lazy_out_keys() + assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + + new_index = index.clone(True) + new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + retrieve_value = tensor_dict_storage[new_index] + + assert (retrieve_value["index"] == value["index"]).all() + + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + def test_map_rollout(self): + torch.manual_seed(0) + env = GymEnv("CartPole-v1") + env.set_seed(0) + rollout = env.rollout(100) + source, dest = rollout.exclude("next"), rollout.get("next") + storage = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=["observation", "action"], + ) + storage_indices = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=["observation"], + out_keys=["_index"], + ) + # maps the (obs, action) tuple to a corresponding next state + storage[source] = dest + storage_indices[source] = source + contains = storage.contains(source) + assert len(contains) == rollout.shape[-1] + assert contains.all() + contains = storage.contains(torch.cat([source, source + 1])) + assert len(contains) == rollout.shape[-1] * 2 + assert contains[: rollout.shape[-1]].all() + assert not contains[rollout.shape[-1] :].any() + +class TestMCTSForest: + def test_forest_build(self): + ... + def test_forest_extend_and_get(self): + ... + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 3749e6e8cbc..5641b17dd75 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -3,6 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .map import ( + BinaryToDecimal, + HashToInt, + MCTSChildren, + MCTSForest, + MCTSNode, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, + TensorMap, +) from .postprocs import MultiStep from .replay_buffers import ( Flat2TED, diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py new file mode 100644 index 00000000000..7ef1f61a845 --- /dev/null +++ b/torchrl/data/map/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hash import BinaryToDecimal, RandomProjectionHash, SipHash +from .query import HashToInt, QueryModule +from .tdstorage import TensorDictMap, TensorMap +from .tree import MCTSChildren, MCTSForest, MCTSNode diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py new file mode 100644 index 00000000000..f5ba93e900f --- /dev/null +++ b/torchrl/data/map/hash.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Callable, List + +import torch + + +class BinaryToDecimal(torch.nn.Module): + """A Module to convert binaries encoded tensors to decimals. + + This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to + its decimal value (e.g. `9`) + + Args: + num_bits (int): the number of bits to use for the bases table. + The number of bits must be lower or equal to the input length and the input length + must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of + bits in the input, the end result will be aggregated on the last dimension using + :func:`~torch.sum`. + device (torch.device): the device where inputs and outputs are to be expected. + dtype (torch.dtype): the output dtype. + convert_to_binary (bool, optional): if ``True``, the input to the ``forward`` + method will be cast to a binary input using :func:`~torch.heavyside`. + Defaults to ``False``. + + Examples: + >>> binary_to_decimal = BinaryToDecimal( + ... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True + ... ) + >>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]]) + >>> decimal = binary_to_decimal(binary) + >>> assert decimal.shape == (2,) + >>> assert (decimal == torch.Tensor([3, 2])).all() + """ + + def __init__( + self, + num_bits: int, + device: torch.device, + dtype: torch.dtype, + convert_to_binary: bool = False, + ): + super().__init__() + self.convert_to_binary = convert_to_binary + self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype) + self.num_bits = num_bits + self.zero_tensor = torch.zeros((1,), device=device) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + num_features = features.shape[-1] + if self.num_bits > num_features: + raise ValueError(f"{num_features=} is less than {self.num_bits=}") + elif num_features % self.num_bits != 0: + raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}") + + binary_features = ( + torch.heaviside(features, self.zero_tensor) + if self.convert_to_binary + else features + ) + feature_parts = binary_features.reshape(shape=(-1, self.num_bits)) + digits = torch.vmap(torch.dot, (None, 0))( + self.bases, feature_parts.to(self.bases.dtype) + ) + digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits)) + aggregated_digits = torch.sum(digits, dim=-1) + return aggregated_digits + + +class SipHash(torch.nn.Module): + """A Module to Compute SipHash values for given tensors. + + A hash function module based on SipHash implementation in python. + + Args: + as_tensor (bool, optional): if ``True``, the bytes will be turned into integers + through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. + + .. warning:: This module relies on the builtin ``hash`` function. + To get reproducible results across runs, the ``PYTHONHASHSEED`` environment + variable must be set before the code is run (changing this value during code + execution is without effect). + + Examples: + >>> # Assuming we set PYTHONHASHSEED=0 prior to running this code + >>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + >>> b = a.clone() + >>> hash_module = SipHash(as_tensor=True) + >>> hash_a = hash_module(a) + >>> hash_a + tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521]) + >>> hash_b = hash_module(b) + >>> assert (hash_a == hash_b).all() + """ + + def __init__(self, as_tensor: bool = True): + super().__init__() + self.as_tensor = as_tensor + + def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: + hash_values = [] + if x.dtype in (torch.bfloat16,): + x = x.to(torch.float16) + for x_i in x.detach().cpu().numpy(): + hash_value = x_i.tobytes() + hash_values.append(hash_value) + if not self.as_tensor: + return hash_value + result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64) + return result + + +class RandomProjectionHash(SipHash): + """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`. + + This module requires sklearn to be installed. + + Keyword Args: + n_components (int, optional): the low-dimensional number of components of the projections. + Defaults to 16. + dtype_cast (torch.dtype, optional): the dtype to cast the projection to. + Defaults to ``torch.bfloat16``. + as_tensor (bool, optional): if ``True``, the bytes will be turned into integers + through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. + + .. warning:: This module relies on the builtin ``hash`` function. + To get reproducible results across runs, the ``PYTHONHASHSEED`` environment + variable must be set before the code is run (changing this value during code + execution is without effect). + + init_method: TODO + """ + + _N_COMPONENTS_DEFAULT = 16 + + def __init__( + self, + *, + n_components: int | None = None, + dtype_cast=torch.bfloat16, + as_tensor: bool = True, + init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None, + **kwargs, + ): + if n_components is None: + n_components = self._N_COMPONENTS_DEFAULT + + super().__init__(as_tensor=as_tensor) + self.register_buffer("_n_components", torch.as_tensor(n_components)) + + self._init = False + if init_method is None: + init_method = torch.nn.init.normal_ + self.init_method = init_method + + self.dtype_cast = dtype_cast + self.register_buffer("transform", torch.nn.UninitializedBuffer()) + + @property + def n_components(self): + return self._n_components.item() + + def fit(self, x): + """Fits the random projection to the input data.""" + self.transform.materialize( + (x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device + ) + self.init_method(self.transform) + self._init = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self._init: + self.fit(x) + elif not self._init: + raise RuntimeError( + f"The {type(self).__name__} has not been initialized. Call fit before calling this method." + ) + x = x.to(self.dtype_cast) @ self.transform + return super().forward(x) diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py new file mode 100644 index 00000000000..73680999c8c --- /dev/null +++ b/torchrl/data/map/query.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Mapping, TypeVar + +import torch +import torch.nn as nn +from tensordict import NestedKey, TensorDictBase +from tensordict.nn.common import TensorDictModuleBase +from torchrl._utils import logger as torchrl_logger +from torchrl.data.map import SipHash + +K = TypeVar("K") +V = TypeVar("V") + + +class HashToInt(nn.Module): + """Converts a hash value to an integer that can be used for indexing a contiguous storage.""" + + def __init__(self): + super().__init__() + self._index_to_index = {} + + def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor: + result = [] + if extend: + for _item in key.tolist(): + result.append( + self._index_to_index.setdefault(_item, len(self._index_to_index)) + ) + else: + for _item in key.tolist(): + result.append( + self._index_to_index.get(_item, len(self._index_to_index)) + ) + return torch.tensor(result, device=key.device, dtype=key.dtype) + + def state_dict(self) -> Dict[str, torch.Tensor]: + values = torch.tensor(self._index_to_index.values()) + keys = torch.tensor(self._index_to_index.keys()) + return {"keys": keys, "values": values} + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + keys = state_dict["keys"] + values = state_dict["values"] + self._index_to_index = { + key: val for key, val in zip(keys.tolist(), values.tolist()) + } + + +class QueryModule(TensorDictModuleBase): + """A Module to generate compatible indices for storage. + + A module that queries a storage and return required index of that storage. + Currently, it only outputs integer indices (torch.int64). + + Args: + in_keys (list of NestedKeys): keys of the input tensordict that + will be used to generate the hash value. + index_key (NestedKey): the output key where the index value will be written. + Defaults to ``"_index"``. + + Keyword Args: + hash_key (NestedKey): the output key where the hash value will be written. + Defaults to ``"_hash"``. + hash_module (Callable[[Any], int] or a list of these, optional): a hash + module similar to :class:`~tensordict.nn.SipHash` (default). + If a list of callables is provided, its length must equate the number of in_keys. + hash_to_int (Callable[[int], int], optional): a stateful function that + maps a hash value to a non-negative integer corresponding to an index in a + storage. Defaults to :class:`~torchrl.data.map.HashToInt`. + aggregator (Callable[[int], int], optional): a hash function to group multiple hashes + together. This argument should only be passed when there is more than one ``in_keys``. + If a single ``hash_module`` is provided but no aggregator is passed, it will take + the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is + provided but no aggregator is passed, it will default to ``SipHash``. + clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be + returned. This can be used to retrieve the integer index within the storage, + corresponding to a given input tensordict. + Defaults to ``False``. + d + Examples: + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... hash_module=SipHash(), + ... ) + >>> query = TensorDict( + ... { + ... "key1": torch.Tensor([[1], [1], [1], [2]]), + ... "key2": torch.Tensor([[3], [3], [2], [3]]), + ... "other": torch.randn(4), + ... }, + ... batch_size=(4,), + ... ) + >>> res = query_module(query) + >>> # The first two pairs of key1 and key2 match + >>> assert res["index"][0] == res["index"][1] + >>> # The last three pairs of key1 and key2 have at least one mismatching value + >>> assert res["index"][1] != res["index"][2] + >>> assert res["index"][2] != res["index"][3] + """ + + def __init__( + self, + in_keys: List[NestedKey], + index_key: NestedKey = "_index", + hash_key: NestedKey = "_hash", + *, + hash_module: Callable[[Any], int] | List[Callable[[Any], int]] | None = None, + hash_to_int: Callable[[int], int] | None = None, + aggregator: Callable[[Any], int] = None, + clone: bool = False, + ): + if len(in_keys) == 0: + raise ValueError("`in_keys` cannot be empty.") + in_keys = in_keys if isinstance(in_keys, List) else [in_keys] + + super().__init__() + in_keys = self.in_keys = in_keys + self.out_keys = [index_key, hash_key] + index_key = self.out_keys[0] + self.hash_key = self.out_keys[1] + + if aggregator is not None and len(self.in_keys) == 1: + torchrl_logger.warn( + "An aggregator was provided but there is only one in-key to be read. " + "This module will be ignored." + ) + elif aggregator is None: + if hash_module is not None and not isinstance(hash_module, list): + aggregator = hash_module + else: + aggregator = SipHash() + if hash_module is None: + hash_module = [SipHash() for _ in range(len(self.in_keys))] + elif not isinstance(hash_module, list): + try: + hash_module = [ + deepcopy(hash_module) if len(self.in_keys) > 1 else hash_module + for _ in range(len(self.in_keys)) + ] + except Exception as err: + raise RuntimeError( + "failed to deepcopy the hash module. Please provide a list of hash modules instead." + ) from err + elif len(hash_module) != len(self.in_keys): + raise ValueError( + "The number of hash_modules must match the number of in_keys. " + f"Got {len(hash_module)} hash modules but {len(in_keys)} in_keys." + ) + if hash_to_int is None: + hash_to_int = HashToInt() + + self.aggregator = aggregator + self.hash_module = dict(zip(self.in_keys, hash_module)) + self.hash_to_int = hash_to_int + + self.index_key = index_key + self.clone = clone + + def forward( + self, + tensordict: TensorDictBase, + extend: bool = True, + write_hash: bool = True, + ) -> TensorDictBase: + hash_values = [] + + for k in self.in_keys: + hash_values.append(self.hash_module[k](tensordict.get(k))) + if len(self.in_keys) > 1: + hash_values = torch.stack( + hash_values, + dim=-1, + ) + hash_values = self.aggregator(hash_values) + else: + hash_values = hash_values[0] + + td_hash_value = self.hash_to_int(hash_values, extend=extend) + + if self.clone: + output = tensordict.copy() + else: + output = tensordict + + output.set(self.index_key, td_hash_value) + if write_hash: + output.set(self.hash_key, hash_values) + return output diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py new file mode 100644 index 00000000000..77dbc229b9e --- /dev/null +++ b/torchrl/data/map/tdstorage.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +import functools +from abc import abstractmethod +from typing import Any, Callable, Dict, Generic, List, TypeVar + +import torch +from tensordict import is_tensor_collection, NestedKey, TensorDictBase +from tensordict.nn.common import TensorDictModuleBase +from torchrl.data.map.hash import RandomProjectionHash, SipHash +from torchrl.data.map.query import QueryModule +from torchrl.data.replay_buffers.storages import ( + _get_default_collate, + LazyTensorStorage, + TensorStorage, +) + +K = TypeVar("K") +V = TypeVar("V") + + +class TensorMap(abc.ABC, Generic[K, V]): + """An Abstraction for implementing different storage. + + This class is for internal use, please use derived classes instead. + """ + + @abstractmethod + def clear(self) -> None: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, item: K) -> V: + raise NotImplementedError + + @abstractmethod + def __setitem__(self, key: K, value: V) -> None: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def contains(self, item: K) -> torch.Tensor: + raise NotImplementedError + + def __contains__(self, item): + return self.contains(item) + + +class TensorDictMap( + TensorDictModuleBase, TensorMap[TensorDictModuleBase, TensorDictModuleBase] +): + """A Map-Storage for TensorDict. + + This module resembles a storage. It takes a tensordict as its input and + returns another tensordict as output similar to TensorDictModuleBase. However, + it provides additional functionality like python map: + + Keyword Args: + query_module (TensorDictModuleBase): a query module, typically an instance of + :class:`~tensordict.nn.QueryModule`, used to map a set of tensordict + entries to a hash key. + storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]): + a dictionary representing the map from an index key to a tensor storage. + collate_fn (callable, optional): a function to use to collate samples from the + storage. Defaults to a custom value for each known storage type (stack for + :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` + subtypes and others). + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from typing import cast + >>> from torchrl.data import LazyTensorStorage + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... ) + >>> embedding_storage = LazyTensorStorage(1000) + >>> tensor_dict_storage = TensorDictMap( + ... query_module=query_module, + ... storage={"out": embedding_storage}, + ... ) + >>> index = TensorDict( + ... { + ... "key1": torch.Tensor([[-1], [1], [3], [-3]]), + ... "key2": torch.Tensor([[0], [2], [4], [-4]]), + ... }, + ... batch_size=(4,), + ... ) + >>> value = TensorDict( + ... {"out": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,) + ... ) + >>> tensor_dict_storage[index] = value + >>> tensor_dict_storage[index] + TensorDict( + fields={ + out: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> assert torch.sum(tensor_dict_storage.contains(index)).item() == 4 + >>> new_index = index.clone(True) + >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) + >>> retrieve_value = tensor_dict_storage[new_index] + >>> assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all() + """ + + def __init__( + self, + *, + query_module: QueryModule, + storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + collate_fn: Callable[[Any], Any] | None = None, + out_keys: List[NestedKey] | None = None, + write_fn: Callable[[Any, Any], Any] | None = None, + ): + super().__init__() + + self.in_keys = query_module.in_keys + if out_keys is not None: + self.out_keys = out_keys + assert not self._has_lazy_out_keys() + + self.query_module = query_module + self.index_key = query_module.index_key + self.storage = storage + self.batch_added = False + if collate_fn is None: + collate_fn = _get_default_collate(self.storage) + self.collate_fn = collate_fn + self.write_fn = write_fn + + @property + def out_keys(self) -> List[NestedKey]: + out_keys = self.__dict__.get("_out_keys_and_lazy") + if out_keys is not None: + return out_keys[0] + storage = self.storage + if isinstance(storage, TensorStorage) and is_tensor_collection( + storage._storage + ): + out_keys = list(storage._storage.keys(True, True)) + self._out_keys_and_lazy = (out_keys, True) + return self.out_keys + raise AttributeError( + f"No out-keys found in the storage of type {type(storage)}" + ) + + @out_keys.setter + def out_keys(self, value): + self._out_keys_and_lazy = (value, False) + + def _has_lazy_out_keys(self): + _out_keys_and_lazy = self.__dict__.get("_out_keys_and_lazy") + if _out_keys_and_lazy is None: + return True + return self._out_keys_and_lazy[1] + + @classmethod + def from_tensordict_pair( + cls, + source, + dest, + in_keys: List[NestedKey], + out_keys: List[NestedKey] | None = None, + storage_constructor: type | None = None, + hash_module: Callable | None = None, + collate_fn: Callable[[Any], Any] | None = None, + write_fn: Callable[[Any, Any], Any] | None = None, + ): + """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. + + Args: + source (TensorDict): An example of source tensordict, used as index in the storage. + dest (TensorDict): An example of dest tensordict, used as data in the storage. + in_keys (List[NestedKey]): a list of keys to use in the map. + out_keys (List[NestedKey]): a list of keys to return in the output tensordict. + All keys absent from out_keys, even if present in ``dest``, will not be stored + in the storage. Defaults to ``None`` (all keys are registered). + storage_constructor (type, optional): a type of tensor storage. + Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. + Other options include :class:`~tensordict.nn.storage.FixedStorage`. + hash_module (Callable, optional): a hash function to use in the :class:`~tensordict.nn.storage.QueryModule`. + Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~tensordict.nn.storage.RandomProjectionHash` + for larger inputs. + collate_fn (callable, optional): a function to use to collate samples from the + storage. Defaults to a custom value for each known storage type (stack for + :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` + subtypes and others). + + Examples: + >>> # The following example requires torchrl and gymnasium to be installed + >>> from torchrl.envs import GymEnv + >>> torch.manual_seed(0) + >>> env = GymEnv("CartPole-v1") + >>> env.set_seed(0) + >>> rollout = env.rollout(100) + >>> source, dest = rollout.exclude("next"), rollout.get("next") + >>> storage = TensorDictMap.from_tensordict_pair( + ... source, dest, + ... in_keys=["observation", "action"], + ... ) + >>> # maps the (obs, action) tuple to a corresponding next state + >>> storage[source] = dest + >>> print(source["_index"]) + tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) + >>> storage[source] + TensorDict( + fields={ + done: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([14, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([14]), + device=None, + is_shared=False) + + """ + # Build query module + if hash_module is None: + # Count the features, if they're greater than RandomProjectionHash._N_COMPONENTS_DEFAULT + # use that module to project them to that dimensionality. + n_feat = 0 + hash_module = [] + for in_key in in_keys: + n_feat = source[in_key].shape[-1] + if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: + _hash_module = RandomProjectionHash() + else: + _hash_module = SipHash() + hash_module.append(_hash_module) + query_module = QueryModule(in_keys, hash_module=hash_module) + + # Build key_to_storage + if storage_constructor is None: + storage_constructor = functools.partial(LazyTensorStorage, 1000) + storage = storage_constructor() + result = cls( + query_module=query_module, + storage=storage, + collate_fn=collate_fn, + out_keys=out_keys, + write_fn=write_fn, + ) + return result + + def clear(self) -> None: + for mem in self.storage.values(): + mem.clear() + + def _to_index(self, item: TensorDictBase, extend: bool) -> torch.Tensor: + item = self.query_module(item, extend=extend) + return item[self.index_key] + + def _maybe_add_batch( + self, item: TensorDictBase, value: TensorDictBase | None + ) -> TensorDictBase: + self.batch_added = False + if len(item.batch_size) == 0: + self.batch_added = True + + item = item.unsqueeze(dim=0) + if value is not None: + value = value.unsqueeze(dim=0) + + return item, value + + def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: + if self.batch_added: + item = item.squeeze(dim=0) + return item + + def __getitem__(self, item: TensorDictBase) -> TensorDictBase: + item, _ = self._maybe_add_batch(item, None) + + index = self._to_index(item, extend=False) + + res = self.storage[index] + res = self.collate_fn(res) + res = self._maybe_remove_batch(res) + return res + + def __setitem__(self, item: TensorDictBase, value: TensorDictBase): + if not self._has_lazy_out_keys(): + # TODO: make this work with pytrees and avoid calling select if keys match + value = value.select(*self.out_keys, strict=False) + if self.write_fn is not None: + if len(self): + modifiable = self.contains(item) + if modifiable.any(): + to_modify = (value[modifiable], self[item[modifiable]]) + v1 = self.write_fn(*to_modify) + result = value.empty() + result[modifiable] = v1 + result[~modifiable] = self.write_fn(value[~modifiable]) + value = result + else: + value = self.write_fn(value) + else: + value = self.write_fn(value) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + self.storage.set(index, value) + + def __len__(self): + return len(self.storage) + + def contains(self, item: TensorDictBase) -> torch.Tensor: + item, _ = self._maybe_add_batch(item, None) + index = self._to_index(item, extend=False) + + res = self.storage.contains(index) + res = self._maybe_remove_batch(res) + return res diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py new file mode 100644 index 00000000000..50d7ef75a15 --- /dev/null +++ b/torchrl/data/map/tree.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import List + +import torch +from tensordict import LazyStackedTensorDict, NestedKey, tensorclass, TensorDict +from torchrl.data import ListStorage, TensorDictMap +from torchrl.envs import EnvBase + + +@tensorclass +class MCTSNode: + """An MCTS node. + + The batch-size of a root node is indicative of the batch-size of the tree: + each indexed element of a ``Node`` corresponds to a separate tree. + + A node is characterized by its data (a tensordict with keys such as ``"observation"``, + or ``"done"``), a ``children`` field containing all the branches from that node + (one per action taken), and a ``count`` tensor indicating how many times this node + has been visited. + + """ + + data_content: TensorDict + children: MCTSChildren | None = None + count: torch.Tensor | None = None + + +@tensorclass +class MCTSChildren: + """The children of a node. + + This class contains data of the same batch-size: the ``action``, ``reward``, ``index`` and ``hash`` + associated with each ``node``. Therefore, each indexed element of a ``Children`` + corresponds to one child with its associated action, reward and index. + + """ + + node: MCTSNode + action: torch.Tensor | None = None + reward: torch.Tensor | None = None + index: torch.Tensor | None = None + hash: torch.Tensor | None = None + + +class MCTSForest: + """A collection of MCTS trees. + + The class is aimed at storing rollouts in a storage, and produce trees based on a given root + in that dataset. + + Keyword Args: + data_map (TensorDictMap, optional): the storage to use to store the data + (observation, reward, states etc). If not provided, it is lazily + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. + data_map (TensorDictMap, optional): the storage to use to store the data + (observation, reward, states etc). If not provided, it is lazily + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. + done_keys (list of NestedKey): the done keys of the environment. If not provided, + defaults to ``("done", "terminated", "truncated")``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + action_keys (list of NestedKey): the action keys of the environment. If not provided, + defaults to ``("action",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + reward_keys (list of NestedKey): the reward keys of the environment. If not provided, + defaults to ``("reward",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + observation_keys (list of NestedKey): the observation keys of the environment. If not provided, + defaults to ``("observation",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + + """ + + def __init__( + self, + *, + data_map: TensorDictMap | None = None, + node_map: TensorDictMap | None = None, + done_keys: List[NestedKey] | None = None, + reward_keys: List[NestedKey] = None, + observation_keys: List[NestedKey] = None, + action_keys: List[NestedKey] = None, + ): + + self.data_map = data_map + + self.node_map = node_map + + self.done_keys = done_keys + self.action_keys = action_keys + self.reward_keys = reward_keys + self.observation_keys = observation_keys + + @property + def done_keys(self): + done_keys = getattr(self, "_done_keys", None) + if done_keys is None: + self._done_keys = done_keys = ("done", "terminated", "truncated") + return done_keys + + @done_keys.setter + def done_keys(self, value): + self._done_keys = value + + @property + def reward_keys(self): + reward_keys = getattr(self, "_reward_keys", None) + if reward_keys is None: + self._reward_keys = reward_keys = ("reward",) + return reward_keys + + @reward_keys.setter + def reward_keys(self, value): + self._reward_keys = value + + @property + def action_keys(self): + action_keys = getattr(self, "_action_keys", None) + if action_keys is None: + self._action_keys = action_keys = ("action",) + return action_keys + + @action_keys.setter + def action_keys(self, value): + self._action_keys = value + + @property + def observation_keys(self): + observation_keys = getattr(self, "_observation_keys", None) + if observation_keys is None: + self._observation_keys = observation_keys = ("observation",) + return observation_keys + + @observation_keys.setter + def observation_keys(self, value): + self._observation_keys = value + + def get_keys_from_env(self, env: EnvBase): + """Writes missing done, action and reward keys to the Forest given an environment. + + Existing keys are not overwritten. + """ + if getattr(self, "_reward_keys", None) is None: + self.reward_keys = env.reward_keys + if getattr(self, "_done_keys", None) is None: + self.done_keys = env.done_keys + if getattr(self, "_action_keys", None) is None: + self.action_keys = env.action_keys + if getattr(self, "_observation_keys", None) is None: + self.observation_keys = env.observation_keys + + @classmethod + def _write_fn_stack(cls, new, old=None): + if old is None: + result = new.apply(lambda x: x.unsqueeze(0)) + result.set( + "count", torch.ones(result.shape, dtype=torch.int, device=result.device) + ) + else: + + def cat(name, x, y): + if name == "count": + return x + if y.ndim < x.ndim: + y = y.unsqueeze(0) + result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + return result + + result = old.named_apply(cat, new, default=None) + result.set_("count", old.get("count") + 1) + return result + + def _make_storage(self, source, dest): + self.data_map = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=[*self.observation_keys, *self.action_keys], + ) + + def _make_storage_branches(self, source, dest): + self.node_map = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=[*self.observation_keys], + out_keys=[ + *self.data_map.query_module.out_keys, + *self.action_keys, + *[("next", rk) for rk in self.reward_keys], + "count", + ], + storage_constructor=ListStorage, + collate_fn=TensorDict.lazy_stack, + write_fn=self._write_fn_stack, + ) + + def extend(self, rollout): + source, dest = rollout, rollout.get("next") + if self.data_map is None: + self._make_storage(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_storage_branches(source, dest) + self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + + def get_child(self, root): + return self.data_map[root] + + def get_tree( + self, + root, + *, + inplace: bool = False, + recurse: bool = True, + max_depth: int | None = None, + as_tensordict: bool = False, + ): + if root.batch_size: + func = self._get_tree_batched + else: + func = self._get_tree_single + return func( + root=root, + inplace=inplace, + recurse=recurse, + max_depth=max_depth, + as_tensordict=as_tensordict, + ) + + def _get_tree_single( + self, + root, + inplace: bool = False, + recurse: bool = True, + max_depth: int | None = None, + as_tensordict: bool = False, + ): + if root not in self.node_map: + if as_tensordict: + return TensorDict({"data_content": root}) + return MCTSNode(root) + branches = self.node_map[root] + + index = branches["_index"] + hash_val = branches["_hash"] + count = branches["count"] + action = ( + branches.select(*self.action_keys) + if len(self.action_keys) > 1 + else branches.get(*self.action_keys) + ) + reward = ( + branches.get("next").select(*self.reward_keys) + if len(self.reward_keys) > 1 + else branches.get(("next", *self.reward_keys)) + ) + + children_node = self.data_map.storage[index] + if not inplace: + root = root.copy() + if recurse: + children_node = children_node.unbind(0) + children_node = tuple( + self.get_tree( + child, + inplace=inplace, + max_depth=max_depth - 1 if isinstance(max_depth, int) else None, + ) + for child in children_node + ) + if not as_tensordict: + children_node = LazyStackedTensorDict( + *(child._tensordict for child in children_node) + ) + children_node = MCTSNode.from_tensordict(children_node) + else: + children_node = LazyStackedTensorDict(*children_node) + if not as_tensordict: + return MCTSNode( + data_content=root, + children=MCTSChildren( + node=children_node, + action=action, + index=index, + hash=hash_val, + reward=reward, + batch_size=children_node.batch_size, + ), + count=count, + ) + return TensorDict( + { + "data_content": root, + "children": TensorDict( + { + "node": children_node, + "action": action, + "index": index, + "hash": hash_val, + "reward": reward, + }, + batch_sizde=children_node.batch_size, + ), + "count": count, + } + ) + + def _get_tree_batched( + self, + root, + inplace: bool = False, + recurse: bool = True, + max_depth: int | None = None, + as_tensordict: bool = False, + ): + present = self.node_map.contains(root) + if not present.any(): + if as_tensordict: + return TensorDict({"data_content": root}, batch_size=root.batch_size) + return MCTSNode(root, batch_size=root.batch_size) + if present.all(): + root_present = root + else: + root_present = root[present] + branches = self.node_map[root_present] + index = branches.get_nestedtensor("_index", layout=torch.jagged) + hash_val = branches.get_nestedtensor("_hash", layout=torch.jagged) + count = branches.get("count") + + children_node = self.data_map.storage[index.values()] + if not root_present.all(): + children_node = LazyStackedTensorDict( + *children_node.split(index.offsets().diff().tolist()) + ) + for idx in (~present).nonzero(as_tuple=True)[0].tolist(): + children_node.insert(idx, TensorDict()) # TODO: replace with new_zero + if not any(d == -1 for d in children_node.batch_size): + action = ( + branches.get(*self.action_keys) + if len(self.action_keys) == 1 + else branches.select(*self.action_keys) + ) + reward = ( + branches.get(("next", *self.reward_keys)) + if len(self.reward_keys) == 1 + else branches.get("next").select(*self.reward_keys) + ) + else: + if len(self.action_keys) == 1: + action = branches.get_nestedtensor( + *self.action_keys, layout=torch.jagged + ) + else: + action = branches.select(*self.action_keys) + if len(self.reward_keys) == 1: + reward = branches.get_nestedtensor( + ("next", *self.reward_keys), layout=torch.jagged + ) + else: + reward = branches.get("next").select(*self.reward_keys) + + if not inplace: + root = root.copy() + if recurse: + children_node = children_node.unbind(0) + children_node = tuple( + self.get_tree( + child, + inplace=inplace, + max_depth=max_depth - 1 if isinstance(max_depth, int) else None, + ) + if present[i] + else child + for i, child in enumerate(children_node) + ) + children = TensorDict.lazy_stack( + [ + TensorDict( + { + "node": _children_node, + "action": _action, + "index": _index, + "hash": _hash_val, + "reward": _reward, + }, + batch_size=_children_node.batch_size, + ) + for (_children_node, _action, _index, _hash_val, _reward) in zip( + children_node, + action.unbind(0), + index.unbind(0), + hash_val.unbind(0), + reward.unbind(0), + ) + ] + ) + if not as_tensordict: + return MCTSNode( + data_content=root, + children=MCTSChildren._from_tensordict(children), + count=count, + batch_size=root.batch_size, + ) + return TensorDict( + { + "data_content": root, + "children": children, + "count": count, + }, + batch_size=root.batch_size, + ) + + def __len__(self): + return len(self.data_map) + + def plot(self, tree, backend="plotly"): + if backend == "plotly": + import plotly.graph_objects as go + + parents = [""] + labels = [ + f"{tree.data_content['_hash'].item()}, R={tree.data_content['next', 'reward'].item(): 4.4f}" + ] + _tree = tree + + def extend(tree, parent): + children = tree.children + if children is None: + return + for child in children: + labels.append(f"{child.hash.item()}, R={child.reward.item(): 4.4f}") + parents.append(parent) + extend(child.node, labels[-1]) + + extend(_tree, labels[-1]) + fig = go.Figure(go.Treemap(labels=labels, parents=parents)) + fig.show() diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 58b1729296d..eb399fbc029 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -185,6 +185,13 @@ def load(self, *args, **kwargs): """Alias for :meth:`~.loads`.""" return self.loads(*args, **kwargs) + def __contains__(self, item): + return self.contains(item) + + @abc.abstractmethod + def contains(self, item): + ... + class ListStorage(Storage): """A storage stored in a list. @@ -194,13 +201,16 @@ class ListStorage(Storage): (like lists, tuples, tensors or tensordicts with non-empty batch-size). Args: - max_size (int): the maximum number of elements stored in the storage. + max_size (int, optional): the maximum number of elements stored in the storage. + If not provided, an unlimited storage is created. """ _default_checkpointer = ListStorageCheckpointer - def __init__(self, max_size: int): + def __init__(self, max_size: int | None = None): + if max_size is None: + max_size = torch.iinfo(torch.int64).max super().__init__(max_size) self._storage = [] @@ -233,7 +243,7 @@ def set( np.ndarray, ), ): - for _cursor, _data in zip(cursor, data): + for _cursor, _data in zip(cursor, data, strict=True): self.set(_cursor, _data, set_cursor=set_cursor) else: raise TypeError( @@ -305,6 +315,20 @@ def __getstate__(self): def __repr__(self): return f"{self.__class__.__name__}(items=[{self._storage[0]}, ...])" + def contains(self, item): + if isinstance(item, int): + if item < 0: + item += len(self._storage) + + return 0 <= item < len(self._storage) + if isinstance(item, torch.Tensor): + return torch.tensor( + [self.contains(elt) for elt in item.tolist()], + dtype=torch.bool, + device=item.device, + ).reshape_as(item) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + class TensorStorage(Storage): """A storage for tensors and tensordicts. @@ -782,6 +806,30 @@ def repr_item(x): maxsize_str = textwrap.indent(f"max_size={self.max_size}", 4 * " ") return f"{self.__class__.__name__}(\n{storage_str}, \n{shape_str}, \n{len_str}, \n{maxsize_str})" + def contains(self, item): + if isinstance(item, int): + if item < 0: + item += self._len_along_dim0 + + return 0 <= item < self._len_along_dim0 + if isinstance(item, torch.Tensor): + + def _is_valid_index(idx): + try: + torch.zeros(self.shape, device="meta")[idx] + return True + except IndexError: + return False + + if item.ndim: + return torch.tensor( + [_is_valid_index(idx) for idx in item], + dtype=torch.bool, + device=item.device, + ) + return torch.tensor(_is_valid_index(item), device=item.device) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + class LazyTensorStorage(TensorStorage): """A pre-allocated tensor storage for tensors and tensordicts. @@ -1269,10 +1317,14 @@ def _collate_list_tensordict(x): return out -def _stack_anything(x): - if is_tensor_collection(x[0]): - return LazyStackedTensorDict.maybe_dense_stack(x) - return torch.stack(x) +def _stack_anything(data): + if is_tensor_collection(data[0]): + return LazyStackedTensorDict.maybe_dense_stack(data) + return torch.utils._pytree.tree_map( + lambda *x: torch.stack(x), + *data, + is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x), + ) def _collate_id(x): @@ -1281,10 +1333,7 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, ListStorage): - if _is_tensordict: - return _collate_list_tensordict - else: - return torch.utils.data._utils.collate.default_collate + return _stack_anything elif isinstance(storage, TensorStorage): return _collate_id else: diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7c787b3ccfc..2afd6b1f3d6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -756,6 +756,16 @@ def contains(self, item): """ return self.is_in(item) + @abc.abstractmethod + def enumerate(self): + """Returns all the samples that can be obtained from the TensorSpec. + + The samples will be stacked along the first dimension. + + This method is only implemented for discrete specs. + """ + ... + def project(self, val: torch.Tensor) -> torch.Tensor: """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. @@ -1152,6 +1162,11 @@ def __eq__(self, other): return False return True + def enumerate(self): + return torch.stack( + [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 + ) + def __len__(self): return self.shape[0] @@ -1601,6 +1616,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val + def enumerate(self): + return ( + torch.eye(self.n, dtype=self.dtype, device=self.device) + .expand(*self.shape, self.n) + .permute(-2, *range(self.ndimension() - 1), -1) + ) + def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: if not isinstance(index, torch.Tensor): raise ValueError( @@ -1832,6 +1854,11 @@ def __init__( domain=domain, ) + def enumerate(self): + raise NotImplementedError( + f"enumerate is not implemented for spec of class {type(self).__name__}." + ) + def __eq__(self, other): return ( type(other) == type(self) @@ -2107,6 +2134,9 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + def enumerate(self): + raise NotImplementedError("Cannot enumerate a NonTensorSpec.") + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -2273,6 +2303,9 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) + def enumerate(self): + raise NotImplementedError("enumerate cannot be called with continuous specs.") + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -2361,8 +2394,6 @@ class UnboundedDiscreteTensorSpec(TensorSpec): (should be an integer dtype such as long, uint8 etc.) """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -2409,6 +2440,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: return self return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + def enumerate(self): + raise NotImplementedError("Cannot enumerate an unbounded tensor spec.") + def clone(self) -> UnboundedDiscreteTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) @@ -2553,8 +2587,6 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Sequence[int], @@ -2586,6 +2618,18 @@ def __init__( ) self.update_mask(mask) + def enumerate(self): + nvec = self.nvec + enum_disc = self.to_categorical_spec().enumerate() + enums = torch.cat( + [ + torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype) + for nv, enum_unb in zip(nvec, enum_disc.unbind(-1)) + ], + -1, + ) + return enums + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -2975,6 +3019,12 @@ def __init__( ) self.update_mask(mask) + def enumerate(self): + arange = torch.arange(self.n, dtype=self.dtype, device=self.device) + if self.ndim: + arange = arange.view(-1, *(1,) * self.ndim) + return arange.expand(self.n, *self.shape) + @property def n(self): return self.space.n @@ -3428,6 +3478,29 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton + def enumerate(self): + if self.mask is not None: + raise RuntimeError( + "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." + ) + if self.nvec._base.ndim == 1: + nvec = self.nvec._base + else: + # we have to use unique() to isolate the nvec + nvec = self.nvec.view(-1, self.nvec.shape[-1]).unique(dim=0).squeeze(0) + if nvec.ndim > 1: + raise ValueError( + f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={nvec}." + ) + arange = torch.meshgrid( + *[torch.arange(n, device=self.device, dtype=self.dtype) for n in nvec], + indexing="ij", + ) + arange = torch.stack([arange_.reshape(-1) for arange_ in arange], dim=-1) + arange = arange.view(arange.shape[0], *(1,) * (self.ndim - 1), self.shape[-1]) + arange = arange.expand(arange.shape[0], *self.shape) + return arange + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3646,6 +3719,8 @@ def to_one_hot( def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" + if self.ndim > 1: + return torch.stack([spec.to_one_hot_spec() for spec in self.unbind(0)]) nvec = [_space.n for _space in self.space] return MultiOneHotDiscreteTensorSpec( nvec, @@ -4297,6 +4372,33 @@ def clone(self) -> CompositeSpec: shape=self.shape, ) + def enumerate(self): + # We are going to use meshgrid to create samples of all the subspecs in here + # but first let's get rid of the batch size, we'll put it back later + self_without_batch = self + while self_without_batch.ndim: + self_without_batch = self_without_batch[0] + samples = {key: spec.enumerate() for key, spec in self_without_batch.items()} + if samples: + idx_rep = torch.meshgrid( + *(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij" + ) + idx_rep = tuple(idx.reshape(-1) for idx in idx_rep) + samples = { + key: sample[idx] + for ((key, sample), idx) in zip(samples.items(), idx_rep) + } + samples = TensorDict( + samples, batch_size=idx_rep[0].shape[:1], device=self.device + ) + # Expand + if self.ndim: + samples = samples.reshape(-1, *(1,) * self.ndim) + samples = samples.expand(samples.shape[0], *self.shape) + else: + samples = TensorDict(batch_size=self.shape, device=self.device) + return samples + def empty(self): """Create a spec like self, but with no entries.""" try: @@ -4547,6 +4649,12 @@ def update(self, dict) -> None: self[key] = item return self + def enumerate(self): + dim = self.stack_dim + return LazyStackedTensorDict.maybe_dense_stack( + [spec.enumerate() for spec in self._specs], dim + 1 + ) + def __eq__(self, other): if not isinstance(other, LazyStackedCompositeSpec): return False @@ -4842,7 +4950,7 @@ def rand(self, shape=None) -> TensorDictBase: # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) -def _stack_specs(list_of_spec, dim, out=None): +def _stack_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " @@ -4879,7 +4987,7 @@ def _stack_specs(list_of_spec, dim, out=None): @CompositeSpec.implements_for_spec(torch.stack) -def _stack_composite_specs(list_of_spec, dim, out=None): +def _stack_composite_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " From ed88a6387045df0d52ea6720ea893cdc3a182c57 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Aug 2024 17:09:27 -0400 Subject: [PATCH 02/14] Update [ghstack-poisoned] --- torchrl/modules/mcts/__init__.py | 5 ++ torchrl/modules/mcts/scores.py | 100 +++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 torchrl/modules/mcts/__init__.py create mode 100644 torchrl/modules/mcts/scores.py diff --git a/torchrl/modules/mcts/__init__.py b/torchrl/modules/mcts/__init__.py new file mode 100644 index 00000000000..b983d492454 --- /dev/null +++ b/torchrl/modules/mcts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .scores import PUCTScore, UCBScore diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py new file mode 100644 index 00000000000..99b8772fc14 --- /dev/null +++ b/torchrl/modules/mcts/scores.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import math +from abc import abstractmethod +from enum import Enum + +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn + + +class MCTSScore(TensorDictModuleBase): + @abstractmethod + def forward(self, node): + pass + + +class PUCTScore(MCTSScore): + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + prior_prob_key: NestedKey = "prior_prob", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.prior_prob_key = prior_prob_key + self.score_key = score_key + self.in_keys = [ + self.win_count_key, + self.prior_prob_key, + self.total_visits_key, + self.visits_key, + ] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + prior_prob = node.get(self.prior_prob_key) + node.set( + self.score_key, + (win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits), + ) + return node + + +class UCBScore(MCTSScore): + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.score_key = score_key + self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + node.set( + self.score_key, + (win_count / visits) + self.c * n_total.sqrt() / (1 + visits), + ) + return node + + +class MCTSScores(Enum): + PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value + UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 + UCB1_TUNED = "UCB1-Tuned" + EXP3 = "EXP3" + PUCT_VARIANT = "PUCT-Variant" From b8a3823c1f85c1f90a8b1155224af464b5b21ca9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 11:26:00 +0000 Subject: [PATCH 03/14] Update (base update) [ghstack-poisoned] --- test/test_storage_map.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index b2b1a3ed8cb..5fb9e71cbf2 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: def _make_forest(self) -> MCTSForest: r0, r1, r2, r3, r4 = self.dummy_rollouts() assert r0.shape - forest = MCTSForest(consolidated=True) + forest = MCTSForest() forest.extend(r0) forest.extend(r1) forest.extend(r2) @@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest: forest.extend(rollout5) return forest + @staticmethod + def make_labels(tree): + if tree.rollout is not None: + s = torch.cat( + [ + tree.rollout["observation"][:1], + tree.rollout["next", "observation"], + ] + ) + s = s.tolist() + return f"{tree.node_id}: {s}" + return f"{tree.node_id}" + def test_forest_build(self): r0, *_ = self.dummy_rollouts() forest = self._make_forest() tree = forest.get_tree(r0[0]) + # tree.plot(make_labels=self.make_labels) def test_forest_vertices(self): r0, *_ = self.dummy_rollouts() @@ -436,18 +450,6 @@ def test_forest_intersect(self): tree = forest.get_tree(state0) subtree = forest.get_tree(TensorDict(observation=19)) - def make_labels(tree): - if tree.rollout is not None: - s = torch.cat( - [ - tree.rollout["observation"][:1], - tree.rollout["next", "observation"], - ] - ) - s = s.tolist() - return f"{tree.node_id}: {s}" - return f"{tree.node_id}" - # subtree.plot(make_labels=make_labels) # tree.plot(make_labels=make_labels) assert tree.get_vertex_by_id(2).num_children == 2 From f716c6519d86d6bbe2400f01a8f08219e7f099e9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 14:21:43 +0000 Subject: [PATCH 04/14] Update (base update) [ghstack-poisoned] --- docs/source/_static/img/mcts_forest.png | Bin 0 -> 234790 bytes docs/source/reference/data.rst | 67 +++++++++++++++++++++++- test/test_storage_map.py | 7 +-- test/test_transforms.py | 2 +- torchrl/data/map/utils.py | 7 +-- 5 files changed, 75 insertions(+), 8 deletions(-) create mode 100644 docs/source/_static/img/mcts_forest.png diff --git a/docs/source/_static/img/mcts_forest.png b/docs/source/_static/img/mcts_forest.png new file mode 100644 index 0000000000000000000000000000000000000000..2ee7014dc54dee00e7a086d496a2142efcfd591d GIT binary patch literal 234790 zcmaI7byQSe7%n`tlpspCl%NPmcPK3&-62D#2!iAgN|$sF4blyg1Jcsn3?)4bGBnK0 zo!|Y|U0bF644O>002@|m3KM-01g@ez>Rr?k2%su zMM{SG!1B~lRshsYFdPB^ECAJa@_Jt_P67!t%oXcFd712?FIdQX659mjg6qm(J&#cd zEqpOt$nqTfIZyVB7sJnK3SYfuVHqwQ-lpokUc#kp#Z{Dl_HV@6NYK+>PEnD{R@H9r zVK!p+`m)+|UVlE&Z9{fs-ep(MjTF<&Cf_5;2)^u!4)AV;<&55Ts3FiBr z{sN1GgYl0aj0U2A=Xh^BypMhNILwZmq$jbo*QG`IT+BsXHNps!)wHu+%RQQhEA`8=XsCme0`*?;(BN%S!GJm+%dOAbNHwUofX1eE1b+6hV}4w!QP)rSX}6dOsndW~nl5UvZ_?7nsZ;J-(tzOZVMd zcq&TP!Qmd%)-$b534RG1ht)UBFwrK??QQ?%W3jqw?@H~*Zas&DGU)5xIfPGD$90RbB#vGh{pLP5JrO6)wY}mN52cRq#;=9!K5ng=@^d89Apxtj<2w)7|<&HNuMzz|Wdc{P+OC6l4N_ zI)sTpk3aFeoQ*H1`TV6JpIkV8|F@NkA-h#5>awGa<^882e!J_2hd7RXn|!WRy^p&s ze=aN_$O_ej4t=$E|6B4IhC>15=k?0wwKyvmvL zQpN92kG`!jvkt9?1a4l3FcFS4x3TIRT^wt84Wv#ZmVan;lqx4`MdPC>*%;@yT-P630RA}pv>7KKRJDs ze9i}bm3p_L|JLSTdF-qccE>t3HqpgW(mMTlnR!j8`yk750VfEn<19L@IM6lw)2k&> zc4+E^B)h&ofl*1Z#){=VQ%d%zZo=j1I&j`^Vp)QRvb#Q_|iXR27d4*Rx@cQWy7RZvDx zVk?SSa$jhi)(8D6^FPZCe}hEDD#4=^J8#?`_OHhvY!y@~ zlBXXNY*w;#yrr^$Aq9v#svW2(Iv(3u^1OQHnyB{q4|Aq)`REa^ZcGW&`~Uk1+{h*cbGRjGkf>hgPo^^r^{iYTAAJ+C z)j2x&1vNSGyoZD4G8V2hb8un)kZY{oFq;x>#Xo=@EOqQCPUXm9?no7dEJfEK-M=1(5qI}vbFbPRw3gc^ z)n6UEq?gQx@<}82yt{w?cvWq`Cc65y!&+rQrZ|y_iyC?=m%Y=Pu*|O3(lWlj5kPKA zuIEXt@xH`M^?u^inJ8UiTkeMlZ~UX`7V>|p<|XS69j0Fg`kpW@y^BHQtUaTbM#WLY zFBRfSQNh0?7Xw|>-~24jsC0cfht`w6Tug>ig>!US0jMAS6e7kHlJZ6~bWnME@I@>n zrpFvKfDJhk@u%k3>CLRf-+2C~pkj$osbgH=2A_L@Tb?7>7++t1Qz^&lAD8Op+rppS zp5_76SJnDvAf*u@#f6!+MZJLd!vyy1cSJ^Yw*@P;6jj#o?L2`;_&vc0o4Hoh4kmUEl6c%*_26Q4Hx%ey~uoCdNg)q@p{@)%akNaO58Npj}3=& z)=2vI9yqCda4(fz?h8luJ44;mA7pq3CC922WqEzSvK&^eQC+-Q^2YY@qD%47HSJ^1IPxE7MDL;1IObfW; zkJ?P=d?M46U9B87y5|Yp?R^%{`6@u_fBr#xVpMHGs$Ktb~Wc{nTgS?J!3C;mc)ud}VUgjnLpx6<~>Z+K@p`@>1wKd^b)mZ|WJ>geb zCBX`=C_-PE(vTy1oyOdR2>2_exT@gXe&2rJkhly4h2c~RT;V@(s&JOvQ4<@ywKBbP zow(|~T!2*->#h6jU|G$h#KrN_vdW%Wb7XSGjfXkwwV!f8RhCuDyo5`}KLC2z^hdCm zSE;t|i^7tEAV@DCNyAKm>K(_Kwzr=W`y5@3agD5u=p1IplILpFyz*H0QyHIFd6g%v z`|`uR!Nl2mDD{b(o=DFsc=4SR%ri>g#E%X^%DEZE*k=`{e-HQUU$cbzHuB9{)1+mS zkdkuXr8u8YvVOn2ee*Jer_i|ok1W^OrHQF@N+0V^K}oVD536DC=>>a)oUOJ>u3+gH zMFcNaQo%pZU|g(KGhX}!4;8s^Z!=b4u;L*w=f!2*Gll1ljTI-Ds72c~vqm{Sw+9TxJjkLa4mZnnM*ltNp z4~gi0o8sxfVpkB@R6Zj`gcc2FGAz_;5pZGS9TON34AWEy!`S=hX)LX^=p4yKtZjfg%!w}O_UdP79lDw{g zimT!ApvW4M`oG{??|~c3%~&+HclbM;K>PEasvD}cD^OtZZzS&*(jMQ27AQ0~fx{(7 zI_lr#4Adc#@}SJ=v|Qjbkow3`Ilt^Ej%f^CcJ#Dq-nK>A^j~|1yPT{^&2r1SFR$B= zv8QQIp00Fmg>Hq)Q68a@uOQN7r>)~wcvwX^0O8y8sy}zRwwl;9w_2(t1vAOkjNF`O z87o+8{qp-+CSoKK*eU_S+=IdITT4>(Jzk6BWAM-|CewH|j;tQoZf|dI1nsgc_^Z<&uX}y{T5^ zIDrV^?D0H#f4NHy$#enlgzN~tB)3(N#}+~`oY{Z43WBLI^yw>{1nrd<=6KnL>5g;zlyH)yYc>Sc#ciwvv3QSy z{>=10G>K$IVM(rK5yfZM+Fz)RS?~mzu)_W&Ubr0 z@fZyNqLHSQhB#@iKI9i}0Ak|cN;Tg8;1ZEfZ?UY1gjisWRx~jc#3Rpr(}eKNJ{9wn zXgHbw_V_F@zqoiE@@lk(e=zpq>gu%Hsy7t&XDKNn;(dM(4P5iTTDyF^XPa#B z>`A!eY-!iY0JY`XDh)ytIOZ4Q(}A9uoo#dk@6%shT)2peiG7U_4Wbu2$5AbGPeHz# zIr8`SH?(MSnC8iodgx53?Xp7?Q1hBN$RJ8^?7Gi>6JH}QYeC44z;x`~IHetZUobq? zcA0fBgHga-{O~}1+}jjxBDI3R!RbiH3Jt_V!Ht1qpKd{wv_K2@*fw-{Wq5aOGY18K z#XK^R@AK7Hk3Tn><*n{UwDdr5ldD_n{OXu0^S#OOtUfvDPiZMWfIkUrjabVs@#%LC zC=_DTNWNmAe@C_q<2hkcloCFts4{mJdeI3TOp8=Y;6SMN$UD*tMmei)SLT1q+&d<*e3;rP*4F{IZ zB80vlLea;#^fCVV0mmb2(keUA_bYafR6gT*S?7r0@fWkG6kf3Iv)oEwYPL40M9$e5 zF#oFW0i;@{a8r?}=T?>?rzQ*?Vhg%Ig>Y(m;|i4b2Ym%R8ojc@;Rp;2T(mQ`Q*B3H z9~7%5i}>DbMOb>lMCjrwsgvigm(UwQ$F12<6S|APKEpOfB9W$hz9?Sc80R;a?$3St zZ_%Ol4duRI5!aPJeBC$OO2%ofF0sphXb`V~W5o?HVcJ{Mb7Z$&qwmuuU@K~`eRQJj zf+o8V9aI=+YGAbW*WIkuJvlXA_8hx$&`EbpJU!qc#nc^w@;b^sCUd#Fcw7q&OkH%n zQ)(PMs+Dmb1UrLIAd>w2b+9c0u#D?sE?M3AG_JVG_=k47%6rxhnCvXJGrCpg7V3cH z9ZQ$-?D97wENF9s zS4uRPC4=y3=MvI+39X3t2P?yqVnyWVPJ|h%8SiNNcm?WoLU--7)%Dn6>UXM++AP1O zpBJCDuD#~iy0MZ>rFFfQew1Q;N0>uue8Xj1R$+>3@CI#C$1~lL7`7LxD$o(NB)0WS zXCJJF_rCdeeQ&W+(MO*Wq~{@lNcOmU2ddK!!JGH>=n(YY8yz@pEGy}MltLO`6nr&T zuIqDjc<5NU5ElpggMbTFR<|1vaNl2&<;eto-7V09UT@~wq544l*AozACR*zH(=+^aVQCw}1XfZ*n!rF^D=;>c0=r% zZ90Mlc(o(%oNR^WFn-4E9h=SNh!zh_zOjp~a~v9cqI;7=5AkNS0wy-~ZbcKhN68Cv zHD~ZDX^B&4Y&|8dAq`|RP5_P*(U9N-yFm&*;2KRtQ`TJ9Myta z_0#}K=>Jrd4MJt}6fes+RcUbgexst&+uVzFqe6<$GBfTGDR?bpx&i7| zew97TFSv`zy?&b$rp82(A0DiFMW1zRiEWiS7CwEUTU4o^11||qWlO2cFJ3s)k@(Xr zYx^f9T`PlsOLo&A>L!v+umV}X-PxDzm{b;#pp=6C!$1z_*ZDrv{L0!|_r_PFLBLpF zTY06%livFJ`h@F?iwBWO1Y5%^$-d^xt6Zr2W7(Tn1E&Mb-`IKxgCXU>``*;uH(eJ{ zH9pW^LEQxt_rHeK63bGzwUBl6*{1Y$SO2_5+WzggSEFvOBH8<^EBJ#M@Z}h7~!smWRfT3X+vnZMk+r573*{Jkxf7eBCVQZ!P@)ee6N2V1O^`t#xb+R}GP zY6(1dhsvToW$(_DwUDmgr>7N_W$yFJ#=a3IkC}rW2FhEJr={+^Ms-#yGn@;HQ?7(- zerJ&KOfGKUn0oCTjX1}rOGWW40s1zJ+03n~GwqXRZ+YIf2ix*H?9n=BCzX4FgevMT;&>^@=fBg@cPTy*IQu8(L6l< zVIrmyY(G3z7vNZ*gS-W_gXS^%t1k1`gBs&JiFLcBR^;k{K}CdOrqYnb((QVLCZwNj z0U|iqS-tFn8b$gowszi=UCZ7`s`rfd6DCv1oFDY@HKJFkuyceQXP(Z{nfrh4dv9(Z z20@;l<^^I}@%`#kdgHWQKV)i{2)XvN_#GF$qUkzCJ*AI*w^SI$ZOVz`ZA4 zq1=*=-&7;|`x)_UhU?Ie=mr&!IOASULiuETVM*e9qwL!HxYt=UbH~#YA?1!Y=CYc# zj?Q}U%rcR+v||re95wq6BBTiqof3PcabJ(0b{8H#^_^4F>Y*?M&)%#OKrQe?KWavO zV!`Q-xlNL}YN~Qt5;SIJsWUnA(RnBcO5ngP7cACTQDgTCmp1i4Hqm#rzF=vULSZ)n zCpk_D$P662E+E-xU96Aj@__?icLhuw?%gqi7-e1$%i?bOY#;?a?|dWLWfR`#hM{&crOa-Qxo~x8|O57oszy1G;d4t}#ng=F`Fx_w3v51%Ze+ zsZFqM+n;4JxA4P5-}%kWP2CjfUKd@;4!tXX*;~}v{=RI$BB80v6({14;gX-<07Su1 zblE0Z^%@F7D@5huh|sOydiq_ZLrAlrsI9WSGD!R0*47YH(~v-P7BL2=g@lB>QlJal z3bFfiz%^3B!kSLsp}_GW8Lt)xMs?LxBtKP3zxRdpqf~ktN}DMR+9{(txLcZ1w3K^h zQPR!k#ju6?&Sx{><4=`c#PPmC$l`Jlc^R52b)nBjOAhp`*h+;(%MlyA?@M|4hsVX6 zz3ZW>dkGCCXRQ=S)#?Y!WYFQsBAoYp9O zED#T_eSLkd5r1K#no`L1M%2R;+}y{`v4f&}zweo%FH5$kIV^MRORJh|mco~;{dabDTHj_~ zX{~>quunkmp=wQ<8@*XX_GNDWV|XS(i8$GOs^q>!aUyW`T zbnXFh$252D!h<;}wwb*7H(ua-EBKxO}Q(7nycP)$q@y5-DkYlb+L%zP=PY|zvnB}%IWU6`Pwf2!|u`GHdjX5 z-(!X7JcNr#Mde?_ft8Bnx8#Ib=BjhZ-UlERdHz)NyEES;LOLQM!ga;}*s^GP zsW(aNdUaN-z1{WkL(yZn)K8CEWSYQU)P5c)2nj(W;po@&e@Zp7MMfZw0D5@pVvcQ1 z*KG+sLHAXOWmfazi6PbqhL^%dvmc#YhPk3oY7iY8K^cdp_F?(pyIn|YtIu619x>A= zQVh_{Q{e{0am-^{?TyxnaChU6GtsjPMy{@&Q{7xyUo4O__Um+Y#S)_p2v*SaeDxA%`%ozNBns;$dxsGN$66bdhwk6IPd*Vj|^QkIr&-gUyJ zsA|`T8Gz^mlf8OD&<&*QK{FP4xf_oETt{#}AKczub`<2$)$9^4pemCxTe1Z2X}eym z5`8@{K|CI)b`{!W_o!@g=@cS#3E@Q3IqO3@^wj0lP~3gdSYb2&%CTf|1@uFSMO70& zIQsfO*tp$UKwfS_{1*E2>^ilMp8!KY=R@G7~W8S2gcxOV#!6726)F58&10& zblSVz7qq1(Qn*QbAd`aF>`+_or=9m@={Z3g5s&#!OumSGecb=*<|pmJDh*ze%&q9)1J5La zfSL?>`)brh8i)Q-0V?;)j{e)LM|~vOcpj&Vvf|fUvC+4tNJi3(Bh`4s_8zRgbcgz%4eUpXXpH@2`37{VU~634eyI8#J* zx3(8%Iz!OUGBxK-olbFJ4H@PX7g6GNJzF!0T`w1nE8!S3yQ=?g{jUzTQ8WoPYw|k^ zU>pk`zs?GdQpck+->Stkv%!kW+rr~gM<(v-rlqRchpYvb4YGMcf8u+6) z-SJcFsBLgip=4?Z>=7YBnb>TUUe_F5BIDFsb``ip=HdqYGWX<-&$~xC`?4@hv){mH zd0ubax)JntNa!_ZS_FnY$DQ}4XN}vf>ib~mwEeEH#hN{OK(|I3bcb)+>>X?l)Gtn2 zY1wEVVg(_Q%%JWxB-B#!-q5akgmhl2uXnlXyW?^9-b@7bLpj^%R!?BO=>17hiuBY5 zK0DSe+3Sz$rG{qzD9K(qY=Cb2qn_3_DdSebB+v~zb`KARjDM-B&LGMAw=o#_aE`5~ zrKrv(JC~B(9ds8VV8B{}pN%sGqHgU!;`fnZR1=Vx zH_Y-yX-kSxl||b*NsE)tgb(N!As^}?A3)q;dJ;!3?2kZNPF6y11|G>d{jrYx*$=KB zw0X%QRMj_~#6$QoDU{p^To0<~7FA}x8+hsMI~x?ScHQGXu4>;+fV%jD$xHCXCM2NS zrlO?K(`RRAk`L*a?)&u%-CNkd5bi}w@n71hUsM9>{5BTjbA_)L7pXuzvnUwa##PB` zol1TVBz60%&Gyu1{EIB=Zt3@TzKiy2s0a|`c_Uxe>V{nhvL8k_I@fehkbATk1TpP=zM_sh$`8DDM)1=pTncf6FH9t8YxU-2WM zrUgd7dIE>0+cBeIP?!4U_{Ce$tAvpzm2!7DecYhNV8;5%JrD?VmbzL3qDc<9w`EOm zONexAA#u1|MPEi1b)n3wkb7@FPOHkLTfIwJn@Q^j%_?Tnt>R~xsku?^s3 zS$@L@c>5pwpZo7Yr>`16?CFh`IiwZ_LDG6=USO%9gue?Ev+KW1cgvY9@=k6$`1#C6 zHBN-U$_k|R{Jn&cYW=%|;Wn8Y;GUeS(^1|%RqJg2JS)kE&+!`faZ~Y0N`rD64g&2d z!mJ)OM&z7cTJ9tTdv`yG37Xra2 zIq{1TK|FC2b(q(#ef))&wf_X=2G46(a(9PQp7O}9-e^bU=$jE7EMT*sy}2F^2MRks zVQRB78ynj-%EQgg=a*~CGjU~1W+AALN=%3Vc3=LyFZc*#X$j@Rv zF^X~pwY8HAI5+uRwebouZn5?`=#Pk1*`|5NPTRpNe+?sGGq-#9t-WD$BxL-IhKUJ9 zo)HGdEO!~;nDJmnIjT`=QucbBI0%BMG%#-5`GwQOm~ZfVg_!DCswaYOyxM;HU>^vS^3Jn=7Hi%o@KaaJ(BPm8>e|m2Y$+-w z<)VwFL?0I}t$P{?bqd02$H;rSd+=u0s8gUVv39OZ~Ku`!*#b-X<5)80lr+n7dY%K znmv**5Fh7*JBGxxJ;jQUr-wQS)rbr}bELALZ0~jXwQ+2A&sYqrsNL}HEw}udCyi~j zU>2DB0}2iF*YRLL{26we;P%NJkgyCmbdq?p-h<*T1Y z-QylQ5*z9NP)LX^YH%ToUMjVDEKb|oR{Q&u)z1l5E*-txer!=$)S@vV>nJG0F3zHgdyZ?H9nFc- zeS2kiEKO`9GP&!Et%HwR}#Rx19 zO|zsr(Vuz`h<(wNr2$0a1H?XPRR1G)oU(gF-qQ(`z88LD?L!t4E8)GNvx4atsf+># z19#3n;V-<=>#~)V_W=$oivQ{AgP8T(+71z=i;H2Toj?S@x3;#9kwF;Q$N3>Pptg2ICQKu zzcp_4w6sI^;_?!fX=^MnXjs*Bj)Zm{fVtMP?w7bbWW>ZC_DNwFowCL2`^`i5XYbptTBe+E$YtC|)?(>GaxB@9D*7}l=XGmLT;q7#Yy7)I-Q z39$Bg{uT(3m{%(0X#dt+U9UZ3B}h#YflR~)G|}`15?rlTM^uQ2uK7fI)Mb+7Y^?0F z{ywY)shS#raF&^z#wcdWm*j|V%s-SSy;2f0%vn!9com?0EO-Kc`)i|_dDcxunYhF5 zJ%``XHHL9y%^N{F)>K)-AAs6+MjFzCa$jNIGy&Cb`A9P(Z~VhYx}oh(DxXhPflV($ z2}TOZ9yi5PO328@868`R)$NIgvG&SEUgkkNJ~Kloff z`{Nf?45vB$S(SU0X6iAm#WCf)p$nKD@ZRa<{_bug+jg1ZI;gFPbh5iJ5F;nu(*UCx zjb8}|SXo)YxSeb7r-pYDI)5KeYBj#<@Y?Q=5)DKjljYhhhE@xnZ5Dw#&zc&^I?teJ z)HSdm8r@y(g;;<>L7L$GycY`wX@L$74nLQY$PPDBxOEpY_&@%JnnMiCK5`&RU9b5~ zA>Q_l|L{$-|JBzwULl7|?Q80O6TPGdE=)GK_yfnv80#_YIS2Eu?yQ+%7*oOCL>BZb2T#zaZd&o2{O1nTuHsT!VvyEq*8ULrN~$m|}IB zl4m+G!j1X2?a84HLt-wUO2sU!8l_Jl_4V3_ynV1VLc_nxZiwzq^yAT7`OCi8;E<3I zeGJCN=Lr3?ZKR5OowG~#bts$cO$75rn}Pp^2Vu6rC!JfXbi2meB=R_cApZV&hG1jW zztic(=30cRiz@;eS2^^x{WXKdQ=S<4U|Z{wUYCwEEkEV8$fKt^*74*@cht<;AHR}W zyyX5(lK~3@CLc7Q>0$R7=De##-qVUe;y4_cn9Qe;I5%8>v-@=?`o3v@hBY$9J~y`o zM|7-VB{LsA#Z(dq>rf8v{ zHNu~_-)~*~%kub#7`)~l<0)De(o^fe^W#f5k2Oj+>{ey4^(Fk(o;o^Oym_Yjs^>D+eKLxh+@u!e*e785zk5RfJJbOl-8Yp2C9Z1KVn%r9ycLUbsjWcPf z_qFm&5KnneH|uT9yC*-lU1JDMwPE4|#vSk>d>Ly7jjE?%KTN(3%Ff*0&j+g7hdu%8 zoThCSqOJF2g~FqC4n`0#Z}NIc1cOzzf^yrs$cV`{pzfHf%(=i_n2yvR59)ix=#R;m zH@dI*9=E=Cffav7EIA0BEMG0ILNT_}f9|_c3d_~Kkc3D(qJu2=>Rf8lRpi}2w*%;t$(%Qr5s^`Uv2nXc}7dZv2j4t@A&PzWBgyw3f_UEiP+48z$$ zx!axG#NRVBw)+6@TQMLsb#fEFeFLF7TjsfIX$NUv@LjkMP+Hy|cuq8t*_MNO5TI4;WTZ4CM(8W&kE)36&?LnUPGek3k`CZt99pZ2` z3!^1U{PYyQ-LG;-24p4z%oz0KNkZ$AVjihW_Y+|a(k)HF6X!%1daaU|<5c{5>RXsf zC`5kK7Kg^fvpthpnxS)jVd^x-FDahrZXUTc}bz7eIL~kByAA)kS>-K=MO|+f;D%^&k`D~bx6$o2_ z9-r^nV2pMlx$oQU4BpS?<&0@jfo2V6{AgZuS?f3+yRN;*l!v9^1^QqlSqK!fmUqsj z&^q;YUF$%8e3GA3%S@o#)%@E&>DDZNR6xh=B|PZskoCFCy^ybmNJtmcq@Bov;3Puw zKM{(Kf#DsF1EvO0P2NW%_XhIvoELmA{jNZR2Q%L=mkz_I4wUzj6@UMyN_N^2(Bsmu zbv)ey(~kM~?;j+k;n!COhXa*AnIL#W$AsigCt2Xbp*`x_yDRD`CvT37>!|8v*O@zF z*y{ME3);_w>6Zxp6VAsQPN+yf)U^WEZpQcc9=2`xI&X5sH-{Hre_Lj@=qTgu{_WH< zr^M==bdm7v=icBmq7z3|y31;}Cyu}ODA!dk$KQWqJ>qG1`&clAPt@=8mx1!qtAhQ& z?Noey53K%WczJH=st)e;RGGrHk3?GEu9O8pW!qZ+MJrG}4C`lbD{Qx?rzaLO2P%<< zU36@6&Eo}{pChvN>&N?}-*Ag}nJFjNnmvEyv`bIpE*cp*!frlyZRO&2OwA1}*z+}uJl>enqIKFclB z!*SuLo0eRrSB+SvJI>Y33u?)n8V%ofzzj@GG7IjL>CV9)IF=L6>7Q;iE-K3)x_uZu zE5`HBTy&binz<5R1~xV}nkU7!{D*3~40a9-s~buWVYxj_TJ_vWjB7%G$rgcOWjnxz z|Kbf9leyQuWlVk-AcMaPe&BuTmCzteYb^4@dUDsj>sMqRZ{#of>(v2K3RHV*>-9Q@ z5?iJRU5z@A8FqU}xooFNO+uukq#mADgVkIthRbdWip%?S810`q%z?L+`?>!(mv0UgSx0ew6AEl32 z;h2CU`f%GB0CB9vmV!6R3LRg%S7`V|1UHaA^2z?-sHzGuk^m()NphU=Y2)?%ytKjL z@n>$R=_;IJ5?xxefZqHutdj5u`(k;|WBTxEHm*g$xn>&O(t0ZSmuqRsmD~#I6LjIz zmmq5s9=@NwZ}~78qa};b@^=A-bd}a(9POd7u&}GI3YARnR8LKIZibRU0+2fx*_dUX z1L~|a&%!t3rPVN&v(syZ5Pq~VPi8#c6@{eAEqaNiwFj(JYnq`z1{Z6EE-;vGx^Oou z{c5fvO;meiUlsobD_Eou96G9S$;9qQHBWfl*XUEbMFK1o#=csSM%`3Y1lvAEB#Bep8v{q{f z!oNiRY9`799?M)BmDsUdp_7xY2z}4S! z>Tw}@A6OMwUEdyktpM;Eeaxc4hTX-Iqqi#r3twa+GaVMG{^+YsiAR)U^}59>Bt9h_ zQ`mjS!^bM`=pUDy`VPb$aU!SlzS;f{p<3>Xdhp7y&oX2)j*#NkVIlsTUI{tKMdR)+ zIqRHZMmMU;`l3AA`>lk@m8uOU6@>II2;2@0$2uRfvvUZ{xI%sghZ*({#!oVfHiLC* z-(B7%%#)~cI^}fjMOUfqZM5bIAR@?HsIW+G_8ib8i;nLBazJ%jX{YLaz2#og$hlzE zVSKfRu>SBnxPuS4?dqd{_d0PRViP~iZY|{Y;3sdvz#av(lJ$VXbpPxJ+BIsjazwVQ z{1Sn24$j6P9O!HfXvhoNtie<4|I8R-9e76jI!^IBj z^-WrNLt2-1i>2-p{;ojJak3d>pwusI732WEnPnpkBtQo zI-mmPVRu*SqX~~A=9F2no~F-Je);Porwn8Bf!Q4&Wb`Y|@5*|o&O=N~If5x4I#AJg z1MyyzN~t)%ONyefUq7>oii-w@M|^(4@UL*tbrr!#terx#>Z`~YdbPqZLIcbpF=ICw zDmAvz*NY?#OgQk@cJKLrQaWMh;rdN zR9Dr6EXs&z>Eo8S&%EBtv;HzNB4^DfT%7%!At|(gN$IGEUm!3Ti1Qs_@0X1~joWph zdf>-5lw7Q}zLLo5@@FimxKWf81N-M%bg|Uu9ITS1y{h$Bgm>W$wU42#%RgzPn6lt0 z0EyX}o4vTK2lXyP6c`x@7fqF z`5&=i$rB|V=k>*H#O*yw#r-K5_&6kuUG5wCB@Y8zwn9Q+dW?f}sLYJDlF7DKh#j+> z!$Pu>;YT&c3X4?sF5aEDKvDJ{iteZMW@mfdIK|GBSpMIL@Lx!{;ZG;EWVEG|mo5oB z7is2zqzK-ckr$bfr4~=IwG>|78U=az>CKY6(Ih#J<@!OPSLI@0}J79hNOT=Y2(U<-}rly zx>w4LyH&xCIkm&CaJ|j8p>@3q9fa7dXWmQ#6G=xYUNX%_SV@Q}P6J2f!qCCB(@q0; zoA`O&&_r8Ao>euwy>5{WYg4(rtl0mRdhX;+ZtFhc`?hv$KN^S2e-5ueK8>mU92N&K zoiS30^9LMwT+hV}5v1al+TEr;yk$`@bipN$JbkGUTo}xRwI&(8zI5H`Ogb-TDgoC& zd93P9DMrAoe#Za7K%Yf(`#nsu-bxsNophOk^}E7c)du6KG2>4~4Y6QV=mZY#y(gOa z6;w<^P|Z4 z7P7fcAWTm?Qsmu|&IsL}vX1v8H?!4c)5Y`v<=9`@`f=iqVpU9033y_(A8CCHn#T`g znk*h)k-mEIs+{Zhav{LNRr;uK9U)tr<@rZe-aHy6nIygXhsgu__)rU@6J!}eI?Plt z(ba3%<-g@T^dBxQw%`mL0hfP*)xAmpIa|pmwyV4ATUm(NoX>JYK>1W$8xPpZ{<=h- z@{>xMuyOIC&lz$Kaw=LNOKWVq6-O>)O`pXG`oH-_=Q1b6qjW=R zd7_i7xhG>^dq|$1Tzaws;&Z>J9&r^YJKzJ@lj~dQaU->!MeYJx11GBOqv8!55&URt zbehRra)V^Ff>+^tQMYkVo9O3&70-1&@*<*8&1!Ki!=wx0e68}3k$4IAPon1Z-y!h2 z|CI%(BO2!tAUL667a?LYOr|;_o);@v@Ou<^iP&ruky{8xAA85r0r->FZ~m%TH1Qe-}o zY4aQ~t@?ebP90DAyt0JeGtMZdj!v6F}O{DH<2)=yFXzKhDXs!o5lBQuLnk`jiAVa8_^TFFqNxc zNt1H!tndmZKe!|;Xt86iOi}na_(fgMco(kn!mjLK);z@cHG4h>dg|Lwg@QDdTKJ32_=t_UW+K1c1fZH1T3= z`k9LH@g`3x>Nt(HC=~ue*gPK2I^L<-c>NEW&VnJT_lwpuFf>SqgmiazcQ=xPbc4iD z0z*hi3`2K!cS%Z1cStBmhe-FG-~ZnG70!FkKKtE!t!G{ERz?68Lel-Nf@x=&qkD$m zofd%Cg)FKwd@A#kooH=)(!$lh2+(^Ir#Je+?n`#rfd!xr8$_2ih;%tl&PH^C5$lJo z*`s^r{vfr|Iq$mJ;CN61EVw$QI=F2@ z6+WAfBq|$z`)Pnan*}4i9gG_Z{mI~^S0Y=rJ1!Tir<&ubL_}`1EOe9ZhYfU%!Erd_ z7$FCNgCqp>6A%Nrfp3|WHL0C&nyHXpV_fYYC=n3I|3YDvI8w`iL_@8yVp!{(Cw)Ii z$TsG0@;Uu_*_>_D`($Nj#<|;ZAMt{?(y3slWsYy56I%tt@oVm)E5=$Ffn}$`9aAEW zr+gw<`{6WV`q{+TNPD^`FiGUlSh_kyw)9Ncbx2E_A#z)&o_mV8jdMkwpTVeXA!CCt zb@3LnG$jW>3JGL+f=>icTAU|ooVdQA1&3Xqt*A|9BBTE9eK zcoNm;6z4tao@UXGDCQ8wB%sL!)m~iuKpI)GFdI)w_@Ystn%(85JAzuC^D+7-y)?~f z=GWDl*&A6N0q3dhOc0pUKe94QG&f{B6t!L1C65#1Hj6|!5`nKFM7ms84~pDSageRj zU)}f4n;c)Mq2erHqfwnCBc??WZ?(_IGMbFvyx}rxZ23^H`lRIYU82O6L8N>WLo^^# zoad0ZQEy$G4lvV+E<6O53XIOdh@NEKa#+t<9$hl2vNwp^icE*3FP*Dc>+8>;G`-b% z-rrl9l!lkp@}?l44ww}BeIyVPgC)?ZTC#EYF;?zsk(Z>;0~w|CRmPIu`;Crll<@vc zd`dVSFjQmH6K=Q0JE_q5NubdcQ`Ozsn5BXgXc@xNVn7X?gSk#VsruwKLy!&}N7V5- zOa@pXZ-7R1Qkc%c^B#vfpQ< z41xeK1PS0<`eR9k_mAT3NwtecHj@Yf$I6Y|1z7D4&dkmG0#q^-_%Rz)sH^&~8DNfb zaJMn?Seq5j!s%3fxAUV1)u3}b^8J1I|uv&F>Y z!21Ng-k!_7-~KBD@e4?=d^9y%W~ua0l^1RuT(XSYKNBSkBcH=!3>|Z7_YFl@0v=ty zL=oeI zpZ}h#s+1cew3_fSo=QH29-ORBKiUx}zIOAmqDS9`zT=aY0LyS>#1j^MaP%g=P$8k3 z5*q4StL%t)`~C0(=QVA?j#$r=RWx=}K&~42xnTsR?D|px-Vm8tW3ZGMEBrQhLl2Z> zcV+d)-#nmdP3&<-d|2>T54jDMNBPmh@$XhvYVnsgS^XsL0UQzY_z-i8JNiL%q&T$xD5_6$ z@reL&m-zASBm_EHD*dB$dJ%(`T1)0DrGCjp*K}ktcj-_Zt^Hk6Xi8UL*f$Yg4~}w{ zEA7D;8vUQ`^noK$F>P=3F4aShJ<^cQd@F5!O4qL6&bL1*a=_G=f}kh99y!L0W*T|LCvT2Buv%Ki?WD$|*#HC_);cnCdOx&Sf1?5{ zufgxyf+?tg75IPkJo%E{c&p$#e-cx7WnQ4;J1dNy2f3B+Ag}paobEySZhhPpR`B`D ztB+cqB{wO4d(4gJPa&28*cb7dv0g8vrE_*4JSFN^66H7~cLLOb*q-DoEdVz_V(bcB zPV!fUbFy8QPAv7#Q*K&9vnH7$y~rA>a_YP0=13vg8YK)Venv)?Q|$|-K*R~*ha}Mq4YcZ2!6>K=Zvd_N{%AiX`b`t2+f zQHHzi%2ROhtEWnM+#>1e0*y$lE>ma9Z`9)b){cwEY3>nhRtCnK3=j@?2wgmMs93BF z6vzb0l6j(saO$TCnblaLNLlF?iyRYQdQoY~S_o~s)>&iVnneFZ9Lxh9)<=<~-tYDn zp-qH)aSQwrrc=CQ8w%0_2(_pWBd?9j0wGddV|7GZN1@hi?r&t@yQX3_Jujd2S?*3@ zA*_8}Xh}$vsLz^;cT+xMB34V8;I={J<~g<^j8=y=yF9kHZ{$2ll|1L%vK#Cf=S`8z z9;5E5o_=8Zn_QWd8?3P0#Ce?TM|^isoR3}7)DRc}cAE}+JjF3yeDPVaZS(VdIsW~d zgzCCZ<1vk<;Gw(37@=WTGFb*2E5e$(OyPZ%20a#pQIXC>1JxF&L`bt+tsAY`HZ}9T zDd`=FWeyJmu^O`8kahGsYKtE}n`0wveBs88JJG`LncM%4#Dt*jc2=wiAE{hS;y^X2 zCEjs4Vm#=5o_x!q+TK;RN{xQlrf?45=h5>@;sBis{g)%_eNu9+F#(HA-6()AtC4T$ z6hohFl)=Uj{bLNee%A_@LTjmOG}_AoH9TMN2dhUJt8%DmqFo+;Al9nmE!%Lfbg<&d zTKRTmFcsLAA;Jn~Kpp6Sv22mGAeTRp5uLVXw zA!}^{`bhADCR(OLoI&K}%7jQ3+ay-~m1SiV`%1cA{9dIqK(2O-sY3z=jAg|Nadd#N zWV+N;l3W^9J#17P=8foLpw@@jlvUBHS6E~W`#PyB7@Av9n9CjBciu)dYWj_17W+W6 zeM_sKB(^%tA{gi7`|tYXSUTqApZ;l}WFcYqEh=Z<=A1W)Ey)nPf=3Qf?Hq}OU--(n z9}hFPonw~Xd^-JfNxl}F7_S_f(vMmtog-;!{zs-m#QuV51(y-w|tmKhNGgyQ0; z{?d`6BWcd?UtepC_#=@F4Y&)6!_{(U^luKqx}FXo+K@m76(ei162Atr&%BuoP#bTG zP#QUkqKnHw_&u|#8kH8#*XeUYPCFu2NONF;xv&_fY>@&~(0({anS!z*5gry;1D181 z>(LJ+Syj(OnJdopGFhtSx99AvPS8L>6)A19#}nrG4gK$y6&71$nT#TJz6f`nXELwGkaSR9%m>IHDls;TX(;&VUq7vj@GqI7qHuls(K-e2!o52t}U(u;5-w)UCE(7ysqIDvq{G?J4UMJuj!5!+_J^= zh%$CNW%?zdsc+l~1Ku}UZ}Q`sU|lehldED{0a=mT8?Gcz-(aEKRRV|&nDo}{oLdz; zS`E!Y`|+lj{GYwP{IL~!TU-p2L5NLMX0L!9ND6Gu_&nt-kUWNujWQn4&$MNb&VuDw z8C|D@8?}zK_|OJNOM`<9*Com#nSl_Tc4Ut^f&y-?iBw`>aADURsR=cnh`O;Q3 zk$QE-+lk{oTFWDS-RFBy~duoRP+e8^q$s)Xde?fr8{D|dJ8SnqJs@;OHf zUf-PMX4f(|g$5%LKv$}dqtlRw6c0>T5sn8v+Ds(Xizmm?h!a~c{e)VZK)SC6VIk7< zC>AZnrbdZ*d#-1e0_h7$G4GgBIH`+?R)b3QokwO|d4oFZmPlfi!3Kx~aHYz#3M3@G2^ zMa@!?KP#)Y4!%`}+Wv2=?9DJNxUz;+g;cF0#enNO?uuzyzNRx@sW$&co?jRIQ)ajk zbVY#_%LJj`$bpgGB`^44@A~Mqou=06Fb`(S#?HLj1EdluV^LFs@sp(EsYXD6{7v2E zX$NmYhgt92~Jj{!$FcSBKz1N=-km7bse zANNY2DTDD#yLnC7@UgxDg8QeHDlQG7qr$?%O8(jspjTp_N;ZFNs3T@bs3CpHjMd?H zedznx1&clw20&jTe>_v@_~!2}j|O2}5y$H{c?$&w1-i^e=*hX@u{KC$KrZJXv6wVuR`L z`|_oT1p!L|ydQkwiaa`>nbO3G*9DeqZF4h(-BzcQvFJvCHMbO$*8prl^ooOb2N=E-B@l z4yca`DEIvMXDbY2i#F}cY~P&XM>)ft(RK<5z_8X!G1|S5LnKe`iMu}JK&Hg=&D!g0 z<0;OC&-3yD9`G~93@Tjef(Rg`qFP;zr+%RpM?cZ40Bc3L(ILGd1alr4Cge{Ku^OBG zb2k492|x2#ZsufmCqbnv@QnF3UHgdhaSp~TG;JEJ1FZKFFNWYuDGXUBafR2cSd`Nw zh5uXb5fjQ;L=L6{>szV1n16)gh)%z&s*=mc@0x;|sp5&7;Dyk!K>I09(W=X=U!(7B z(S-M}=9jS{vSF7)hpL3F%mh)jl?-slt!`5A7zySuFevYaP>+9c^lFJZ=(8W9rAK~IR+js34 zJYs?9rkb}{bzedj)4hX6yPG@3OHeV;5kIdfPY)6fmZmd4{7 zE0zrOydPunqliGNed_i;>azL2+Heq7KAE>K<>IyG(YsKK_T$p2M5O;T{mXA#LuD1# z!&vbFZ;4Szaju1Kmi2I|shE9y+QK`I?l^69j{u2G_S7>(O30uh_ikmO4vRX2bNtM~ zNzFd6{I#bzFY~5D7+zz&^cVkY|Fnir*Dr z;GW?${L58FkB#}GoLxs$!l2tpV7TuhnK*^iXC+PHtter|%LBM^B!sFCB~fSC!(Zdy zl=bQ`h%WsN@!Kv8|F3}rNvpI+jFGPht^5^%Gd{Cxm1B~ECjToip#fY;KFf5Xjtzr` zo;F0cSkB9;UHA&P3}@?jNOa%?GoaYIT=S>6jcX_#<=qEN3+gC-6@upQ7@8L?`1k6G z{-m|DG)`Idt~!y~s~B8{enXvF3<5-WuFU!y!w6nYovDvjEU*EC(sRs4ML@`8sNC4o zmqyg5`vI_t?BOEmZ%x2QFAb|Ttm>ih821kDLC|Kgzf?0j^3j6N6B&CYiwqp!*V;LB zSVd<@n$-_kz(ODA#~E~gYv=+>Av3XXOl9@y_Dc2Q;icWhTwvZmh)j_Un> z@=%w^yVsOo^EVz)pC!}KQdVykXlH^Z$?@S5#_iEUCG?ow_^GeuR8n(@iF(U9FlVrB zVb`43>bmO(e6i*D3Ki2@~|Q#y(Yl+UnR zgaAkNbdau@ubC896@4F-sx0{g%yp*C9i029tdpL4Jh%3_G@V~9ogI(@U4CTF!*?du zzkc+uN?N{I>BJ^~DfUVq7g5Qft+d7S@KREH%3*8R4h_cS-CH5gF4N*-<6ox@XQV&Q zI<7iD7V3uV!Yv<-*H0_NLWo+Db@cUoN>Uz+ls3sip>WaJieNEB`WR5q-!~fL^%x~_)b1lkKg^aS9#-<`A2nRKCBtlJ8&|0a#E9T>hJN11Cgdw8}iMLD(|7LYSOteE0_@``(BryG2yuseP z=wxoA$xZgrlV1GzI}++D(4ghNkx2Hbd!(3Q3ZwLGPDpODp0^V2Io zyd5$=MK3zaBu_ z#Fc_sXUe!)Qb*<&_-Imd2Qa<&x~KRG?oXjX2EkS*_$|4r_mxp$3!L^*d3d03U?e=>BNez^g$|kCs3&ME`@$7Vv%M|fzrc|3H5?37`O6gM%%KdEeG$h zS9b6q2CLeEow9r2-jd7PdXH;#xf?#~A{|74@gwoSUebTWfAMSL*wEf>?;!?!K#Gu3 z@-mr3kR;d~0=S{*i zMEjj&np{1e<5wO&X zpaJ}iAqPMF4k~Z@!$D6M2?*bPe zd!ym?FIa{pb9nFgl1+%UY}NiwM<^?b2G{M3vihi3GwS+YMhJBY5%6`u`KEijO*c*f zB@&m@1yc|A9<(6@{}P~8t)ksIxXYmnILZ|R_v->wQ4tZGNH}5nGW2#$xE_nhmALMS zULQ8{5IGuiR9K3&KcM;+e{n#y{NRTkCg7 z^hX6Dhw4)}^;7oR*60p06Wy_(#^mj>X5v?$u2AI+Qp>n!7^lZnz6~o$Rs9@?{Bs!O zULqm41|Gj?nlrax?x~@r#mo2vhe$jUpnrB;s^Y4Lyi+1Isio}e?KBqR;fBH*i z3YD=)in-zb7Yz3wICFJ!+>I!=2lc2aQKdZ50s-b6)w57^Ts%itMd5{7J5%H5c_G`6 z>kH3+wRsXr+zv_TT$H8o#Tb6QDl}#t_a@l*-ANavON8BYY|ve->V@j*`>T^9Z~pQr z*-p+-RleX_JS)bQ;m)r`0}Mh`Xbj~ydJ3j;r5?3Wf6C}J8(JRw-2U|1smHW-Ha!AYz(XZ>`Os?=UR%9I#DT%MXhSEWC2-` z{08M@^TbizqaxglMt>A1)w}0{&{S$RS&FX%pBsWp%-IAmCPLN}0>A{U0MlmTt_-P< zn=^69I_pPPy01zo-fYTq9MCG%U`GKS%Q2CgIFC)+aOawEUc-%Uw-*Y}UsJ*AFZ6R& z@UECn3+cWsbO7LkuJk7=YjDm4S{U+Xy7W0?-0%pZL6#Bq(^0tFRRemiCqkc zHjB%#o66JKioZtDQ}K>HPdk9@rIT4&4T94#0QL<8X+Vgn2@AtK3;+o}|ID&iWN3zADwa#DU+q1%<-8`*{kXU2AI&|5YbQ!zj3Qz``aaVCT;QHlt)9AnvJ}>e@R1{ z>5Wcbyzy`Ig-@3C)>TVMyXlcjc3Qp(ePgpEDzv(BuXSU+3dZm*F9%df^{PW#Md%8s zZ!Jl3j83roW_(H|dVO5UlmGv<0G1YY)PAP^qVA~74owvMvp>(c_qW^q3bVZ&yQtmu zE1;*X*O(sbrO`5Ld+K+aW33N`>eITYM9QCj~L2K zb#m_SxTL2faF!t;MNJ(M6l{g6Z*S(;PgGf77UioFx@;pUcPtd4g)cfYbZZ%U2N&xd zk@^al(eak&%>)%s@C)o4Qgy04x4C>298&6)M#7<&)~axUjVEK+4yZHNm7-N`5xhj9 zU16L?2Gp>Vq*+!W`0_BPd;K1n56Qd!%LmW`kFGttl`qNwN&_r{_m=`S8s3XpWI2Ijtx^ z&WQdo^m%@ce4H>^AyFWLienDeeHqF6kn>?0pe+|AB4?Y_497oY)C8bs?l%Sd%#w0Nbf$OQusj+@OZ#P z$*N&r7KAM-=D(l3Lq_P}r#DJ_5#_d4d5BwR&45&{?R|u7pv>t|$c*wCNMX^>X9DOO z3|47;sI%e%%|$g;nP$u!z<9QofjSi_zy%-)=%$hSiO(x>O64iZ*3OWW9M4jOZ=RND z$GsD@r$i??&KYnlLvR3Fajgi3Hp{P+t@dE@`A)o+9y3-MuN1?A3;PsysdBw3 zB;F+(^BK=vW&q-_=jZp+7o`@+xgR*AtbTw=c270((}h*U<3h~OEC7)Qz@f@)eESphnKrdeuCy~p646S1qTP2AWwM*R>@ zT`oS7@R_wMs7p-CWOlce7qjDjp5y%Z+_c8;J$6V~sK>s5l||Pa>u)+wD*Bf`9}&~p z(gZak?FA9+33akV_4KTAT3tMG?{f=}=tH+r34N%Iw#AF^a-Q9usgJ9ZZ|V~^7S^hj zXi%vf3b+bh%V_`iAr3-=(dtogBcKi2M7Dj*-g@C9;M?b$mOlO85KDrO%FP%#y+m3% zT@xA+`odrOY8}bOQiRCSeO2o6zV{R?1_G(S%$}7aMC$olqdGfzr=S+M+ZF8$jhWC^qm4pq>&GziQv@8j{d zAVP@PW2JAPa=h2@nxuIv+{ozGpnnVx9eAZeou57bkJHrsUjz03L5pSaod-*OMw*|f z#asYVM|^ZZz&S;;mN6S(K_RU~MF**(c=uE^IsF_o_hZb-oa^4R5jxT9Wf$V1w|2a* zV^cP(-#e;}iHeGwq}PYAkBn%-9T;IfR5dmWmQNIj>QA&koywOJ?0>L2r8}Wr$*>y3 zMlSr(D~|*vsb9Vk0@up&cmytIUt-^ryc%`1f>oIPRaU1~`WWS(o`XhkTn}9xw?!#N zK0>);`Xx+v{=EJ4xW5#*8KkXH>crEx;jR4`6+|RJ zXWhl$IE)apfJ-&?LFSds>@u{2X|QU@XJsYIo4xf%^7*-fEncXP@v$oJh|D(x-i4pj zhTRQw5M zCzvZvA(&G}N?+K$AhXD(s-%{W6gz#Ni;xrW!liij$?Gp~lRdHI_QQCVCnRrO?IR(o z6lHsDF|6JB9DNjR|AhUz6XepO8WeL}Zm%k8?6aS{+L}_Eb#2xMc)U~w0EhsId~L$J zL2E{{9}QBAlqMayV;kJ&r9b`E-hP#flxZ#eT-x{u#Q7;5l%HMvjCI)UV?A1n=+y>$ ze3M4Qxk!(>;&EAnwxSYNghx?W>($g7M?W7m2{+9as%Ep5`=KYwu;XGL^s>6fY}n>E zHjp>afhA-GI7yu4;0iizA+UI4XvM;r@Sif>6Nf`$8E@0dxdN5OX#|Ph8RAzVtLhis ze`*a#)RfzGfdglW{*}UOJHN1Rgb+)e;I=TlK-iORYca!XwqSK0Cq7;u_(c;w%a8BU zqJ}OO--9qMe*B6YtyGvHKvi44=#SMR6jlDvRln8|7(NAL#E2)+MNY|_hI@9|?^ACw z5N%MoqSkiFLEju zpXxL{9sM=Ue=&xdb%kO3gMus#fpSohHG{~dAZKV6sbq^mNd~3y^PuB^62b#@mW(xW zGm?AYx}6QKl_sv_IcJ*d{WKGs^hctTr8ttVI1uuf^8`t7UyJ4>QmHDpVgS=O6&ve^ zV0aS-=f|9LzQ^mEdV=o_alUDGo>m*i%)Wt6A4IV)G0^Q-K=kp3;<;-fARRlfJU{`a zfS#?b=>JhTqA)5Cc>=x9@d7)S+a7o@0nggQi+|An@ z%YBkZji}M#i2^>}=r!-Xlh!LWF~RBd9*0qE20^>JDpVcrV_?gcj(F&~eqcn;k)9*n zd=c$7ZquISXBL^hwZ9OCzcDp+Gnw)$JWh4Y%nA@L1|@Yu`o?~H9R(bb?txx^#5{#9xSm~6>;d1 z6EX>+=c8Tn*0q1-MiPwq{4D7DNgtR%k83DL^wZ=W(i4KH2a<2dG{2M(NKbm3F(%)a zuA7$Mh=;g9II=T_piBp4puX99i~Hd@@rLqDLj}xAFN%cIu5_v#^{wPA7o$-e1{US6 zcN{ql>9H4=Z{)S;Zz^x4Uis1w)hW|+u?nKMzHP_kb+N^3UNmKoCecy9n~L{r)akxm z=2#RE;x+H69Yz4t+=V4dOHPUym2La*H}1vzqV6HKY;h0f{qr20>h?;?FA{1LV9LL6 zE2*>7ZR3ZS=yxK)h4|ga=2%UPsLVTNfhC$Zep>VjsTP={jia(QVDrrj! z4sS)+z@{qq3W3zEavgFub$ktNSV6X1Vp7rW0H{w+>Z|8K&xcgdQZK8~RUD4t9|W$2 zl>&+T#_do(a<17k2DF^s0CtBdG_8z<(qqHLuB6rtss0)Z18h5U!c!u3cDPhG2yt9=C^rH@pZU}yu&&VM5sOXtdm`JZ3m!p%MhAy>Q__c8 z>8mvr7XH;SBUVl8>e-5em50JmxQ0k3%J-uvM?bM2y7bP%GluQ&DT5FyDO)ED9^_Z$s z$-T%{=x5_&p%L=NaL1F<-Y;z7jP#4NkWf)^Pi4*rP$O zcE>N%$+K(Ag!q^*ZnM1<9T^zx+7)Cy=~ovcZJUAFI!sfh7GO7#K<#m{g>GxY`#z?5f=^rb`LJ_NR@q3FerO;Bq%3s~y{T%zHE1 zk^V(Ypfs%2grfYDg|w?WnBfntlx^%dCCXo2u~a--CnD zyn>`PCEAAG20GhG)dy_yoCH0lhslo$AMsg5^)9tLj)4&FK0^pjpTz&#G|`K^I6RI) zfG3%B+G||S#0upN0Yxo=*uVuH4Z|2a94cw5ZDIE_AD0x^N00?%qw~GO zXns!MG=OVIy3R zQl%#Y5rPtvE!Zq|2ArJaMdib@W!7W4MM6so$n z83(K}i$Z2|AhC2SZkHV@1W!D*6@QcqvKloa>m4cUDg@8D#z1Bf^!2vq=HRN|Otwmj zR=gkTTmD$9e!PZPt(pH@wPK&lga1(Z5%cWJ$s#0|KB@EoF)S}0 zI~YGq1RF*mU-6`q@EK3upzS9(I6b7gDgNlVM>2Ym@s>iPj?W1gzg4TvR-~k}PRjs6T^|ay;(jZf ze2R6g9j4kL>@!U-pjl{@oa?wpy3ND);|I3Kh~*e`>$_JQ!z&(E`IeC@Jz(q=o=wJY zm+jS(gn~}Rl^=pj3Hs)pJ^3DofW@+!;e}+oU&*Q(2pX;rD=u!WLekCl(6xwR$}HXl zG6BJVd(4*=qn)8SUI7v8O15-w6lJv#9opRO3Q8^|wC2fY`<6_OG$>It0_S9GHr@RR zX^8%Ora#tEu5{L_@s6LJYANkwMAqxzpSa#~^Yn85Ww&+Rk0oIroAI#OE#V3m)ZtTz zdg(z?fMMLA?y_0?*p60-&wXr{Q*CQ!1h0iog-X_-hdP>K*iFfCd2es<53%WU4;Ndb z*2S5b8OkR&tuF_3t%xQVhMAnwjYFcX)MhDocTf%GlltM;L1g;!Z4L&z=;>|e#DGl~ zoiF#y6%-|vbxwO>4+ct{E&an#_U~(+ZQvMOCjNhU=cZJp3 zdG)o-G=O9Y<@7Vi*+KELVK8P)4G)O`4mtcEqG-7SjzBy`MnO48y(eYGaJlxY7XQ^@ z4JOAkDa8T|;$djKuOYAZn?C?7_|ylCTD}?Q&<@dbpo+en59xC}E?e^+rjSjZP?UPd zzJzuW$cKAOlrp1RJ0U*^4dj&D-~(X@8We;1eZ^#^Ot)q=Dfb6|6Q29|mog8cjA zLUEUMoXyP8hVXol-S6sf0(W51ubZd{6j%j~@&Wl>#X&2Uk3t-L{BsC z@;0I_`T|wiy{e@h)+2l&eG4UP39q(-Bc2uvIM7@NrBDb)|*D)2V zW#*pgpknsGKlKr6mMHD zwBxXSxIXj+d48TYx~*$dF*}7|tve<32`ly~E|0r^}ip8%5hp*)sku$QQix6&e|A>D_O4?c9%zToftP|LH zuHEI_Hj!o3!@b}1>i6r*euTfjKN#*aBTsHVuJgl8R@twgwSSr;X)0n2V~>QRc6;u- z5`r#%yXAgZ~7|DJAk?)m6$+!&_{?+5stQv3&#wovfp^n0_tT*5E<%R+;Y9n%k22 zrOB9CDIE8_2dCwjs57YDc&|ye!Qi%_aFERVXJr(UwbSTm8p|y#k~8ZOsnej+QlQS1 zx|69YbW^2b-~p*Y02FBaxIrck%^OZKHT?J4%OI|1p8ym3E%P)5hHX{)e=E?aboL z@czU9O7?Z^|Kapb1G3-rOW9&-hqY3y6%qhqTTXYw?!>M|ZujFeAN9IO?FApA-r+&r zHrjF5%`LmZqsjjyvp+bYRSs%C1{R#C2y+mRGKXG#p}2IcP8Qv$iRFu{`|ge*u+I(F zEYDLGWS|bbMikXHlyBU#5^$6G2T9j){6Xihuo_oElG$WNqo=KuQ{j0`icYRy$m>|-SiL)(~xVX&0naA zf=Z2+vE44OJ7p&#PVRTpL(|x6x1gr!Gex3)wAXdTyr-qeh<5=Nh>sdgtJ&@C+b`u; z0hu((djFBxV5{OAmux>aKYdRL%Y)}~ROS;Qb|yLZ3JY}FF98+1m&9HNf{o24E}zKY z^fpSoTKZOGBs;hNREH7YJMy~yZm%B&yAJpA0_RO1YIN5-XW@}h7jTZ#=SM|_$S0cS z*%6*?pU)d=@C>X~r)^nzAisuPRcn=64ay3I_TSLCq=g*g$&N!QXxH51lI$+}Xh=&% zI##N6jk}2Mk%S+Gtvpy{A&5YF#2K+#S$aX>cjHp1B7tx<9ojedVzMT((CuV6!U8Ll zJAbDJAxzfQ^S2^rrq6B}xJ{lU?DZ%oSVQLe(guDUuveU@*=+3+%8Z`*_Y8U5vNUm zetkt%aknJK-=krk(IFCS`WuWC?*LZUbjb=VX1qjl?NSP{FDQLT-$um!HqD#vQ#Hk| zuOu1P-Oidfa`%!VgG*e}cdUx?HjcxNDV<{dO1Sr4;K8>Q3!n6G%bT_w`A8+XhfjX@ z1dMa1nKe!lbrEC$)I=(jBti);=imA7o*C^2@A&dQt;os?GSu}a&2- zR+<9X^I2WmC|2AyeHXPZ4&it`&slv&?ANQMsz+1vMz(-kc#G7Mn`f`O$}zvTp<&18 zXW@+c*>qxvAUyGNJHZHN_*MKXvfJ=J`sRdpE|6e7cF z+9@tZZB>O@G*H~E%MPdUNv@=y+|nr;%KW?JQCf@n_Q{E_xJi$Tn*RG_P{Qbm^L$pL z`R}GPLoivHB#cU60`pgVo+F~GDP@Th@}uh)&n*GLY)qz{(4Tt~NworSMn~$k104D7 zoBFU(Vt`;q^_rqj&lnZsDL}EG%r7M?!$NniL~gAlEYp?Cz2&nNny7oS)`OSb5FURd zsf6ZLr*GL_Atld}f1)7Gg9`fWj$}hzRaEdvbq+)gl5*>7az?vJIsG*b^kR${L@gtI2hxyWSZdZPEAYwGWGF0ZV8DO9Y+Xv8ei5VE!|Y; z>L&x-P{q4~SVA zAw6-TqondrB_X^M|CgQw@2mEI-aVE&|2=s>{g8n*Oz_HPe|mhJWf|+|1XuO){s+>$ z*h!Gw^v|oWs4q_oWLG&441^q?W!UBM^}f!VdTuUDd{R%fsrzyH`MZqU`_S7G%B~rmLLF(-um(L7wjIhiz9I$TLQ9$M*+>eqr_1=Qm>$W1TVg(G7nV zOm&ZWTY3wu@Nw*vTtejb`A4`uXK(>M=iI)4n+rIU$M}dpdq1cC1_|9N+2#ihP50>E z(-ye4I5#cy^xvZXn_fH2d>T>t_OhWvA5u9MT<+iATQq`(M%(kUW#L#ed~dJh4QB*X zMafxpfhDUsZ&@Hne=PWSM!cw#|7Fa%r7N5+V>3y$a35KqdLsg>mEKhq+4P3KB|w#; z;RR(sZTrh)OXv@B`DzBnVv0j(dO|{*sHt_3d@!xR#8aY|oj-a$kBwk^Z}|D1udcirSR$C4Ly?L|2%*xvP-KQb|ym{`IsNbFX0tK&p8PdwPy_s$0<9BFYKqWJRn z&yHl5&v9!Q;;r!J$>UT4k-CGIcU4Vm=H1@uP48<#zoa@DsbAl>zXQR0z0Z><>RXSe zl-zZO4k_Clw{Wu96&%&8+r$>2Djd;uDc0-j>$`GO#hx@deg%IqvOf(i+GXaw9zE_t zFU7VLsFTLO?)hEDP~6)SilnS?216lt>x%Z1$yyM$a}mHn|y9cdpyDch`&E7ocGvWP-4NTxmUd#iDm-W1LX zPVoQN0-%ANbWFBc{$PqX%&w-_LMsf{UT^BQCXv`jTj9=^i^gkFcm`A@ z-78lOS}v{k!0&9*G@U=9ZsWMxBba>FxgIm=^Wjtaq%By$@_$A!Am&N2zGc1mK9-L` z(Ia~_-w22p!ES$O|AgtEZmHAMca}#dpGc?$)VG!kQf#E@{#g$v0%4iFK9yu5`2s^) z9;4Y^G(s=fg|hc?3r+1)Hu}#jr1b_4!JYptnJoZM-=RSMwjrjLCE)T${wq;8|JieU*MwTxq>(DrS$UPx2c=)#*-{VNrAkc8sX~|R*kKjQ`VL@qgJhH zO##Ug8y?5mk1rPqBLaaxvTk^t)PtB2U*XJE#~$}zi5SFj9Beup@Z)Tv`*F|=pW9rY zh#IAK!?BzD5>GEiysaA^%Wp1SMsYYj$yNeb^ z!HEn;a~F8v2>O~W{QhCr#OB1wrX>K|zu%9Dj68M09Fk<6N0Pmr6giv-Tz8~{xUFrx z+$S+=ZvLw6Ix`-usc=2#(tZChgri}txJ8sj8;EvASWm(>;hVI^$N_gjQ4|PUB5WL4 z@O?yVh@IHxj)tBhjttZnHdwwUESl7B4TqxXJVpG<#W};Y;ky$Fqi`y6{HcTpa8Ht| zELwTe>Nh<7gi!J3lVNU?de6r9c+%l*v1JDRzX_s2+VIq&YPec`5wASvNimz35lx+` zB&OVpv$GWy0ZDyV^vZ-Thh^>KNNJF^N;!RQ!?@`m9|X}uij5=@m6*FntkM&SHxUn9Ka`{ZCoRhMapi z@yy|P%4Yv*SnY#V&MMJpY&-`~?aT&gRn2pn4OOxlb1$JmZ#lBTA|?-z=u7Mty-?*e^LX8N62( zf9A9De*g_Z^1dbo+1+&d8~%CLN!P6Y;G=&u$IWilD0H8AcEW@SeUCo+=+m(OjYi`z zd#nPJ;r3g<-TdbFKPn!rvD&N6<>%hfvG#1nLzZPL3=a=K4%_JO@89D8v3Df!R!r|R zw_OV&QSZ*JMA|HWWdD2iC0cBit;LcOWowbVlC+=_B0?{cENz7TmR_<(vQ#S4yK`Tg zLWETB-TC*;n=`(9?pf#FY~A_&{Pgb3oH^$^-}&~%b=~PrLA2@0ad*F79+FQ; zz}sxw9zh5R`aJr^957@YO5ZIlEw$IKU7Mi$dgC2u6Iw(Po;j%DVWizT*N|ohA5CiS zwI3l?MB5lBQ2afmj~Xh`{%Ma(-{C24tl?dU_QH4~1oYHLMXmmLMjZNXPWcyM zK<@W{6$G-O{A&Qc@kE!qu-xP&9vippM4TXougWsy+CFG`p)@jQ+|{kud%j%ib#( zH1bmKnG0H9N4i`xAUF#R8#ZiGR#xU!l&4p(u{Z)WbJCLpolT3TA2&1VON+;!b$=sP3p)TuKF3=MgCd1pr= z5l?e&d)C$Ds!KfwcBpzQu~_VW*L81({}PGB7>rLn`kg}7tpC=l1u*2+r%Cs||NZ&b z-?r3Z0qnYNr$i$09R58TjdrvwYpM>8L0yJXcg{?qsSO%5Xpl@MA;My7Bob*-TwJ`F zj+K{}*EkZ1e2NwS(@(nOe{Ig{cIFq$-X{~Ey~}&FYSle{oi^dC`*3lx^RKiEEz1Jq zj(54P+cA+yKwn6xF#yp3clf`zn?2b%0^W69m$dEDAY~~%y9{4;_9>S(!$0br1z8?7 z9ZN@f7r`|wUG>(C3?%TkM5)v^q|5k_4!n5llzHXR^F}<5{f!~cgPnh{}lfKXOB3=q6R~6$vNfUgu3LO{l`Lp zRsJEobJ&1s+PRPh%95zv3Xho#C<#|vERyup6!fEC{F8rlr-ZBG#u7^yQ*zgQV znc!5mgmg{(56-B}dm@8{$|%QGfUR4szL8V zfBTY>l5g;uY}*Ef^b-HKk&fd)3W9gpwhbZaUT^Pi9q;fw(Cnpiru?yiD!of&%<67j$L@ye@?IDMZ>`5M__#J|9OVvI1rUU*AF@Q=hLxA z51u`xal^x)dxNLI_8Z!+{qk=XA9}{g9Z0WkqX15xG6Z!Xue#;<0571uVmux%4?iHJ zqHWuAX;j63l`+n9_6%7xv#0MLHvXr!ahmmM3_$ZCAlV*2~=#pIF@EDf==$YNGgxqAqn5 zb$?w(`=dxZ|E}de>?<4%D)aX2y^%)OcYm0&8Xm=%#yAWm~=NZjPRo?wD6Q@lXXUEwH`JrXILtm2RBT7>Z~h~OM1%uNSII_Fxo_qy;` zjUy6N1gSl&r%EbtsT6;NcWfKC`J&cRbj+k%Hosk$Mvw!Z2Vvv`Zk#}F8Qg{r`JLl9 z^<%NvLDZWL3g@J;3sY(VEcxog2X7zI;RblAa^~UnX=Z8jVizm3V+rT^!}pp?kA@W}ff` zn9DI$t5z-9x^?T%@B@U+@2r1t7@GebkT>tcrKob{${kj%S_K9G@FwRkTk+An0ryb5 zM)_N#hrB`#Y|t|B{zEZ7+L6I^-EN6Q!t>VGuU~(k3Kc4BqTf32j4pkyy?lrVJGAb| zYtvSraln2JUiY8#=xo2ow&XIO_GP8Or;09wHO#{tb#MxuJsfc>jE2SY$tnM) zSf6`-w}JpK?0@E@;y?hFFb!_ll@9O};dx#RE3a^x&e$R$LnbJhj4KmjL`4Qw;~QAW zl~!SiVwNKGgfZOUb?CmY<5gl6YMs4TDWEE*rP>n?15L>#Wz!H$mGxa|pyC)Pg7f4H zz@bw~I7~YQlqSlciIB<@XtGJ0L{T?Xxk4~4I`4|;;7BCm;RA%ick4KCNZYOrq3XYv zbH7v3yTEl_$U*l)_77~(vT)3hS2|-rLACfB2A=8tbkt#|k~?pHO6tkAZTo3L$mRa{ zC&c6N_ifuANeJl+zqLNS^PKKi+=cTHiY7fjJb(I})HHag0gy-}f;<7+wkHunQs%}x z8;3PJam>(uQ#_A1d+FkSG#b6nvaDO+{Rf7=HoBqR>}L4)uUmf}*6F&I1KDwh^?#1E zXnLGSw`|*f1(pd00mpHIbO7ARz;)f0i9`bG6nNL-x69m{2cJ(MMFU@#U3e?$(8ja7 zc=XioaNvb*Se7+D9*(b!lo%}9nNf?PlPAM)fP6_en!7u0eH#sG_XaBJf5as+U z6jE*ulL-)H%3CH(8qXnHVNVL-CA1Zo6;=$a^IXO+DV8RBK9m8+L>PcI$_xN3o}udN zN}*DckD0X^2fIp`L_${(`c3gHEYs)Uleh0OZ-dF1<>4!@La)jEO@ve)%-C<4;-*L- z_oGjOf=_5~PHEdY2EPx6->$iQ@H2z%>2`%}+xP7B>JH$+lZM*kbFgyNVsh(U?Yzg2 z`Olf;_TG;Rg|UVW8^+4Y%D$v|=cT2kb=Iz3yUDg~@KAU06dTK`H*->Pce)=-SA0ZR zk2jrLrAn3RQt!BJ+pBypbctnIZCuyA7Q*IHj{D25>pt}IWCzxUdpjGqC5y#kx4N!- zFU+Ig_2bD2tvp*r(WW2P=6Aa(wFY6IYe$n)Pk@pK<;r&3&N1+q1B^*oS=qj8)~xZA zb=$T#Kt#mOh=P^zc)VN__G!*ba~{brd>-oUrC!ylR*z4cuq;M@EiL=~)C;;bdE19$ zFiwL84f1JfMqa%Ge;GG)ZjFi+D=GWFfB)n9{H}eeIPmyt158ShR)c7}tV9lB*zowisM^1Xe$d}@%D)Oix%>YL2*?Wl=~!9AGJg>v3?r}z zgbQKVQzq?0SV6!qN*2l(Vxp+{#2Y*pG<=cm;~vWTovHGryZ=1F47Xuogfnf%&BBF; z3pxUe>@cf@%4|e@u-^~E^i@I^zyS+$rr%)3r-p@4c*hfLP2k8AG!0Rv@GjB6jg9%JRbML+;_INgSLEEz`to5L4eDG4`=7!`-tb+rWYu-kVIeZ z`Sc6!!M|g%*b3SP5Vs<%Q>RXijT<-q21E0VsBAGtZ(OE`8w-2ewjpu>A{e~d0o^*@ z;TaHCtXfQN^9>Zy5xhDLuInEC@?-}L5gtA5)T2dEV35KAS_kfgFBnI>3yKA(D@sIyP~wG&6b|2Qq?_tTb5!}Htjce~2-)v#g1x~o^O2Ezuv zD#l{54zBA0Z~`L$qsA>hM zqB%OoU%ECOK4YTz({WAKfS=9Q-kYos%{k4nu#1Wr(@%yfLMl4LvW`=J7XxKpcDrj{uqGKNsZTstfezNM=~i>#}eL^-Gk zHiiXLJ>5u9lSO;LHxMH<8W-Npc338}<{;k;{~kuT;t5BhGh#BnhTs`^YkT?rSm)WIP_#cKmEO(#oB-SI3_Pna5{b~Fb|sGEKuP(sXf!(AvMgF$?rFzy zAd(;{!&8RtX?t<*8CW>^(WeJ_#)>oYJ0@=EIx-6X#)>_J^@B}fc42%v<^?}&-axwa z4xHpcQ9>N{PYC(+gjVNn+49TglPO$%xIvM`HqmJGTFbH?_J4aakw{!c_qyNR7m}~O z{)8^Ow&OTI!XLP3BXdqac3pRqWm%Ad05uu*-5H(_+V^FSWp6)aSJr;{KbY1sbLdXlrWb*Ekk`m}dfG;Q*cxlB7l^&p}1?|qhhIH%X zH4ng-wh@(mQB~zW?3sn{$Ycs{Z!}s);`yo+YL0;zD~duGsgxkqeHNzhx2|_o^)9A& zcrOrt&SdYosS}wfX)JxF>oeVJ3>VhJO=}zJ6mI3gm;Ofgh=FU03}6~bA~N>Lc(I9G z#yMBQb0m|6!{9WBcIb3s2xXX*@gJjEeP_Ks z+5UHo3;?3au#$CXjjs|gSjHindzc(4M%YOqt|(MKc+6Bd(@_XB5nLoho6V&+R{Rm{ zQ!Kqmz|h)bg}~(rQ=-;!CWM5SlLuo9luo4($Mg26JV5Mz`}o4mhtm~81O^zMXCYLi z3=hWH5LR>qZAzn-f4zwGz54?C%FB-91eN~S_I-YS{b`Ic9jqOCxkFEP&Ffyr@j~y* z=<~5y>{Qov-x9qK9^pYZJxOZsRTmy9I{!+0t7Tc|_|JP)?W57?QI=(WjK+Arw6yeA z+;<)n@1-jil9j7IBOfh(!&$eX__0JHF-&%***s|2r}Z0(3uv@M*Ny{8`*W`K1i&^| z+K=s=*GGPI!;OyPJP2d6_yn*5nCU76+BWnkkEUoy|OIp)_6RAzir!ncGh7W zL!UpX)%iu2UU1XM`Uf1?50n(wO_DY1SCiM?e&*%bbD!7%M6;_|`&lM9ciZC$*_z?;;#>?+**@ z2*0!}>mA4OQVGhAJo3m&-+lMp|18VGbrQVPfk>nRjxq?Q?Yi#$iA3VoXf#^SvMlKC zP?OnH;h!+>y?4%J?<)<5s_}J|VAVK2j&M`)C|yx0NxW~ED%30=!-jK8qeM{;GFqi7 zY@1E~sln?)(?40nhYtN`_lPrBUTbF~FR;jfM8A#pF{LH5sxTufm$Eo3zz?8v!V~}# zoMvztx4M-Y%T2z4_eTSdLOV@l=RK3OTRMLMo~F>dobsL(g-{6|48;5D-zi};$AAKh^K8D2HF0XIDId5_y#AFhH2y^rTg6$uRgA2)wTwr$@E z;r-ASaO9a6NBirH{rdCq_m?jJ=(wM@{6N0?=96I`FMe}kadB}m5|jcVc>9iRXFeCZrt;mZ@%%`1F{#i@T*d#O7(BQ{no+v<~;k& zsG2qQzO7cx+R`%zcAq1UJhI}VMT-D_!rXr_0BqW{X{77A+%5^syxu4+ExigNF5vgP zyuA96NCct;Qi?b5)&p%GA;RO}MqE@3-QQlJ92kwudN(DtYqZly!Yb4puVhLm06`H` zr9v$-0In&lbv6 zIB`;BRjS;EQOadF%`*;!`G!bN#lP$9g(jY+8Rs)K_``PvDIahv;43-hUuTDM&+N7k zP$~Z$Mp=Y&7fpzSg6Rg#m9^$$W-MZCJmM|eR9%mTsb@h%SUJoov~t2>{#aO53Mo3s z&V!H1Uh}@e?h!8s*7O@9<|h%3luC!nKm~#UPXLpJ&if22EpO1I+ODrm1$d(fQB?N+h+2YW^U9bsGK)$w7) zw%YQRs{h5q3{?&LMqAVKB)`Du&Lt zOy*3-lPUS8=Yj_ZHp-J1@nI6F#^WU-AzRs;*l~T_PKox>rRBVR2&{(6W0X!OxViWaLE-1cNYp#o{~l z&V?DvrxT)%8%Bl+9pkB7 zi0CcvcZsnep!=#b!;ZnJ!DNi1GM*tg-^SnY6@P9q<807U1rG@2{8B|u;`?nn2i;sf zG|wsjI#Pe`ncW5gBIRF4U^Nl4nNZr4FgqRZCr|i>_jKl|P^7tDH?}%$SOo}rMR#m- zLw<{<#|IYLC11Ti*#9!%LwwZHg&i$jhzGX7{(-R6Z) zh88k)1w7S$BcU>MD?rgVbi971bOAq4&$=SAf@=_9X;~)oKEMd0zK_oL=?5`Pac;B@ zUU91Q!>}ol4VfCc74%iP7rOiI<7a5P%zlTHMgfmder;5f#)Rgp42Ze~5Cw{OiY&$p z-S?dZ%3{zu1^`LN1WzH>IV*_I30q5>jDz=C#yezngA91HtQ|HTCe#@a;fX0f6v0DA z@98&k%D*lqT#-nmsJOUT zXvE=J>veNHc>qv&YmEg9c08F9p<^U~IWsgS&6vV|Lq?M@O$#gNh&($q1U>AyNKj2H zIx6TG5=vQN;m*|GFy3<>UYSNos17F{5ASy-u$C!Rym4ir$qFTd3sZsetg3IrrY4zq zcT>E1>>9(ia^>xHHkIjY(;7DNeGS=U<1ic0H7WAqc8v%GgR=#dyLq1%Pq(` zWq!B8_uNANfWXd5v;_;9GNmcxNLzXda!^C1hroPP zxy>fR4t-YF^QFRb-G~y^I6R??XPjl8Gb$FR_(sa@FhVxZ_=5)xV*t<*R(a#<@=`oM zl^0#tP@*$ub}u-Q4IW55 z;_)IOc@Y?4VWFd-iVQFQzNQtlGu>|${V*L*qM+j_6lJ|08I2McWU^8mlgDE){r3&F zK^LnqG6jq$F|-4sIvoN!shqb7EoHyKD*upf(|xLXFRc$uYo zAa_3pfqw%8c*_6X1Ey(v_&-9$s!Qo9~%*JT=Wr72&Tc zZ`Kg+16cTq6i$5!hSG&CyZ~nd83)p8NbHn(hxpL|X&xROuN7XuENmXbK{S7z^Q3|& zZw!`h=;A+f#@IYn!1Vy0VGTZtbD({WWN?uvl*Z?ZE$Au=&&|Ym$D1#w?Iu1CQcsoX zezQsyPlyjAtm+s51nuXH_DRCyGN!A7u`@FG(~XpQ(oJX{=iH`L1Bd%8KC$#R69Z8u zxX+OpjBuMNPMSia!opCnKcQ$7$3T)b#ayJ!`g5o2<$rczTnjgI%D+mQ%H980K!80UpcKdB zai~HZ>XLb8GG#I_AyyCoVt~K-+>Fsfrp_9d=G|nKeVHJ@A>Kv`7B8fX2TUE|QzTUB z62UOl=ymovVYc^E;Dw+>zovCcUa!9)Z9 zjn0>-psQxpaOJlIf@^S|Obr=2v{yj0S!DswA+#dF^M39MUYhNF*Es{(!c$&!P8g-Y zmm{QI9a#K;cVrO&I)))%NlL5Gufxm3;Hzv6bl)=B8^LrUGdz#F$eAX>lQ4=u#StCn6UVS+Bh%S3>Okj-- zH6VndhyXSkV92b(xH1IR1ebIaHq-N@q9H7OqOiistBydXp`eFuyTy(zQ`k)Q8VPf0 zbd48LqB}O-BaW9;C;+O2W)=pVNE%+-P41cI(7TjKnlSsVA*?DRtpqp}ggWDvsQ6jk zeG#1n89rq~SF>H@PNq=v__hxhWk(v_6av;DBg)@q;i-Zlk4MPnZbDNzw1e9sk3OfG ztBjs%7?O+>K$$UtsX?P1g~?N3eTPH}e=~{j5Gloy$hDLLcksS$Dv3xwmEi>MUZyQv zWd?vuBU)re26`_H&0}?&S%vW+;zc%IPWji7qH@pt;~|i3GFYu#pLZA(ZmKIAvPg zil+qeyr3q+r3{~_l8@GJ1UH3ICM8Tc6FRN(eugzn%4k(4LN>h1oGbV#C4$KGZ0Rc0 zrXEKfFD3&gUST^*6H^R)@Sg782VH{WK6)IXLf`dwE&=po>dm;+uMuM5K zC0(UVWFtSBr!ZjV7KYKl`>KlMCc>5omw4f27-qVRM@MNg@nZ87bD452z#XA-qf-1S z3^)b=(!baEqT`Ql9T-d4)Eixxp*=~-o9CR7jFdFq-t%(BGr&`+XGd4iQ=L_+yhx*tAOJJYhyx zU|pyRsB!ayW++-+_sleK>VZ$ac{_54MSHS;eN_r-Yh8OCK6)Cz#U162y zMG@sHn+UDC9!wur${x1@#a|h!EMdwi!n3J|Uc&$ofJ=8TbQQKRG(&Z+c-8(9E-4jg z5lz5r59{es(LLUMU0f9?#Iio_xH`X3DaT>7VfJ58ivc`wY}XGXBan!e0=&W+dRW{t zIs3Y7Su}56F}h4s79?}hg0;xtPw~K%&xS4TsWK9=_m?dHkj!~&a?$%AznW|HtBaVq z=l&HCfK*tIue$Dh_&?P2F(gaE7j zgQXCv{eyK(6*8|7L~)LCs6t4hkT+Ht0b(|DQt^D<>fI_q!Blb96~1*4L^s7iXDrz> zjAE@N=)rilZSPG8`N1E%->ky;LKJqMwGc3w(5Q;S5Urcd>&Bs~h$(304a;y}rC7+6 zZ_z#7oOG2}SSO5B1V6Sj2(-@W*#72}+)35iRd}mwz)+1NQP63ts}v&)j)2isU{yvG zS&v?`3ppSORC$R}+YCE9!Rd#JHA6o2|i=<#I+3tp<0 z&K}C(AxQrUUQx96;+`RPP9_GDOp1Go_&WT)boqw`G&NAi%T5&}3JBaLmkoEv@P$xx z2sJmMPdF@{_Z|&NhkK&hMWXNVU;@wLGs|`YkQtOrWhWJlgt^dj&!#m|UP9>{7;JFc zWehj`JB-N-^@ra@Vg2}bIu>ty8tPBsElLegwS;EV7epx!@Eu;uWzl})aUl`~E**UL z$Md>hsf+;<8f9FILBXF+!Lp@E1OQL@cN|Acuh)9WZ>l<8 zTaf_DQv%pfM%7{F$Zm`$hxDv%bt>~Y@*ukTjdT3dO3Ig)1NyQoE_TqKO6rXIIRhjZ+T9<)Won5pi zY`O0EW=GY=$3v9H9TtmRFPQ^s?4fda*3tEwRIptr6Lgb#= zts%e*`=2$b7+^yCkxG>}K?^G30c8yUIKLBb1b&VxXFTPDZA)lEz)cCOj-nEFOp!v4 z>Lz$hGa8L{u`KH;n2Tju_2TjPIuiuu!N-KS+P2-B5Hgm26MCo8JYC$<&%~P`gR2Uj zI$@XKJe)M4rvhalEw`U%Y~nj8)D}2sWaFDGfTRG1<=@YWZBvtf==^8oshtqnr4Jf>`#jeAL zFUYUl=u_F;H4O0H@TTh*?KvwCXIhv~W$0zs*S>4RB8XD3EGw-&Hkz~7?$_r}DJ?SD zXbpTObW)O3^QGhI@@4vb{~PZg{OrsJQlb)bo_`Z6ntS$N4*^lmzYa6jvF7UtWUK-v z5}t6_F<9(b#a|>yv0hdBTqYsPqP9k17=uD6O0f`k?cHj@_nSAIg+ds$F}W92qD+eZ$brCfG;mBU3W|@v@wc`mQuliAGa-@Y0j)FY&#p>EDXFeMur8E zpc)20S^Q$7Zy@}C#JF1qo9;Irmly7n5;p;|X~@!=%m`RC(L50->T! zv_KgMZV+Y+ii}Fp&1CO)Rb3C-S`yYUKss1bVcF?gv#U zO@&o#Op^5QPAQB@`KIbyJj0sIn4&s&nAV#{eyEJx*-)x@HA*_Qzlnh2-~5ncjp?X> zKe6?U(*8CzSV_*W79SMRQ{5CJznvZolj8W1@mC$XC*kkjEmOm>#KDtNqXffWH*Jnb z|1n(-Lo-w`%9f#CLZjJ;MPc0`|Lnk&*;c>ggjTJK`dvRhi1fo}Rac3}urPf`SH6Jf z!1#Cj@D3@U&nf>ZOw8Titso#P_@}EZB7&&RSVf))=9|HIEJKe(BB#hK_o8rEyt!co z4NmkzP;UT@uf?!M!VqR$cr7>HaW*MVd__NVy5l%+=^!lg9q!z8h-)ks``mS1FVr8d zS%q=P%c4v-ZCsr(S%yo;Rw+OVMv~IP4+Ld4S6Rn04c|~%=fmO=oj)&n0Sg?a;lm7E zxM7Q#%o+Cdj^P~&s8^L#_N@&^=PF|`us1&Fo|JH z*Qm<}Ow%hBJdQns85PTvMR=L;B!K}SrCtvmFHAI*%n+j+70aWeqThJMjZJE)h$5zN zjWC`{=AI>7{zH><7v=P_@i zg<;|=S%KHjM984y@bJjLmyy6K%6UWrpK;9uMhW5;Fbq3)5p%`HI1>YauAnM_KNt2q zVJ`~;AZttty)Wz;!znT{1(!8|)``_E-WcD|fI0s|LF&XJFP(Zw!9z zk=X^Wj-TaC`=Fz}#vue70YJd)*sc%l+LERys+1rddMDxwGJ_y~2ZT!~ASWrNrDXs> zxPb9X2~U;sgA^Y&Ojnz(Dxy6o-Pcxd8M?&V*U)E*myU#$ewH8T8h+tDj-w(Ch8zQc zR=Z)uU%Y+Fwl~$#C?o9+uaaj}atIHWneaK5&tTaR*Hxjvo4`UaS7~h?o%Qpck)&+l zNF^yv;6X}lHDChYPi=nyq{*@3=DgQEpoeANl@9G#5SfN+<5XGk8U% zO!3Ytthrs620@-6n})+lXGjy6W!^N^deD7Q3p!rvhf=BH;8RwW6dB!<(NmTN_I|2w zpg6xc{Mq87DQ^Wv05&`mugMfw85~NfzU+^S!$75qC3yVDb{RE~D*EU4Y)j7V*~S|N zetZ4FSILS`m(b~^7>)4@rMhH7qAZa|EuF?6Kd3ed-P05 zgjH^}d=)JfoktL`MG6tTW4ibBI}*cZ=rMIely!tI_G~cGBeb`7N<@q65 zfFBhi`<~1&sUshWXo@afm&iC->)r$3;P9g$Vn%mARQz7#j4_3UB;t(QEyl!{B!M^F zN~`xFQr>0fQzn~=3`_z!)uU0E=Arouop2(0EtAQ0&bs8Rr5H-ub!|8#nM{JCF1U2{ zhFrCOlby&t`;Ucyj`GhEZoJsAL|Y)Nc|WsbX4C7B$N#GeCczw7OYNs$7LmdCbq>H% zqehJ?%a$$Mj^PxIMvt~EYa#sB`I7$R!qz>M!k3D1h!B?tn{SJk%)6*BDo^asx^xDE^$QsQ*d5_xk0> zUk-B02H>F@M+PR3rnwMq2K;|wT*N-ia)1#Dp$_n9V*k-gY165 zm@)$Zj{FeS2+%Q*>zoZ0J>s2v_MNbJ#MWTJa6#YD-gn7d)4cPhcq)Rotd(>a*+R5u zyvQX{YMe;~la9w+l;Q?3oD!atyCnEOx!?x(~r`S6p=4w2RO0MJt*py;WhDHoK0mzW#*tyZgdmO2-{}7U_TE zV@8UQDwH&gqAOqqX}OAFVBqh+u-y-*NRpC~iI_0O3r;S$(K%xUEm2;au3#q%qt+?H z$(}D(@C)IUk)<+lQ*lhcp%O5{;*!p|I1OOr!##^=$fZQ_54Ljt(2vOr(;oHyrs4BU zT7j10!}0igbvAtcPD&;$^~@|xN+oTbHk!80PopS~Z~03b4`shGR2K?y7VIBJ16 zK5q>Q!=%!XD_Vz4;mm@7Ej??dij^-%1gUcrp9CImB`wAB%RK|81-bqHx}`N}M9)x| zspNt%`+?8TJF_Ody(jcej9muJSJtkc=4t1}v^}%po{MZ92zr-j0^?FC(lv z9#L7|or%yd%2Q_BxxkohuC%vXmNg8)<_yPi-o&tqMx#?K%eoYPD=jUx=S*31m8fU9 zXq{n@n(oo(%ibe{MtYz3u4Ddl#=9@fe1di!FdACdV`wrcR-QQ_u-;H6`v=bpyZ|X) zW>J-7-rP0tkK3NepSw~(WCbmeF-7-0v+HL=l2zV+l@MUMAFA=F@R$w!=NUb+F)rRC z(R@K+pE`4FJ`FEdDVrGA-X67pOrJc7{*;z`j-J&j{HasXT|{#+e&7&!)wwn`0Kj-U zA=aYr%Dmg~j)~$=zacOHaE0Np=K!4tnSnAa?PJptf;@H>j=J-gDbxZ3fKa}`jsYFT zUqU-lQipuHFRX~J!d>=#UZe*o)03YcQSj0nXrAL;C4`V?lgZ>uB_$;vm$Rz$OW;lA zR$$keC;JR<54LUJO1@t8*`nr6k9#G|8N+D}A|nEaen@$p$iU6RFPZV0UB5~$Q5o)a z_%v@1MMeoWk_i_r4toZf@2Btp-PA^G0N6=-rE$Wpp~hiGvetfhr*sIbh#7D^C~RveBM&S`2uMRJDm2-){W>A!Ec$@! zx?K~AL?FR~Y-`*0dP2y4I}HFU91u?|C!2m;KWpY|k3Cpg zTKWO6h@c9~G2{N^^SATw8r>DYa+2dX?~2BVMx$pELIzltbtEC=C)ah~saUb%sOe9B z+ddn;$MJXByh(O1!?X~D@nvE7EFNv%yzzrYUltAc@u!WD@B7?KuRhcnr*4P@7Ccv0 zx#sOJO%dS1U6wzpBk1|UOCWGFR&}mGh6Mpx7@$h9QsE+J{=%rCFhVLv5TpXQcB>?v zHHh{I;{o>gn<)PD`!C-&YxCSmvpr8c-4mIIo8pu#JXe*A&W``wjL}1AuVtO61r~Np z4!tx5AfTIsgf>QtJo8}`e>xx505H2Sp4LLk#M>#uFA?txW3XcJ02_dq^ps#qDg%Jd zJai2JSn=njxM<+2u#lwAJax`FFLe%kQ$P6Rbu#?H?!2ME5b(HVS&zr#@pWP7w$A)i zMVFhZ}b)T>->o&)6ev~13 zEEYT7b=}hsZgluXEe<)RWkb8!mk&(5BF1T<001BWNklJohfntB9>{xZ`)TvRmYSrg0%Q}bmogue9O^#?06gSis9(CXgd&Dx!^RI}$`sm1_ zi}UIocxwub2}!}bCQX{u{Oz~jt|o*WWm#6V=|M*|ueX20+OfPQs1v$AY`6`h_|vuU)*w~WSe#+AB;7qwL={M}c z3Cq7;#I*o;`0Tn*5JJW#5{aM^CZ!V`1G?53<;_(jJ7L)b()8m4x{AMm_5_h5GQ*K5icn_EN8ih8*B&}92RF413#ZZQI3ea#O!R_N8?5g8o? zh$rM{LZmerYw#Grt8M4lH-wO$KYXxDXn@WyF%#OY+=;_^&{v+_HUo?SB1f30ViVJQX)%6y{H=t-7k z0ZeQ1)_%$f7ra?Gb?h0ud(mjLqh(oBMX-lRjR~XXf4~2JvEw9)ztUr=5G+FXpJ($6 z*>$KGw2t!6coR~DM;)3W5E6qtQeIwrSdXeMec}j=s!&zs9$}AW7>_>-C+c~|Mj0J} z%%d|JmCAYhlI=B!3;?RyDjKvf?GZfYz~LPNL#nDyf$Cd$4X7JKO6x?`Ig@=ur%$-9 zM_ZyqrQo;BaENKRiQ-S!$fHRpN-hYApIsPF0ZSMdiq3sM@3us`;+OO~H zIpFk7gb?8OjRa#pUs_sv>)N$z6L4QQJX>deynUo=$IAbH`_wO)IQsoO7}HPD#O(}O z+(@Byx{5y!UU*M+qNb36ikrH@o&!wl{O@ZUeFX9&_q&{^1S4aOXwOaO%;JR1IVEbn zO~=5ro>TCb*E%w-=%4ACRnalHzjXPB1-B2UC9%q9o2&M(Lk)7z>^2aPDgU~aZcVJF zs)LIGkLO02JdaTDh?j}=pPiGSK$a1qW`vlO+#?pbdJj5h#cF3+dBf{Ju8$qZX@!U^ zgx8+{ej6N75k$wK#3CG2MeG{>!fG|>?C{|n?YMfSGVB-N{ zy8ZR!9>x7sK-)q&?z(A!JpzUvPrfS#y%AP2=DnJ z4hSJSj>IAF{*`X>%Og7#{uU>D(_U>MKY-Ao{e+MEbN^pDi;?pW!NibUbw(Mg?)M~x zMZx)uu!XJgMCb{5DjMg{hKFdbic|2;bFkeog$Wo?TYExQ`P=aR3nrg4t1LovW5-^I5%MW4_5W3wVw&UWRBq9(d|R+@zLT&jn|60Z!Irz*3JB zYuEfmlQ9nD?CkrY32w1(jM<(e3W5+`LGJM zRRPlMhLxrrBKXk`ko}8&lLhW<$LD9gi@T`U3i9RSQ+v;Ck3%0VAh+@mPzgV-Z#R(_Q-shz_opoPB=+ zoM7Lk)Yb-Rv4%yA>E8G}L$-EW>Gm%M$*#OspM*N{rfdmDY6qUc++O>+>SVK}EfZ&# zP{1|JNhzKSjY|(;RbMIZ0_RiEy?N6Lg8%HZS3~=F=U7?}@#-uEhu^YBT*H^dtvl>W z3DwvAYQ9UOx#KxEP6t{jXJ4*bolH`A>amUS3|2{+zV8NY?zm`*(GnbnkEUCd%AO7j zcCWg^{Sc-*Q8ALO4P+*NAvZl2MUdi1$y@R?Vpg3@4xj7=bi(sc!LUE~m489Mw3Mi& zW8cZUIHCMHhCMvxiUmhPvb4Nj_&%#aDZw^+vcIZNx(Ae1t2+3P6c%B?+3#sK-NB}N zJoq|cD4Z$oA4%tXiA_s$^G--nvErJD=0tVG01p>zke9lUoO>xzcERI>ptHfn-0mnQ z%P3$?b>3f3tiyg!{9A^f@e}^j*$KsnwmouA3kAmZ_HZgUrko+BkouWth<0ljrQKS^Hl@Gxq7#)w)w$cerVX%q{kaR%$lr z>KB~<6e=LOsZ8+61gH6wC*4P6i4z#p}>%q*tOk`_45peTwyv`U4^8n=b$TX9k9zGb3sF)Wjj~GyGZB@ug1B_#so|Nz8SF1(Dy2%kmw`cpt(oQgt9sn5<5M0A2ykfU}=uE$euhItg{^ zU7K6h9Fj*^ZY@X4Q|d9^r=P{R^%oKdC58%Uy|lw;f5-V!_c)U9BclS>##MG3;nblc51SPszk0^+MhcHSFZB<+h`Pt&(VHtzD3Z32`#fX0pkh-8cXzai&#nsc&OGCmukiYK zDSysK$ufFiF^m~6_{;Z8+RrbZ46+e`@f)JrqllIWhcB7*$to;g2nXm0=dgDks*C1# z%=3au_J@M&x~0-TCSWKm9R?`$KpN~Rx-)8^;Q4}28P^n>C=){#Qhiz(ex-@jS=SRg zWMW>vTly^=fHO@LRe!H*k6|$00;OomnuqAw%7l8&WH zynI%RWjJ&nWUgXak~hG8G&bJ6@cXYO&|Lr#gU?g!K@|yUgN71IJv7*i`FA3KTO`I# zJsL@$C6oB)4)}dLOYwTeBXNgWyy|OiKvYRI=-d^?FTex!eHXy4GD0HeKL+=qYHS9( zk_9~O%Xvz=%thS;IHcR3gFDda4ab$Re^?X0G;^gxWj;0tvDRpf96M*EhPJ>@bEcni-ZpcRW(k7Ns(%rlu=i4 zuu`T_)8k{Jt-rCN~PmkB57!8h$FE6h`{a74$vd*m&)F{!b)R6 z$r_X+^}Lnn7ULHL9b%7RzFu_!uL}#DapFn}zVj@t`@3E&7?|pUh(2VUSd3R&9?mzA z0Gl6GEa+vjBPvfilNFiA`-MBc^c{bX&^~9(_QQYWKFmE*I}ZHHst6ae`fxs2jT4nU z&ikJXXD5qHz`*XOpRvAT<6SPOQnvkOVe0J5!|Cme&g;If%Ap$@dUj^MGgor>%ssMY4Ly-#v&t? z2MhmPx^(&a3dk9{;xU!{Ab(BODv!nKAWn*v2y`s{sHCqrK9(Q$_t5>9S3+zqg(#tz zYdQcehWAwfIfw!3`8ukTb-{^qyOE=HkWs4M_I?B7vmVMS`lK zZ5XO;N&7SK*4qP%%t{BW{zqE@X&x5fqys1n>09BQLaWeTx4HpuWsj2?(TOz3=dE&) zr&+r8t0%$R&8EjEAx6%0+LQL@^(fA+hyC0{=DSy=Sy+tmzc!`dVXWWPri7<&x&t-KurcWgWb=m z2j{ASwutMQh=9fX>(l5d4ZUuv$Nh6lCHU1@(r`7l#1D#ET9EaY)Hg>`1zC&P8}SH? zm0&bWsaeyqQBclxX-T3_`)8$P4cp)0gA2N&sfVwxC*u?f;$4stV<6E6U5%N$VJg=gS0-k=P!0q3_X*yXX(z? zu_f>)?JdF|x}(`&Z*&`t${G4t*@ZknQA|{d`}@tBT5{6*T;)oTDLeJ->3@=l9qpH9 zq3bPL$cE89dt-f*3$+bf?{oe3_ju%Jw)+{Nuh0=cR8UVv$gr!;^E()l^Qk{FTSC&J z-x9dA=bic19{p0{A1<9H)~QE1k1|}cX7Z#RD1Lf;y6B3BN(3ZHoox8uRki`z>cO){ zixvsN2N861lXvm`{0LPJ?Qc4q&!}Q44zhoT?+VX+2ml+*=RK&k7?G%gYF_SPs@NX* zqXjjt zi1A0O#gNq^e#yakwAifGa}KADPUKf5-ZiAE#qly^i{P8OT^cd4Iytzw?n=LHwPMEX z=rD&Q%H>bx{%0~yu28>FZFMIHL73^lccd*Baa7+0&{(-J$B;F8WUD?RyDKQZPB6Nn z)vc&i7|Kb{9!5c1BpVLT@>oXO@V~ufY1nF^&AfPhkb!NN@hC0oJCE9bF2z$MY-a`f zg(cof(&0~D4BYi1u}qlSHt6e1LkQ=Dtk#MAxV3+WLnH8B-5rp-zqcSkNUFGoz58(X zw&SoP)J@Dd$cNp(ijJ`^(y-~zPxE~EbLdmkdcTrO2KV6DPu~})#q3BbVYKAuT#+f|?O_|zvzUcjMS>MroMwC<@Ib|SYM6wx&PL#Jf1_T$eE z$_oQGIxXLe)&hfskFmGMzNXzmR6<{jt)K4BsN_tX7eDp2^AGg43fI)$;-RRAe;+W# zFQP7-PAV<9vcPQ` zlXRvIkgH-sB3*{ub;Tj3@>(E^^Ih~>nX|_*>je+m-)40K>stAlmD>|L>7BenwkyL? ze6#%yFKFxdP*~>#B1tF!i*aiSjSG2;d?KXEjmvFV@?v*4;A!D~ATYXF-28WAJQvCj zHl#kwh(vIjORlBb9;?=(TYAxG6fsj0%S(R+MG$~*xCMg5s1tXGVvg7EicJGv!JYt! z*U%W+L80m#R5;jbvvy-&rvdD16N*>>dfo$y%vvmIs;a_cZ(zv$sfN{_adW(=o(QLJHd7-vD zjQ)Cu%Dz9tF}i+9P6hi0PrBpOxQt!Dn77-kBPzz_%$dT0U-(lGvUxV89PDpc8_d(+ zH(48u0nH>mfI-KS?+)VL%e^!882Wv?fextznuK%d0s1rUgE2IH3}}5_TNL_$8L`Q^ zKa%9VuwvWsoRdcTHyPep8R74QUj_Wl|39m&$AwpG_G-NjxfS1XXR~YCmYx?>kU>_8 ztF`aU{wx2P^!5$Z+}cnQbwHQxRoW&Cj`md3p%7fTpKibPpj93<(WbSb8jX__gv*NZ#>y$D}E(DEM4A5|^P*98oJQtwb6l5(s z6y(;;BYkuR$Di~cLmqg(l!fnyK}A2U+^L#}yU^_dRri&GoN4xcQt}sWa>=7_*Yn8L zW^PsSnQZ8T%r7GvZYrHt#!MDMDzSLoE|H%~sU21^v$TI6yLkcQ1*$@N{QaXhqhnExvJ(5WoiP~@v z`i%eC;|q%^U8#$9T`e`;DOq)i}{dkSKhLQ*L%V=ft!|~ z9p-+YIv%}SR2E)icWt>gkgFC8+WZ8N{mgnJ6{GEwC3?%f4`sL}OE zRXSpFJ}t#$;Vey&IfsOv7%BqL1i6ue=J5i-D~)(Bkg|A?Bi*Yinf^r#W+?n1xw@lj z#|FSX=FGwX(Q8L42Y4iU()yhdmuUScEMg|>tT}OOj^0ykRYIY8|0VcQkkF-{vOTGb zE#IkgcYUf^1RF2hicmFIu5iUEjw0)?dw>s+fAjT>BTa+k@%pOX7BPUENZCwoFUCya>WCF}dzI`x+NG_X+NLdOSJ!7w{lJa_hiFq_ZU zZ`zBo51jn>UvbS=U8aF;`~OOwlq1oT$By^SHq&bkC*3HfkY3+m_BjEnDN2=S^n;BL z?25AVS;G^a&?K?FjT!idTjHfV5+Q|ORywZ_rc*9zB)BKTR}0(fxR=dnB2*Bh>Qfd7 zn~RhTMREXX+>GE$0pqTeL<-9fnO0Nevcq!n-x*YD}^>+rJJZ zPtz=ksd@d1z6&|qy&g_Xg5FWPJXY|;cHFIc?(Dio9dEv?sg>;u!#e?z{T0O%jw_87 z0{}8>`lZ|8FS(tZkiKqQsse$O& zWp^A?gEw1jF~1G7oawyab%d7b)uKu=Skcyf90*oiJC2N}S2}W0%IKbD_eOZe06Ce+ z<%4qs2R;2I%!V1)fso!hi?IxbQ=Us3geAkOx-r!d@Toup%LZVe1ZUDdT5e@dPfv?D zEL5(^Mx5T!to&1{ie{ONJg-mSdd))P)@jqy|Bx2VqrpoWWrBrEnCJcOAJ*_mFaSas zxw6XOLjjfW-+Iov;2}DebxMt_dIxptit9a4@ajeBh7je{AJ1i!l zgD6zDXvaCjz0XQZ?^~mXOx=ByPXykcQV+bq@&se_XICZwcs#aG?ff^MTCgKnssMT~ zB}nxl2^T{jg|AT30^;bgOdR&k^PWZrP0|H@QgGM`?LAA||8c1S;vpF4x#he?0ILP+{wsVpN>VW^6CGi6ji zWou>Il4}MXTi+6F=JnT&e4$_5$sJX)&X~IHCQaP4xSQf{_!*Oulevqf`;`rS6$Zl@ zquh$hk46j^M2$TaD=LhdzIR{d?+Eut9kd{_QgAruU?v5v!caX zy@@~KDEInsl<Q_v{qFSMo6WnZjzTUw0o&B~n+2@fccWA4dNO=NEk`%mv( zZS`=_z=hWMZXlse?}+BLALzfC$HRs`T1LB57fZrK|Hmgjf+Mp6kY`uUxl>K7yZj-@ zvn$sppT8QYb3ewKvh{~t;ZaXveC@ws={gI#LWXdI*GMSoye3=EL>s;1a3-apzd6E) zw>_7?#pqfZ9d^Rpx4%B(hGnz=+@4HU2&jH2{y8Y!x}B>Q%c3rkvW!&%dG$O5siPw< zq|HEz!81KE;|UvQk1pOFZ#tq5{-A7E0cJCe_T7Gay)SBd#57O7LxGX}20{3B2j%jy z0-T$~51o!3ZpB3k&1;rg+AXL-x|bg{hV(>d%>FN zr;aV!Kp3))5M_rWae55>YV^%Q$`V$r*ro(sc66l!5@%1ycrgXfCeWgEH!%^|89ok{4NByz z=%LnUHge%;b@yht;;R4bwNgftVV0uxyJYD4Pac@t3V?Ei#m^V}S_dwcNXDvPly{)8 zS!qm?YOQTjZVKbe?!L!1{iWq~k;$)=fRcL??dctc`r}*hs;4_zzwc5DaHT#HkVQ1o z?ucZ$JFXnCa=|Zm!q}Zh!Ij{8Q_HV@W%KK;t_(Xe=`>w-wi`CkTcbmt4Y(`Bv%u{l z2R8=VqOIP3_)}tG^`cPfQ<}Cy(cd^M#^&h_=(5{JY{28g)1b*h=3yq&$T3OmAl;os z^7JFL@nkUDXZ9$Ger}rM8n&~4alPa%bFz_^5Lq?@@zQ?pP+RI10e!dnm=ZuY;InyU zwBPak#T#P_wzERe{^w~C`Yr~YY00Sm+{EHEd8n}d(2SfSkSV>m0aEByY){6e1?RqS z&CqJnlgN>-U45oTg`X2{(h&OCs-=mP>k2KXDInz|^BXPc^*37&cT{e@Z*b=O9=RdN zo00m`nS|PFrzU#VG60MTTj#JCn2R{nz>4$g6bqYZ&rUwi<##qPGg``uy7V|}ojbTY z%m!1n`&XG)eSg1|Js+Qb`N;M2hPDn7B|GtWxWGm7(8{s2aXX=^08Q2UTx73prP}T-3k!J?e!`1mIJFK8;Zk zb9tR*=AaLk7!Xth9ENFv9s09-ai~Dc#%n>w{?jD_%Z}3O2ERi{z>c8Z24qz04>uJ& ze8{3?VP@YgP<#zLR$jl5LtZamb2BEhdB-s69xKfO=YhW1spPyjpE7nXFT$X?efTG8 z`n}Qjy&?JK$?Zlz#Ejb~dDQQ$W+Q%Rm)3pU+S*!EF{qC-Bvwoyn|M_#Yg+zjc}!eE z6?~O+W!fq3W^yY0q2*&F*GeNDQ;0E&q`2@Y*mMkZeD8IXbcL2+##*(YaDue?0aa9V zPaxamkZoa|vhyim?1@0^8{|rZ3f*iH`#vrXGQCTX`N$=PbP+hPA>#0{5x8mm z{Py-0`@$p^MLYlRzx*Xvh6t@xLEA}lg#0&1RVN#qFBJa%sOYpOf0LB$aBE%`)_I?gmzPrLO*IuD5@zA->j z=)nk*?T-|$=U8MnF${KLcZ!8$MM|3ROn^XP<|T(0SOZ`;pgKe07@q|Nk|4TdT0KLBaN67D@XT6O`8Mbcu~vlw<7;8d_077)Pows9y4*9wDq#1wzf9lmpydI#+hW3w3VgaUQOaT zj}>$9V#UEvdF8!HTD&ftN=VGz?qsIY5#B9Z^)1iyeNabslj9fETWoqxzLybFhD^rX zqA@6EQ4G3_w$C^Z_ntRdi`a&U4`7epS*4ltpgR@Q5H!(?DH+qNW6JlLMNyUb%6MME zGAXY|YFo+3$Y%1=NPxO@B!K06;(XVetdMtlr{rQw(xgXBeY$uDnwo8uc;P^C<^B4o zuQgJTOd|Fvxr zImG@u=~fH%ddcB%S;`gM9X+de*Fma!@2{p%Tp3BbdGKn&d8qpP@(W zpH~Q*f!4j#TVIyt3*%A`0ClCgScbb%6-e$ z?^|pIeE6Lze1{%7KAL=+a?+C_CXr>urSo)pm+ATcas0ov3=dyQH=gPQE^IOSO-g&N zu2r#Q1#GP&#($s4ZT9UoOx`XK3hU{>Vs+tkZO%OyDk)%h>W^@sN5~8KVt-X#n3ODD zlMcGJmv%ayY$%g(IBz&_v=2Amax?6n*%2joB2I63Jtw`hlQtrZH#|@u{EnW^(0?B9 zQkXB|U^4|XNEP35w^MzL>E77elB}fZkzHclbD*Hf)%0lIMx6Ev$M6tNQ);(pv!MNw zr#JUHoM-sw4R0xU4c+O&`j0@U^21<&bYLdJg7syq-8rEyTo*U6iF*n2#CnkAr&6}% z^_&X_^%)`Si_upb^Ix1%T@pj%Pgti9Ny2=POzi_JMvtt zLErFv6+|i_=x}uMVR|OW=xP$+ z0r_3s6&F}}_dY#k19i}IPgrz}r>TMP@p0SCa?5bEMf6mPk)G^eW(RoTW#r&0+|94#KDJaWbZ~!2f8GprkN&8T(YZA^Ms*U< z@dz|@70v!fQkYe_OgEyX5jj7&ho4LzoY2mP!b6_GqfwS4CUx6>wKy?eUm=P(hw zU>Wq|sddP<9Ne1kO16H5#UK>})fF*CJx^5tMtB zgh%|%!o~KEG_>Qv*{bWo2mhlJG}07Y93+A*`4W7$Hn5RQ_UPs(w6Q(fYelMecv-%nP;O!$PwYdHXpCrDFB3sPfgqqz}az&6DjhFj1J!b!nX)Z&Yg)7d$ChPGf z|JZ{7AAQZiNBVB6o?rt>4h?U)9lyP7j@e7|5$)@mYKGwAXmo(LA+j!()!*%_}F4`GFST6B4{VJULw%Q-ZcfPivq;_Myst+;8ghp7Tc~8R(3X=eTKOedE9`@3)qp?;7+fR;pV?So3od zKbh7ztvI6*(sbLQsD@)+?;jBa@b0|O&>-Lj%mKbxy#YgJr9vNbHlx^Rwix^{h-e` zMgs9=fTvCupq;)Y!KV0FuFrc1ID|zmqixqqZn^O`Cct|?m3C`vh8+slr);+KJ%65m z6LNZ1!KBy|{?PTmag$#gU38%`_3-y)l^1-l?#~#-#NvOw5wj03hQ0ras^^_!)u%FZ zp^o;h{SCz2__FeDA1QQy(MJ32Kr=HRsM`;~2JlE0!w}+Lu3t#OGMK5X2p)f$9^q3y zg9|iWxJQ)xWoOw-cGV<-AO;3PMrK(qxv3;?R|8@4Ft%bI)sIT=QBD6`MqrkGm=f#e zXr~*p;^Q<$q1 zm6ch7%*Q`~*v;5(4$CX0UlG+lyMBmc!xedAPGGxy(jf&noPv47TAQEPQA!RIx_QP0 z1={(n_d1~Jhi)g!w`EFeB5#+jPxkZFHmd#;7|OqyEFf`T0~pHW#b(zX?)YbOQKZe= z9?t`yea`9aa8WZz3%COGx8{&HO%T2eyq~f=pce{BHW>lm7M@My{jdAmbNVw8aN`9J z1(ugshi|$&V0WSKan>?J&*R%&Zq8N$%&I%wSLk`{LwaUBKO=a`Lwiy(Ym+%PT^)+bj+rY3E-BVchj$ikwd0b z?J2Ktzf{Nk7s?md&dp-uQyhm>sQ~cO=Mx6=QW#EZLp=QvC|>+;nh)Hgz5Dg|UzJTp zMec_o!WI*y?6WJfNM*hi<)!+K)Ip-;+eV%JN~qcgNkdSfYN$5zXzzkPh_Jl{EdB<*N&8TWHM9Fz}$ zA(?~<5gG4YlHara@(+^y0-Pc6Bq3U*;$%`h1|=D_t-|t*fBXA8%UYB!zvOXB6TcSC z$Qov{uFJbcdKAm?Z9;XvR$^y+J#ADv^l9f@(gz3?!^-H>+PYSx=nJm1_YLqyY*}-> zW#jhorZT_@&X?suVtdG)Rs5BICh0Bt&<#eq;In{e1=-7j3ezyKTDSc<{!2}y^ z^PS=TjW-<~r8<8a7lxfB@RnDqIep*};r&&g@JR}=Fv%Emk+)cy?g>MmQDQd?ujkvi zch8KTK-aseU-N(9n_gX-z5IHBMJ2)p!u`Pd(~Ui!SV3cNoA4u?W>Bf-j$hheSpY@% zI;}1@6v4MOXhk6F^+vG<%Y2v}o+EG(c@kkn%Mf)CAU)oeEB>xclsIdA9Ss`-BtT!> zOB%Pas|NJcds{6}+9=Kt8UCW(@wHh4UD~k1BdlHduH)Mj$Ceuu5o9?3vYwt>G6`ye z!dOO??}Os;Q?rmtLrPlhE(emXbmEw#NB|tZB@Y2jzcDI?_TyjP)FSWqYTjs~LYmk) zl9jx8p-^H=ABWAMn5BG0A0{Naa2C)Jms!|)!6t{{YwXgT~ zeSGw6yRqBzP!T>aaa&J&>n|P25Hj;s{g8EYOgmXJTdw_sO46(vSz&k4S5U@6%$YMf zc|m>Y-#;<{t4fRlFJQzQemkyl{s+K7StNo_k*x7+TqSCwDQ4Okbk59Fa<^aD%?Gpc z<5!oMm)_(2(2U&dXF9b=izkrD!a(~H6i;H-3X_>ILcm$7$Po84o>zDq@4~0NYyBP6 zVv%`_?(#_e*2N#+kiVe+OSepD4f*~9^@l8`WkHTSk|QA@mft;;vkUG%FuK5UW9x5t zfCF&Kh4f=8uh|^PikHpa8cRmgj|ZSkto)VfXGeT*DZUXX8kHo1DcHmas)l{^;`snj zK!IeQEl#b8DY^6I((_PJ{Ll6JN~(?Y`g>>Pt@B9*OpZ~Wf0bG=i_?R1pWp3Ie%=G) zM{DB@XJ0y5Z`zm0u0+0B$MwopaWbo~z3qAOE0>e`I#SmZjKu2zcy$%6K<4@3E^yDq zNkgJnvbw7(Uq8%viNd*JkR4tB)#SyXFFQq8&@uv#6j6NxwDhy^TCjTJ_nvcS{`&dQ zOlnZCVYE|L^7!du6npOR&nj-VR@z*K>&URDPgc7xLFM@9S)R*)Ay=a2NvV5K%NbM`u_7dP*b-2=o(hti(ZaIHmp?N-NYq-Cmbd>8?JY0*opo5xR zCp?Q$lRrE&@2U(8q)jOws0|*b^a|9nafZ_Ld`%koTH&hwbgv&rVvYWl^K!=#6mips zQ!hl{_`s5*_hZe@=eF4|CQSa^_uUWGGJ8XY7KRY7$FDj))u(I$%TH2yg^UgTw~Z1F z4*UMlf!UYqFAy(DsfpP7M3ngs#9R;Lj?nXAzB428dNf&|#>@Qk>RvB(9(n!rxXbjH zf6G<0#@MRn0FgR`11aaX`noV{*AbbRTYDKRm7$b(Ak~G8igyv6QB25N6fVf4UiN~Q z-HDg3X+1tRSMP0?77v@5CKyq^_`RuMWqNLmOc@@hkBW_sQy)dotT8!gKi`-5E}PEM9G8PT_och>Ej%~>Tr77>2&yhid=9ZuxS zdrAIU9+?n7M)J`ZITTa0;iS*H_BGJgOq;fFPd(z#_0B!U?-t^Jp(8geX@aA|vN*hz zhKc0msGRNI8>o99@{_;*yf?eze&=1QQi^`usOSsCQY4~MAB85tr`NM#JG&NrAvDjF z{btD;y7Zjy7oW1f$MHP~IiJXWgpx`v@l#suC?v4d-q+>p5Ao)&^Y+yKtO-bdKIG&& z{JHAw7bWwYGPGHz-R!XYi-XV6wAZVr7*SGFRl0TGR>+3}kQ1wB z8miI1JIcN3HnccvCPga1sg3&vyzGrHq{n+Jzy`wU+b_6(tN$(TlfSQS=IX}dd`Q`O zI4HudYru7n#NY9bFhuR+e|ByMU)J%{DYCuJguqU+?pNIqKA4B>2u5`lH$v^l@W{yc zNl%?D&r-$ipiID$Ziq$#<_jfE3%knocEKWYP#X&=sZ}Ov)EG7-?19ggbS|V)&Ouil z&K~FRdKsBd`Vmd+_ZC|+8Hd$V4pU01Q^ zqzpZ1?MjE)r~G|aw@&+$?sizrP_B7wj^fL&D}BnAnrBKf z%XP3SFsU_#s-pca3E9s3GP39+buF>kA1oznpO}S}on$+LIwOoh~`dx6j*%Hw^Ox1hh!m|G)?&uwA(~M?sW#V5cf>RI@p<31gu-!bCe) z8FC%ElLB9qXH9NwN5yu4cR2tYt_^Q7eSnG}>ixOT3QQvyba{p0S;3}^;JUwgMR3}2 zfDuVJQtQdzb)f%eH~O!B=u17G-5pM4TsIk%)zHr8NZLUBLdh(>R5GcKaFcGj!}w}7 zWJGSUs=cJ+yBs701D7-GWrG2^fT0UE+zqdcd|{k3s|KT-V@NGcCW(tS<>akC$<}mO zl+PNCKT3@`o!~6mrn4K_^M@tl9C9g!;a99V(wjf*Sb(iHTPNRJK9H)m_)V|9TAGB9 zW(v~OJ!gdxSMeAcoP=vg^E{H$)PDK0W|25cfBbSYmfC6E<2XRgr1PZeDaiZbFmudd z@WSx8jac9<)jeW+Jv)sh;333&1p|?`B>Ih~f^v`u@hs{#@ze7+)bHk5Ef#vvo+q~G zy88ho1u+P;7#Gn$`uj2zAY)GFX@V3DV1l%mmw2d%cryBFLL8q0Ud$6sJmohB4V*0M z^Ts0>UnR&M@HS0FAQ@>1;bCD}9DBk7R-9!Vq1Ufuq@JTgDUlox7Ro%mJKlYpHxOJ; zJZt@+2_NgQvn^yW&KHaTnGEyTKXi3^Rx92QmN8B4_A?Bl(f=4GiqDnled`i%+Oa@< ziC@dkm2)xw;MTt$a@6rRNDBYHyGvbT(%2pLYtkM8keQs{8E0Z^4ci_%VFyl*ED7TI z;krx(hDrRzBX}Adp>DW6VkB4vy-E*$5`q~Q;=6v~E5z0Qn`ls$CH!a()l;e-MaXmD z3UkPr+tE~7q+C-fy54Y2Zi_njwd1}UadbNJ_Pz=mQK`(jTwthtZg5_o%eZMk+u}|T z;*scF>ZVg{zPXPYi|z8veup@>F8s7-ga)&~J*j-GZ23Gjw?C7T>A2nuJulqPi%j_g zMPV)c+|kBT(oTjw!P`xD)bxtdBT+~@TNCp&wP~Pcq|CT)x}PR`wGw(9wfUv+R}#tY zOz#qhztY)-B17o9r{#|l_3|M2>m-zcn6#X`n*{kSUL7=!()(ve?OBIDN0+YNR0gq^ z;L5)%lT9Cya)uH@lzoOJwcd>mAyDB|b)H&PPJfZe?x8JnAt2+5bkQ?yWW?p_js&&e z<&?AiYC~CSXO-F4K@D5^NB1x=~c4{sGF{^3&TGWM^-%99l>Qd>) zwXP2TFY|L`=f;MmDyWDG4Rfg?zkjLV{ac%~OtYuRL)ZG3FvR6=kI!-wdSj#K2A?-% z)+xlpN=*x{5a!Qa(8e}}r^n>G*tc)GYgkTR#!Opf-Djz54n7rc(zSz}BKz3evxXB@ zcieHhyU;fsddoJAZw+n}95;>5eu875UFvKytV3^(G+qjYi0E;#>)oes(~h0nVClC> z;}v$i2?GBlt4^#6F*;)7DjV31)UB(Pp`QPs4N#Id{eURfAS%u3-!qAu_*of8c%S*E^VR?`~&VnD82!_Jxu(-pq z&1e<)N@d1>hOw@%%#3yp4=u6+gxC=AOnO+?bLORkuPThK5G=FU<6gdG=h;N4H($@b z83ES?Y_@=Cafyn~9AZRHQa}68KhPLGNKTkmhfe<~pKR>(XK>H;6huS#>q%grpRG1z zeuT0X>NFT2WUN6O-u9xNyu5Uvfb?=HnWy+BGb2aIILsSd9aezqr3}j6=8z3yS?y0X zD*HkZ$_f&`hn~?i>Xmz|IOcl2guJ?(5D#~Be|5Rh>C3M9@dbwCCiwOm=^oj*8m~oB zFa?Zzzi`7*GhU=-EiKQ({S}T(=Y8r)LqU96jMI^aw*Gp}mIO(|5c7?ZMD2qQ8lG^M z6htC}-Vg^?_zkwbyP5s^s6J7i-T1<9D9lzhd#qz|kRP65q_p{tox!*lJ4>a)CC%({ z2xnAAa@9QMp_gUMXdT)($v}4Yg0=Az?ncY@>lsgY!{;uFgt6!eYF!`P$@lzqus$fQ z|K zVw(bxYfaHKaMzaGlx4}W+%Z}ZMPq}%(Rs60GWfpHDpyv>Bs=^k9?TM#S5i{oAS{;d z?Qdyo&`-YpgJFW&_rp6o={4g3ZkedNpl6m!NiqYERgv>4-+?@e zRBD5>slYlrNPEwL1UItd0W!ikW@`%t3+8#Zt~|>jOz%X9ilF>ec8`jQB=^~Hd3~V^ z4`x)KZQf8;SNYT#hr^a3VH&T-j>by+OJXvc%VR8ZiEo+J4RzKetFyV25ZbP>sm^uy z+a86GLvi9|ygm!c3x)?7<1bB=*7AqPTN^XXP{yU`^y9C<$Tw)}h3l!7ujrm^}Hin-80#*Eb~YB$tHgV|za2@3>KvU}Oswhy_Zb z89sX-61r<9LL)iJj+-Gkh0;He16yCLyovbjyzF%%nk`Mt-1ySZusgdQF74TUX-goF z#9*fJ8B>2vRG(!-8Jf6%zED5qk{C{=@z92Qf6$#VmG|-fsarAtBRWH@o+k)AhOOu% z@*9Qr(#-J5Og@Sw2(0%#BUN8q$-J7nG#)u;GZyEDX;VGLN5MHM>}8cJ1}aD3+pLMg zOqS$(5g}1Fq%3*!u(tZz8SEPfEL2 zp}b1C{*X^kVf(8Ng~*^mv>FJs+`u{OaV$Y|lkU+j3HS`;jXQAK(dG@E!btV~(mvzo zcBNc*_(XPJRJre6MhsF!O?+NC_BKijR!_4yZ+oDH{H&|Orj?lt{K@+&2yF5(!2TJ# zX85TaGz8T=;0(D~Zcigj4G)YUqzGI)TIc6smi@&p!xkIwNZ2~vw0gHnB}FB}%?f%{ z(pr_L~_#Wu{}Bz*q@wQiy7dn{zz2j>Z-oCj~Ax@+P#C96H^HB-sm0_+~TdY z`|ezRYE(5*uoX8IPEgn{!{aEL{j+5m~d)l>8 z-gYedIb83;>N6)?|B}3(+hr4&=&_&SlQYLS{$%}5MwXyU11)esP-4SOt}#Hb^QCfO@>uQ1Z3`o)(Rls!m#VmU>!zhua{@L&*eGZc8wwWQ^v zQojV`Opj`!fd4x-a*Z+$S+#2ta^j(519A@b<%T^PhYU>1(W`w% zMW)7l6*?f$`vt#qc1XA91^>!fd*-{N|Hso;|26sk;XWJP-K`)EA|ahB9fEX;bV*CQ zfgm6vND4?vN=kQ!BHdjI3?w#sz}U_{=XJhkf5Ef+hxdKe`?}prv#Y7*s7$vgL<%KQ;t5ArLW`H^MW8XWDeH&LM58ovt+~8q%iS#9!q@iyp|l@Vof$>+NJ|kHq@UfC8X|GHAHi1D1NUY5X7tVHB>K|u z&<1?x+0hciQ^7Rv3a0&SBHo;hT$F{B3(W*%lfZAlab6RNb>ODlY zS(84HkS~PSXR52EDak5vqH1gu4NT$dY0YZwmrKi+(S*`xTZ=m-t2^$L_j3G;0i+-;$Yq^W5u@AdxpELR%I9(wP4E(I(N2@xwgcYnG`wW20eC zCtfMbIFJ}d|L#W9Z!(VKky7_TsnghC_--XUoNENfU@o!s#+h?*YiswsA%5vViV2!= zPO`s39b)9CrvjxV%tmVfiW3!ZbilP0i^frfsY?W0XW71!j>r z2A;@oLH1EOJDZ=K5Y(QPR?63(zj2*sgidZKpe@giMDI$?H?Gvq-Z#UY)+G-*omS;G zUWAY?M&P2$cx?P<$4sYKKeO8cYeeH;XOQ%J^d>SH6t zYYSU>E&Pbz>7oBLdO?KZGT~FC%XRo8-kfB>*LRR0nHQi<9P}hF4o0%DHiwG&sho_s zx6o>QsX<7|7s%i}=5gI3yq8RH`tP`t7U=NEAG^wkD)HCF>ImBhuCZ#$hOYNz5c;pt zUFJFy*<+wQaL-Pzx{L7)EQ}i`ByIHz*|N{nbrWxss- z6$tZmlrFq&8m9QdiTLj8WLx_qtP3`6siU^MtO^+or{g_$Rr8-A;1x#)*#s;G z-fYhh9;=HdS9d4+-5iNW)QBq|Sp_d1f-ls-TBM9}71gBu>fUz%${^evHRbn$c_(kM(*vgjzVPhhAjN%+zJ$aXKZ!go}b)-up-%Q~#%qqRi$lz3AbS+kGhxGV@ zCS(_s1M|1-4BMwWbLN61i}f@cPM8?nS#xKynGBupb%q)U!q;a#DD?rrhWxi`Ip0zKn4;tKcVus z5s!5%!ZIIThfR(w+_}r97wo?(@vBELUyJ?P-Ucl&15`_*)~C9*>HF_^q?Mxq&-w6$ zx4*6eOqkWFxe;C>h>5Fa!8eNMtgbJNffZVf&EL+?b1qNb-9r4&6b3V=KM?~&$omvX zg{kVz>*c4W6z8@VYtJO2gFCotxNfa?$^*;RyOMl<6;Y`w#pnO(E0vD;(wRMs4 zvx#0Dh(M2>sMco6U7EOI_Sa`ORO9lT6f`Mwd|J{77+vgdghXeyv;iGK3xV>+DM)9p zLSQD77gTAKPjcPy^}`c#^x5r(aqxYMV2$yPEO~$kW!hLSTgD+$_BCkH&S||MCPe+m zc9#2xhrq|2ldjnF-{Oh_(%0%%dGS_Y9?G#c@8 z=XM7kd+u7=T8n@ne9md+b$d78u58*rM&Mwse+nPu`TdJXkvzz4@(DVj9aF{g%wS-9 z!<@XL2B5#4lK$+gzjGG|bm=|!l7zT7J4tuMh&j@K2EQ#5@bc8t=)he@r+ zw{I06--LA=xrdh=7HXO1{h&&g|D`SpStrDUjwD!+JW3kp@CXXK?u)K>Rxy-CHX4GI zt42lFj|Z=8J00#1{a)Y$_87goLv?jgVHLf8wt*Wxfz@{tvZfe$ozZ`nTJSDS^{&v( zBs_K6@Tlt6x+B-fzBOzRaR=v6b$hr_1@AXG|vuRIuo~6i~}KG`>&Z>-5nsqIXv(YCG58=%^%5aFovEdq&GXs zFYZw)BV~QNJNg6`V7nG*r`x9_oU1K{(X}q$Q8K|ps0%dF;wkeXVJsPzl&XywRR~kQ zb7q~GU$Au?*V8pmW1*>V_n36Q3W)EOXg9ABMb!>KXTuSCSeq|5wj(QK-mcc8P#a0Y&Xh1VwB5 zQiz%pp(roc?)&`5h{}xePFKMgJR&Vm@fAxxE@Lo8SQA7EMu+#KJ@z9jZkvmmU!T#$w z{-jY@3=JQqff%q{rzAqbLYG3{ zVRA6ydANg?tCfgo2-q@S)~ zvQ=oAZw9VipyNhzBQ1)zqT9>42$s_WV!vg0byxb@zeq}y?adfZIUYuF!QrTQ;#!ml zgtBDBR4H%-iBRXk0E*0)Y1EGv_E%+>D3)Mb8!%hRD>T~j^{>^V>;fQ32t94FtNh$v zA?9rs6~6M%Bpe9RpMF6fwpdd3=Wp>9S(K7c*HPcM7cUDhWyTtnwf&~xuCzeY$7`v! zM&+eIb?oc%bYVj87X_g+XjgD8cBUWHeO9wP3YnYK?S&-cg|TW+tYqR?;d}D~Mc(qh zR~2-7tSG<1$NK+RzIiw0BwD@phGIxXC)+b8Y{8hLj&FX;%tR<3WWPBGYtBV_7iis3 z<+@#S&}CSZ)E=i$dpEeN~1~3pb@3dU!*s;m397b2x$C`765@P=-qD) zr<~4q$DTHRO7g%+L(M8)zGXMFPQX+CKy8mrgNaXqdajh}US1ZOIn{h72FOvaG8L?-g%=52f6u^r;72SL z#0EI1;&2e;ud2w^o2eHg2<#t)Wi?JLBp>CHAB`A4X1%6j&V4B--Tii$P3;SxkonyV zo*b)m&l;w9G0?T9Pm!+U9Ln=bQ)iow02A~qSU$r2Socs+sylY<2-DCy@}c~R)YqgQ zNOsMQJ%v&SWbCNfStP{R5_LTpQvp`SLEA@>#BV0Mu1sQ|MGMS{UME1GSZiC=nf@v~ zBZ*-#p@(*tl5~$fetS#;Dh`GV~pp${T&6}22b_uxiV8as3L->M$4UbC34_jZftZ=O9wwBGqRel-bNAOtmd z>GA;a#xLkp2Vl{HtYruF4tE|eh-AsLu=7{UKq6LYe-x%Xf>-7u(=X5~jyI;b=@p;7 z5V`bh42J-XC9Ot80~*avckjQFbs=r`1P9{vZknEbxtcmGW3frW{SqWF7lrRO1nuZK zdo&E{GS{4ti^HkjhFH`v9bLb}66MLrahZ|U)Qf`7GtW&9xZ)vi1cuyVunIu;8)L~Z zW?8qyTvqY1+IkE=>K2!w3JgGrs-EYHgpYkJk;>@<5u+oVELW9Q;7p5+d^jzkHcB<$ zwbKf-vR1!lWJ&KXo(&g2#FcePzhFg`^?i*Jn>FU0f=-l2@ZzBriyO6_&c!(m`XFFHF?_IRa)vl^O5_Ot^C@`T%J)i{-u+X4gDe`s>lEb+hPvtc^b(oQwc zjX$BO$2Cb^YQ5b>LQ~uvL!=n~bN%)Z!*;zb=-7?u*Msp}J-=LEM9uRqez<~s;ta`aD>401P=;PmWDDdB;lT{)S#nHp9-yB=il)pU;n z?g5q*OdQAzav@FcE1o@73&antjJNGIi%)F%2&9p=Sb^MmFdf=ZM4G=*!a3l0-Q3#h zK4gbeC`OETqmVuFu$2Gu@Gr_ibFafBsoWRfbiwzk#{Sqx_e4)@3U!F6Xj>5heQBjm zAD~Gbz%ENhl@!7j*&80ZmX`SKG_L*1@|ro>;6qdK&iZdYG2D1a!&8M9pR2b52_X`9 zDl97@)0N_AG9r}RrT3UDB0TXE+r;nf)0&)}*=i$cXrRO@9#&t-TtWX^S26hd>M3hL zoI)NL8j0yI9g9EB`s+*xtJn9*>VcG2>fq&6;iEO~nF3J`lP09~-*ss>Hw4RW2;Z3R zdG<-JwaCkTpR`pe9JBt#0K2XZ)C)e{iXdp#-Q0Mby&uf$1Biy1Eqs}zmd9@3rdXYZ zgT+sQ6(58Wrp58jPfy{8h6llseFFx;ZVEtE6L-l@h(Jx(r~H5xfa%*=qzgTce=0(f z>&h>mmuw-o(W*RTUH+6utf(4*Gtre8r>@3*|Jw~mebSiWm*Eu z83ny2%(tm7xLxI1xZ;p-=ibF76c80UnJ@U_FF;3)gm0&5xp*;bit+;zjp8c-%(R=f zRb&kwR$>n(GPJR>swb1?elwee=7F8q>qF!%mN;%C>Th;io4S1ye2?SOimOeZUN=CZ zL{Ng4l>X?}9T&Xl#qzky7TXFIF+L$|<~Wp)X^>P2a;iQVJai@Nmj5u!hgG3Hbm-{5 zZdgSBvbCVe>)|A-gC!9q6+PvH_|fvJt=gC`dUH^{cNv(kz4mq z+}zb>C2{b|`M?!RbfVKf2pX%-XN>rc`$=%}?Zo*s+K@_EfvSeH6pE=LqjgrF2R0Vz zz={>rTrNQnRu)sBh4!2lRddW%DOb^*L$wl5ot**Gk3+peXhlCKWZiEv9!xq&67Ze~gHMFqf zk9@SVk#j=Sf_K7s*h%k*uIYl3wO*}0IjT3_A+NK9cB%M`d92(wq3L4V@Mar6OVp35 z#(rT?i>mo1g@jmQl&TmV^@rYTyR&ccWWt1GKc}`YfvnG0w_ToEO8j#>#AlR((3V}c zRDRWOb)RGoXbR$R6v+(``i^U6a2XVP#z+M&rmNl}Q-W2~nE$@vXtx)SHNAyZ@sq zvq4G9Kl3&C3R9n_W2rDW70vQ&k>gMGKl6Sgk zOp2n1-VmTQq>KS&qI;%H#c#67ZNGP0fs}=jHU;4IDT}=`w5yBqyb=K>e1yQwjUvWu zTf!$-;x9z&a^`h~#1u_Gh&?9JIYqZ@qFapL!7I2&$!wB@lX0d7uk6GO#&;E#9gyOD ztKPG3_#rlhzYoOl>_C8g~*o5(G5ZnGh)a|z|9zE|czjnpJmA5jqG6(~ zR{xBsC_cpPRCA1n9Z8rTPuETASp$VDsZ?O-QHbvm9yz*EKJ}*I@b-|j7tT$gd{kNd zCT9S9m5vPDo{1Sfeaz<@e2mQ;@SIEru)D~oVCPU#nx@MU?Pel(VEKPg%n$U$`tjpg z)d@V-A8l0~$Ebk4XI4)yK=IKO2k)Ki3q_s}-)s~RsNZYf1Ol(GppY>A;?=(%muBd~ zwLQ;Gb#Df8aHI%833kA{dGHSLjuwrPsYL1UG`1CXlKtq1ej4Df81yVGY-shA z%hkpyCy{Ej$dleMQp8piad+>9a713?|FQtMH`~TU+y@dT0+Kdj4B_!6rYSRh3$;BN=xhHH(tEjl#wE zCpS*fh`+QC^fyj^s(jdT$8E5WMtGry2n~$Xy!eH8-S-{SInYEWfvyMFHOJx{Hfro3N6w=XxO+H&*5F zJjkKJS5aRsWcCSPbb2Lbc}A)60GC|D1d=N)DFG*0ksFlpAjzBh1?SpwsSeNSf$&yp zss2UQ+QV<}^3bWRn4JG7n7x@LlpVw2#6|smBJJiCpLwVDfRciS1nJdK?0ni#s4hIf zT*m*`>(^*;)6oE^d}rgz(`de`&zL2xVDCj+%ypCHtJjQ=!crUjCke3r#xH z0{hzFPcEthe74HLQ*(2YKC3Jnxm9=kw~!(Tyt8D|Axc8bHV1F*S@WBt_FWCb0Y67u z+>L*<&Wcat-mooQvZL5KPX+KXucil9;5IXf&gPH61EV$}rEAhX9YCnb75O~(nbVG6 zXJ9>iJz!@Xks~9v<&wO>bn&P68i;1!755u^UWa<@gZD1o$;G7%u&HQkAB(EB&(=^0 z$ot2Q2c#!CIhuVTdNb}7?dR75ELA3^qqDmBC zCyA`)UIl%_oo@&NOkck&H+!Z*VU{b?<{x1GA5HN~Gb~G7>bf79 zV=|%Ug=Iv~DQg=SFo&zs6_nd&**ww>5xbL(@K$X>f&(cES)uqnOYvpcxtwEfw{R`5 z!rvk|s7e!*pHn+q%_gHv=K3P|0TWI`uw0mpCbyVd!nfWxyIULfa8ePRp6>6|6t@ck z6xgeLTC3!k2|$97fj2cqw15agl-sG#b@)|uHJnXoRXx+2+2i8RnXYi03}MX^vovxh zJvm{`G_2jwMbj*Fv-Qo^XAKYE%3-UyE%0Ie>yuN9YJ2h$x?_iH{&+Nkq_R(+_S5Jdw&-i5cBw50$ip)4 z@Bp?{Bh~va-9Xhps_P$;Zgf%gxagvp?s2O}OBi&eTi>S3lw0O?iMh|yHf>2{fG~#C zPgdFn;HM7XvFgvG(C?mdKG3#XrOQQo$rh6~@o-q7gT^Ly$ty30CAn#wvh7P~>L+!L zph|KUt<7#C!7t!ZJ^!=7+A^B|dkzYJ|3+(5y~Wb5e)VD%*=Fv0K#f+5cgsCzdnnuh zPq_`5?M+#vpSlZlAKE;2OI&ZPXWdz2~r8HecKp0tTK(FhQZ=xs|3O%qO zV3F9atw!jjF^p!C_TO9fo|Lwdf?V9i2&p&n1JOB7376(mlJi!Iisp^6b0gK^1P1-y)YGUR=h z?_emegWJ7|syyEah^L8th>R-!6k`G#LoD-WL?VZBtG26>!wm!D?0?|{%HyO~=i|Dw zy$I2q8KrWzU7@KY=^mM$-p51NMi22+8=kq>)Q0$Os--z9XX{QvouvdrUaN@aeAFus z)?2GoOJRO)BS#CAfb29oitG?e0s~D8gm0vq38%2ZhQ@N4ZAZKNFE*aK(bGe(K@IkGt%p7D zk1Q1p1`LOY;ZmFT2kmwxuJUSKhPeF;-gg&sLzkCyjpHTsxcHvqt@rH-I;Ui8TPu^7D$YBAUUPU8sKJmfaOxvP= z$DkthI03vX(A>B}gWU<#;+AiZL$j(DAq?en7dYQKZHbk+4cnp-BZiu7>d(Khe#$8D zvd>~HTG!0ajlp+34;9+-x|QJREc1_5ipAzEldad(YZ&_;@s)N-cF$R64)H`*M(-V~ zB?CHB3}2}WDd@9(#b=5#^{8zS5T^GYnrkgJzO`S@kWCz%wR-2rb7?mx8|ZnDyGVvFN47W9CBE+k0h^{VsjH1{pTF7J zT2TlwVPuS{rta>yCV${g?_tj$^;?Q8N$PSEhmBfCzB%?M*D5LI@>{TJk+#)`t?qsw zK%`&pW*Z)nFIL_mKp^Cq#e6U4)579J5Qe;mBr$Toh_re-XW5>tcVhkpc#=?qIb`2_ z5L7+lxI!>1&-uJzn?yb#>%RJR->AJ10S2og&;m}qJrJpS?&x0bd z7R$YKQ2D}tdszQ^DUMfaFy!Xk7UL$;vCHgt)E2NCr-b4vJe8C?a=2udn;(@sd{{qn zxne4B*eZ97-XJH2QX&!yWwA2-BTilWEij{5DQ*&O+V}crk@?RF&|7i^B44FURawTQ zHh-M9orP8(xBt*P`HTyNF6rwQ%4~w~CS700mHSbZ9<<&|GUS~#B}8jSxa^1hVq)Ks zg_l1T-6q{?a9`URq{P$4mJ3_zA9gSw-ca8&!9(r;fwgkDM0QgjGz>*r8dnqKXNQdc zc)L|=%DKz3f&VhxZ%a;D^;B$;7BKAMXn{${gsUoOh4pX@jr2mgCJrvaV@i}2(mp`m zA?eoqoImE)#35vzk9^Sqe6AI@hBJ-+tyaGId^c^!Q1P~>TNqAp%K!fLK5M}jnwtk) za^q*R5g1}qAv*GGA`~2gv^)DQ8rF;Y7sp#S1s%i~%N@J02O+l(H$<01-2L{u@3Mn` zX0N~4ayhb2kUyWyim7^oX^y~P^`=o`HAwbjuiOS~R8c2vOmHK~9ALkN=UreB)sj*f zFFBONd7xy!W-jjW>?6#D5q}Sr4cbFb(XI0w%Hl3RDX2KKwoFzE^+s1BAFcr&9;nj!2{ql(S>A6o%I-m>)C^76COK9*oWGlnpajP-WeDDyVErw!=`9wkZ0yjEN_5deY#Mjql(E=S3%(yEuflg_*(1F z%YFBMNfe53t0gXeONrs@hO2ukE0UZySmQhFD&&PmVgp0N0oLY`#HPzLYW$kklD^zE>zv>v$6fd4)nrLobduUAQ4H&=+~zrxT?&0A&#pITW4mX9 zCJqzTH>1v3-+-4pm{Bo!=tkUHhE3am6fui*U;^{%d0Sgsr48Me-Y(aaO(f-9){NS3 z0yGaOnZwtjw=WRc>V{q>-}Gre(jDPw!l|GkWBUcNwCeWcd+|5#os{vKk1)cHXs&MI z<;F=8qi!`t?#EA)mJAhDbT5OhWi0+O{`@u0TA`W*)5!k-5Pi9&<)%)Q>HMXiuyS8% zV=wHd*|#(CIrFz7rTSG*8mk1it5TH;K5r^Tb!)UwP>=pBE7(Yn{e0B>8MMcxW(4p>;b9YwmyLddkDJqLA4Gd72#w zZgt{C%0cD?{I(hwvhJ={XRO{Xy9V5EJHI&?Te7{)lJU+9ADyPjt1#c#<7h-Dz8w_A zpI-}xx$I~2Bp*tK5T;#HKpnA4mn1<1J_?3ihBABZTA_u!&Veh-Ni`FoGNy4NnNWolBWl}CI-EfT^r4STK+35p5_!+ayb0Po zdGxT(pX5+9pThO^=mZrJN^P7nKWBA6;P)XeTC&1$>itlnE~DIh<*Q6Mo%NH)G8N4W zR&V;rc4zN@dgxcGUvmJf@9f+ae7yRyr0Xx+?l&5}ShU`+c!|pmx+~Gawdk-?qh+nv zcNPdy!c;_Q8PaB4lnhY{2Y|>tvW5s{pu320FEwq+N;X{VPMItaJ~3~(k=!P!9;M%S z@RcNp&_+@+pSP&`2RXN^*TQVXFVTg@zmZ#-^aN99r__0qj4f-@?RoI`$63Nk- zj=0qUD*5--qcxEbVdYZ$lqi+4Fb&jJU%BpNaU}Oq>BsDLd!42G%Ynu{9NjV)f#xe5%=Du5S4+=i*LwFObfSwQ(^$)U->KzGAV)RC{yV5J7{FUqyQIcPxY^x*V(9?$IJ<@X&s(cFEhCX+4nW z+gjt-n3op>2bf&@A&2Ta?C>S2d#*y=Q{mM)_w{0~#tDPobymu5bR9|@ypt0h7aqK! z2pHyvt6``A=Y!uGuB*=1U;7JumloPmDGOWDBk(eo-gUnd97LmV5ky%I=T5P;g zr{p%NQFKjHbsg{5h|zYzuMK{Ms~53^A9aQ&`Nq`17`o#LlUB5vYj-$obod!~7`=|VS;$sWI{O%m_^duUGa;XH$JPhp z6v}5(Vn3TusY!UhO2WDFl5t`~ulCk4*pvzh6Zs<08#zPXCy@{;^tZ1xZb|P0w63D` zY+e0v7a7Ic^9OJ^FG6>;G8)Fi_gq`RApQfNNHKI|hY86G zj4wSSlNLnwfoNqzGCN^L6Z(Tj$GuO@*P!weNgXxZMk!{g2R@PyPUW@1i%cvfFIf3E zPU6140N)}A51DRRyBv=Cr-a^*L;Gd0E#N@ZCK!S@E!c<}Yd5t!97Mc+6|8 zL)MKfaKP?lR=-w|$4OvIJ;y=B_x1dVhgF7Nq8V0VCgYa1gLp-CO?W9!;um0Cy`cCZ z>H*jHfhL#2-{S&0x&j|MxRf_BdjWfdG(^Rx5duw3p7(_}7hEnR-)$2^|HIWASUCQI z$w~KFiO+3{%jUHvVjAWXfjDkhDkMn+oj&`DHg3q`rHx7aPR}0-KulV=wwt`6mO#_jA}ES_5;8$GqxZ1vNz!=K*T?gI zagmt05rX{SgzVaQvkjX`Gt@{B2IN^S#{s|ICd;?J%UKh_MPz&?G64M;880N5-)OAX z5U73*eeeAR&Vfknd(!@bJeH{U;_LVIcJT;IQl^LMiV1dy;-UNr&FWg-(BFm+50KTV zl=0^6x4p#-LR*SsMV_pzhlNgSX=MVqs54vz#?qn$Kh?Gh|Eq;j*O&5+ziIOA+*leS zMigj}w4o)VJIBldbW*l;VEJI#2!*&_<7S>8+|9%gjw694}|0 z@Rd1+1@l!1j2xQBrgvkK)O)3Z##^kK!*73GS(716=P9nj%a8K>5OBZS2~YuXi(7;e z)#=xOv|2WO<~?wk1s=d(oEmJO8W)^#Ohg^eVax=$oVM(#Uy+1|K3bK7$w@tTT=Dv3 z0Tvf!(q@M;?-XJYNbY?pN1fwvgA}2ubnXT>xRNfvX@0vvv06-3oSn++Sb1d-WR5Fu z_6`yNe>s+UKNxUHbDid%Lbt3y$|AF`;3UP#$xX|Mgo$|R?uz$=vK}ic4FW_j#hFwi zl%RjGOoo*5aSHxIGxrGO?f!YYo5Pu+j|U=9RxxiRha0Nt6FvQY4osiRrzsnIY%6!Z zr&9myQ=d=c&7V0lqCM99LC$DGiMP(rM-;=xU{o_&Pswc19AR)VUXeSkD{5@)!%yev znV`>?caYpYg7qbA0xV&r8xNrO*=-Q|E=uF8L0V68?tSoSGh)a12~tSiA;;$`*lVMS z&4UIQAx<}raq!-ppk}4VwsF6hLmE{{+V+I5l^XonR2m5FYhp0d+*DmscKntkeZWZW z)=^*)3>o<7svV}D-DCPkm(JuZ%f$B!j;{LMbT8W4an&@&y`cjwc$V$z&-5>tvU=N7 z)NfFTi7-B7I+JrN8yCIe5JT>a2N0m&y?ZCQRI_47jNP7{`$xj-wcgR{DNaaKP)!&U7PstDJSg&V4ICEGcw&f%rPas_^ z#OGVOqbVT2Bv?b)`!BIJh|^LZh}7frN1R&r*9(MgRB9gCrY*h;crL1sCi(&Uh;FzL zIa$J=(R5mZ)|cr!20zly<8vh9N(<;;-gw$-Jp1_Y&|9oVFu-}m&oH=^GIxuqVXHLI zf{g6%I`Z<*+ki&FsDlD7ujw92=7OpUR`ag6k7_=PGmM!LYMC%DV6%cxBM3)q%1(~+ zPhG^x04`^Hk4KA>t(r%UW~Bcj1RKT`;C4 zPfe91EVsSuRxko1wJ~_!lk@ORrOf}0q&KHPM&~xnF->WIQ>A{^vu81NCIevbv8!K3 zR)$CY9XB8)wH589l>mLVv6WFEZ12kV@-wdZBy<|5E;Hlp;cB6?vdb4BOq7Sdd`H%1 z*ZVn+(iW_KcZPk>ksRVBxmnuqd`yX^p_mcwQm~wcNNjUvQ11JZyI-%)&=bYOCHy;6 ziaKcESOJ}z+F?wS$Cw=SUB^QXq!3VS;5rvISwD+k@{3h}X!*Q0+|TA#Xn@P_cM6r^ z)?+wpPa3yb)gTEE$34t$lt_I91U;6d*l6|+8(CKh8+&|5M#+xtbfA9~7A&p}GErF61EC+jW4W{1>~hD!h|xn508Q&;IS}ehY%qu*UykcM8faQ#Nbyog z4gZw!_*|`50WDX8(evx)c3(M2W+#~z+U)k`62e*b=({;zn&5d@!t?vz<)9jt0CZqO z`^gbaYlz7sg___w7GCP7PWS-Q+QIUE(?gNtlT;j!3b|t8vyN0o$v1DkNfQt^`N~NM z;$ZUmKZeysQEy($)DYznCYTlv6UIe;n6?(bP&D;BS}S@^Iz@3S!0vMVSgSo{s9#7m7mLb>FgA-KkCe-jmH2U9z4o{N%?uz-vpv_Y60(?^4ZasY z*K<2&85Z=!x|f*?KWNbzGyghiB)IyU+9OgASpRfrzQ%OTO@D*su;Wi^_HdLe5}>U| zl&SB9y_U^##Jf994h`;mz=nrO_;fad$p0X9uKEt$U;{>acJyp+xCA8U zZf%sCuo`TRU#mzWT&jLb)h-9e!HzuU7OJVNR%YFXuBvOOp>RvXQOpN-(eFflAD#nI z9e*qE4h7_5hIQ?}cHx#X7RtOCjgpQ3jv3{}ScasVS$Ot4cs@`~xv(6dIe6Jmh52^X(Bec?rJ{#J_PPPASU zRoMl{_DZe&gRTJ^u<~jNk!E+I0Z42e01`Z=sOkN7A=b-6V-0lFJS%W>SPq=^YRGT8?Rgf0i7b}@3HN3 zKCQ>n=+Vc##u0BVt_fTwAIR-$CDv``38I6(;GHs0{MP#n+EZW;i7^@k1c*1gjRj96 zVTPGftL+p!Oov8(O}AYWWL!CW{BiEtyy%(L!ltu6ryOkcg=*{;3;9mi*xj!Qg`?mP zf0I&!{Xb&SuSsZpo}ad1dko0`9wLTcgaqxWTA|;95EiWlzGc)-+)Xb$Ro&%w37%qrwCMv0ukd!t|FQrL)?cKoRJsznX;Ese9N9bifSdrFnX(e3k=Nyd&(oU) zdr|>V{>HxUu^~12UiA-CJOw;P+HEW=-RhR-Au)Z(N3T*HLInaASCwXdBUPc^Y39P? zGlW%aj%YFA{I=YrKn-rQm;W_Aodh!;P^e$LR}l8%MO5r=-kX}|AuWf+3eB$4FUap1 z!wBHEzGF)EQy)FL`yOw;B~JtWVl{XDMm(x_g%&KN-8|bm1qPZL65e#WY!e^Pneyak zxTsV7)Wdo*U4>v_{j7b3Qq|OIRBlR7H<@712HP`2(VXd)6Vie$@yIT{jZ23}6 z*z$7_kd7v)ZY}vMgqH3!hC=oPORQ4e<0OeuD^Xk)^FNFl&{d1yAHP&Y{R^5};*z6t^ zaX)Q8V#Wu0IE2vW>f&yXGj*TmRe(|MmG=t5@f=(~X?!rNfi1$Due*@11k(u~Y>Yk9+=PYi2 zMaPe`&H9;r5w*{TmDj)Hp!ZCw8h_oo%%9F&cGd-^_TRsM-&#A*{bbN=UlIn@&K$6J z6+TryVOr%t5O{CFxHuMVM?2d4SSe9wPr?+t#FPFs+%wF?Q6QB-u~mB_?~dBqJ?U>4 z*IiCYRJwEMb1D1d1vU!mphFiqZ{@a0Bd$EsA&2#vu>GR{(8)yrT^X&$Z?G4~`WJ}nB)bad?|?74y2%F|#rLraRUpr@ z{e=ARShD#9$%w|U#|_Flzn48Kru&4IswY7a>gzHW=IQdBZ6YT$L{DS?*mdO&Fy0YY zFiHCX1FVc4xq;&ZxhVy}VAE-r!IAohh0?%sQRhhav733mX&y(4g&;B4OKmVlOi-wE zhMzUR4vG()`9y#HtBBaxJEMg^+CL@`;z~2A`b;G2XxL0EyYdwD&51SfpJhuEj_}9=PwQ+pLZ|pKBZ6PNwO>DIz_>9|f z*Po+|8n;7s(?k1HdR}+R^CcGWF~VA?N+ZHC^C7MjCg68P-8ETJJ8n0VILxDFF0J_9 z#Jj8Nv45V=&FF}MmeAUc8bK8fQz4BH_iR^x;DU<+o}^MIDRhX`;6iwx1rPYKamy$N zQ`ed zoGQ&Zixxvi8w3co0~(s%s*+ibny-{xlk>ANO`c)252uSa^aLfQ4Bz8$F}?NN_$+}Q zxStR3BmkH$?d50Ux~k3p+Rn&K?&WI^PkKt=H!N8-WR^zBntz516>s2Njd@u?KD9#r zCc3>(~5(bsDiz$1d1)Aux(FdpETGK{EZZuBtjQ+Ah3ZH6_fxd5nkvGkjs&9%9t4 zw+TdI!Z)NG*ZzYLC+Mk3(VSwb{AxVRCvWhc>_1d#_%gO~t2YU|H~q9BUl-dG(Lekr zafY4*-N;ZW_-UfJv5;Q-yT><^A7nnpeIrWNH^Pgv`zd5wKwO|`aejk&-gn2`Tpq4s zJLQ zGO7E7e?lqy*MI5*J}oka(6xtypFD*+(PvFeizy!bjyKs5f+SxN67v%kB}IPl-~qx2cphzOWAp?6H+|DowD z+?ssD_P-HQN{2K^ND0!-2$62-R_PL?Yak#hB1lM=QqtWiDK)xbL%IheHe&nj`@X;T zFL;jQIiB6Suj@R|&-qQ5lS&rUtr@L1Yu?ss$W`dYU_`1JM+Iul!@i$4OVjPJQK--l zwtjysr6lQZm$*yFM5N=x;q5S^(FN>M>A!JN+R_pka+o!k5d!Wg*b&F(`+qT!)lT0Y zEFpwrF2+x-#>1yV_1fA|Arf$}@3=5X8R7#t_OzChp#P-j!UpTE8Nn&Y9g%SOISSat z#WcN#cuR)RG+8M`hG>k~tN{PUzEVvo6m)|ZR z#y%fJToG|%oW^1-I>u$&RrwM~&R6)_Dck12Q};BQ9Im-bDX0GYD_&Pne*;zUdMkvK z1lo8>q=9MTaOP!lOwFPHAWx;WRP>-P6IYWM#K z_VML8{0i51T`Qr~2reXge4cx(mK|3O-qIsr7!jpe>-IBNV?0gsGGtcMiW zG#mej=z7!BUC1^y`9J#2Z=a_zK-ZQZ%MJVlfdhFNGdq$0Ci%(St~<8&U-DsQv!U#} zx%2#g=SuLuIaxMQ05dxm z&q7&Mr(jXXrDNZj`ST}_grVxy|y%M zA)v;*Mw+MauV1K>L#i@>&QLu_ZcgeMtZB!(l2X|M_rrDvp<2FE?uDSKr%DzT2&F4} zntW@X#)zLYsm^ia9O1?C(tw9u42SdLupX%e{bm4=iU7wCX{p;hV1yG?NB4$60KBy? zdHaTxhawT==V9x4I&8&U?BXn*nW#Yt_jThCzC!Q~vn3i(xdO8Ng+d#K$Zf3Fh&I4~ zx}=FJOnsNi-vHPSiN(eo-jSYZs>tjy^KLVIk-iQk2IMh)$DX;V53yP?A3zE&#G%+I z>gM5zS-2+EeS4C00adcIAh+u{sV{)#F)6Of&~a73Rj_Nf1_YmO13Q6xZ5X>lFpp#y zwzp=iEwVsnUEY-nCy-JmpV_)yMKQ6j?b%227_-hu?q|=p`7?4MNsZ8^bKboR-q&J^ zIylA_b}?_@=6BXTOOg*Pr`1>Fh>j2E^B*DgmcZA3qZ9Ih>UCGQ)by~=-pGsOQXJJs z+O=M$AG9l#2mM|z&+~Oq)s=*PiIWR>5&wGH^mAT5xLuP>y~t_`9|ZvqvdUs%WrHB)%re6 zSuHAd27Y2Kj8UkxH2;U!49&FYvuFlCp`PO^nI&X}QNLR-5*Ro^T4&juqGjUoikB!UciSRVXmuBNa21YyWKEgx`7&m-9GOD zvK)OM{PniCsw$g-Fjh^EqqQqnuzU#wc%tntA$WhQs?fc?y?bB#DjC4T z0bxX#Mf@__*)TDrfc@;6u_zg3shxQ98v%i}jufL2^kLT0MF$O2qF_kH^~C;J>&Q0j z1K)JVEz3<|2>z8jR$EzopUBE~JJ5e~Eo#yC>?@Uu4_uYgLjDJl4Giy|g0u_}k3E|D-(w9E@ zw|p;wDkx1FfYUQGn$DMDZjE0JBd7W2F25aHzPWrK8H6><8$W6odATg8M@yW_#N8WM zGBI-QO1;j2b#(gR6g7?SL}<~%gVfJfEvw|1s=V2&$hWrlZmGI<&nJ_eQ~9=>J6ep! z!1bmcs?azM9Ks092e&HT2vFirb55G?t970h_Nd)7xrez<_PGinzMvn5=aFE>^eLM= zO2kIuD)yFPVPnL&%#E)&`orz!`DPC`mp=aGQ1JPGwaH$&pnqOY4eb*A%v~LE6*5uwuZy=4t`p; zwXqx$4Qdd3YW&)O-{CAQkvAY5HOvo29jd8u(f2gV%pgp>BXZ3%tUQG6md{uFdGv~E zmwUiR_xmy>)+`#>^VF5A3Es|53}AugB{Qcuw8W@9yL4ehmIx{bxwGK!O)bl24d&ur z{^zysM*fpF@32sNwplHY~}9rdfJMi}L?6nReY1&^rMh1J!q2rXE_rVOz7<~kS)6Mh0D2+j79F+8iF1|Hu^Fvhx^d=W(VhbtQ~fMy{MP&$J4cY ztOk=YKpLK=clSA#*>Q(C*m?X@p<#G!n$OcZ&F^qMfom!b=Ud9pljAyMqFFU%8p~TP z$jVx6wmqzsr{%aBerjoZquyXmcf3W};v4 z!LWZ)4EUMeUuWYsHr2+TvxY)i;iWwC=7V*3o6s7Ud(N^8`TfN+y-v68s#?n_kzTVShY{f=+L;Y^ zy=p8OF@-;I6nzlQM5@fs$DAp{iqP(VT)(o7Q5~qiFtUFf^E~#Da{TC2uBPDwMT6y1 z^ILWt9t(z;>Csfc9&=&}U`k84b8iD+bpcz4$m`o3_j1X4jzw3*CL+OE`b_VY5X8*o zX=IAms!i@lHSO*ffYQ&xC4&3LVQOFAH)+S2Do%6jS^O7cm>T_U(8D_PZcf~bPvFIc zRnZBe01M&V$@Mq37?sBS#<4pQHW~v`r8($~1cvn>B|R;=6LB!3Y1yv+!Klth(rL=d z!aTq(*;gTu))Sirg}E!CK$;*L^=$4ouoN;O$Bprmom&_DWCI4=yyvPeFT`^mHqpqj zkg`MEMTu)}vtXbrVytqEkC72d;7>p27u<4T(Jl!tAuw9E%V?74>|A|o%9||`m;A7;h2yaF! zeKQh&;&WA**TdOxaY*SjC}xp$SscvNVv-EVK1s~og`neQyzDF%d?9DxMSN@}PKD%v zw&1|h`x4~5hlEc91*_QMLy;gJ)st3K*_`ZAO|8TPe~qo}dzp{efw2fY(3mEbOfnGW@eF9LK<8gna3JD&ccH(GD~-I)#JeBXjN$gQXyL=OxtSCAJP z1!l>?L~zux=afw!&7QP6naOe_PEKG1O)ddbsoN=?*1su+HurSSNK9S1$Kwu=?+|lx zTS!QC$N3;*ee|en*5mpArQVvYmQ{RNTvp^JESuZGQEoo51-O{__hFOa@}2Kg(;GE z%UHB4|G-V#`4vYI;rR(iF?p>6i+bMuCy7CrAaeY`LezkZ&_|W&lf@7DivNRu34Ki;^#P!+RUzg(RZ_{E4+9j!2h%RX&M`+0@XkIZhLyRbZ^&ur$-a zd-j!^foF1?y#cGypon>P=);n@h1P}E<^K6AnXgf z@8vOoQrues%RCC#lX<0|4B5PEe1B?)!>JnaBhJ!awgtBF)Bzsb?WPFwNqh>NcE%uX z=FKzskJQQw0*=iOOtfx{iD=ks_ZY}esTg3b9%H2ViUk65gpwDeWqiq@!xKRK9(?S; zrEi7<|Ig6kKJjLUaow%3yxYUwff_&wkDcyP>2CL}6ozmKxDO`q0cr%BDqj@)dnS60 zhueMOo6C(xhcz~n=}?v{t(VrD>A?clVe zOd75NM-O&sElm&`68kY|N?xz?6Q^YGagHK{fess^nEX+Q1;4LMv}5(r{UPEPi`Q6B(fb*f0qouMT|7R zSM%S~TBX9d;lC>Rd@fQl;OTApwYA#pv z!9y_>Rs{y$nsoX1)sZJg>-@pL#Dm%2<12nLPUWAp1m1lIQOmi>1ELM9maxYLjO>ja z5g<;=oW+we1loFKtuB>Il8tbxZ8#4nFJw#qWMHV#1Niy?rGU?+7q1kVo52o~9~4>H zlh3_5wIW3-g2oY>3|!(k0QhnH`BY3{VP5GG3I4j(4>lUx29{XyfGX;we0ih2d!fJo zuzcK;+|m>uUT#i@TS+Bia^AzPXOZ9h&(mdjpPfco(voqse z>^<0}-m!rMK6D9`2299c>>28|cot4iiTZSIbW@{^v9;H4 zYP&AIs6Ws50!xR;@Tss6vZztg_gqg5rsuXFDsccBK7YrzvH|eo>9@abUiHVSC?iUL z%jW_TTvIPQn0ao6SGvixG<*oNH43Bo*Kf;0GV)meGeG^@Xu&dATSLknw`fP3`2^oe z)znmM=Q+<+&ex#inXPU>|$Y%4}57n|ARV-?ob+d6KllI=Dq;hBT}hkl+0Swzg{K z_AbTGpAEG*i2bm`dwbgDxzsvitl`Yq*`Wk{cS3h6+6hH!Uj61M&Y7%7{!DH!I&I>B zr%?mQHNF4Q+tR5>?#p-(z)lST5l&`%na!1SI>VY{7WJ1|r>1_7f%g(vf;$z%b`fXR zyHlo<^N#-vl6WKz4w2Nv=GqHY!x$c*vCDA`;dI1dYkwB;%HbZ<{rd!L?)JgC?-w-OxJARi&g!(of2(|%4VpL;HG zKY2Q9&Jl?|3UzrJ*WIu33h)jRuM0p+v1I)RTq2JkIFSy^*$?)#S>9hz#}cnbXxqBo zO+a~#lLKxixd}E=63;jIe}DcWt)oC1+gp&d{rwM+&;E^Cz|S*uIVIf4(SReO`a;Ob zXYbye2AR;A_C1{F&P|RL8&cl$*x_DxxP(A21{FxLp92xv<*Us^78?o2?o(Wh9Q!93 z|8$~5EANP+}vcXXl9b;6*f88Ma&8 zp+1%3*d?q5oUCm=5`8fF4?2Op^|A8;9CfoyU0DeW#Ne$A5W>pqI5OWM9)Nof>!*uI zqThcW?31OuI}BxcQ9zG=b6q<+Ae;K&+qUa%*Mepbq}6Bo4{ly?AgtFzeBMh5m`{3q zd2Gs^JQE%yUTs%xvn+|Y;5!kwmBOYFj@$DzjdpM`1y#aDM6Lz|#*tuNw;G|U1@X5n zEmI6$Zq$(IBiT5*3kMr9SG_86aKD(=vfcU+j#+`CA$PA~0gla* zN7P!nvmv1m9U?I*pxgK*zYEeG6a>omB{w}yfCVtYne*e`Vzz$wFON#)MO9<99DiXp zJAIgh5(uC@%Y!^m3b%pfjRF$4eC*5L0&ERozIj7O`Geh5I9HLGj?t*4m&}aU)vap3 z7ZUxum(RU>=xKr{CWSXM1|P8oWwBTYzc}~Du7l5Ri@9v!=J(%A^?^-XQM{M`;q3J_gd4aPB2^A*sB z9aXOND6l%r!Jq~&tDTmoesPmNdZa@vE$cxm%cNGZ1Ldkw%;U@hNEn+NhdHry+KvFg zYzJS;PetTyTjo|n4(00mu7rr?xqq%0O`1PI;sO~Z?w7Tt|M|S%ju9Pezlt;Z&`*lD zhLOp-P5VDDKmd4Ccc++!HSU-@giMa;n|!dBCc{3H%9st#SJCzN`csgDqk3+dbyi%zZ_4Cm9ZE+#+Ep z2T|^{yIY7v5QF&6a8_UI=Cr%#%qx>^A@L~U2>=8%>@Zg68lvY`!)eY%I3evV=vZA* zV?t+nyi!1<@(WpOlLEQKS~o49<_))2Fb&(hti4 zrynEyuu$FRt1#S#W7RQP7+?KwAVv-w4HunEfxFYe^-zx)&`;NPk8$L}4u3J{{;B{p zL_fxpQRV)GkBPKbiJ|D>_QEc!tqBc=Tl*mC|F`ZD57H46!M2$k*D&%ID~+J{ZlX;- zRi=*>`(eTUh8B$EHWmLua@Jru?RO=fTHMHZ*}9iw&on z0}`IiopGjcFJ!C<;p4lXFD)iNu*&FuX3g>Wy|R;Jd0+nRnjlS@yiPd7gb>+Gboi}| zMX3r!OBU0n5SmXAG`ixgs1ymB&B(j{d#-q$v>lK>Koy28ky_+9RpI=T$jFsIxfUq@D^UVQ8zCdHU-F+by;j_J}M z>dp^S->+MPbiIhe()Zu951zVp1CLjx0OVzau)|2`rsdCOfH)oi{JP`5lz{7Eig@Yi zmEC8JVG`)ix>j!Y;Z)0rTwrTRx9Ao@S2dQTeT}t@|F4ura?1|#61m1hzcDpq>@M*` zm8g_HdD&!bj>V$(07@?U_s z>(J5Wihog_B#xNZ>V+F3l0#7PR_C!tSJ%;}X|Wrbd_WLf%?X)Vp7+aeYbr!nlH+BA zt{iqH@HztI5dG};(}pRtDU}Hd=fn1_Fc{Q9gHw~l?+?hwq-;S=lyNW+h1&ZSAkkf3 zbSu|BW4Sk>n~G5+8F^s!jfX?jS7dS*G5!!}1NeG66#Zza2PB)uM)4H(Rio=(X!tVo zz&5!A!s0=R?Top5wiWbg6lDK7_?IsOBUdZs?aeynHZcBB79JnvmnM=oa!H7GDH2?#JA3>p zgi=E@Fh6-6Lak;d*V!3ZJ75=(P9HMx^3=#mBHLfJ?R$EifzAx6qM+;m!xks*{NobmCyQf})=p2rd^?*t`d0 z7y~Tf^dJ4<7cK7lshgDv1ZT9r##sseGr<8)8oqZOV!DsfCgQN8cJChWQzfMBCUd3I6wSXaJVs9ubXc7y^ z@+7LC_8CCTQKF}eH>dh~PyR9{?p)`GfFi-i#PmDIuX{l?T(g%QMrQl}MNV~iJJZi@ z1cHTOw&d$vKoy!>vNxL8O-@E7IF^KIQJF*AxdPvETJB9NnVVtHB?Zyl?5$rpzz3u*(;k5tPg+ zeIdxCR#7GFxz*&)3aZ18M$!U!NRBlH#rmUF#`$SHd#qS!BJa>w1zK6F>f98+ZjW=> zKe-m-Wxvy_bWd1143zXiW#D4cyWY3Hc%}UzjYa2AR!=wB!=2*U6Kk~S`a{#h?gw{$ zNO5lvd)%!jeaQp*p78;?yh_*Z&3S|Q z=y(_zI$&m0@P<2w4=|Vz7KchOsvfl=S~IRh7h=d?>c5dE zGysgD_Wr1OwOo$;)HXE)L{Ip*LfuLA3uwBImnAbhdo*u;3w*}3$cPM`haQEB+>8!FBt<6jzHO=uU1WxwfLB-j5ND1h%*}sU?l)Pp z#<8z6E9kF8I>7DRA^apMf|6J%7u@yG&?&YnWx*n)M z$Z{m`Jy|g})$rO+;`1lIo^IM`2^aO$&_M!p25rsoLarWfg54tlULUNOUM1}jKtg+^KxpedC>b5QQ`q2rY z&o&0eX_v`_sP1uCv5d9UzTvW{tH|ekShOKR|bR%XQt` zjkG^LuZNvrdBJxZ8PqY%`VzY~>a*pEqcnfdSw*AzwwUSQv5$ns#9mNMn5oG#S!3jh zFvQn(Z3QFI0A$_CG_D2-T4gionT3kQRQ}`UQ3uy__XM%8q&olrW`=IOls@tSa)!>z zq0)9Uu zEoiE{gDsG09B!rflkFO3%MLlwm?w7af%JN>dHqQIVm^jLsR^mo{-< zvE)#aKyB?t?ScJ2Vs=FU23p9qG&L@o*gcWrn$mfOy@PL*{D&uGkRgU>(?$^^5LP`V zLvqOffUiski2nQXe>@BEqK>nbOwrGXafgoj9C7UQ5OUc#b?pvg#hcY`G5!E{X5=$# zreJ(w$W(-8EBGk>ZGzxmQGyFTkR>GGKf4gEt*=jWHH4|Obd?F)aNxUYJ*aIx)%;fO z5}>XDSP3)sqQB5E*;_>L_Qy{QSd*e%vW3QnweSFqua(I0+%>>I{}4&-!6IJ&x4$w9 zu*)wQNeZO2WevXw*jZz}sRJfr`jhpXJa_H<6p%ggbw^k5_RT>(HFM0Ap_N|@jti3+ zH<=He)zzh}sA^Ab{gp)8#rk*i389-+G11GgH&K#;IO{Jz+OQc6S{K}|B#eEODHUjUS^YjP@`X%x1%1*x2bB6W7I$IGgfs--#R#lWuh%vxfw5S| zu^JT{iH;u14hT7ZgwD$X`xnz0$yjk(+aDRi5`f9YX1qSL!@w=(n4o`r>7Ia)08+VzYKpKOxFniyg(QxeAw)FzG~1wEF@4rPIQX`RrC zjEQia!)2-OlB2;5|6W>O#+#Mr4E)M}?DNSEC{+8dP=@OXglY9*97>NBT^V$0pGN!U zgHMntcz?KJb~S}V5YhqYYH@Cp-iugMDVA|tA%My`91Wi36e7)cQ!r<@#N{pfvE zdpptTl;gh}!L5;}w@*_F7^Zx*6R>N6^x9J>g!YMk{s1QKDc@tft%(tRM=!E=Uum_m zzamKbt`z{+#A;e(DJM;xPkGPICyrL@J zRdw$>YnKc^1p+I2v=YC#m{`MeygOXA*OhL^!#Y^8XuGV{Oq@7EAC?eGITEA45bAju z=WkrsN}8P6y99i%dp$RrW+M!y%^Bh(Y}hPg=@Qd8yGJp6wZZf}%QBIF71->6+9qmM z-u#IAx;zrKZC66LIh}1k*}TIP<$uxP{;?5ta*G*d@)e=7cgGa>cuCI1`8@bfZ( z3D$5{+4JKa(WfBOu#IZWEg5346Y>`gzSqHC-xhfl2vjmO^4ZCq_v_7?-6n|hrbpF` zjGbbSWoslQ-7P>weOby|RF-+sj^aXl>w4E+UdDyH<5*A#I`;3qm9N<O zM}pzN)l2_N0OrNb0?4R)oOE$KKdqE;wdku0QP3G~^;P4(mC;(&dWJY2Ws=z=Tu6 zB5}YPdkZ|Hp|_lX7t}f)8TXBwMNU*W#CYqq7t{B;`^m~F(OdT9;psuDe>y?fnSUHx zLeJNh2nhl=r(`bNjDj<*88*F0$F4X#Cz*RYNbgL@f%6 zg@T?tEU%JUn}9pmiX<|v0DVPAd3b1H4Ygr+4ID9CnqC02Oa1J_YQvt=@(|rY!Rh~c zH40YgQXcb#|79Srdu?d>jDXV()8&;l-TiM_tKTrVfEe(*Fc04@1RuzVY@RBl0Py5B z_@e%OuZ56fXP0+HJ{-W^1n@aOqW4$TbEEb8EOY;fUOz3zUUSLJ@@~JpAmjwPLuD5Q8GV41)03U=lW^Kpr$1)?j@7ZI3ctTSb%D7UnB$CiP#9Bz#?vbuw$2)tFZk~J_7z?jAABJGL z&?RHAHVi&2FZJT#5<)EWv5PJY+AJe;*9tkXrM+{ra+QgUp+qJ;a)S{(h?o3Y3&8H> z%jsE&eUaHMtGUjWD*;KA`BWxByVJvsH+v2u{$qIMMT{@YJJ~W+YhF&#@Qtb5V$mJP zFd6Zn`*`Vpi{;u7UV?NLvfyOb4I_NTU)1#A@Eh>>4)!|6wbD`ZjTud)MWtmZFijYCQ zW-t140Q)}c(u_?JXG0^9Ax0TsQ@v?mh!b5kKikDe3SSK02|HMYFuiiK+Vu!AJ0Q8n zpKQR!djsS5Bvj=`xE-{DeZFG<{*23ObZ*bg?jKp4FdyfCbl+L$x)&oU)iF0xF-e8x z+14cWF#n;_WE@L_tVtGnMVeT|HB<@mHZ(IwNT?c};}Pb3pJkD@!@Y<76XN};99i37 z>Om~yVJ_qxFj?_ErrlQ8t;Q}F5yAiDpyRz*{+~fQc;Hm;?kd$dqyBlIRE=(zN_vs#k> zYp>d{4sEeC&+GgN)8r$9?>+@pTKH?;KG;?q4O8cqGZzk$7H7#Mn-%xYq!Q#=d*M^igg!#U3t=% zNWb*_tanzO$t!uv;cGj#UYe5W*#gVZJ*7N=3Hu@Wxj*qs(qKRpx;RQsv zBsB?P17^ZWlx%PPjVvdC`#w;I-6`CmaT_^ed~3bds()h4P=~qFMNNTT^UEwa$ zQSYS+;^`x;Rsx?&u=Vf!M0qc*$I)WQ!4r$$?-nKtYb#9vfJ(LT)@AK5s1)|e0X*+K zUp`S`)GKJ~MdV37!M$z1z>J076}pY-dCGEW`In*$6j%AU2LGF8QHL1x+TGoEDE;tK z%~fpo^MJL`(4s0#*JEr=0mH^kb#3YHMXB6F_w~RUy%H=kq?HPH&%9Gkt4lqyxe-ii z=tHmgHHbDsW4u)pLb69k-N?w&)~BeBbI-*U;W$(*p+LH*2#P-p;2bCiwu<%Z9mewq z?l^=^GtrB^`q;lnV9#mkW0(|Z>7Wj~OXv@#mc=Zh*7tUe<2x1jBe`)Yo@%n|}!?zkA z2DP1cxxPEQi4YWyGOGjd-$GLpK3p5<{AOYe>e6nrkPJ~G{84eLKhKZ5o9->OP@!f$ zi^}C66o?mpTRn5_p#f1$Ji}60vn>FYm4Ht`%yA(g@Mga|L&ydlH5QCU#Smon6{E`& z?!_Uz&pmxnAUb=~Foc?;F%);_%=W&w;zYi!bedf3ae?ro8A2iuwdM^?0Bbn6uNTVd zQ<$d8I6bZh`8)m&;Njaxq=xz3iRq4o-%auMMY!XX26wTowYWhMt{niFxgk!9^_k_) zEyE;4dh`6i{&2wsMSJC7?f1;0m*ASJQieR4?%CkpjwLg;ge8yvXLeFgcS1Syg&aXr z^2@J}r}W%x+Im0_$=0JVp4bHc$r>p`SL!U^Gwd#|n+G2-LG}SGnC-0PtHH)bfDVjo zuXyk{#LSn2NaGIHNQBF)1%J0Fy*+4$dlyul;Unzm!C@>B!Af%Hac?MSSBk^UA<)OJ zh0@s*IbdSP3Dq%9(68Z$CkoP35x4jaz)Awon@mCgis4b_c{VQoi)hv2+m!3fsp=r} zSrONlIEsw+(C$qNL;~1d;&j<-K>*U*|9f=ose%Z5xv&1t(B`c;edkiAX+>DuN!SP$ z$fp21l@S#v(DKXGvQjKLC*I-q;e9+TYQb&ha29avhyHp#vkfFXEjj==`1iuY>9%nm zb6s)*_2`bno9SW7@r0BmD0ARC?gdkRfE^FAIp_{9%I1r&F8m)1e4gT$ZO#*{^Gjba zz%GOpdD_E{|J{lPxzL7=X<0ni&S@GeV7CZ0nIH&E<)S|4k~t}xp2fF!DBGK^R?qP$bMrn5j$Eo!62!;5D`fkJJ%NJODjzm>UhPh0V@!KcX+y7kfPRt^}EN+lfAijtC#g ze`z*Z0R9^o!Sb9J0X_52^@O0TF)(6jt^{aR&;#)1!fY7Qwi zxV->S*S)b_g`T{GGAFe6DECSI0<|xFrG~=3k#?)p3=g|v-PoH7TN)>)Ip2_uBoW$= zJlyY>!EjjUSwR|&X6IB%zc~p zlA0en-Sj5!t2;pjpsb{9-t8QF-K0zCOR&O6ciWpS>Tx{tj)&Rox`JGaZP(&fthpXm z>)_h_JRd&C$_K-ga=iW+EHb1@Wl5s{;It*nE$M!$v%;LB1sddlr{gn=MVv65z$J+8 z60Am_O#{#>B7R2kosUY@&puYYwc~j>`S=O6*pEnBBJC$(WVhf+h9|Jnf7SQBN^Di}t zDFj1z>DQo@+W93L0oGaUGor2#hy!i5>2Y0U&|R9IgHC1<;M32HPP_=z2doY;4@E$d z-c7zmsUpU<0y;{+=mp*a(!dusaET@{yU+J^a%>CMC}$lUZ9U3KHoAau^WI9|0%=iB zAzZ_0NEWo;@g710Fd@UWkP+ad;MI-i_b%O858%aYFuvP$!sJK@`L`Bc%S7M`q9TGZ zsmz+G#J(EdgH$0LNV4xhE6t#kFitOAD7xr^8*b~*^=Rgr3Rz8n5v%VyH6*D8An_c1 z{?g)*;A7ycPR1#k8B(OZ`Grp@!$K{iKZuX1Zt(GEM1Qza2(^n>4$VzD%P=HOdvcfM zuwOJ%DB}@KUk(aip#`y)A$=}mO)zq;#txnrgu|j;6bbau*0)RLe!dzwoyA^dcQuW> zHpb(=c?s*nD#cD)mXMDNQ?kFEutKb(b?d(q1Ne}xK0sd!WeF9V=-BW(6&3>?uD&Tw zH%ee+9slXr$pZwBnHExfYQ&n)S<(icH|dQ_sxzLVl(r@KsH6Zgxs~f(9*9wz+;`!x z>t^nl{h$onU_(4Hw-WIySvJ-lvI#aD3LG?P1nbGT97RL$qt?jM`D8Ed6smFixVRU zi?9E^c?PILs_?$B!^-%yKLb9zEZ^fl#>br|L;_>oro@_TA72mZ9Mb*SW-})4AyeFM z{v4t=ig)&w-FckiC;n~V6kstHQ}X#G@3x>;z)|oMx&W?RzAE(IfSy%SuLGID#ew#X zf~4wi>U=HfKez9Bca)QIz2joz9$-5H-QWTMaK%gqyn#%9nTQxboDM_^$pz^&`39ZO z{h!OCQd*JP)(;4Nr#KAxKQF+FO`#J$9a<&z(+|>Xxqa31i zE7jKP!-sJ#H9Vel>A64$+vplq@?@(QDftnMS(N(JL~Lwzaeu86kAL>iD0SGdN8nL& z3V7psaX+D4Vh1i|oDi}3TpB4Uv4wTbW>giowf$&5$esz)z3$E%QWFVl)19ksy2i{~ z8rSKtRi8ZFtgaiFSdbPH@nB?~#V5$~CY!({{xG_Ow%~QwJlfg={eD%bVD~eBBKe>o zwCmJ>!%^&K*hN{;?lW@#@_l!~fgKlbp5d@LANgAUteg=g>%>>g9@LM3XSR0=KV)L> z1<}~{SM4kT{i!N4{2?RmHsRREU+dQ7(0s5%6i_fN+p2Z=kOb8WA>}8e;UmLJ!;vA zbNxYlboT3+#9zeQ&%O!nf^4~LJQ~I6bfKsS-WHMc;?i1p8+hEE%J^yd^S12!IVD9KrC)(3vwzJx(nMNX7;1N@eZ;LSALkVI zS{4o=itA{y2V6xl>)Q{_Jm#9$no)zN7)CIyE-=g^vtJ%;L6=!lob+<9+=ZTdtEmaL zyvtqUMv4AMpO@;H0~Fvskav7|hSBct_5_JID3cL&E^KudRU7;{_j{(Y89ix3vE!Ca zvkc591kWHX+A#%M1Pfh$A49){8!*xP+NmGk!RYnqrd4)~_23f$zLuzn>U^|HwKPYdTB`b82bXygee+=K%P{UT4rkO22U% z@$w*Yd#BbN%ds?$Dcql|r4&*lG{F`gY%A@n#ReFi-|RWKJ^j@G@{}9M@*lK4iqd+L7 zg_m>v7M_6pb2w(*Z=}nr_e`h8t!2da;?&pxf*8I27e}VLc-S)F-I16#+`U}prc1*w zSk+ckGaWPecg5oeW7K(A1E(6}Tm(E&mu`RBr%lKKc`4T=CeHD3wZHMC0W-kciNp2T^F?@f4J5k{%&CIXf+O;OkKS~3NS|^`~&;4 z%cQD%9xL$6Cdd5)B#wS(q6|6mfa9XeUiCJEwgD}Ok9HYE|RPxi)lTTKdA0_yBjCq8*W(M&L#8}7@_ zgfmNf60S?SxvSTHM7ZK|Usbv^>Ax#_+PdSsM%>1LhDh_=@YVv|Ex*mG<6J^dG6V04 z%Elj^ovJkI<<1dmwr&-w@Q-{Op}jhN{>A86D{00=C0EBDeDK1H7-LGnnwK2yakh+c|D7izPz-waca?^wVD)n*B(&wM2LoOv4XH#IOIEcxNGd;c zU>7>)9Z5EsJUlS3*m+;@Us3AgwU4+>B1^ZL*_>1ysSGYE@~@0$9>hPW6`WR;f1%d7 z9J}o3XWyDqV{3#)p*WRaw0sT8>kd;UZM*6g+N)=n5Vif99f~{qag_p{m4>?RG zQzfyy!jGwoU0-?dZhLAW()*8Tm-y{S&~CF;j+;zPx(xT(zm+EuJWL46z$eyu>Wb^t z0gL|4ZyDc9?!G5(_5}TkBlKO;4?QwFA3jyHcDfZ?PCZhWykbqO{Oxm2f&S4dS}UH^ znlJqHQ>*UNX_}A@RiXVSWJN1qlYB<6sv+4zX*qw0{&lZ(mPMS#BK70~CThS2eqB_e zMI&K?Cka52EO2U#XB|qjpEGXjMR$3CyCxDz zGTfgf|ClFyE>)Y<#;FVq>JxbA1s~sb~;1fmcnQ^E>?+f=(wVY~oqkrzn8z{_rFb z*aadf`C7xW%6AuVH$-&m^N2kotlCB+y&>x2Pqzc3MQOce^G(TN{jB_j;D*-$qzbnu zC3!IUEs{1uM>2DjR%te!aZXglB6mib$m$!Xb<}fdKxtF-@9!D3^eyjlFo&Zrz1rV| zizW@t>7#&u+^Wv1?1R4Tyfz7K+qekWmfi}PQT|Y)8tP7F58{cvlxDVX%PMtuD3Ce# zh$9pD~Lr z>?APRnI0Xc*w{6T+*_Q?Z!$?Gn=wi+3+ROSNB^Mihm!Q$f0y_1{5Xo@#HHnJ*`KwZ z??-o+vmmd_-rGStzvZS59msGAnago7J}<9q+S$9vAye;p4TmAqVJ~$}ys>I@~(5p=`vjHy3eTe!fX0 zg$~ywehec)L2kg${aGsOpMyM5Ox!Q94k{TCq)(mN#3B8nuh%wAOaA!<_$}>*%ZP4c z4&YKNGnN$Yvr<7ii3ZP&WSz#vcr{AL^Fr`v_YS}B+{ur={rO{G;cnaI=0K_AP#iiPU|!1v{qRbAp4*z2Bx%0 z)xK7J$YhjUIA{Db*$r?4om$H=g!#gjrj`RfDx4%yo5XR-ipK!8-vC65oOtLMffn}B z2h^f_C?G-pSKKA9a}jZvb1tck6J=UJtfikNfO4P*c=Lqe}UB4L*1{hriAS zw3Bc$521w8wC zAetjquP@7gr>qjHds|s;c5wS`l$lX2-@!QvkGUz$B3AnpCIXV=+7P*u0tqnpo|nQK zYPi~7#j0I2?2uBf-(~c4_v7PFl;WHMuCw6l)smie`(fzPq=;8KoB5A{Lpef*l5Pdn zU1Lz)-bl8yoa8dy>#?bhS09|Kqf;Beuaf(Pxd!#sZIgZD3`(}yG}KIQpmUhCL%p9o z%st$enS$6WV*=i}J#-vlC~slNMPmPs@{Mv;$KeFU~_aK9QP8QiZo(VLi_QYMiS zCOV^?*0Qdy@!?dEh1eq{cQpTTE%!`u>h_P7XaoT=F7eZVoaHaSrax;K?itA}n!^jA zZu(PE&5XyFB?^%?Up@J6aJ4s~YTP_LOWU zEI6G%DsRrU{W&7IUs6eGT_ad$2E*=76`z={$2bEEXPq==ULj@zW_ICFvHPNP$=Dx(iW(^!vt zh$ex%Ef14`E3<(lhaAO94d-aSw?mW6fzI(pjuP#=B3@q?SDwcC#O+V|6rQU)ng{jk z?~B)9qhGm6cR+gBC}xDfYMDNJa5=shd*TwtZ}Gohfxz+GhtM7lv<9`-eB#s;Eol!ASt_5|0xfy|fV&QU0z zuQkPr((bPa9^b=3YfPGOk`=~RA-ru`VYu>ifE?__w8$VSul-M}f4OHD!-8$N711iwivss$!Zgu)$KHn#S_Wa z0|)rv_Raa4l^ZjAX^hHT`>rM&*`&e}J8;ZY@Olgw{O82gJc zR|22vgLz?A8oH&!tWe65rs>YrO440=U}Y)Qd@uo#g7d1__!(o|%=%WO>a~*0=ceSU zZ|I(9Ih|R{$@VembmV>4q`a#OB}U^DML1bUviBT3SHAEQH#MErq3QRd*?O=KU-%!hn;BtBiFCos-tX4V(v5R^1n``Lus0_p_b->Wd-$A}2Kf zuWR(`UoKi_p0YJ|kCeL$lol_lk6Z*d9;)l;aQsBCjxWRc{8=omZ+}JtG z+%v^!aBaod$rcl}KK~VHr74Nyd2rbO>E1N!TpbWm23Xyr5ruZC*4O7n`8! zQscmE$=s(?Ao!jNYLJbEmRv?puZ`YD7h~?n^s(3_*;DSrpi6Of)+S>gGt1$wb;)pc z(4xIS#H;(#tCUOe2nr{on!K^<9p=6Xp35X$zGn3?wAXRI>0@+w8OeqLb4?6}0i#to z_Y#|5Keg7t2k-s_{7R~^HGiX+K45tTODGeh*;Jdx^t5^BZj8DuDm%1Y&>(`E>?RVe zaJ6z)Q8^;eI?HyI%zeIjXk8RjTR6Fp3LmlWY#i?}*h^rB%#i{1oc1aN!->|?M=o2$ zk9f4rX|sohNQd1Cl}U$G5hooysiQ z0^}(uU-9!kUoXXg4)T8LaiE-*?ArV!c+zyqErgIx-!w2nNNw#N^>f^dH)Mj{bO&7L=Vwpb)>xgsdFRYwQm>Iq=4U=L> z!ez{Y&R5Sl)6>K#c7}2my*GSjS)hCXc;EYI5VX-WCcPOK>ckmz(hwx z9Y#?OGm>5$5vesoZ&*Tx%Uz$HfMSTDK`xPmfXrPLw|9Ts__)3_BUhBnmnbhS4A)L_ zTVy}T8wWzc%N|pY%$IA2FrT5LGR?}bU2o?mzXYtcM;+PRUQ2S#ZU~`8J*GqHAI$Rizv~o|TZKIr@C_ z{bEwfloYH;ddM8Qp!IAPmMfm*Ybs@*jbm)^DX>3m05CiFrbKkwl4Aa3ZkFBRzoM9GE0L}CmGfrhTPBn`pf$@(1h+WTj zv=EV6L-7or@3^2_aswx?gPQ)@LeVNFty+Nbo7ieZGE`dWvFs@~yo^bDcE6cnuTD{)ay>zFN1`QM{>- zA#owUD-vB~bzHP`TvSeh`9-;k+WfqNf`Te;wXomXtM527<^*cBNlTr7xYv-K~P{q=UcGTFqRB=DhsAdqK~j zW&+3JBoue{Oa^s-vXKo%0i><#GHGlMKARJLQ@QWqO03!#1gh}vMY zYHl$94sq-b3Eejp`Z~Ii9;t)aF2#OZQXBoMXw@#0igymJ!x#y(kr5Z(b)Zm)!x!90pb zRrjLSFEM5cuR!Uu@!Pj4t@v#QyR-A5Z|3ubhnSASb-b5+OM+=(ufJ`Z_=L-yxzX;` z&D&?&C^3x6(&)ze1a!P}%@e13_tk$CLmV8OeR0(pl)Y=7QSY%ca2b%_RCyM`A$8sD z`y`o@i?T;5%JbXGLn@}^M6}caovzn6#n=JB=N|qDLHuXCpbjgI1z2J~;BmmIJaZRw z)1v3J2b!lCpL1g|G_HZY+6*C04$dNBx3HI1LYz<%uo9zgoNV6RTn~|K!ra{3z>iC8 z!9Z;eSb%JCvM>3WFxL+{ad}duhU!zvQ=;#JfkoJ;6B zLk?yShU_Hz_bdsSRLKO7spfi{_H+mqcFiNQ?kk_etPp?z+eBF6#q+c;udfR@4N$hb zYy%hgZAm1mp{})M`az>A-4!pZt}HYfEhHXbG}5$_qU(|yel3Nm6lsNv&5_ng40yhb zAOHNZ$?pcN&i?ASg%L@JTZ}4?YcLE5Z~~}M+M}uxNi@w62~wv&-*tV8mlE>Je!f+> zbKloDR+`Xy^0GGywyk{zFeyyO=h^OA^{&2H@iqZ|e^1jy8E~|qPa1rAk}|YijTdBa z54<6Oj$BCNDhrS$|K!PYW)u5;Hn-vy;j$7Kt7sO<4MGeNY;Y0H$xoH?gy0@9b6B1V zm`ZMiD9tVAvs$k%#9v>fS3`#hl>ui(PO4?Mr#lt3?t-%<@hN&+DIe6l*UGjiWJ}$h z@cq?R4HgzGMS7SqP?)C5+Bqh2q@!rgBn+8jTG(6PF07DKsN2@}>0*IazYO)LibXi~ z_gp-C9tl#>WQX$+JQAm6G+!^_xc|h4h(8(d(b@LnuX)aH*yO932n(sxc$fO1;r=r5 z^uWuPYQTudu|BK+-Ttq-!_)9v`nuEQaKh}lavq*1U5dFloE*Kj-wv!{w|Tvbc6yo2 z_I^+xDSeM9gGv`~RK`%6Mcjr94lN$J7^_+%H+O@##^?>{SFa^f@>`Ai6GI%H3kvO1 zeKRo0y9Lzm(1Bu&^!ts=^w4s~4X87Ax0ee=inC0z<77n=jyOdq$aZh+93R+lQbdc8 z3HZ)&Az^f`dtnCVc;IJgOBvFmwx1A(H9oBj^ea|((R>z|!#>gmGTv4e#V-MW(EGqsEv1s!K zEI7P*Q{quSWRepJSWSaun3*(~ZdF#fT^k5%yMQ}o?mHsW?*k%1UDtb7 zj{(OQ`n%8b`Sg2uu>|Afp>K7B2qSwk4cr1_eJA8B`JLGmzDzKv9Iy*wgjm-n4^+ez z_9r{O=Bb9eaN5K*Rp_(W9!zaP$zVMfdvGb)mRRt0EVRBKnNi z&#BE2Y5c3VWpy(^_tC>a){GOKyCB)7*yxHUDnTo|o^~cOiQEDgPooPwnpx&d)!BT6 zW(6-eKV%C)9A}4u41s26#Sn!`LY(=_myog`*oc8bbczaTBUs5yb?qdyw(B;v7Pxb; zKWRg+%z_xWFQBZZLsT-TDCdV>=T`_f+eFSGf}Zi^{;0j{a=~*6%4df$`6xi+fYOX%s>@*xACOI7K5JU3JJRg4lu)MU{LBJ;9gHj{WXZ zq3iJ1t1eT;%K_m&d&=ZPxvzU~^c8BWPd)8Y43QZJDJiA1bK<(Zg*FYHuEEg&Dt%~X z216weXfnMH54(#W|Lvsa6lKoQ`@yi zOzYNC<4~!dQI&_z3g-08d$N8OmaDMuxSpJ;_1<%o1^dyDAy|)o_=vt2STdj}?)04D zmYn3zs44$2;^g|=`?^}q&mWMRCq98pLv2*`B*V^48*t%LAN5Hdm!>KOua<8kgqNoKQrQZ6_Jp54khSG*#B}GX59zL?>;X;NVC%ap zE%IN(TC}r6e=HR;&AR(t%f_-q6EaAo80$qfc}4`4VlC)qhSN}#5q_Kq1vyWO<-2Ke zcHFK`!^0s4rJ@cZo~kS-vtNU=Mw7g~%whNtR$6G9?A-NE6${i&Gvhd%=jV$o$)n-a zJn_6zGN`@F5q=4}fdK9VG>Me5gf^ufLq zVDFUo!92pj5IKwQWXuUu287C9D7HL^sxy#-)VOu<<%>s7jJ1{DWTV&R)tVW^k;%0X z1x}AfLt5gSw7Tq_WB0X?&@_-^P1?DLI*Vg^iAC-0Toh=XCy?bC^zAxRpl`QKoPf!? zEDM@lvUhlqy_MxVDin>a32T(r+|jXoTpk?(1$kA`ZlKa6AJcr>jDqLRj0QnBN9qd< zbg$B0K{=T4%_WPbjhLT!)8vAW-I|71c-r9f886y=i*+Q8do8$a^n*1l`*HfVHA=8K zPbMb??9NZcw|-IR)Mykt1@^R~XA?_qFs&3I|0?+qt8@Iu>H{%Zp&uajsqBMv5t4iaJ?=y*!9k(*#yAJ6XA}A2|*v3{Aw`fXaojFRR*c z<2;EDzT}vmdebe;$(S08w^J4W4zrcJetEdIs%3W94y?oePy>{=6 z4N!~?@WR}7RHm3tY&V=mp`Re}j0~~?pk^0y+-Eu;F!qi5EfYiEO28r36{qJlJ{MUy zuV2n3x+QF?@sK4~8jd+zYi`Fpf((S1_is%id6Wfq%`SG34N1WQJ_9k{FzPB;)|DyL z2^JUFMt>0^X%WFv%E?B5UYQdqYM=1aV6KV7A_?ZG_~HjCOwBz`q#HQGXznGxAOdNc zkR{REq)tT~jP@J`n5@0qeM~p8mN^XX5vx;D(|$v%NF;x+sy5H5r=BW_tp|3K3yP=> z(t;?2*10l+TfPe57c`gFv>YT%kvWRk40;ul8z6!d`iLa4;ZoXwmN|r3K)h^!?!@; z+X~&Wsy}IZcVU&kQ{DP@m(@5H)q*qt)$1}#g71a8Ro4s_+DnpZy|6Q4T45(I?6tz0 ziW3FpNdyGTGiqYvSzFWRW-7Nt3v06zwC+-hamuE*e|bBKiLn6lyn-yezqaoRXFS+I zS`>h)&%}EDPx8vt49acI1(h7u+9p*_YlzJP)NMHYChtbJm`SrJb&W6YS4sUSR+F>t zplqGgy*)K%Hr7a9lPxYy6nEqtMv14mxPPam{Y{7m}gwDjo^GDLXK*(jC zhuU4dVz|g)n}E_>dd?&^6h1PvdwL{B(~_s3zsoicXI(BCa1pE$_>r_;2ywail5JNZHgi+W1E7}{gQZfH>dnGT_U=A1?m}8#m zSNsjUB<{%_(=IaMyOFe)PTVd_tsn1zxrOjCq%DYKo@$xrZP7Z6+C(`aLDY(f_A9@y z^l7Nss3I<+9-$r6DXI1dx$+9fH5m{Ftqh8%TpJE)=bmKvX+t7m(&vPSwcWo`(#KGz zm(kj^e41;k&RF-_>3>U=?4|7s3BsphHlE|MX=~RQ=^G-va{H_+p5X>k$gVRwci|{! zNE1LKC`ilDEh#o8OPsKx1{1&i)WdHYt6;{d0MeK=Q)}UprmzB>&m7jpR(;m z^E|Kd$S6=*Eu**ritT*i714zxE1#3!Hl+m(qpDQV7mE%AhGkWQRI71O z9ssiowpl6WEFgAz%P{6drt{WfFm9DpH~P-#s8 z5E!E4khgG~TDPI<(TGu+8zx#^vBgu$wGa7p++4w2uAK$fbKE<93(gb~FXdUdbJdUT z0ygSWo)!3jrs3o6s{~%bq#RewaISnn|6GY_lD!khtj|aw;cTtirHqpCivWH580|?V zt+{Yy`Kjff@{^RdU)A*@6n61r<-VdWI&TjC@5<>Fw@CAUKxSM&am{QWZO_;MQTCZF|#ANMlr5ML3CLaaJ8N0X=qj*SkY@i zVOXQBJb&hHdQq|)Bc(|gRO{&l91#%Gnw>wt+nuU7%eISkmj^IOU68k$*}PYh3rj#Z zThdq%JIpY=8OHUfQZ7mGe^2_kvNm?KGiOw}^3uK?!FuD-%rW6R%bOv6-+_6#flPeI z;^O0L>RjR*_QpL<4X%;g>W~3C=>p(ivx|HlKHu?WPEOpA11X#x4o* z%dS|DMh>oBn>AzJVN)EJz!_QjGlyG61P>um_{E7LxiFLyauBtk(&kSknhP4GW|~JL7iZRZwk{=MP-#4xpa6Tsy0<{hEv)QI!=U}tzM#}eEh}tuH%`M-^nE|3%`1WBr%Jfuo6cOXSi2`PN`S{a}!RB z?!N5U##~Efcx0WG6`uhm4zbPDc0uuc;~_QR%ElPtR9nP8(Y8bu$_quxT&-JX#88WP z$fEuC`X?Edf^GL3(Xx4O=VjxaW=97Y<$U{(a+iZwe1U^6*E&o+8X@63@LpcQ7+3e| z6QOYz2tp0uG7r|;vxK#&s#^rVvF&DYRi{y2MNqfo`G6J{$`1-!2;S8CO)dFVx3Im` zRI%8=nH8QMZ~<;w2q(yr!A9w5Whj%&c&}Q4@~T@(0^6;ervu&v?Uhn*@9)MX|0zy^ zie^2n9>r5D45}vtaUu7|xxzd@Z5dayvNe^c2e6Mh?sqx?%(=T!%c)w%DyBHL{Bda+2065s8-#8Ip%~Qyo?*(of zGq&O&-LI=3f7}R=d(mH%H@!Ss$cFh2J*o?U1H+y#ywf*42&9TC`zF9CCTo&5VCx!` z4Xc3r%=%h6)j2-V-dD-J#hWvT&{~|ov~s93Mi>bzQ$7?j1?AtE8w!;HpGiK92q^4) zgRxzp=TR$9P`hQeMZTG1kwgzNXsG*+0^Jej4qsja8h&G`=X{e~ANDG?Z zTG__oN!jvO03=m86?_{p)m8#I{HH7F{JQUr+LI6a*Cmi^LYs+InTE)5-JuG8aZj&` zc5;r)j=f2-go&;{Kbg%(l^^e|B0>42M!oj*$ds9SBTM*5ZojvysS>7*!_d-wt$f%W zg$QXV3t?E?+k&0o>1l4Jl*mbI%LTGW&MuyTuxWSBA!3H@D9yp!fe@LL%A@%`Ze9T!y&>BzpY@9p zO2fG@5z79iKZ7)Ni0#c+AX-PwNPbTsyXX4s3Sn&{8EeoF`*El3w_QYEUMS1h0X$Gj z2obYiY&!rNxz^My4auQVChRq z3Xam8kj?tBkQxl(FCF1^)7dwWkbUNo^6kr@@b*nXm;*RRoMLM<;ISSViw}l0%wII^ zShpUixgRlj#&M?V;}}w;2QR*V64At9x49S&F1I>zFSE6>Y`bM`?z*M;A&PF{6{!t9 za>QW%cF5krVKzc%P|T`>#qRY}7yq*4IjR}L>~2r?7Ii?RGIRB7EGqzlV^7WCh}m4T zu8&&!0sq_$<~Vq3nv<*}hAYaw@S}K#!f(2{3=ORY%m&2?;=q>cA}jcoj>Q86-8 zm2Abs?6u!{zb^jjr|&H(oR+_=s`~V_XYPZldz^N+|92nks!ECL0W4=R=Zn_*cI2`|-T=T+9UHYzJTYVjUBDWDVL_Pn27Ca*C;Xl% zP83#3JTm$S7{X(1`j^c+9hRrG3}MDE@7tRSQ*hmiNTzY&1pL~ou{Zn*%^Yh@Nl$D2 z-Gh#a04+SDLlyeB#E&ch@xsUxErxe<>HS?}uobLINGcag{(%&}?7Ou+4Qf8;n-&oV znZCTBRYALp86ou)RyVoPC#@`IK{4W5&RB}nr9{Dcg>R%@!qj*d>>%9=m7siwN<;E^dLg_9m8l{#I&0Fb0t zc-v|0Zdcy*Khy=#0#_q}c1krbtGLhIv~D^XY6!=ncJFGQ^WLf*EqXb&RCiIv{!SmJ z61g2%dbdCp0}M)im4~&{Di3aUj%b*$nQy5BsFtA9%idJpY#wJA?#m9;0;mW1L-T;mpDHZ*eOslxNoE<1?HOoHY!_$QzW| zD{Ts_f-gz;H-mEaLh-q>yk^M*s*T9ntu&h374so4=WoB6;U9?!MZ^Z+nM|!BrgP(5 ze_FkKQ$E0_x>RoHtrMR za#PlpN-A7EIqtt@sl zK_(blYE==>UNjo>O+l}joGlf+NLFQzA*>yT1NE-r$ma-?Y(fZFO{w?2=ARH>P1evV zmsl~$u<&!nu94E$%k&K=%a_EX82ysIC<6Gxm-P_kxSc&HJ7=09Xx1EGa#fxsOL?Xn zW0U--PGJ%{*ZOrDV@+*y4iR|!NfYeer8}k==c&{0XYwSuv-!tR88Q(|-mE)^xc}8s z0@03>{rr-BQh=+&c>XRq;w4Y+^`4N&O%rbR*_56|YNS^4@!ew52L%2%F;9qwPf&Yh z*W3GucQZ`4xc!fZX?CkuQi6iJ;=W2Y=ajN}a}?i`zd`t)G;ni!hQ?Z(`bK$#qWFDxgPMGwO!pc5(-$L1KsoE}AjbA2z&96&h8idWu`GMA<_)^>Ee&K({0Qq%|Gn@{xy8-wEYInXc)?E-SYEglp|Pr6drf>CdhKJ>!t_Y9`FQc1 zzXoyVI~GYSGaUjWVn0P8eT4}`_t-C3Ba6-D5CAXut(K5GK3RUE?2yGQJ6lhYt5V`{ zny+j!tF}C0UCLwbsTMTQSf<%5yx=)#tYmhp2AP+oHQOk|LznKzmnd<5^Q}7gm^Y+i z&OD3rNtw_7t*i0Od!qR_vI&mg?F7<4&r6h=-3h-iebDwW_J;AW*SSpzvB%E-!S#r3 z6N$(#?UPzO)Hi!rAdsQ;EGW^;E)qsD#Yoh-i)CgX0VqD1Pw#mBJ=7hdtO zr-DDME~e}(m@^m3tPt-z+W3_mxih!JMC9~7lT~-yyeNIRDFcx68Sm!tAE*RPENie0 zR#kIqK=OTc`LyUc`B=qNc)YeF`SwD~ zG(WI6jIv|2+xED9n7)jUP>=arx-{>9C<^~sfrt3FIKQ%B(1I6~ zrrmx0=G7+!y&7+Q6MLKC$+NFVEJ>LJifDy#%n5A}9a1DJEvmVH;HKnQ()X>4#_5HR z;q?VA9w?@{EBtKWU_~AjPxRLVCYtb*a7<7Nx-_P2M#1?(XyH0=R16tpdrW(mR%h2& zrnIpEqbp9_aH6h;d)%5?R7^Gx*oC--aj`2oo2lB$BdZFuR=zSj8s<`tCW1MYFG`U4 z5dB+Hd!4bU1vVMw=5Wzw>Vq%fL7lZW+gtT4Sz+~FZuKlj%DK<@x?{(Vb@XW8S=1-e zKNYmVi00|Hs$5ps&Ke;!rZ$!1vo>}@O4p8I*`3iUH2G{EDZYc{_vxO^_B^FW=%N8b z3zLD6?h%ncs|Iu8P=b6QUP$VyOtVM*kMOLPrMHK?6B|Dl6TSP=v68=dFC;l;w~_L& z9vloJ&vvGnT*XoIe`52m&K(bj#Ks&H$6-@Rtd+rl6yvKgcOm_h1mijs_EkEf}DKyZc2YE?kZEn%%pW2*DU0YN2qh9fk zAxkAbD`LaVPqO?A=OWr{Inj&==e41T6iuNfvg*wGM2YY7f>xtFPCHz3dNxV15sT9vP(d+e=DD++Us!W{P)tA_Ui+T;=j2wugp)SeGq65bFyS{g4PKW}8U#58n&N>e?#oKfrHORjARga-;ld)DndsbhvuyC86mv zDV>h-b0zjmQ(p^}d?6JT>CaZRi_~y|oz|+h(yu?u+?b7I_Ol7`yYE=BirCZO_D1Tj zuF5|MOYC@8QTdY0Yfx(|n`2C~O(xi?Mmy0IHrXeA7*+w-aP1Frs90zpu5C!;SG|l$ zB*n|0KjN<=(O(LY+Rba8+qq95)q zNhZjkE$Y$6%!X_H5bf3HU7BXT%K;a`L!97hKD(H=j`Ck}Q^mhKs6^Tm$E z1+Ty!tM#`Z^bLqZkh#|Ie{!uZb>g%bCqMG*Nfzs7YS7>b`G^cb9Ae31D@B z9S0hCvObIXk?8zKYN`2%|42{3-;;g|aJWua^=XL7PF&e5q*~SC{YamICcV{Qbb)uY za)RKREsXEa$}QT(+8bRBk~{aX1Wj}GwAeH2bnnO`vY$SAKI} zww)BUfL-Fdd%OxvfI+~P0siP{PyEMT@IchtZ>x-Z5pCYHV7tWGJg?sTte3hKAmfdx zvnR;qHm=G)je8wvxDgJKw9GV$(6r~PCJSjJgN(WUUOGa@{nXR34f*E(s38q}8W_+t!Ps^33oo*(Fc74%fi5WR5!!$z5 z#tlV_L09eJ<9;MIDN)tO^^Y|)Km15y=Xt8ExJ#ph-NL3@g=+(Q$11m?rrs&2&y{n9 zl~fYiaZVx+%NlmY0#tqbLl7p40iLVeY_RkU5;{_;ug(h`;!dHqB2S1FYI_G8jVZ0c zku-r<*f^4IMGUSdwW{0k?(OjyRQEz%KPWXKSU2-2HuhExf?930W_^y%3=h-tRY`~7 zuVpKwJb;?sapZ9=N`5#~sc9%kICj-AP)2)x`qUM#bO+F+jr_0@H=|`3H~SRDGXIvf zx&m#^%JU8Sk93#OZP?+AP>r524zE$cVQm{6ugjafFLk81c+vxVMt7C%(T7x_-to)m zV8FgN5wcF&0#ljOp8|_|igd9Cg-aFJ%9IIQ3T#FAD+ZQXgJ;rDxec;0=f2J(^I7HU z`>Ttuwn%DzIoB?e8kSvFopU2cV;*m(7{xlsc@cX z{=J(;d!_9X;0zOg*1&>6sr%EV7~cpduPo824w!Oa{`+yKaa5qm^T0zRK1rjgr!#SW zZ{RNO%wK$U#t0lUR-N_p`OS)Yw7Y)=A$V!O-j^Q)F-k32ii2C$(mbyN|M&o9iMCtH<|9g&Om?tCR+gI> zVQXie>}`tEILuN{hxC=!p{>T80o%oPwj`6fnMB7fZ+?p=u-YgtpkX;@Ss_y;!YJhgV+2~xNVTG58gz1Gd*;N+C{YvnZ2poFGyWHPXcVf8PVpJwxQp2(k z^#i&^oF+FzXY}x!*H{4 ze-G5rSZ*|5!y&A85Zs?Xwe83uOR?cnRo_I6s$U^YiL>#MKb+*-eWhVgr9z^M$wY_L zXEL!eUeAI&UJGS>HmTC1Gae>eHJ!Q1l>nis3DY}!PK|c8pVNY7XGLI?}2@436-w=O%Ce4 zi)e&&e5jybzF3+9#uG1t6sOGC1zV80!KW(OaDA)8N^~?jKJB+8pkA_q^SOVvl*~^E zuG6oCBCcK{#QiTnN3dus1Ek8SrG$*War$QsO>|l=I#$U`!T>G za?9*_i`XvLdHM2fjQH%f%W`1xNYwR4K2yTIByq8SBW|n9=_I>)o>bBY8bBFa&w*BE zlUF=66a_QnKhL&i$ZBSx%ShiYmU4>Xtl;Q$)8TcG--`FLw-B@Q+UujWpes@E&6_IR ziI-Q-U!gq{iuO^R|7=p(ZIo;=ps9Xv&!b1k2B<0rW58+6wT;A3dp$l`eEQ$<|K}^` z>>oSjdwxBGL#sI}wBJNNdXoJ7HAD=6%ay5;^gKf)K>6$0w?}&=Dm?Lav-_S75fqj7 z+b{C*H#<&z$nSTrmrcLp;luL` zA|ejSSFc=MO9~^P{|+eK@hJt9&AC#{K)pNHvvb6)Buct^{u83Qs9*t|lZ;yjX=rNEI`1hH< zdP0Xdb3r2j-SGW^QPLqgqG`D>{*pRU1Lz1su34m2ZZP{nDk1WZBK@ElC(Il4|9**# zg2K2U&nZ-C$9eT>Kd40hqPS`b$Ke9w&gx9Jcu?>L5u|LfD|La1w4k2#p6>}PNWeFYEA(&+`7w*x^f>ANo6hOgCvE!)TfPQV zL%Qlr)MR8|kedetA*+mFUD3$LKd#67CiDogXiBdCCRh#2b;enx-n@>o=WL`Oc~@Y} zuSt~<-6v90Ech1*MGlPBSGTEM7RcQ*$Dvu_$UiCygm3>s`}y1NjPqVa2I3sp*R0WH zk)M_ghQ-X@y_vib+Y&Zs?J*I?-9AOyNP-7(^_<`Npg5WPy{WcCmyqjYe-b+MzoC{; z`>yy~<&#hf$z74vjs{xv92&<%Zk-77%nQEbYf9X52kR-&kP{l$lfN`$p9w8ZQE6EOl9{D() zM9TGc?R%Ime~h8#FyUWfc7i5@YXrqcOz#_~{}SYE1OXT19VPY_l%)So5# zv(d6GxUUj(;zP%yVjJgW9-q_yc^1UOxB}n!1~J)X^|oCz2#8Tr(^ht92WgY+JQaz zP^n(3QcWPT^IyY_g6Wx!Vu0u{L-z$4qahcxSn0p6MBxb>$|_oJ!~cbn6{ht4%QE)F zBx&S(AAgOW@~!<~6hFp%;8MkXmH#*A@^m5SkI`D#!?G0r47s1S-3m?vMIvdW`Tx>b zWftcE>D9yKm7zm{Yli4uDtd0_XJCcuuQ-w3{LT12BkBbmaCt7bB*CYY32AQ9EwO^@ z@N`~{y?fw#yF(xIHvylKDNlkw^*?s>#itNQT^sruf`w0N1juY7@!PZqo>f?nUJD@tifP6k0AzUl zn`l3R3O#o5eo(p+<)#$=s1hF+5GVIh)Hl=I|$=7?3Kv)R2s|LkB{_)Mv_8I-%)>6kp$xZt6@6~+{nYV zrOn1atr>PB!Jv};2aMvMKc28^pwMV*)fkjN?Rg^gx9Mi3LaJD=J{IrxTa-;GIw{Rl z)Xb1EEEFzo2=>2L>zoud_u3*jo?@gxTyL;Q-(ax zzX*LCtU*^;yWPRWmxgh&BXId!rHAfZ{|laCHkD$uFy2U|0o_ux~aN8|i2 z3e$~93lz}ecvoae@ED{$!D_qyM6K`25q_T-vV0~roSZTW6||Neb{d3m|w_UqVsipwV%ITVQT%s;lRMjd*E?I zM64-X{?q%zIR`Nd!?Q&|+|FIn{CBMy@1BOcbbk}yS#CyA2%8=BmNV#r$Jx{LA=9fj z;IRLfeQb5j1a595){&jBXoCO^Yqq?Up>F{G!57n*_e5hCQ`C9Fq5w(9I7qKuKXaE=&=JFtiPn%w_v ziblhw2F^Vrw!<-(CC?llVKe6QW;)>6aLB(f!}}NOAk>!h3z`70y=Z)~1&bLQ;C%HF zXwG0mxtCfy^uU?Mmqn}oXQnC?pkBmrXp*g?aG{Ez8Szh2A0IQJ_7S+V>+(87@(cqQ z2l$JSYK9F(dmd63%npKUuf41b0KbjAG=q{oBo-!I?R5Xb$6d|vqBd9*|8qSX z@E7nRaX$MAph&T%$_)RW-D&C&O@Mt9JN-11HkNh+OZDop2xlbtYv-0Is{Zw7Bl772 zqR93GP^kqyaYb;^RhimN2RaS%$_@1_e`QAYQ`qexD1yq{;m-Gk~@ z>Q$o-t(kdf)|Y?w3D47jfC(#Nxa9ZmXuqWN z7F=uZF%2P6a7%b2{%g^G>C7AOgxIhR`qW5pRfJ?&zLxv zL_WBv!4eb&3N7*G$ykbi)eig>wL{gVmEzJ~Vga{m&9lwwUm^XP$S63Gv3MLc-7|1K zNwSl|e{thCCPbV)qzjZjVY<%A$O828n~y0D<39J8nV4oufI7_n68JnhceB^18P5RL^MjlOjvHJe55*ei1Pg@{43tB*en|?k!k5hND4P_#h4u*9eN% zqxe^~!38`J?P1?Q_yzzUzF{szTpbBhQMk8RUPYnVG%M`c{a&T!?Qn1yoVLcP_k>%Z z60!X&BU&dA)f*bXOVlkQHSY2dWv{U)Ts|4jdS3Ov^cX^!F8r*9XjC)O{fOMCE0EkH zOP}2JuE|QH!O(UtOzz7R&{XBEKjsJbB>N{l3|689pKs}+{2$Gr5q*|NWaT1j;x&%qZ(5$+Mc?x{zLa2rXpqA5x=nih9t(DZMO9Vy~gZp)-u|U9LOc>4n5miU{?S z&r*p5E;1V_bZ9RqzdZ?!xn=df1W8(E(DtIlrBZ;62fhg8IdanemCL)Y)RM0{##HGl zH%5ZrOT6W4+c3IcHj;FrD|XvgOG^tVcO5RNIAdMCRV*^g7)U{JNr-2on?OW^4jbop zwgXQ6fR9Yndsj$td!EV^6A%flPHlldI{j|waO~aqVx;qfKqpRZEPQ*w92qkA#)&-q z;|3UyIl*g>a)_^MIL4Pd-~u31cN=RiY3PCKfn(xgYq7|`rbjb57vw7%Zb_!AWGD2=9igC@`?3Uf5I7HCW z=lSaAn9R(~MK6AAGxmgb@8p1s?^Msp5RG>l)aQACm!9zY5%}aQZ-yuei$(4T9`cTV zd8Zjh90^?^B8$6EY(8>#5ZeykHzj#{l9I#U?03nY=l>HUd3*f)(SHDb!3vBgIGD5! zr5v-TJ=bbB5b=1L*l%tVIGaD85xgdtPAwJ(&xqOB0AFy}JP6>SGp3snCfj4;gh{6d zehxc@`A&zFkMFHJtdBW=?W%u$ZfD=h!(O*bNhweU4sbzJ(Hp9Fx0}MafrpukJz7zK z(?G{FmVXWm)=IE#Lyr3>Jof!Uk4xEJpJDHfW`JgTdRo_KX)J)1&Y1qLB-uKTOoz`g zGnmgvh5ztf>7O>#w+ZNbk1eV|;ZLLoE8DLu8JxAB&j`}Nr^jWGv~;nE*-r4%z4Kds z$gP9!k}_fGKZkDtM!C6r=Q?L-?71CR{x;h4t@WRbcfZ|7*7@0rcTqiSf!eRY%`NX` zhr$A9UoqIPsKpm6o^Q2Dc)!r{%G_pGF%e28V)rX{j5pnnwp)_;%nypEus>WBQ2S9@ ziMCm&TCu>1RGzd>zwuSk6F(M)bK;8-1U!R>(_RwNqCCM4&=I9NQ561b^xiZIZ%R8usWsAj_bf%04WCD>p+oP)GIfVrMtCA zNgG_`OV&i_4wmrWq%tOf5a=IkodY@ZoVl-Cf7(ny}d9@~UDUgDuF zqeBH&o7p4ME}k0<^-M`09_E4dBo~;&0w0036#MzzGZx8Eb^I9&o(dQbX|Y1zPC6hN zhN-zILvQVR6eyYyp_)MFaQ*Z9f9J_SQ4cl@Z8@0J2{X(eaB*7NR8Uo|)vvRe&$%M5 zv$3fcF4fl7MyPmc|ERIC(MYZ}U`FEstuKEY96X{E>{LAMRP5SO$gtO0s3a4hP~3?hM@m-CYs#td?Sb zitPYOh1sLG$!hpI?MABcDfJ5VO^DUK>{piO@jX;=!}A2v)y;~p2`bT{Yn>c9YyM04 z_7^cj@rkH<@T%Pp1v_u|%5oEAH%W~RSk?2*1LzLAYZrz8?#F`EA$0h+4n3{yDZIZu zXcN48!5jpz{bTka9_m1dj`#=kA?$;x0e9Q@Yv&tk!)$GB5yi#DDJ_J^$Vfp!5Rkw$ z5f2L!lWMp5E?;1^JE-~`>aWhr%lq8oZ(GgteR}!5#7MDa0l7>38fd{JBfvhqZYZeP zK~3L+d)+L2xEOXZe|Iy_BcnTdywt|Pv9OAy@7I-Yl7vbR?DZe&x31zS^G49F)yIS` z_90bZyZeFa9eWXAnUC_=(&5I9G(GqDyKzS%K+H92_Dah?rhVPVz0C3;SK8^7{;Ykc zwi?e4$@_kR+5S6HNcN54sB`7v=UUA%Nk*L7iC_V+_AS)9Xex919jon0DdR=A)0f%&DDvGjCwi=BYOI!?sP) zXHSTNV2Yrq_PLR1=@V0cB-!<9Gn>ced%b^G;E%R??_Jj6a+P}2tW#;i#pHv{=+qY| zYdD5-rw|@oAvJszg~I#C4Y4CaLrJO=ov^!diKPi+Z|b3)&9363FeQJM)Gr5q7WIWU z?=6}E?0c|W$;*G8vUkSRM`&pY@$VO^bw1AZyIctSfmkio%)Sx$^4ux2rg5{p0kR+v zaq9epS?f?QxJ-9)==4hX6RMJ9_Q=}J4i)ePAE-Jq*dqJJmGVov*S9pvEYpXqmWfO( zECQentH^Fk2m{Ssm{8ZLXFL|uO5{Nzk%+#1Zz`Bgmc9KlDK5IOo;r^_c0WM(1W~(u zK@qu*fX=S;4R0B(4tzJ}Zv$UP^1B19cf%cD`Nj+NMEZ2bn zGl%ut%h;hgE4O>9m!DZKVYtuSZ40@ibc45s*k)ctaQj+ev;X(@Ic$g zOO|q7^H$645Nb1wt_M%6w7%2`K8DM|$G(9YzwRE~@;Bu0?)xtylgPxxL^r){1d71A za1*2joyY7E)7{e0aE}U?vbUuA8=r4C9pmVVgn^n-mweWG?}eU`+K|Vt3+WPQ*;1Cq z#h(DWAYeupsut9&Rr8&-DqpeNxXs^IlaQDQ=#$4bKhdm3v01SPSnRr_Z{%MeKLbBc z3U@;#p{1qXpqZLcyq^=8o8fcU*y#VnpER&+x17>_$5muyajzvj>uliIpV``Gqq z!Sm@`W?7d4H~)Oo@NRtic);2x%L2c8QYv_a6I;DVa=RX&ty&l4Pn^zx6=g zx%vY^Z0wY&`(D5INXwF+4O0t0A3W9D_Tm!_*`nrrd*KN;uNN2b>DyFnfZzM2Ksq+S zqC(>X?H`IAsvRD)fiF5R4g3tA6B3cA1Z(uzWnyMtHAbRYt?J=$=we*3OWAQA{5H5? zc;Yk>{dAq>^MDJMr@M|&w;uX64_P0i!{4Txx7g}cy917ZDo4(O*1VOEma*PAZ@|*5 zP4X3kqM85mblOf2=?S~)5SlnBs=D?)_%1=^HBN2(tT)|phdNsjlONaTx^XJVQC2+x z!m#sc7R)p!eMW_2O0lD7T8p1hD-7T9-uW5`b^qAPIx)(>g{(bq^KgevOY#GM<7@v6)ug+d~|NO7l%xNIktsN8vCJFITE^IqX#PGtyb9M z8d9kU4H8D#W#=Ffqzc2sJ-U+rc)bTYU@Q~I(ED# z7@B!M01!*tdkF?9XG7BfcTZ2;*Xx`IC5>z4L3Q;6P;C(c$B&OFB}_|wz(SjkP3IY^ zT*#PpqO1+8G@DKYlXEehr>;{X>vl#`5rK1W;G2N^ZDon_qofkEywDFDWDMxTjuLm{ zEel$)-?Eqk^wyds$C_hQ#F{VI$mD}Kx*kBWPFU@2Si_E&`w*S;A=MJ!Ryuj4sV(6V~A+^Q5AHhp&|+H|~u zOs+r`L)8ylI|+TzDsh!r=ner8TQ>KD1TP+}@4j8>&IuNbM4UgT=s|H|fw3 zAC1wSp`)wG5p2=Is0d<(NhFc=o+**}FWV0nk_CB?4yX2dE_&I+^3YL>h@c2N5Bzw# zzPm`g!Wlk`Pf|)qd||$Vw(_VOfCaCXp(Ih0dG?#WaJZ#c3s*%b=#YtDiD6Q=$QvML zkkK<(GfA$n7k(wkavP)v%4hM$8q?$G=;(Se5xVyh7Yjk=AP#qQc3ya?lmKJ>;jnV0 zb;xRFfUJ~Y#3v{S&A#zjwH=Kou2L!EinSK_t<*x=<)I@`PU#dMH+^aOf)dO5a(RVY7yo1LT@@|7QI0M%mp+LA@?U zVjH8(`LJu5k5!*xQ%_t$y%Zv^*7j=wo?{5uepBBcJ?`l~G}T6+)07&)L{$5U%^PqD z*{)xe1Bnyp);Oj9OnC2n6q|+mb}BWWG!-YK-b=SUpSXS_Z@~TeA!MF?QAk7Lu#Xp% ztSsG1m7Oz;%eOOQG9i;4%*O$diqt9j;Tdi^Zy)nV5>Lb2fOV6Odr7|^P&dQarx z75`^~urh4ugU;>TZ>2oeV9CcrC~&4=a46^F>QerRFcS)RPOV$xZWg~lII%*@bXS8k z-$VEObH{Q@h4I&$K=fwaY62q6d2U zc7Nxr)p70Ke9&j|kYMYjkpiLGn!Rm9{iz(~_(vT9HsZu{^}|q`7IS#al+l<^-iOO~ z5>ab2@auQOwgnSejG?-4^0T>bdnsc?lU2L0DMg;NK$e=R{U;nHRhu%5{sQ+JL|K|+ z1SyHzNd7**QxrgkE*<~wHob0(wBkWp=z$dM8hO9c#7Ru}01eyb*Z~=S63VL=>Wj7> zihl#U;}OP~n6R`)JB956>ji37IXwP`BF=MwW|(Blsjd1wbXzGw5M+SQXEWDBuw;|% zpgL~t-aZhdH~XkFJ(`QQ2d$7>Bz!j#hD4j+b}JhY2wTXEu{o#8Ur;t^d0!p>U5>;Q zUM@v=nMb94GzFygX%`d}25Se6t_6?zqC?T3S-MDyR;-=;ES?eYP?y;mN{W8y(3z~= ztwtDKj`4U^_24-n;b^-yX}SoNFldK&AVHvE&jbPf(2ggr>ya#al1x_GWh^O zdhAwQR-1iKsy-?3>gHvysDg zZP!)~<`W*MzY zTCagtzjQkwua@yXoqnC}NBZW5K3Eh8Ti4+$N$OI)FzhG4MX9XydlVV1pgcAQ zrh?ZTV5I)$?YyZ`2pOv?GEsrDc}7ZCIcBJO03GV;gEk#Ga^wZ7Zm&aXj(au2WyBu>iwQI!7^?zngXok;D??<$@fGG2fD#;cBsqQ!*}Z40&be@$Dm?V zM4O+Grk?ns;YJT;vE?e|0>*n|rKpy(z&NXAp}>q4qdE#bE{*1TAl{&p>d`^on2K9@ zF18v|zUf9SI23q%A2m;IUaJme4<=X5C(UCkw#{5;6>EhnB%_EGS*FN6zQ3V4t4Bh1 zf4tE@A-V`|%FitT4u=Da_tcA3?^h}cd{@z$*y5<~awl$-xpul$IAUw|V!zw$9r7V~ z7#SdjX1iv;GIzXRKfNk#zWK(c`2-yS8|YuM8ou5OtGQu^S&>y%t~DYAI)g`H-OCk@tMw1MGS)=J4?w0ekiA88LZrIgt3EA1{; ztlPcj>KP!RT=WdaTDh2PC2zhXqq>i|@{$Tlu>I+WDS_RFn>-ks%y9}yJ|VX}vvOW6 zn;P6~guz$M_x}*O!~Z;E2Xhng8X|9|*d{Q{Hl%&Vl^Fb~E~SYWrmzP~F{$S4bI(m4 z!&L^pExV0wY803msO`*LPt@c18&v+4$q0I-;X+k=0b_C-$w=at`BdU!@dU0G*m%rU)RBxXuVHUG((%=XPFtF{Ezl&=zw-nj`GRfSlxd- z4LN(&qZ)Ougz|3RFD$u#Ds*yG+4biR#$h}6#>dE4zzs9L^UVJVu=i*kn$}O}ufH^- zB!xFY)%JD$2&*Wk90C6u54b6zTQZ5P&8T$AX1G8B%f0{kIwT#eGfL;mI71P!SK_ep zTBhqR25zfc-aXIiJa_B%1;R4z+Fg&ZPOI}W^oP!xy8-nM$!VPWxzK-(JcRBRh*Mb+ zTl^7U3DC!C`&)f0Q+Y^R4>($77-|#pJapy2Bob{ z-xZApANt!s@d|?mb2n?@A6*hV2@=$_GZUcOJ_CR>T_jdmW{tbp`8wy*%@8mUXWWJT zU1V@;^Mv{;I6$wYh7c8~^3qL{yjJp97?C0Qx20=cM+vVZGh-I6bibA>CvoTVG+j1D<{3$?Q=))nL&) z>ZTjy&5v+!dz;}$e$ZX{Pa()h>Sld7B9_dZd{{=a7is{e^FEF-KFJnqSdW@VT)}vx zD11&J3ru0-gTV~XTg9`Fb37+tx~bweNB*i^PYgBD1=w`kxbhga_?6HV7m8QCYNP7J z@r`gc%+-azg^^h5#h=Jj6mfmSgIpnz@X@VocW-=uZQWwix#ZJ#@NO=i*~)Bb-G)S$ zc$*EwrK7%EY@M{#5=e7nH-FYTqpm)I5=;kvjdnWiD2_=! z!(vHe{s?vE=eqUg(8Jf)Mk(i)A#A9px8$n_9G8vy2J)65#wVz!6q$2qft+vbJryjF z$8an=RkxFV!>P9|ac9$1a&XTfY`~l&Sl@Ur< zK=9OQj;T?R61RXydgy{KJoY+GHHl?|Pe0?+V%&V;UavzuV07SVGvH#nd>7|xlO4oA zg}ON~VD)C;iKYpCoB$1E!^1rl%y?C<8mv?;71YQ7{Ij1Ly>wfV69l`n7@JI~t(OB} zm;LZFbK-906jxF^dTB4lM7YOk)Y%x7pO!$>$3cWB$}gPZC(ie+;?9USd?i)gs(0#b zCX>J2FK9^^_ruJfc+~J$H-Y)_KqtO@5f#Xx=J(e8Bh=VS;U(z0W6D3EJRiB-hbe3= zTT;B|rGB24+S-&G$54Oprmdt3Z*h%6$k_%X-K8$0r1i3o0#DWCq-{9Ub`Ws!vv|Hr z9HMGH-jTg>0qTnNV*v69L!uVj6`_?J8c!)D-cxBvr{4Y;2F0=30dIVT zPg4I;bQ@}kJ3~^9chppyN<~>}`lR(^S2n5E{4k-3z204_ZYO7Q6_bd*j}REo>n$?d z$`}gx(-3E^QR~H;djAXbCqyxs5rtbm*v`NiEJlRaXY~WZ3;?eYSRM(I5r;^N{}5I6S1^s+>^tnoI~UAMmKpts#&RR+ZghwDQQu za_MXy`a@!2re0D_-ELe9QeoDC6YC=GzPJYv<8cGGPn!&NL}rH?iVLA$*WMYBmTztM zEkT>1W!jTfwv@TJc`356pfbZd=R^GtefqX?D=Fr}hl1(DZ1my334@06&DB&h8C4m^ zYGd|gyvmYrP<<1u|B;kTwdamfMx~z-`L%0N@&Lfop|TOGQP6HCEM*a48_m-}amXD7N3Y|w{%|h3 zlM@UZ^O`cT?=NfYulXvjdBZk`D2uC?f;CmN`AstJJ+@q!{NG#j0IUz!2G-LnlmaX`^=%Bb9!Y}}7Wv#-5 z;H7mdMqCf>)$?0Zj&6hRAgN0&_HS8(kAa~_v$+6a{6(%9pM&Lm~&&vP!u^Z?&m zpvZS&(6O|i^g+9C3=GvGtwX*!TwyO|*hX*c^)FER%?k<`)-RP)L8f5X_t(sYb_|?N zWeDCcnSYV0cQw>L4-NCTVE)rTzJt%bDCflbz&7S;UKFlfLMxxU+Qph3aw%qX!*rXi z?MXmIJMf}~Aa3}&$C*BsXxf4@Y`)Pns-?AXEzsxk$N1ZdfMY$E*pqSG*^l2M%8ON|c{GyEiBtq%U2OQx38|U_`r?YUsMR`KZ_SS|w zN0q{)eMeEEp7{v~f;3c^<;5SDw0)R_ir-v4k}0)V9=i8&z;z7lSUS3aBVUeEXkPg6 zLxFDa%I&Sf;?0hEGWdK&#Y^xRyS|rq2UOhbuUD-UW4fg0g^Zg~$GPknG6}w$k4rvp zH(+y=!#!vAc2NRcw2U|SA?i&q-)0=V6Qd+VRQE z?z~&;LzeE{#Hjn5_7CrPcZM|Y_pNyQyJ_EAgB7h+_?Jh0j z+;T925L2TV`ds1`9T8JXr`8IY0Yu8x13V_OBOB?oc?I9x(Okr~6Ub@VxZs^6Ms0V)L3{hKR1{ z*4Og>2~Y0G`BQdrui*OV)pV!o(X%%*1xh$4*1EnrY~@K#(SR8&R5B4_Y>1fGyR!SGm}e?+Jqy^|X#1j?8(nRZezJOb zH6)AEr|ZaAmAM>zJMR%55g<}kA{!NBNXng^N#G_!$E0a;i9FYNRpnkVjqg9W|Cfv||hWm24 z&IH4rE#we*WD0fWf@wpDz`0Q5+`awuk!g>wxQ>a7D_}v%YV09gz`zIt%}Qna(O-OD z;~xE|I||O*0UP!vEfn|9)8#_G&;i_q@KIa)ot5a+FUk;2t-RStx^89sm(vwv)yC>G zhJ2es`tqSkbhSv&?_sd=bOI7B5S%i1l>MkoOQkUNL&q?2wE+3v- z$6s`Sbs^Gn0M+UK>(>QGcN={px}A1B;3|IKrdqxw)1^~se9-natQwczWZ763$qq?W z12P(d7Qo|aLR)RfiL&}-PglD}J&ydUxPnC+lG2^E0%-sEEu62yk8dL8Y&M^zLT*kP zS&$|T!@bw5^LTCCR6@F&t(pDN1%Athy%i@O%#T3V9zZ??zAJjxS+RcsP4e9UV``0r z7C8R92OpB_3;19gaejY(*T4GWoSU92^8`kmWr)z%7}u!k0V23<^3cQQIvgvZ2ixDS z**w^Pcbl+)G1n5&3&CEWnU9ph^bEZy0k01|&r=N@kCh~8f2*XLNNW>VT;K)yEgP^H z%zfi;4Uk;K7h&K*M|J30;VZa>UX3QkH(P1-^SK=?idaRAtp#8eXe;YpkHS0mHJhws zMu!WxPD9^*wD2GIr+x z+YU|$^Fv5rfaeb4fQ4}Gn4=kt61YT4DHCJE zTl-AP@Mx(~$OQM5ZQvD)Moe%?iPw3zt!8*8meQZ)yJ>5l%6@($6!2j)yqsR1M!)?q85zs0xDa8YG`~v^lVjC1L4dOYwKzeJ^5cIhUZ3#zK=vHn_oo;fNf>aEP_dp`ttn49o7-{;pel2c@rML1x|PeN=-ndp%&f zW<0sR3@HK|@0>|cS>(%737LreAqSa}IHC+_+L}B=r zr%9h%*J1=-GBJ=&F>y4ox|Fqlz-ep^LO*=xgh>V+pGyT&o=1(^p39Zr*4_Gg~?TTEPU8oNd-xWxb4u&26~?@|hbHK7=6ya0JNGUs~L z#{5T*y6x<#i`Ds{kYhro>O3KLg~_Iq-9~fJoo16467&KTxThS6>4&!o^(Pxooc7}H zj-6K84YfIQ+2$%*WKRha;-SwLBX%R;inQ3Eq+w02`04vT?XD52F#dLDS;jq zhVNs-6V*olqSMo+e@wWs3uW%atp692W?`8x%gq4;6taX4WkzC$EHEQX0^9inLTkj7 z!+)45Ud^U#fEPV&CtaMNZ1-488@#obzlo>R8I z<_F(|_otGnZPYQR_r{Kr-;*6cz`l0LZ_87>)hx9D>$AirsuwWBWVoNh0-_H>KR2ZN z+vG_4@s*3(S6JjafBrVupB3T>zn$bQ9tzWtjnGZJd$q&KgejUQo^p@1Vj(ZC%vB-B z%jcNTyN$i`txUxy@V?xN)uPMk z-k&f{*jdN~Hm{fYiAjF1B0u{{@CLEg>1Gsb!KNyv6LgRsEP3xc|Q&^5$je_&ygLq>gn0u^ov$8zPxy4-k&6f(u z0!szh*$7Hq+{^Mdr<}XMS_t}AC*>)@F5~!I#&vGEH$8aeQRK082J7~SN?+ab*$rf2 z5G-_dwlsiF4e+FG7oQ+xoW_pb2*K(Zz2cydYrLE=A9L z5>I8H=>6Yw%NCwE3e*Zo$Lw7&Q@@6xxZD_6t(8~cccr2U0q~V zv_*2-FlKN0mIR}zXf;=;X+w1veUUZq>7j4)qPP$)Z6dz*lOg*qPN&Z|@nZsL-j%5} zNx_s-xM|JGog^Gu8S*pjmxjS6&L-slzU}t%HAI_~#{waHU^<-G!FH2gI=&F&!o4I> zfo6OOm4;vZ%g5TOR2tTqJPz^$U>8QUJ0!PHmq#q~Bj|Q2OZ@uK!t~~6 z@_-iWs@_>(2MI66p0)1Kxqo5DvaV}Vat&hG@!ueV7w1dYQ0E8%`tDL$B`)b(jaTR7<^`76!h8J`AQHAfYI)sS5mJP0+^Mej<1qgA%anW;}0uze@%5a9XX4F z=}P|LpWTcH{JNGnNF1d7daut~H^MoEX!IhkW`;A6mIKutG67X4U6GXzCP>H&MP5?r zcd{eiA)5O86IrLht>ee0x-#?zHRT~^u2{Z#$h+FY7FwrD@0t*V6TCJdBL`_Oe9VcY zjDj%DJ{6$`xKZl#WN8NOh?c8#(N^U#kM_wMz2xf9Cn=MgtSsVSaRbO(XGTlwU5R z@xlRN_r2xJJ37E1xuCzK#{QB;WzF|nLJwglBT-C#yfcdTpTGVdOkQn(TUs7>;py#V z4cdRC7~8(7Wvf=lcaDy1ifxjbrTa5c4g8+LI!Xf~X^3DE@5*>|$uw7^=DdTvHg0|P z+r)=JiD}8a)bdn(g5X$85GM?!%e0RDJ~F z7KdxE_`3If-n-_I&3CZ}XKczfmkn!=BtE?q1*e{ghtwz~ew4Okp3CBqr4H-L z=e@AK#k0#SDoGj98J*<(sf9>V>%KMekw ztq{$5B&80nyJ_rc;g%MT_h-%i#ED>nvxp8Dq~=OF2GLZYYC;JR?Jv=jXqR~# z-wrIr#TumpHyc*O5mEO`93oT9kq7IF!EU%@|xSFho z^VmG~Y#v&Il9XjS5#e)xNk=xar*Qjt)5p4fRc692VM<$iMG)wQzMk;7m!(q#e)AnB zmdjYk{Or(BF4x_sPq(kxZZLkaPc`dtG=>`-aFAj;v@_ram>7nn-;nT(vYMfN0`B#b zp}~8tgxWtui&Qn|OaoUCRGS#UZ~g0CVtc{j zjt((aq#LetOWTe`rzhKqdSW@?-T{V$@#t~beW0|{4ITL;H5wwvgfwGblQPz**<|&g zxCJaDaSlnRjFn*OLm?}#{p8=>ZhUC@A)F_Xod{FNaQ#^=%>!faFfII`Lj7jV41Xlx z9Llro3Qq9s>}S}YV;BE1oDLxQ`UVlIvX1c5lb$j`%qE~&;m|%U)$-W!Wl-J`W8U3- zG4Cg1brr|0U-3g;vw3%|ERML*k0woun$VNShV0n;Tcf<}C#L;R#l(MdZV9#nKCVYX z87nA+q}jq@Mu_MTW$&i4;nEzF_3tfl#}}sADO4rHk5I}VclGEyvN`{KF?Zw+oyB|( zpT<Qf8;58l#5$`a-`u<+IMGGzmI&X*^o&vs*6hc4>pSbiJWU z??`^s+7efjD*`0aGH$0htR~7>9KC_sKl{Ps4SV46jlqgjp(RhZ+L8RplCVfC$! zN5a()XpgkxJ?ykW2zxy1bzOke+Xt&K32gwTBW>U$kKrgW;jA7x$kh`Hl&1{q_$fIG zWoT6?!;gfQrkrgop`x|jie@fS6{d)_hS%NK--OXI&1mCv7uJU>&&d*vjvak#pd$}q z0}Vz^`K8aozhz5tV2YJ7kTp($e z5Rr)EN`l-e`L1005iLZlz+EuFd0aZf8n@XH+XJiUGx<55e2sW@U)h5ioMq86&=h)> z0X71@m(bS)l8+1jLE(OePYtM(QM@d5IL+w(Yf@vF`Vd;Uy6@82 z9c4`XcS@|O zIuqUgPj$le6e0Xdz6^upO~x;=aedL_9s`oZeNy0ST)Sw3sq5xIO3`FPx1~zZ1k|IQ z?FR-)03!#@?G$-vnu{5?sV`{Q0qkFn$#eeE*iUAld28wRDw&KO-_mksCl?zzLY$2& z{5M1RDH z_^~~{ z@+(FR=q5NR-7bxs!|rfsep?b9$b7Pi9rX1vz$wJ2qh!ywW;4 z#mKdje2}@7VJZ{9wWy@|u1k^cwaa9@hF2ku-u0est3{+WOr}X`p726#W@LI_1tmG( z=qhFVJVpC=JMz>;Hd5;9)xESsc6G=Kj7>CTFwj#~!E@hqwYq%vDT3DB#u%~`T=f)* zehwWC?<7t8PtCr2Ne2Yf+O?0A(dMj;a54E4v}nh(MlYUadHYI&qGI+N)q#1z!))jQ zb);Yj+eq@+RojnqroIz`(S_C+e3o@h>m9~0{^@j-hhXi?wEPI*eGVd;@4@WPBhd2~ zOjVlVy5q$&!d?7F?bJ^v>KBFfeR0Uk%(4+wj{;GS^X&u(a%VYtzIL>&XH044K|CsW zxoTOAj}%!iUx>OEOt=@iyW!Ab#L&&s(mGW0Y7-&(%5G)uhwSd^M&qf#L}*P{HwzJO zf0+tW!yL{8NqO~r>}`X*Xg5joq0v2V96gBJZdwM+=BxeIm z0?xz+TENDAM1Hl~fss3^#R$6knL7xFU0g<*iyds4_@Z`bYToZ>Rr_WRzT&q=rKvNs zcBS#ALB~h79e=#dShDL^!YHEWzgLCsvqE2r5i2S=56`&wt=mzjRWivzS`8RT44)NdIC-u|*tyj8&CYz`Y^T<`{wq~$ z9bOXNCSYasJ`w9jQj)(_X)NdSu8|4>fiS^P0;kq1L$)A@0clH(%;iG-RaR3XA;?2@ zdoibInt;vJeZ+v%94^6aH%9izh-u5wsRa?n8dV{|n6k+Gs_G2F9I5Ef{ZADQjYdjl zEkKt6^Wt#dSz?3f$*#8~kA(Bq?SBns64-kuSuX!k7 zc?Sxvs#62!aa}gGrDb{gX6YVgF`=8{9c@u5#y`T}tT5i0O5>}W)x#$Tp6^1=tNT$m zNGC#(atLbKQa%VBSSlQGGv@Ef)kB6~F9+OecF?cof`iW2AV&=n34EVF%i3WSR+nar z%Ul{8QRIgiQs$I8)HhMnEYib#>Ws(bAy)(D2eN2^uxN6YF5g%320)5@vYM>mLR~{9 zQdJk+qM&&&`|xRj<)y))TSha_8Ua>uBNt&Hg!1xWmHGa36C7yl@m;o^^T*Q659i$w zTew-Cui)=^=Cgmd+Q_PPrGGs%@M!EzH!ESLbN-1!{`EI#L8vIPLPDhxrQPVtN#Gz2 z<8CrR@JB5hhpwfH{Dk1XS1jA%ck(99!Xlu3iL5~P_al&pa6TDT(?4cYh~^~2YwK-) zn!B<~w}WdbYgbVK%Qo~tevGbj)F)|hK5u!jaYZe7;z{sY1mSjNhJ}aRLL=(+R`>mN zU=AA=dA2j*G^UDZ-=n3c)+_)dZV>Dt)d>btglmy>2M3Ssc+& zWsqV^@b12s2@svAZR5E;m#<$Tw(V`6??18M(Af^W zy1TW;;ob8_Xb>qTkw5DS8`Zhf%u55&8bWMW85U>B#un3!_tBc=L${>u?*24_Hom)X z-xg$)xi`yhMx%9-T|sbukp?-$rcz_ov|r9PMByhBN2JAh^$i zgnW-b+*d{!RjxlysSphlC%CfwuCMVaJ#oV0bbA$Z-;&nG{kB#_aPu8+xe|bVnX6-E zjJV@Uij#&>@HDA7y#x68iIFc3Qm<9;(~V)gknalXEnnm%&U=ZR37iAVX+NgQPt(uo zA2ZKOwlMAM)+Ww3!Xg2Ut3OJz!Io~LbF*YwzAeOj*S-BCc+m-RrId$_`Im|PS>H?7 zlBEw+#|G@E0Y@pZ5W~Zh+A{SVLhf~hJU)qY8rhf!v?HUY%pNb@p$4wkg~$i|AxLyL zG`PrK?WFa2_rR1@{fJ@Qxo>>xHce1Pb?N7r!7(Vc-}(pz)&34c7^mToHWB7ML)Tx4 zH`MrgrLSAPukfq4C_5M`I5Uc%g`avzS2H;-h|IUh8e{PEUK<@*W%onwa1+Fl^nIdi zbWGB)rE%eqLx}WHI6P9EvARczR_7oqZL)-K=33GTK>@o!w|FyBlsOs76d_ zJiF7f95D%eq9mJiUxFzW}kVx{r1XpWkDltM?2G zX#X|F=f%EB&cOQq5!Xz<~r8_GT=s_fUQ}iUK%$ns6oR=gnZ}fkNdJBgpqxOGz zbPGr~(h|}oF$4hv6qIfxm6Yz$h#(+J$3#$&Zjjn2L5a~F10*Cz4;b5f^E}V@_x=aF z&(68d^{MOVY_2=nr5@%iXCt#9I;k__s+=7vLG-^|7q-w^g!yEil|bZGmYP#JGZOOa z9RHP#1Z#{8FFdV~=I=!l+70O4^d+KJKT=|iC_%3{LE(|+R;{(fm zgFZ(DSZXd+_9$9vNlD2$MF_FKtE)05PJ_DbR&~$c-x4PUabIpO^X?Y8V86Zc)=%2$ zyVTtaIkFxVdB+(I6@4CI+mYH2mS$^-7=xtEfk=EBy-q>&D=5@jW6!@LsPS_sPi-xbHk9?IT<683aUsQ0WdkqML z5Dy4gU{__OK)aX9WOM!p9PEYW)&gy5(qk&vRs8rSko{j$;3KNr@+a>8w|i0}Hn>4S z-urEYZFunKNC&%_;MHixebwXNf*cizOQB~oTdNKt^^K>)KNIpHVccOGKfL(hX@8`U z>+_-2Z%#jFTvCsP8-0%>#W8b2)6-+uc{ltB4fA@;x#m6#`+7RR7+Z8a&BC+UFHgq9 zt#RHz_Tb|T#11v4A8zun`$FeP-W#<|n-;208p(o?p6I~)vY?03=(=yeMQ_R~)JkhR z_J4b(Tf~MC*CzN0yyWCuX;TKGH~``|1vPrZrmxR`Z*sma-+Uubw6=y}Q2_EbVcY*a?(; zKkiN#QBCw$W#&7m!alGW9I6{557Nn$=D4A4L`cEbt&B!kisVzKA5IuUKdr8mJK?dQ z5UE#$pibDIdtapfSAB}Yf8qzBJbZ>NU1i5A1Ery|&X0>?cK6J9wr3VGTAOAZ9-quQxStyP<@KbF_m%5aeT&HuR_1y6&YBjr;XR5)lU#1CLcg1 zdjlCiMi|EpEiDiY*Y3UuOJnI+bfn*1K4~|a5S`_f27OMbBK))*Kt(=i={(a5HWcoj z8On)b!30t(dgyG;7LlDeFPJzqC;hRwJOC?PoG*K7bc7_`{)~!3(&0Js5x{|q;6a<$ z$04>)1MK&G9!;wsIDagW`hUovlb&V;T60_q5Q{KZ|mdX332ZYy3rwF9JF# zjblDTxE!kwW#4@azBb^1FebH;YRwy-cu6ON&4cR_PcdRZ4A=%{RUN#UWRpRPgN9)Z zsYX+Jv`HU6Aq>BTF77!w@YwNx0$g0yp9yhOglj4VRuQ-?7(#7$LGg~<)AEK+MbRNP zP$DRu1R(QuGF3P#Et~L;%9m-R@T0Ng)JC^{Gr^-Ek2vw5a4r&sR~tHLo8Pw|gVhf~hmC|q6Nf+QPk={(??1b3##i}ft}RW@zwxXTG+^!B#|8?p ziww}@@c?Kq*Xau5Y29SlQn1=AR@W#2X&tI`g@<+7;O}m;m0sx}SwqH1*Z8b8R2A#VT6geSV3c|@oU6~E;VF|s2u5KyN0H8Jy}zIXkNyeZW7naJ5vH|_ zMuNX)#1U)%7rxPG-mtVY@8siZMXq;sG%+fwx0RQKZ$66H*CkGeMdd0dNv6ex_;t4!W!FF8ptDCHi2W{6=H0zsW{Hr3TCLb!IG@ootaP*7tt+-~F??)+= z$DA^4n23>7^Dp-{0zw=i&hJdW<$hQ=*zf-NBulj)cAYSzVP5qXm6hw?KWIa?&Jjo= z$R=5zFg$?M0#TOirb&+G!W};^nJ6ImEM`Q>2s_Iz6;1!+z()X__;25eyZ+yauf(sF z#wBD_GMhCT=zkmB2%;Kb4_qO%I^34?n80L?iifi9syrCg)sFR7UevgY4}>1_PLu&^)a4 zg@xh~CMRLT$qNR-DpNiE6-K^%&w`+rO9;Hv~#A?($Z+(%NrvbJ{orOF@omNtpBI1M3ZmDT+b762^8D{#qb}i zP`7<^Z#y&?OBcJCRk-lpAuWA11BS#xNrK9&jICThY3dVdC6-9KA8**VIdHQRgRDpf z_rhs8S^xg(p`212qe<-<9Sr_ZfFT>AbCF1fmFv!mFeb5sc+k_WoB4(Zf`Trw6z%hr z1|6IxmWB55!g~+xvQq4sj0CyEXlpJVA1{}m(N(Np6GE&C)yzPzUU2~L;{C382(r=s zN((B0S)LEyF(!%S{RtK7qSj7Q%VV;R?f_tWF);O!Y^PsdplV0%*ppK&P?fZ ziZW))W_c#nHJ!w(&WD~}#^*sNF01hK?lP?_mSd12CNg1R)opJWgC{C1)}P$Deu$4= z3zry&%zxx9@HpPkJK9IW9{6vKe;6@JiLJ?}{VQhLD8&zoVAX0N(n@4b4LaGfVdn}4BD#bSRj zQ@@UI#1bKQ8dbqkx zM0X9HAaXd8=#ERe=r-MH37--!d8mx_xiahf3P&=m+i$J3{|RP~B7A`Az_l_$j;_`^ z?OHtahEai(|E7t&peJruxx*J<(5ZNyQikm*yR!FoIG^U-U4;m*>;nMG9x2Xg4Oje~ z$Y2EU8S;3Y*^Aua*=isi@5ceaGl5k^N@sv0?ZvuE%3iR@oWtiT{+o}H>Oh5dM1lR# zwoQlC$2o4OsqbVdDjBV_TYHaA^;_esWqk|l1J8Q@&kJBQ`G&MLfE@{d5WiR(GeKVj zC$oyvE!U+cag)X;7VI+RW(?WNnCjP|Rph8mT;bu(7Oe6kZg~2?EjW9_&`mrb3Ne5~ z(>gowBj88?rcx@NcdZR+EJYcX6}L@5Ql_{FEW)f_HMhRsgf8q7}dq+jROcl z?cD$>nt;!fqTU`m68`9eOk-IiGS#HUc^O~5oLSZ^j1h((1swrHpNih!(I;NeB?~&p zW&*T0azZK@;4G$j)FW4{9A`%3xobX)-{B28puq%Jzqjk_;5-Sq;@AhDev7YOPxv5# zs$22&UNKkB%xe(9Eu0BbKH23cx9`0|Pzss+bowMPQux-+B`@bQR$PoK_E{eJI1D3Y zB30>Wt8h4NZAJ)Er7AOwnpW}&cd_W0dhW{;f`-%w;>fUu##w(p{@+&qMb6v_o{}Me za$^pbpsv6s{u30)@P%foBF604^F1ndt*53xmDMZ;A_~TQkc1EaEDN~qW#|+NzvD99 z=qQiZYzsnG$@NHM%}QV(?BI@ArR4EQ85}lBOWMxHFqo8DeTaH|oXZ>^fa+@yW-&Hn z!CW+?vPdqp+y%bb}M3coT$$&O+){#TH)n?t{;0oqCq6r=6eZZNeL6` z>{C#`j^JUUycN;b-;WVsm*?p}tQn8o7_e6-T+B93^QP90o@MgdO(!GcokFw)jl;&2 z3TP*p#EB~7rwKa9Uu<6PmNKO>CB|%pB9XA*LsYU`$-`9Uyu2nm5KSk{;K$X=Kff3% zUi(Aif8L%^x)aZR2LWUFcy|SAQPBf*8Y~BghsSN;v!>NTahpSPsM=(7c8WyFx3u+> zqTO0DXlE^%G2OSBLrn0>Y2WWv7Tl_sxXZ~!|3|&c{9tt%^yaQ<$wUn=CiqNL6tW{G z3_TTleAu>Qh zb)h6&X*59$vorytTL~j3H6xa&$BTW2rQbxl393l*cmh_JPlvylM?g3YOdb6lNvmBz zBVm1z)!!h2!8$rTc=x?yv|UaqF)QbYb~1B%t^1Jh(#PA;kg)w3)h=;p4Hv|*@s`ww zLG>*|+0K5Fwih2ciq(=~FHuf8&t&-iJxrm&m+#pmXIHFg7q-F6w^77XL(a! zLiOQNzOFVm|MU9F@^%&^eSKL3jXOSh-3)UxcUuA@pwu;yPFxzeH7LaDDrCL+t=`|p zdF9dma{Wk%bj;SVkUn>})5wQZ-!HbnRiH*MY(s?Ad?!ay(8NPRpsh39RPxMC#IUn} zttz4_U|2<-FA`(;^^rl6&dfih)E$NQ zb^_=O@x9RmZREdTF0!z9wm4@HYVG29<}tIuCcpCnOwqtY|LM-rNF{XT0C!Ck^JI}~ zdNE}>+15FfIq-T^B$4deL`qFmoRuU1AsKsZhlb@KV1OJH-}A8I_ubGd^*wTi$e{!D zp*FL!Zz3&5YZPK{|VjUq?n7 z^lm9^a4Y#Eb7t~`NtA8l%E^)Ty91!V%1TW|^k#^&)}_v8Hoj)SE`S-&q%pW%u+yF3 zB-a#vb0ZVU)2AlNj(7_%JD8EO;d$T?+WoTCdk+VCXgZ;Jrx_1mCQt+gJQW-0yN^q| z`!bRY_lN#2KGE@J95k6D>KcamN>w%%dfu5BL2$?EZ1=m%z9ydP3N($E&ML9@O%BjG zFw_}ihu?qOGyEJ1H5M(Lm|D?Q8cS1ZdtLm(vXX`cuUuGb9au8%abcRfsAP#|m-Dl~ z5uw3E(_bgYKYa0szn!w#!!>rgqOpKm#uK~Vmp>91pacn{5ZxY8Ed~1IyKmIm`GSJj zbo~hMU~*Cf;iRoIXNwm{92O1Un~Ug88a>1QTJfX`(}W&-$z5P4dL;^v9D&}!U5KzX z9hra;+tKmMNUPM*yfuU>7Hy|QitW};BFjpB$(tr>`!&BH4!QixShy3`;k+UPs@8A6 z>;XB?l#zWd4Frn@gfixWy*bTp!PvL=C+!HadD#uO{1Rfn`YkGH{QQ$vFBJaUBbhqv zSKETJS`aN3vfo7%;Jt5Q(j3%~z7AD!c6Byyx;G(gy)?20yYAy!Nq$+3#PVn&e}!CF zqw~*_mBWk{Gp;^e0d6!GE|_sRDeX73=6Z8?$%Chg;o$*vGBC@8`$yY3a_^YIm5xjy3LHJjUrRq2dTbbRX7GeXIBFpUT9)1MjAlR-K+D zy{c|Jy)E#^y@wkuKX2K!*>^9$QvV`9^^Q1QAXYnXf%j^hFS zup)9Us@clR;J|5a4Thajidr`MT>bXm+;`gho%$zU?+26C@Ej7)0P#t@@poIl@JDXc2h5LuMPx*|?hD^HpDnDdZ7b*#0!%~K9dCU$z-B941@8{if7gi?BRH5}dD*Tn7=6>r4uFe*Y+NriSXRGIQxt$oAGF7J(1lc0>~7Al$e#ehMCZjp)n9%_-kzsK&?D z>g?)(7^s_cL`|EO(xs30SI?EY@^>$dl~(o#-xZMHo_Q1YZ=2)2NCYuA1$s+4Rx@!a zuYuRc5g8H(a#ssvXgfBmAncI()?U)0XR=p3SH)kDUOcyLb{XH(#7l9iJYXa$nn`0I zOCT)TAR+Y17StSokK!O@16Z6fQ5{ws{#gHL#5c-lM&|QZ^o} zKaib&t})Pf;JhAyUY|eW!Ta(L9*qynQzo0Ob&RQ!bB|u(f>3JA)kg@(x>-MFbsGSIVhhJxH+?`=yE2oP~HNiJ+kTg2XEFp_KnnU3NqUE7nBe<#5rYCq!eK zAcw~bBi)uA6Yc*hgBtD=TX7J+FlB}T3Ti%tmh$j`w8>9>)Vjn}+5r(y7Xzqc+xMce zFPZM0*b~#pf0YPOdAKK;+u;w3=LHKjrP)-x!^pv>GD&ABwUwFU<>{WX}$%z32+H3WN!uf$|fcqdt z<=(&RoQ^mLx^xhhi)$(6_U+2$8rG976T_d&6Jl}P1=s7x-9c_zk$6Zn6w*~T{{6vG zqnlY$_>yjpFhbvGgc1{|TNAb-b*Lv&BZz6^WiT3KW*aLS^h7#)$_WQk;c%ACq(`c01#j|4;9Lk3fOQNx3#ou`Kcndrj+H;|f;_5p6m zx3xgM{SV3{P7t{f?C$FTel>q%op4QfKGtWwChzg$>0StI(rKYbR>E)fO+|L4e!*k9 zt0CC4*|HKgV|+8mI9iY&mo$p${Ow*(zMjb8HZ9q% z|9axvJ6a&{S>0$zdn{x_vt#2jWW@F#eSe@OXY0|+b_pSg)}g0C@p!*MG_j{&?%fS}xZpj>WwsdtdOH`|hQ;<}7TjVW$(kfrl{?p^_9w|zdZZhn zGhYFZdxia`G$>60tt64 z#)j|xF7x~5w7Vo=ozZjDg|xs}wOeGAd1MEHIx1YV0ZxYI_Sg2&x< z`SJiL%rqPTnuF1S#>#dWixOytCX7z$cPf&mfmT60=}v%o z!Sv2>F%3b`BN)x6u`gC`GO2_fZ)lPiOizasTttsY=7}I#ey#`2HQE5D9EQ`pk0}cW zgB_NRxBWT8qbuX}8tnSTY#5EZVXHw%#Lgb^jeuwX;DwXQWf1DUDvVQxRoA4A*~o4A z)+tvIX^ZMHC4Kv~-_u3N_4#9GoJF4vhcv2B0(VxpnRP~cml!0z!bN4~@i|PiL~}0; zp%Ih5#8)T7kV~%c!p1�sUAJ#Qyt9+RYzNaQ6!d&K;C87+mH*HD6hzjIm{N%(@=B z9F2ldO%O!*1QM4AU>&}^k=Or9+jit#ZwIp3iA8!4;#hONIWHGm{owG9_Q=I!+|>OD zpLWHWtQ^Di&qZ$$VsY<4Q8tEV36-96`E%3Hmn zKcS+7%g&r_c-$5T-!f2;OEeDIZYCCadm`Xa@twzf&(ODTjUg}8#L2i9@sZ>Bdh$2_ z>fsGvdpO3RvDhs9ELuY#_T@nv=fXB#OPZ#{{bO+#YX~w zg2DLbAL@Zwl{F7J0YgC181F0}@(kUVMb`y#%w&;&3KDdNP{jd$V!i%f>%R;k;O|9U zf|ix}_k)`KNlH$3rB}>eY)=>9(_;_x)JmTLlnaX)VEz<;VpbWGc%O*dN@3%$532ZB zy_f{+SGSpvP)2fmhcjqDw}Tv02CsU;5>{@+D{X)nIW}IG1vJCf&-oV_?euM01Xrh# zDy{yL7AF^8o&za-_$~9n1|A>wLht~*XCQWa1xoUo5)W{7)>Ps)8Ym<$(v= zEKTP`Q+R*`p|2iu5M`Q~V%Ip7M$W z2pvclUMi=I&iG+*OJMkCh*Bh}pA}$6eai~CFC{g<;;9i~0xf|j3fU@ND^P;0x^Tu- z9R%|mf<(m@9s%M)lStZ0H{Ep}Jv@~Ia12fx`D!!;VE^F8M1T>AIL?n)Ks|`HXwoz2 z`y6R%FQ1RY+WRR37I}iT8+CwwbA<~`0}!q26+xT5XQw>kgOz`zbE-rU;i4>1!*GXL zEG;$*FPsIhtz~$GiXfaRmal9stHM+!B;OVZjdPSQ02~57Aq52Wfj^2jcUXbahP@91 z{wF$BGb#{n!He#5NJ)?48d4 z8M=(^F(?q%KAvs{NW~FbMry>kst+L|TNggJFAp@o$}xOl)L6I!(ilh= z&VE(@V(^nYF^=a;!scqA9Uhmu2T|Au^^>gc&kBrXQu~z`jquW8+sT$6w*)W=DZ@#0 zUl@g5M9m>pUXq!MpJ^7CS&v4YNYR_^3GRR(lqe4}DAD8mg@FFMu%e-xjp3yZ7e=4& znotG#lyCnO2qD2RgrmJ_9T$`E-`@X8F})EReMcjlzmLAt+TuWP&M%H6#|$Pn;+h-A z20Bs3QTslLj=QheO6=u+h4n2*T&f2>@A#6^5G(~Pzx>WhW~%epbkey;k1jAR`WM$=lXg ze1^?^z0t|#aVJyRzp{}^Gn0Y@P4YpCZW=xlUlvuR4OmK9lUO)_?1B1W6DWWK4>(c? zoe)?>9yMNo{lH;s1ChED!nFDdBVLy-UEG9_)qp__R03yv^ICW0kdw;7a z^ZjfEhUbkO&s(xSJ~$~V<3oL70nNzzqYvUmqnGwsk3E4-p2zm9ISK*2fGSL2mgg3Tqw$)c_suI?@^p+eRGdN%svgH=JUB4{xOw?fMs%y zVk_x4vG=KdykvrRh#PljfuDEk_L6tECmY1n>TW-nVxQf6^vGA)z{Y-|FsOoBB?Mhv zXMc|~DJT0s_6N8KQ6h3#7+o!Cb~tRu!EjJP_q^rLAp!A(5t6IV>O5YMAd{PxO9lj( z0AOqOZ&%G)vtsfa#8|CNX21uFjoTMoxX!Ba_2`#;dkpOLVLiX`5Rz{6m)N9PBS;Lm|aGR97S8o_F zXNzIYJ}W#6FKUGr^b2GbOp0hi(Dq6SfheX5Y++Q3`fRjdk=u2(mx?Hd1o=Hou-xKi zW-#|kH|-5w1BUtk5T_#&I)OHl{%6}1kY{#+(~a<;+y?tg&_wPZtG_q3ySfXQt@dIC zEu%GdbND&`s{Ato7#S|?FXkJyNP)kp?f6w}*sdBsX2BJjK+I}NAIpHw;E1B}n2Q4*z65`Q}i*9^tS7=P)0OgBq0a+pJ~4A3S*fNPG1z)GZb|(p3UB+{nldbOV8(l z00-U)Gi>r3`O)VCF%s;{o*ueSAa!`Oeh}TtNH1w4lRKhUp5goL_qFS9LTw zvMM3CT3>n5z}JMis#x&ddR(Iugl5Yc9Yp-M{q(!f`k5ev0_xYkEp*ced)ke*B1d8e zh*?6mllQ65RXDwTBv?CIrB9geo`5yP#dVVagykv(Km|u+cZ+X!;FwMc z)2nAZk7mN8|MbE&|HG|504Suf`85}vu;7sD?2IK(TE6>DZUmv6RW1(wm1$Jkgz$ej zY-Ej{HoTVXX?N0nlqSTJ_l{-2Yhi0z2g_t}m}t~8UeanF^Dm6NP;z)`M>-qD}I zZ;=u@L8YN6r$dBn#2ZUwxL>*4aZo_v@yRueygB9}Y98+-zm>#H$d87f)YkgA}m>^-TGdjw20 znzL_9W*2Jd5?8JKrU#W04e{%9CxrP`O3_ab?KXle2-0@81H=QmgM44$N2p1>fNw1u z@)Cr9|1R!^+@$z0B!C6V*AK4f2LcUc^r{1D|t!F~_GO!0IC;pAwg z-&?TrxHY}M`)u6ijiMdxbi#nt4Q{H8lX*}4F#>|{-+=4(mZ85V=xYNvt;wwqw@?O!u76S0?!e-%HK}M*K=l5HoxP0+8^J`JazThb7KC!{5NVExp7-F(oXUG z94t}rY#>D%({QOu%|ORSJixwu#jE4Cpa|UNKfNDhIdR6hc(mY!nbk@1?ZUdp?s{HZ zHYw$c$RaE8g9>nfdG2o9EsoNE?j94Y|EWzME5bR@Ob=3rM(%CBCy!5YmxYAquL7dt zE@g`$0@vrwNQf#2AE(yM+ZnXY2dx>L>X|?O>Le&tcth0|Rzvh*#n6DF7C=G>epgYG z_m;KYcGZo{YEPd%>XUu^$7`_yxt4cv5L(U!d|BDRi6@$6cs4^l^qJ@Tly0dZB@j|( z|8ioL`$mRz>Mu(5eoiA}oVdcn7C9RiWAokvt_X4AZe0QpHMX|y_%P;xB1`K^e)GY~ z*rDI|4O0$rN{|)`r&pYg3WVvS7RS=&a&J2&#mw)E;{m%WVj}sTcBAs;6L~RSFb)N+ zgW9cjY`oMaN#JkC2<)2GE*OZIUVTq13HKoHR|izb&`;3&JixgleF49RS)*_3S5Mxn z_4HSbUiyKjhWK<(JqXw-Xg)U^9|qB~fy#^Sn&b5fD~ETd6O9r-dEOiT0%>_t#jdw2F*w8JRf`E^cs-aq7LA_xS)>hHk4i)3Py_U|6)B`dyX z ze!;T09Ybt%PL(EqbMBV6GL(;veh~g=6O9&!1BF|dqMdTABB0qp#Af61CfEZ=O+BIu z!rj}Nq#Y=x08{r&*azP3sWBsh7+Ld2a}JoD1|@(J9#iEsIOnA@J~HQjs_aaHonqNAK8p-cnDjjG`IP1E)Cs|(AMv?_QxMGMbsKDmcRX<7Xal^Nt z-O!s-zxvc&pKuGLB)(TX9AV2lZQN-$B6?%=j0k%ncc^-`sN_#5pO9Yzr*z*NPWBZo z^8kjjz3l^{8dRWN+7D!YE6XApshJX^KYl5A)LBphkmgBkPvRNX+LpXc6%srLhq{q- z5*|9$+Iw5MagVkFKJ<`nzn~{qIOse!TJvCt`nS}9EZPnRB)c$efNmE{)(k`xzJ>|+ zJ(c1?*InNQLN-hI0B?Z_4(CMQC#cLvJZDiNdQog~>q`51*uu*OwCx6R)5U_3kcC5e|yK|TEJ zPTs$}Nkjh{!W~iKu^aOd5J16I5>NUgNnI8zoTl>LuD4@zFx!fN3ber6E-*U0z|-)o zrxRv9T@%>Az0!=IGNegpQmD9Xxw*;r6Iu?tZK&N1!&!xg?p&!==FYaIF4kK%QGqb` z{r?O5ocfu+j;@Kn?`H!&`CLWGC6?(}T4w^-Vnv!>Q_2Q4PKM{H#L2nFDI!iM_b5K$ zWLa^g>FhMn3vtS50vEa05cm79`4&#qOwoI-TsCIv(>83qTouIBunD}%){9?A1oq;Z zkbSm=EopzDI1P5yNv`)zLwdlwhZ|DfJ7kp?vsg!JIOcXTf^_5eL3}^3Vrb{E_(mfn{7IZZ@WFPN{ zO149OU~WW06S(hMZ(HppmM=EgWixB0*sy>Bz{vl!{!)dOYMrAKSCdfeq7$@vZ7*MS zxkgv`Bfy$h|1XfSL=$0~tiC}dS}pNiwZ?%J5&XNJ0sDlT((j|EWet|@zjoaF{kbF3qrZ0n#nMR)$QGnbs_&7 zw+sW*v^j(`z!7uzFltvj9>Hhb;^zJUV!=(h<2N@AYov%d-4cB9!ZL^a-D6-#fE)Sa zNuaPBhDS70R@%LKRa$l{=ef;yGr`*fEB@?BA~d8(Q~N!|5HyMgDA(lwjF3Z5Oy%6` z?mtuh$1}Jl8@9DFRp07W|Db=H6FF_pA1V5EAC~l3`C( zw{(0y9Jko7<&z^fj~n;J z_REkUojMQw;9L=6QD05@{5k3%sd9q+F0u)q@`xZJJpG$WA-sj_{;JCE%yTodY(Fb5 zuJ!|ULzL5A3Fs5uBl)AH`-epXNT@Z1q`^YzruBDfY`KMnq;}Fv(^iZdnO`D>6#BVy zzJdWl!(+V^%TVJ{^^v&7NrQjz($cT+{1v2n!==Y6!J5^6kQvON{A0k%_ zqylML8o#^aW5BA)Cu_@#ZNmuXc^PLFs>Z!bhDOl?JbC(lx3_UWt(AQT?mTjaC_Y$l zuxxRjdzYM#r+Kx?lqghwj~aLq10bXX+CPIx{EFTjLxHg(N2~9q*BEC zto!UaEL?-l4-SnBBhtC_UmTKG>s-~W7mAo4NQ(|>T4a}DZF?#9AmI_bjG%7)58iJ( z0hcBsnLOp&`|rhfrm9OgG>P~g-IkKMFyVD(W@dpS%7ZC9 zaa$1&BDdpnXMd|8IieGf6p1BOFJ(v@1mc%IO~~c#3-8%ndz6d+UUeu#`1(B4r-g#I zkNq$`4^sks2?fCBEW~0wH#WEP`mcR*sk4)VwU&md_LN8y63iP3 zHpc*~cOv)28ku|z9{0QNx`d^0#gwwqw;g&-ojuJMLk`YTcrG@qC=$b=xW2=!K$N~< z$V?5R%2uBg3kT+H5?~Vr;vOrp0I=wRTnDi)LFySMb|-I1aN()K+?NTY&Ca5qAXEvc z&-ZPHiEawM)!mfYaA?dw^U3kA5A20DC?z5xf}m)-X#x?!r9gc; z)ettYH~tvUpg3SwB>F&B{(Q8=h6JmLb|9v3+hwHVuYJq%pu>_NTo}F$nEOy3{-U)s zp-~5f1GxOwyo3jaN@WNBd>Q_PIE9*TJE;WElyNL97~V+`+-7j^Z&>YzrP^^_2vk35bJXL=oT@+4D z{wVCD(LTc+R>aRWMM#28`2vZxUbG)JQ6-&@?&OESMx87(h~^VbEjy{gt=;+EK^R-g zy;zSm3!0KVwU&=~@XBefd6{F3i{0i|a=^R7iJ6Oe#h&~=;@7jQ&`WB{636n*zWhEO6-KYzg~!wy6qi zWAmk5GXIr@oI^I|j7ZeKor&KIuX-64&H(j^c58*iWWX<%1Tt$r(&C4HZFL6kv&`*7hkvhy$Ish z>!X*9T#g7ZNTWyT{8KvGKXMUiXL}aL zab&pI4-WypBqyt*O+(8^hbcpoOCSSpvIF>yF_p6+cSQ zsw6>c@Mom8gKUhH-wKraQ_@s9)@c1bApwX$65NI@dES)4)SW8c5#_$e73nWfG07L) z_*zqp8lH5Jq<&mH(Dg3Dx8ljaCl^Hr33v%+qe@pE6TlRl%IlotFP{GbjDHEpr21yFjM=_#@#4p-;)-Z?T)oUq*m_fH>_o4kbcpk3|&(?-WChPaQO4j>`u{<7b% z_kvW1kaXZCimi}&o{={rEOgzeV6*#IAAA87@^ig%(T5b2_q)eGtecG9otui!!-O%} zcrBt6?XMGzJpZB!J}cHaI+L+~;7sESa;02l!QwKeBy#f|u5S1$8)8u7FZYKj@0LlV zhS!pT!S9`8v>+>poA?7AazdMuOeeUG(~Ucg4~oEq*dB!st3--_W+AnqQ&^-^F(+4Tvrb*g-n16qy(vWND%7X-Z;dwR}uTo(2iW>3# z?aD4pqB)GL`G=mq^FlCX+w|#=l#WZp5@>o65Z#D8KoVomiE)`a2Y?kyWqjaqcaEwR zsff1jggX7hD$loyUnWkFcIaGk9AG!6<^J7Y;*{IVD9M7PAf8-fECx^{r(d z0e@AsIY};aFkt$vCrM@|Q}6OA<>pjXdTJSWu(q-azZmR(SNhZZdXc*K(;w1QTnSTHX6L$+mTzd!pA6lDJdEU_QqB!lu$b(86F z{JolOHIbr11SGouX*E2;gVF1SLmWm9T?R*I058v4?yf+AT~NQxl<*g6bg!Iba#w0m zTBk$^i$N)XBX_>x<4m3-K-^J~!mHiWGfHq+JPD^s^VN3ymCfn*599tE3h>HDFT=zW zzY2H`?M#nUvAa+uuzLLCNUxzR$s%W(<$bV$Y;8_*;#6xO+&TIEj%#ng z1C+uV`I&y}dTD)d4^7{(&e)&IPqB+zXWBb(CJey!GOY|p9x3&>^}jNS zY4Z8;@JaI`3<_B6COJt2pJ6G^8QRIx2b~ENYn`Z`6>ok)*w~KU3kVJfb=H#~-h0j3 z*k^sQIk5v?-O=bs41IKc+Y`GBGqd)chfx}`#et98{#@Xkk1$*qKej=)<>%w*?49G7 za+hyY2sMniZiuPN_TN%`-Zf5!bv_~e3T!YfP?~>B10$-Hk%4zNP$}+j~vG*E`LZZ|>w{@uuuB7#0 z7?;EGTW_{_mAU-Q=GYGVlp}j9d9OIHbS72gjqN*%BI5i-OexT26=2ai+v6mhKTPzm zjyr*XdJ~gqcubkG#=gZ*o72{VE~-FrC-m9nN`=MB5q6EeUG-|YU|JAmcyx|hU-0v^ zuF+fjQ(%pae)C*6+)CU42(ITU`w}-E+$X;{UXOgi?;cTZ`bxz-Hu{$I9p7UI=VYHM zFrBUI)g0=o;M!~pa!%b5JtJyNbvK{PVXS)VI^*?S0nHu(-pON#>-0J5`u)JqyFJBv z>sKrL!t34qzTfvUkl^uqD!R(#^fV9>`q*+xIem%AYU)v;6|P&ux6r6hQWXA7(&ZT+ zL|B!ZuwvRzdZuZE`~V5x>8W(2hGs z0qq(j&ichk6t&s%ogQ9F`A)~2iBF>++W7p3Z}Xiis>goAjmaP&oorQUVc5Oz8Yy7N zz2<+h%6hG%aVaY1jQwfSB7fWOF^T45%LVOLJC~DIVi2Y)+Wo)+dqEvh<7cNRgH7X= zh)dBogLLttvU9kPWlpPw99S{kT=J{0K<6Va^-{1XTXWqBGmW&&i0 zinrbH{#D{r!bhmB*<>CPLZX*RDg)YL2KT=cg9;TOqODF74hUb zXyC?3QoSu3&$7NFD2l4OeAUclJ(E2?w9;WZ1c0mF71Rav@E zWb1x^)eFm?i7!!oNdaha3jxl#I1rfe!n=2M{ zU2GQ-%CstJO_>?;@hJJr>1bl9?yrbJX=VEX z1XPW`gSU33D;oZ6XW^nZ;$=1-hGfarl7h^G4J2NnJ_=)k zU!8uB?#K95%avBcVWtX+4R5&K5|pbLOgo$h`wz0+5TC{Sy%QGs!P?hb_5|`j zW0o#+F)@>sym%FQ&G_xFl(TB5>uJ3P{os;$+b?tY!4~hq$5s9zk;vsW&3^b0*HS&5 z#0nIZiCxW%;4k>`5uf%Ou;U;D1>`sdo=Lu-XI+bI6BSIz&gH=xAk1i>ya@DlGIGcT z1rPsl<7IM;T|57P(f4Y(irD4^mg}L7%BH7->Wii-Kqtw^DueOAt+8mm4SqCEx8je@ z&60I6_F!d6MPF~pO&Uo=q;E_%nn_SRET0L*5BQLAsK-AQNFm+#Xg`7>bL}DC%}>z{c`EgSVgWNEs90 zFk)XWf8Ny9)~kffu!D4_{zG(eG{p#5)+zbRcS1izz_t<81ZqdQhp5wuQN3bc@qlK{ zCZ89WtLeuIwH`_q-j-p_?Iw2--uNh_hdq3CIaqG%k;<`$KSW=}gPt3Gu4Vitjo$Mq zo(fu?K(gr)p|fWa0pUg2X3tXjGxgc_t_^&7);s=TjsDZZA?KS(SDb4{NT4bP8d|x1 z!GxGrT4#|a2+xWjrXuim)UC4R_sNSHk*;d`D5>QLqT%IiYp%_c2iZic*KJg|n5m@R zVFmbt-TN9C>B_+Fpt-jm>p8cjy8M%>P|12y|$^?m(=xz5uRz?ZXlEEJ-lY0yHuQ& zW`tv7_aX6g+zuO_OLpC9c4ZQv=C7C$*OB9-P`2SNr_Lakfqg@FOu0#y3qU**7!tc%q$G%B^%!;de@7>OWWZ^Bc1p1A1J5pHM5Stl3-9e-Gz;s=F+wBn_HR<;<@VJ)a*j;8p@D&u{Oa7Fr9}P(To3PX zPC`@%b%6a|LxEGcT}Km?HhN!WdH7P|*yq@8rF-#kGX(8-9n_F|QR^6-#PVY%5Tcs6 zFv!e<_i&mcGe+G*jjcuK^~I$3^?j%PSxoPtZ3!mVR64Rrvf9fvUm3mTV4imOQ&>+NP|j)0#ZwdNEoz~NC-$s z=Pn)6NQtzHf;5OoEZvQ?bW1J0xXbQ)_`JXGe_U_@XU{Xw%sn&r+%tY~>D|H(A~Jjf zdCmIN{k#zu5sRH-pR@w@<-|FX4*+Xvzf6ox=Pz*FPGu@1FE4W)Yyw=7oK4TaGdxl2 zowKIlp+8U5V85lXn(?Fo^+?qE+;hTEA{zo<`k73I70>L-8_NDoyEq#hTwQ9Fe+?P(=yF`*WQcjlKI$!(HeKHglO!I4!QDgtNxrUW>h5q%~UW)MDQJRf%xGgYiWQMLk?<$-e5^{J=b5P12PQAO0qdyNIg|_Ub!+d?>x>^?S&>TA|1K0kjfZ19r>Z zm%$RYufQ1VabVotZvYF-Cc!rP9Ei$&2sLf~bJ%s$I01% zd1_ChHN4`Mp|q__F%U&YrOYE{KpHpkf&+c9s2T;rm{iWYYSIjEp+g zxElnM-~JL{)d2BEGW{3NL|{@CV`cMg;55+$E&Y)RYKV*OdCk?XpI`j>SHZ*x(4ucf zui%V+gcV4crq{Y?jrYG-e9XZkf+zbIWH7j}U%0K6RnkK8UussT2#&=cX(+Q=pJ7h$ z!aqI$bvgYdT_`dibD6RS`t=AK=Qnqvn9U;?`#j>lfKU4$s^`?vEw$cpsen{uvlGGJ zZH=@`)hv!FV=MXs>EvrToNKdkpCV#??f6aGEPc$mkR)I`+q-$f<&fSTinb;I?`>wL zanG}r=BpXwnB%mSBEs`l_FdTs*lo<^$yEv6Mj{|RB$YqMpXN~x0vr#@Z-h_5!6*`J zvmxp&+ItxsRyZwp?@(wyMV*wyRF=L}1i7Z>OFePSGYWa@WnlBPP2vERMS98^JBH$1U+(jYK_y5fmudxEWKRu7QaY3H96 zbRdZu02zP}PY&rZAWih`L*Vldb=gPiNMHtj*~O6rks@twNIWIrq5+iK4tQWqM@ zM`pET;5-4zzztBSNfg$-aJYmh<7=Dz^wI=rPY*Ri)v@Q&)92 z#4w&prA>sg>1?^Zt|8|>%M&kKB6u!a=|68e_#O4X=_=sJzz3_Zuus02F85gEt_A^o zA)Y%F!TKQd@}V%;*_TC4h%6a3 zX8y^LKbrv3l>y%Q;FJ}Ub(3~x^1*rRUKJPw^*0wbCqY^5c#z$S?PW_w_JSHpUmHQ7 zAxQ;t^PNOKv|8E$#!a4G`srIBca{{(w&!-F-9d@!S&1fsy9Kd))8r1ty^o<=J}r~_ z_B_5sTs-OIJAA}YrtMBtScUEwU3qv zoBg){;Fh8EuYk&0k#>{JW$4Gh&z3&HKj3jaPu{ZBkt+TCdg{qjFC^NGaI?iYPs8?G787xiPrrrqFCa$6NKv@Z^$Z zJ*4ZIS+u|*1IC6aIhZX!bB4+?(|bY2UK44rElj?44n8;yrX%OiNPj7JXl33sNJ*Fw zK-QDt>aqWzUw^)4AGAFF%)gwtRzLtZ^mg_IMG>QnV1IMy^`F3(E!B9pY*kUy+HpN8 zsBeHicyMZ7&$)o-%k1;EGQIupKD_+x0)CU>KIu*1dt`YcSPOm&3hTUn2#TD}LBLYE zE-3-t+)@!GrfzAs1+Dw8rMnm`go-+Lx^0A?8`gPX`Jxbnw(fC}_sF&)n11X^3|US5 z*q;;SvbBjNeI$;LIzhuPUODZqs&D~naLfO30a_i1U|-BXZ9tH;828~-g|@uDcI?=^ zrS*iua_ULL&e7I|@Di~(gc%Eqe#PA?fCaOg^#Tc8a>J|pC6B+tN%ZbO{IJclB;qJ` zg}0BXU8;3a{;t1QV82&44iE^kF-v*6-SmlGl3 zXx=nXeC)?ad(2CpC_k;M{l$?|phxEgz72o*4PDU>plp7LDwvBCz(npdNL&pJ&!b7P zZ7|ga0{ANcc@d!USt=C9koJ}0F0}>dwXP;Csjw`6EX0}|;NTtJjn^f})f_6def8f3 z){sLQvE103CP#%w;#`Yw_P77(7@{}yPTsQ%2dCX*hc+IWMIr@m23IEJE#ro!%0}%v zNqpBXw+ddDnHqj^HJ(aw~UzhKSktcD=8BZA6GzT*DDP+Q1AC}!;bPF zAr_Y_OS~mw5x!f|Jat4wlXy$}W-S2@3^@sB8(OJWo70vZO$92PEvlqF{zys^BJK%L zmJ3FF9FTkjrKJQ2u>Z2dr3QccF>p{}EpFX)uiQ8P*1x-%0cjOMX7HldR`Noi1MP{S z{tSOB>Wwip@bG_$9Sap_Znoy%XzW4%BRi4kn2Hw+$~eJJzur~|loo&b8O<507qrcb z(au;GrhZoA{H^dV3*=f|90U}7m|GcwW?4@q%N=xPoBZOh8MfCdJ{7K(sLL6#nUkZQ z@_pufvAF^N9BL#I)KuOGd_V0HP*ws2XDYpIEgzo(01~v|a2@(T? z&pteXc>`x$_rUSl$HubOk&l`2$mfJVDTd$SAZM4vPb6XS|4pevgd@)gZhqw5oHK}j z`wgfx8R9$IwN0CzpO9uVi#}ugp@$2BJWi3WbQLL!&KnKb>m-eaVf~QpGx;{a=#@_j z%>L>(+%kg}lm#8$o3-9u@5^nD!U#F5FKH3rZlo)oQ^Ac2aL$<^6Mn{kc2Ar8XhXqm zW$w?7=L0_tOhElH8X$`hV~d_GC4uA1KC^y$7{1(nGMCR4iLvR)+@q4mW~2ueg+K!^ zV8Z9h?i$`Y@+ao44;GaYo|cj);27S4hO2*A3|fIc*Au zD~yytV0gKuhMnHk3Dv8i6vE^6{*73DqFfhn#w-{ctl~NjO1~D@FNT(n$l*67WoqmIgAkbnG38rxa(VO91Q5T-8sp72z@Wpn@1}?S&LX-I#yRsp zLOMf&ZhcLASRx|^@OWmDe#@qMt?WnnHM#btg*p(ue0YYw+ek~Rf(oKOG)_&jpI@%(C|GOb&x7YsP5BmPsG`20sh%2%uX2hfl@1M>`k|IlA$hB^YlWAq- z^Cny}9#(pjM#5e>dEPau`@#MBH2Xy6Dj|tj(A1N_Egu$9*qRmpHvk8+zFrLFRYeO) zGHJ{EwwDIU$8kr5lJ9^tD2?hQ*y9qTz@?RF0vO*77GSUW&}EtnI0Uiq~KTHU~k4rs0)u^af= z_S0u#H!?(RmKDDmeu>*vKGJZna%aIHMr&U`8vl1T*+Gh(xdKIa=g29tNGv}kbV%i0JG5(wVW zfpZa7$l*>GkB2;~)KtpQ<)vo;`i*h5w-68FN|OG3*%(X>UuxFpEyK<+eyyK^eHaD5 zZb&21G%JkLH&qt9wm_AiT#7eJkH}7g51Xy#BU~E@t6%lup<$P5XqKChMc9#=PM(|m zgI`{n;LrI+#D4AMxjGuJ}8o;t3Awjs~B~ZFDVyyvPdTB1r8t^^5%;Z%(x!2D<-HoQ(|V%Vmq zef(^9rMSUE9`ceDqw=jiqL}~lCpp%cN^8x6?)W-=7m?7tJ@{ zZuJwEE(I~U^kz^0H3VszzF62%HG^p*$O-9XC0&H^KLp%7;>;^K9K=rM-B#LZRh{X; zww1Io9`1C2Kh*R8=)aM>gO*!cV7fn*6$0fu>-dgMPWq#_p+-Xqke;zeHCpZ#JWcDD zeh!{;0sMImJ@WcPv#Yj@N04DO$WUuZVi@%NLZh`W`aCz9Ck(dt^<+X1mL~5sr9l(e zjJJpzAGg~6{cWy_=se#C-DwGXO#`5;3OhipBCsvhNP^Qbs~z7=umZ~s-Z=&L-dzml zD}8{x(%+pLJ%}0Q=?ek%?-Op3E;%oh?2HxavZe9K-+lJu&HO&UUS113s|G4DB=bEu z0kvZCT;}9Z)jemhqQ%}>kI8ZIKYha)H^(Lel8|qpfSG>zTVDRR`sOis&w5%QbW1sm z$D})NnhZc4w=n}aSTgw-VbKf419gpyj#fN?<#<9WFGrU6iRcnBI*i8T+c7JN0zCh* z-1o?l>>(>6-0GG^gm2Akj$F`9tb72Y3*tzfBJbR8oTw}oIikEk^H$8CiM?VOSD(sX z@U|zxl%2Uh!T$2yVsPX61tn+!(Iz!k^Hfm9h}Eu@R3po?7wjyUBFOUcxzMPMqUDihxHYJw?@ z0Zfol9d3Iz@)fXiZiM1eXnm%hc&Ik0Vkv_=yix|Gn<~W46v?E)ykIHDe`@zb8#Hfn zQ9#7ilmqZEDi7Dn3nC&;V4L6zS^1MWTJdJ#m;O!*`n%qs9VHERY$jwh@sg4m6JQIp zNnpFoEWGgo72fi^Z`sOCGQoJ93qJol^F$)ye`-XM_TLP3!g1a@*&TFcbL|J_vQ}10 z7H#TH;*s}*HyCDq`iOW2{=Lhbl*znDRm9Ofm_wdP2gLKdvj`A!w#GH7QyG@!jyEo~ z8hCD6zJc9M8$ZED;j6#0(DY^`;y>BTr?mqw4I-uH`6hjr+Y__8gHx)1_Mt#jNO*He7BP@ya0#N7;@rQcr?{>tWFIKPERZ zCzwlW0aXQ#({3oP>HMs>2UiJnl8jo0Rnh%V+4(J*So7K!3PW$ z%o6Im_?);F`E>$;)TaV(&VO*KnyX z;=ry^?G&m9wNyG$>}0=q z6fJc4szyd{@>m>Y?=@?05%CKbaB%DW&()RO@p_Tp<@%h`&IF5X&x#$bafyZ{1auaG zW2xa$O}bzm)9b60vvfWU`~qtuFA~S)VfY- zXfQQP4b7qEA9v<*WDB|~tv@g^)%@7%OA}Cqzd>K9ghBAZar4GvJ-sbTA2T5utWhAA zmtNk-bT=}}T11fv?Q7fr-l`oTt`@7AJ&EYyB8d9X2ZIX~aI%TYxW z3bVaCN!DN8TK!=q8fa~yd5W~&8oK)}Z_f1T-KUCrkY=-6qrFjEn4>0`dSLgPMT+aH z^cQdYu5q{@uFwGCG0@J4%nBp= zoV`AsEvbZlQgk1+1ZFIY(XDm&+M(LZhM2vjB!PXzaurZP!L=p)0uHWnnJlmBh$cR_ zf{|3#^=C+(q7G#Pjb+83RQbC9{HzGk@Q*=`NJ)YkbJRv9kAI!uEOuUcafLzukyP9E z6|%9kv2qdVWZ{|d>;0IEt8L1E_qzz}e8v7YVih~a);wK_#&h%g@cY)?q}L3hc(^#m z0}-=z1r4h9CSFTWeI-l^^ilXvzG*#eRRdA7IsLPOyEWmk!Ie#h-&GzHBo~&=mwbQR zCv@h%9?dmYNV$Y(@as)+FmskSyDen5p9(V3tiU=_KF_iE1WE?iQS`4PhEKlZoJW06 z+CY_^+R^hD58q}EZ-7{}7DGBXe+An!C-$XR!(sQwvf)-LywOkJhZpAOKc_)y>&J48 zPwELip%x^0qy-Q{2nFwij%zQc7Q)9L>#Nd#B2q2lBcQmWqxwj*7mHYme7MP1(^%uv z7!Y86?iVQPQ}b*wtGV&wsxCN3rcHuPCUdcPS}_q*Gyug$Wt5bV#*lBewl~yV+3tUo z5in#7y`+m}yo8mlpguT$I`ibZqI<1AI^?JV}HG~;@ zVQlQ?>-j76Veo(y?M$1z|4%#$kxiW27;P|QS@Z8mrOkvCGI&+_je@Qf?VfS~0{d%I zY3w3|*y^%7K1tNWPIJoT_N9Fcj!o=A?Y}7K0rsJSJN6?DPhXCaTbn4^>FtIe%7n(^ z{+^oF++2*XYM&GgaX_rj0UZrcNxnn4JWAB-!iy)#ozJwEIO6uA_(nqarOmg3W{pJU zr|Lu)&pqb%A{Mv)S_TMtez#cIG>G+?mV7O=H=B_$X&4C#AlUtWngR=Od*7@{@KISY z1y-EIz>Q-bkarvmCs6*26#g}rPIrKVf$Mx+2Xjh!QAHw;x2KAASX=jSc1FhuY^GYR zDR_);i&{<}T^W1}h+3M9zo^K0u;78$dJ%BpxPj}>xQkz8KNKpzvr}mmzQo~5WZiOM za$KQ->5D%Ml9s<~1M~L~Aptz5zlSt}66z%jZs2sLJXw6MAK$0rXlE@1;mKyLDBklF z*g%gChN$@%)^z(52}eSy^IaM@I!MhueBH)>c|(06ThnD>ldoT2re9fSOS-`(OI+G% zL=98RmXy2G&F7rr*Zh^%em31~Jtn&W(ZYsx=q&QmxRHZjE-1^m-kLg2T!)>`)Zh1U zks<^>lG~v)eBLxAwK57n@1>7}jpRmr+MlnBg|IFL8F?Ju*m&6o{Z(n>fM+paRHkRB z^hWoerMM5@U@p`MJt2^OfDgMOJThkBc)*~*BFFr9@vZi#+jcUH@cp{~*yM4t<@zi0 z#*1+zI83}MRjJL?3R`Fr-DKD1F6<#mw}c6Z(FAA87;HbTI;nkxISSUovMs)OLIsFZ z(nFB#p8}76g#%iK4#a%Mkb?XdnfnUAch4ISDC*wf`5|a_Mgr9H@QY(`ga#2khUMXu z#5WW-1j^xUl&tOh(=YI0_O@*}Ayb0SvGgHqiSk5hgu&s`Ra!OWcPKiR0*!V}O>rt3 z#woZ_vG?+zRj7hi!9O;%RDb)jC*2<0zwcL)JiEUmYTG*m9qg$gXUhLuf@;&oXQF+RIvQQDX)_LoT;26%3)*>O-@ zxDOU>Y5egc1@oiO0&3$&UEu$NdHk=?So*6HPDK}i!~@l~3n^`!-&^)8HvNtIn#;hhf%X&ONE zOMqT~InAdZjTr<0#eCz=N)#J6F7b3P2dR0n7%UgKtZm5ffb>Xfw)t7j+v2^sZ*M-% z^JFLpQBksjj08VW6s)EByjt%`{*;asp;!Lt>Q#KN$ZmW*10;PU3u5;4DY#=Om$H*H z^IbPU7(d<9B&3RQ7X`SKBT`$R8c--(y)+@ygvMM8X;POyNE;d65!^7q}J_Aa>s>=x#Z zQ{kjS6djGXvhxJ5?AdEmF!NEL_q4c{>G4u9vJrv^OK?|^=6MXDxYPi$5D8Ld7X`ku5WD-q+PMX%t^Ee^dE3pG#{$IcJ0ym)vCOCG?>9d*3X-m=k;2_^G z7V^faxi`X6Q&&8_C2JdY_+Yo;dAYuLj~MTAwJ>L!dYiZ0lC-yTZ0^g%jVh+CL+@kO z7ib|;)ZlSbz88tNe!mYxAN=0zY>)S$76Ja<7D;;8*bBWYt9b~-oJ_y-V88$B>!bd~ zGP7#t#cYL1SR-8@O^yN^Q73!UkH^@FiHTG_%ugY|&CtssuykB52Ncd!b+Xf0{iC;( z;5xli^71c)LkL&sAUk@a&4~VUR*A)<(Un%wYV8=fDhC0m_hX-M|>lZeAZ=K3~F2kmvqHZ+v;0kQ+ z1SFA2$HRGWy4<&B5DdqOKD4VUOH4PkrJBiPd!dP_HRb`svWduF^1o!JD@_{S+Qwt( zAj@CA!=&UMoYvX#Xyr3?(u~omr7p^d)vZj;d0oYzh0kf|FM9%>05kO-8Z|NRn9y{0 z{ApNk9~BRwe)nt#l*?in!i>Mo z2;D#Gk#z`B{Z2d^s|NZjKRzD3MGf!^p9{@I%!DDV7Uk5L*N3loXdq1p7#OTzGb|%n zR3clw815MpjrPvPuP^nEzguS~j`M`TyO+Cbu? z--J!a*G!ki5uaFx^Ry86WJ_bG3E+W!SKTdTW|NAlz_pVbsfuQygu3l#I=i@(3tdGI z5;pUx&P6m*>Lt0y-r~G``>)#R!268**P@D9aJbGhJEgzrd39!U-E~nv+^HOK2*-Z= zO2_`=K=Uz1mnr;w!yz*(gLj679rKGfYH+UT6FpSEmlJAtkj3G?{S6+kT_S^lw3;`9 zafdz9u(xBOz1?>`u0BDRsFH~4>OKN4hgEOLeG0g>u;zft4Ne*o8Zd~f5R9(w%N-+} zH)L8-b`~ESqk(1HLhX(YTG(U#jO^@E-V1Icb@fQ=+Y%ozYw`m|<*rJ{JRD;=QZN?= z+TDc10G(UbF}P2OTVGOVyk&cF<0I2}ZDAx@#jqKfnos@(RPx@^@)|R5gSWd1m^PJ$ zBeDSSwRcf*&A#BGzK6FLC3eoQ{V5{0qz-LTs743-Ros^tfAS80-OCiPB0X#+qO4sH zT=^NS5PT~Z6?k*jGBuqj2!WwJS2!YgUwxM*pXO*QLK!KV)Brivi{!zr_&Tm<(n6d- z_~}Ynp>I~R&aVBz+Rmftltf>N8rWurd3sp>aQqy~`1ia@NQDb`htUd}-%7$WHk|_{ zC*K-@n1Whl(#ke)MqGXINUmCxH_vdM1lRL=*wecqUdcuXbooP6pw@R_?T2(bHfDyp;q9GcZ3@E6~Kq?zmb z`t!a`lXqTG2iQX_3Iw99C`O^>l@AZPT<9Gb&f18vhN=LkFxFXmg`IEBe#ZzE8 zOsEgh%K2W02hUpl`kqRGNIsi}JvM>Bc`GfbfrMQ7 zQeIrCaUnsclnaGanvfIusKZg-Puq8f<5Uq;t1TnkrA_S5nljby0_#gRwXkmC1Yb$F-gQf4a|2L|Rhru`JMt2BAw>4?!!S~|`<*)BM&!fdm|Wfd%bzIs^Yz4qZF z2X4l7hx}os=fPuNsCxm7s+~7s6v0BTre#t6a)`)+G2q-pHS|we0Q*9XQYuUaQkX?9 z|BYbz3ozOQ!=2WuMBzmvp5Q@iic?7BXKtXY{ywmtg7Gwo^uc)co_m zXLcd1q3|Fj+!|P4wf$MVu(WUv_y!Cx8vh#*L0Xg=E;nYufOxQ#b2+6rRy02 zt-V8x2TIwf^5|L<&c_3+x zhASAT)x4jmV4K0{_#=`@KTKx5(4{tKR7ppPST-nR>2zx>btt^Mxc@pGbDU&-Q$Y>_ zzyf^s7fTR*DmpYORV*i4+&S}6)eV*ze>s~MdZ2oBVn^VJpi-*o= z|HlOw*G!|(!tMu$+QMidbObz?J{QrWpB{;T>s49Ctby-9Yl<|O+m!Au*A*{dj;`ii zV?(F%VC$|omlvqnsd!o}J)QrN=MLp@%xNe&+dmgtffe|ZvaqnUL!XjyDZJI8g@|zc z7iz4SE_+NS2$*m!nqQt?OOmR|m%bSTwmti3;Q=g0UlyWrbng%~*HwSU5o4mBSQFVT6y@5>=< zh}1o?BLqE!#6O!=WzUh3rj|k?5>9MEw^m`hy!+8hn6e@Z-nHkxl;AhhBxsKZwAFVm zK}+tCztFt3q=RK~c23dQ;tGT5H`2w~qFuSRF)C@La8uTyTbZ zznonKK(IYeR_tPbKg;pD|f z3izd3NRstQE#=5Yop(ZYO!$Q{H^BPW%I-|E|9m7;n_66_v5{CC+jqOr9#04$;J#Va zPQ2TYzxCOAVx4yCr7^Nq7u@(|D3H|6=5Kj9H|?3pLwmO9dRw_cmq zh|fxb-=;Xtv;!%IkFLhJ$jTGrJ*XluseY@oes7TELovia0{1FmD;c|UFP@fa@5xKi zAcwC+{ukNS_O@}FuLb&GpZ1TL9v_LB@Z?&s-Z6 z#IrY>%mnpmAOlh`0UI$%CrKpR+}F3IxZ}pLt7`ttBOnu#z^h(BYtFrz_<0X>@f=lK zuN3^J{L(S0rf1)ymAX66d29mv4{}Bu&6u=rMR%1yYn&Q-ZZP?$E37MmQDK`W=)NwO z2}+Gbuo*D}ee;bf9aE$HRCryIU2^?1O)X%>gNJ-8$GmaRN=59invoOdG5rVMB+UT? zNpy9*YO65l6Y9zPFjtkGh&O-2+z2siAcbqhkO0_!R}&PV$jR}MNgH$4fBDe*sCatF zPNwD7wGDM#^7K~m;=Kefn@D7;d6_aVPMsinFq!{`G%rwjwBX5Q)F$nd*3EeGxRM}) zZ#XQH!a!__JqBthMZzZ%kqm05IaPeF|86FgFAiT^Jo&z08i?y+_J5nFX^i*d0XNRB z9M_fv%O#f|wjC{1Sz{W`(P}BiwO^&jv|5f7C7Y)@OR)Jg*`p zq@-5y-F?n9wzIkHI1T97lYmQeoupi$~Ei*#kySnzhl!)1BFs;+6(-ZFaa*Unz690Z|$twtzoLB*da-+77?0CI`GN&RV6Aa z+}HQ_U)SL7C3_~n2?7X+0iS5+VLp7mX++GAgw!Z131!h0*<30m#WF367f;pXLd;O8mS z%W&qu_FIAw?i$hv;|L$MMK*H~BPL7H;Ju149z)E3CjsQ%U5qn0X(a8PXcqS2vXjj} zR?uA9lgeO+_!`L1#9H|_ko=-w8<-Qu3mMuAU8*Br$@Hb)dMSBVpTk7|%3q8>n2llxi zO(54k2`!n<_gHA{_Eh!1fR#rJ#}fI(KtmH3z)S@N6XP{P+7I=P|!2I{mF$34_;6j@W}D0Xx*Qo zUW7>`+Nuw#Js$rE!Zp@>XV}34BDd{QkCGjIV&4sg6 z@9fkxfGYw;hjD{M-%V%s7T_cqJRo6^K12&9oWAf@03}(YD(O}%#u&Qxr52V z&bthQT;uV5PjoSwXKHu5AYb!k^H#^zVn@Xnp$9VVgWnf?JjMAEQ z1mWal=NWNfK9Y!_}xbkBa5EAqil)k~;`J*zPC>KLu)Wh0GoJw(eV*nzA^Br3Y$U>#QX?Q^Zx@5LsYTFB1 zNcMG6ijRZvmzcm=PtEAOI)-IePB-Ip2+h_b9Hrx;v1m_PmlctHU^KFy$5&a~DEp&_ zAWRjg42{2YStkw?=v+!`TXd1KFj|8dwv&c;*!+p4!eEDZ_j= zwbfMWi7~lRv9lj}FPTutWH52_2FngGq=Wj+BYpw%2OAQ`4{RG?E9zo$Qf)-DD%4r?G_&Ci?~k3rWezZFO?lCgUbdJ{A-*AeJhBk!|%JljI_D4V?2A)1 z?TsupH-DG4p14QY^K44fr%GU%?VN@h4~RwFDw=JITeNpC5RgZn8+!^iCs8C{rLJ4j z+s^88S{ti}}5c z8OTG;*_-VaF=5~AU4KBGSvqhJgTh~3`rHc-&k!ARNL0(EtqG5@L*MB`kn~@J3H57l zPQ<5UIRF34Jf^#?Hf7)=eh)Vn#Z%loV_1o4AeU3Y;AGtK`qPCeyJG39Z27m3wd_!{ z;`!P+V6IF%`j$CHG5Y#3x!{a0OyI)pU;5cJ>_sbTn!l1G=HpQB{CSvtSqNUCHI7|Z z7jr6P;1Iy29Gt{!}}rKqc+2R-mrfah3`oL&LRTN?WzJSb|l(Dl8j z*o=qcrV?0I`D>X1OMb>TuSwv1-O&&eO{RGBQJffQ*+hm18n914zLith2T^2L;#KC@ z3gJp63Rd@qRZh~5qUrK<0;xHW{*qN>w8SZ}30m}4i3S9xOf$Z-s$9#EYoF*!r~2rs z>gPRG$PjS;1ymh<)F*wp*0Pqz-kjR-s03xj9dc&fA{( zVk=$RooU?xXyT`L2UUwVRW*N2dN^*xRgXY~UV`ssurKvOlR#M!d}I&g49Xr{_mg$C z-||*uU3n_YL%;B`R%4^&{zxnUKUh{JVCmJui~a-~$a(Of-hKQAzhbnzbTl*pH&^>w z)Mx0X=3`BGWc`9&Y)l%FDj+(LXZ3l+?FU%m{UUyNLMblel z)G#%}uiCy6TZAFNLgf`Sn6zZctF$YJqWNJq0rd&Mk+ordKLml8Rza%1`-1E_V}LAR z{5q3Q)dKXa+1m;&Qvbpe+9-{@!|N5tAAdPPlK&2h}moJ&B0zo z7My0U!8)eo$y4PqF==Iz#lK7{n#wBj)kECirq-5U53{n9?L!$>-+UG_X`GcW$bI6S zSH`DJzrZb0|4fyiiGjf@r=&zG4+_I>bwp9I(=_Y-NF`zBgrOgmEogN>0JZ-)+V++L7LdNlc9~}ZV1H*t5v_*YUWmQkO zJ_ec)V0O5u-{kOQMl$G^&ChW>xCyOV51+M44;bcw!aZfKWxVSQ5IFAQJnco092MKY zes47dy7m;Qyj7=2PKdY3IO%0gHl_x=TZ8w0?7d#5F%hs}Zkp)gM(lYw7!JX&_T{%H zSNeRUTuKbc!9ZriDiJC?r+&CIMpv@OfqgY-rNZ|oU^RzTD;lGrrnZn^QBwjc^V9RP z0WJOg{ifV%z?S6Z+Z`|lV09DYUDjYlj$x+A_Bf_nP+HlkH(|DPjBR6K8dj}?M$75{ zCfV3irCQf}y02@=SFSN+KvRovCsa(~_RXo1jem1Vm?ev{n}ckUgJV^)|7DHOYQ?w= zuItrf++4C5Br}nqEB~SNdX6R$lXz2tv+d5XdBwBaervWbrE=cu;@?jGAv?UFdpy#e zxTRdaM0Wc-V5OR*)+DML1O2hqP%dKLz_mMH9Zl|Dcd7R^DorLQPt(lx)!# zK#U05|7SFG*lF-jgWCY5lw!uF=UoeXD`)gS&6ye-pU&%qQ`STHRRt}ik;aPj?#|L6 z;rxN$iyK#2Y5!_sAx-+fE}9u<~5kGQSHROPZ8&-*C2 zEWKfZ&f#=*>TF6eIKYgOZ6(=0x78|&IL1=bO`0i;2RB%5;dz!v0qg_$1Y1dm`0(dy zPZS1*hUAlSa+-Dg@r|!g%}w()>xg_FL2_H8m`V1&U2gYbx-R|AzrULuE;Xote>cH? zs++V-1A8vFMB(ZXA-b) zo;NbhB`gC)?e7tbWS0yKEOC!E zSa{3U&Jq#LPF=rUTzKunqenvm;`Tq@vDYloBG`>7;D2FU_MWEYHcz}X5isTKBCFzxJ9)!4Wb>6?(j-3zpnM*E{=gqouAU}mo+9y za<1}!&XyPS&;PzHk1v&?-o?5aFc+%uf8@dzkK}~)y<_TtNjbx9`1KsMzXc9G82xG} z#Ee9C=zmj=s9ec3sTTE&XI~i!q1LhmcrU5utCI1x2?RmcfZ+Y9?6^mCSs!N10*W^S zTw)u7`ZDeZ_ZDvSA=Kpz(pbsbOtE9n6g%VpIA+Omr(ZJ_@5>yz+}m;?{h&i;7rjZPg?;#S!FMg5J4j(^qNK!v7x?=Gkj}r; z@14b@2nfycSbj_~2sIkz_(ZxJaF@ODS5p|&wsl6B)D0&=rMy?OZ_bu2#mtF&vo_aJ za(AF^5T|LSCOr2S`gWCdbSu)4tKP)1W^LMXWrvh*d9=!AF38~su~JfMwN-&|gk{`& zf~VlW!|U)O+p?;4+)7C@*GrTkk}A-kah@z#>4yFM8_;_ZecsdLTD}8b8(A_}#Vzkt zSnfRZ26Jr0VsDCzx?6nGI*7c)^}r0{;h+W<_xpmtX0O!*KJR zE>eGg=6x7mP|-x!-KuTd_*}S{rA9@GxQ4ED*`j94@BrV>Q2oG#ivW z|K_*S9$_D(U}>ZTfq!q$W8#pKy7)>#>IU=RwpKmI7f2zAr9+K_HXw)DZ*0$!lAGm2`&*UZ7fbGx%u*gPNm1TB9jQa78( zW9s|Z<^;mS-nYfW{?3c9)E_|}?dEq2$~7jz(Zy^jXX6BU|LNye7XM?Bd6%Mg$qHGW`8^ws-~l`+9!w6j-@eu#cn1OVa5- z^#+=vu?1UYe+C!inL+j86qi?hTcm&|dd#(j5KL@m1r?20XGtr^)4MCMB< zjaBV#|4Uq-j^ACP!?iE-ji_V^RKUe7^*n7fJH#}{p)W6c)}tubL@&M`shM3R?5U|ag#d|RAEXC>XCd0h3D5+THeH$~rvCgH~%C>__i(XYx)eIkh z;}jQ3BN{MnHmv=OCbnoBHv3K2?ug^Ef3M*4DluM_37?8q3mdP{F4&D@^->GK|JXW0 zNkMt%#ZM%siR5UG2u(#CS04n-Oxr}$m)Tf;e%i0N}y{8V7r zvXN1JB)9D9YUk?{(`>i@%j~Qd4M+RgB@=d8%D)-kG)jSKyO~SI$3($TDUieI+Q)~( z7RX*`-;|5`jXsL9Bu|>qge#)m1DE$&OL3^RpX%#*IZW92m<5vAE*!g*UG3}bnz3OS znl;^S?)8k@OAd)cy-<-AcK_J1m(d2HE_ONg7mD|(LcbD|27EaJ$Nn?H8tk#F=ctG& zP6NO73eX?#PL$Q50N_Qjg}3w#ytH(6hkdfLvh1C2sd#TFfeFQyE7>mgFdeA$9SlST z&(@0`H2RFap&PymU<_VJCsT{1k972zBCzDlTPu#P`&ScwcA|mq$Ss3zWfQ;d@z?uZ z!iryN!Rmr;s;XU*cXu~MWbRd1)Hxy#XKO0KR?N!jY|Y#*^9{H)WaXFS;Vi!Sn#e|( zm_P<|x!-Yj+WL}qBh8-|v6Bt6Fq^62rO@3@l<*GCGC>I=8Z7N|fxddH;Y!4;L06#E zPBpOCt&OKhf(`(am$A1d5a5I&n$HVNIt4FzPB>;4B*ev)xp3bDTJAJFm=8|JeAtJa z+pB`k-8WWb<{3%dCzLxghH zWEQ#E!M$?W(NplJ&ZvC{bVQX56)JrKd%v&nrd?GrYOAd=K&vT^SOFtSu|3{(&cB0h z4zO--k3e?lWsdV4whQ8dz?FZu%gmf+ac9ExWXP7n2y&GU)4JMMA>Nixet+QZ`~UBlu*K;Te`br zXhb>$r4bMXB&55$L%O>gW`JSl*`EKo&vWh<_Z!~u0zNZ)uk~Bk^}Vho6jR%JPAwNu ztPifu7+Kkmv1}i;eBqvr#0*m^ew5{kMU+i_&MQqM!?dfyYGATjuzu0! zW1ehDdKT+9=9hSI<(~dF75QI;Q2Qrud9C%Rf2$`PJ1u zo_u}>?j=5ul_%s(Ms_*Up$1B?U(c(L0xA1={|?LR_Fq2bdb5v~AT+#w zNx`d+L3YG|VV4;d`YIz;#5syD7ix~+pjOt<=3g%Jry%rw;y|BPW8T{N%W%6PmQpnx zRM=NsK=cJ#peh6YnQY&0{44LHpgI+QDwy0y8GFUt37@F{1bS36g7W$jb8LUTt2U0L zf>k2dDHY36Sb+GM;SUcR&WYDtLoB5jVXePj*i=!5GTu`lWhAta1~CFjVI>6h=gs$~ zSH+2t?w& zO!04~+Oa0~u44^4W#nVtK^SnZQAC?PtE0)=G~6I(9BJxIxMb*Xq?G&i<;$0(ew+}(Oi0`AYML0iU`X&B?ue{s2BFUDcqujQ^OpAhU23H#-y6M3g?&!422%HWoXw7F)Lr=5XNX7x@X3?q_oX>)^ zYEPTgZYNpL%zlw^u++R3Aw=jFt2_oiz&GZmiS4a>wt`-N~;XjsBC~LBenPa)+#ecPhvP1Gp$G^d9hjvFj@^c zR9XXDP(h;BlcUAx=uH2IRI!inJeItRuqP49R%@}`WeZG-VZBR%ITU%7dDCPNr89=J zx17)DvgHFPNkK&X$R4ETj#lhmp`57Nv?P8HXIAIU%b~0ic=ptdUZJ$Vf1X7-{mxzX zCHQW#XE{*4_nlYT|2ICMF@H|R_NOI(Xbo}M3w=Q##(OWmDl(HTIXm)`_;?vosNPrT z#~)j$Wj9OSpJUHlKlbZ#*WZ8;yn*YG%6s(xumH``@3#;UkL#~F{8^LpCEj5H6dF`) zY*Oum-%eJ<;jLdevIw=cNXro{vjye(@Xps>5Uthl!T0rdkuYC!ZLOR}S{{!Y7on;W^w7l-oau|~hHs3jUy2~7*RXNGPgmCo{a2`ynpgiN^hiMd- z*Q*LjF+j~foYgGbj()s4W!|6IpNK+MxKy~boIbKKxJNybz}_D4vCcGqsqQ;D`{#|1 z;L-g6&uodf{f@zB1s|XaPaYz`#fVozMRZYc7kX=`b5ws~m$h4a*@o#1c5uvHd2R$z5!n&2;fPfP84uqtH zYr7qr2Dus@aD>jiy$bdD_)$6&Um1C+AVzyH-ot<2@+08qC57b}jNE=Snm(F+5`7c# zKxT<%Ht-NvixeUqv0vr4hFvK$9Bc9IWF)YOiVZCym?k+=7@8*RCp4v-wDM5OSoSCA z)_lzJ-dJvZZjitI$8~cK>#fJR(Nw>;aT(tQbN<<24C$=JyS9Kz`Tr5$9br44;$Q~@ zg#=AwzNg5-!j~AY2|qUM%()51w;zWZ2ZK_2mr*sC0!Yi?-@0R5M3+-P2DCcsvxO`e zUh>8fu@q~ch1e}-3MhY0?QxsH3iudjc1t?jc%ZBB@sHy&9C0_At=|#X^F8sF2_i>$ zYWUDFVREE*XxWlk48oXIwxUJ}2m}Zn&0&kbiVk`q6K;)%aoXMq)8c>jbs!3HUVI0? ze26=g>G_V9r|9CX0ty^xNm~zHhDzqG&a!TgwM7>AO0hO5RXe43OZ;;cpXo9lJ3W95 z;sr1VXyZvez*s?cgUYX1w@47@HF0<7k}bBk(~)^+|4+kspXiUHpx*yewdk;K0o(~2 z--UJF=~aHn>|Odtu3=yU&WjUy+GjytF0L+FpFq8_Q9)~bqy%UCrDd-ptgqiFw9v?f z4;;1J+eMB|7X0n)k6eM;p0z_x;OXhLFGGHQ{$Y)0sjYz|$tQ|4-goB1oYix~hab0t zSbv-8QP<+16@N8o*+KZfd_bps6;yt8J~c!Rl+khX|C^^eM8!6&nOnR1`c6(iQ$pVR z*ZcYGjCOjng}xn-2;DuWvi3X074Xw6GONr`m=#cLDgfd+Q&|gRD&TzOfN> z^qQeVd4CK|u?ASq3-v)6>tsF+?ia^rWpZ z(|VFa1R#<+=!`{S(B<#aOhVFid<(7dL|aq}jr-m;QPqp$4vSZE9~W);U|ks|S@GXCw8ay0a$f4o$Cy#$)c{HL%eP z(M0Mt!sE&1_Cy+yBd{EW+uiM=w8*fJTH?7gnI4(>o685TyN$6RtvsC3vDj%iEI79O zd(#Q{Y0U~4c_w`4OsuG=NX!FBHrQo%>2#i6w}S@xraH~F0&g4fKtE6G5>dN~$e zPmH}9c1ZwC%!AppGmEXjT^qUE<$!X6x*m#7qZ8tRfaCiIKXYx#(Hh@?olT`oRb@OLMNAK?el)I)NJ{Lw|^Czd}%L!v%@|nP=Tba8;v~RQ$~8d3cqw= zcaj5&FarJxMB{`}^3B!Q)k*1~f{}z1iYN~XKnT)W^>w~Y5@>xZ{Nyt~xb@q-c#KW6 z@nOS`G$bIU52~Db?_pH48QV{!7vvaFrW!d!4A;I zd2f7!3~lc;oSZ9#m-8KNmYm!U56#W5icOpPyPMrY&nD;N?nGY}Z0aW<tpIQ23Bd^wf%cFjiFuvWldjEPR?MpWD5c{S6H}Y#tihz2AF5DVjuXG zf-}IHYPHFErugWWe){9;B0(D|cmhCTzsAc97}3s+zqu;CLkXzuzKOBOe94OHz-ivg zBIfPRk|4+dCq*?76ftmCPLHMAFci6khK!Rzd z|Eiw_tm!I#UA8CfkR%%$_29D`J%~>UpQX}KIUGc~-5NGK`w=wh9Eote1z(VDT|8w* z#5cCO4tMiS#GneLvcEQlO|q&*)XCJ$j9PqSYpA`7KFPUS})~G$S zYZ~!yn-bgd3m%6>f(uA)p*0TBdXVV-0w|OJw`++fr1``nb=gpYC{7@oB-?LH@aJ!B zrN#P2PT-BdZF)-*15=?rNW>W1Z$JXzWiB0eK0lT2tA;?tn#{jJn;EwISx}V$OgCLz z-V))J%}L{3_lW_HK6GdcDh%q=cw|^#s%KT}?j(u$$@%n81`esn)p+>AlYfGFg*Nyj zxCr;!u-^H%#t&ASpu+8DyKuKmN?_;ZDfl{Ay^-w`tCPI@S*W3X<#L-6Gy@sPzR)Jh1% zQV~Ar%!d_Bd#k6doV+tUadg00t&MU@y_b(3y+e~0b7BL*`M}D-k5})G4M4w^1Qyhw zJ^P#k9?rcRh9R3)tD;uW7Og@waHt!CFnQ6^8S7myBrJr=lKXf_F?< zV!pw!pBGDio6PpH|Ef>N`_wPN5wwRgdHXU7X>_*b;y1KWRbOGjC=f=8uL=4Iw4#1& zZq2`CGFG5fely+slI|X(&Q=ViWmLk{=~Pz&THgCg#WT;x!A?RWgvfSe>(P!VQ@pOS zf~ms;($*|DsN~N-FQ|jxAUOUAu-myRq3uS7hWxfx+JcugOZ4i>3kzd!e4BC9)tmaW zMRy@a9sZt81ibMu(N8VZV2%Rwb-Ee=(wx|OK422_(dPy#NfHOP$i?|*=xY=*V z+s~~XW1Vu;);|VB6j39`is8AKOK^siHMnfYm(%M)(g=Ev+2LT&<42o4t-niS>ub6D zdgg8F^y&eXqe z8lnj#;MByNair4TZRwKGy_cWo7dRUt$a=<7cJT0JmCfQC_4CGuC;UZ8My|;_{LTdf+fz@b#CIR9J^q1FYhdS zQTnoyBR9PI=)+|!khSSanr4Ybmg=0=aks2%4_G(tSkrvB*-eEfK15G!RUNjZiq`eF zO)hSpm{(e_hko0RAP`5 zMag!E@fScl6HvDAq62)QSW^`ep@A6r+Hz z>5vyV>z1k0^!oN+Sm(3d5p2GM->EwK?r@-Zo~T3xe;*TeYWU#7(>nWCcjdw&jdAezfLIdir;8D8{D-vL{>HxjbJ@G zRsO1;#a=h%(D6bet?_>#+0NfO;r@1BX_S+y76dR9D8x5%EdhNf5%ztvPk4le^!#tg zU+o}CHAWsc(KGM#+M;uOy&aFb9sw4^=`Z|FR@y2WOSo|Xc~Ii`&Dp4|7;SOER01s_oX$fA&D30~5E6}hxQ&znbF zocI-B0|IL!Yu?8uBT_~7df(S)5}TZ7)n-J@Vy0TiF~Ago=4a|q{zYBs-S9L_in9kL z#EZ9M{aG$6xDy|Z&oN_`(47Zc(anopWdO8p3_iRcD|iP-TP zt^3bYX&8AOn6uY!c>7w0=E0YFC}7WZ8|&I)Apg2G1I`B*!qukbnoKq%OB9%cPe?K!5E%W!S~&ej9uTw16sABDQPu)kmjKut18rd+g8=9+1-O|$w#vvTV@f3z>+K{^U4dUTJ{!IL(9zERVs z)O5Zfd!cpsdT!H{J~Vfvg}4re-iMXWWISiu$0;^Q^Y~&tjoUA=+=r53^r^&#&`ra@ z)QgmG8F_`NH4&lLtXrSUEyHv3pjTl**fJxF`lV!mC=rMf>$S*xl+*-I`pef?7hwR( zR|>^7C2TOzk771j@s0S@R$U;DUECd8akS=M+c;sGavw5!evk1MWa!q#|HZ>2LTD^$ zv`^ZNdLCAEKPa%Bx`{o{>!==-=8MY2>kh_B+bT*TIS`=4gU^z2QS3>B53k7NE4jDb zN)W~mlTA)AkzVL)*}w_CdKD4|B$g$H3Qk*kQP2-MHiC&z>C1p;l`u4i4VH z2`)rumcHh=C%vL>gIS;P>z|PL20RY7P$O{-H6pXwf;`+DA1GhQ)ykG1Ab` zm^%B1sGL5jbvD?-tyM;KNFnj?!H02_1qJy+9>j|3K}%p3sAh%JmU@i*PNf1a8r9Mj zh$1XHi)C4r!<_8lm%e<(omuJo6O{an*XyxJdcYk|*L$PIb=y47=jLhpHO55WM%|&I z0E`o%L^+L;NUOwDY2+J0uf7-DV&|ltrce4P#|Y07a9k~I!Cc~^0dz$^|Clrc@|DaY9e{hIbKxeUVS>_qPNFn|^gt^@+Q;0Z{ecD!j1 z{O(*xl1xe0WNyf#UAF5qE6&S4Ua}t@iV(X_2)ir`c^hsFo+d1SW|z_oq-_{)P4z~I z9JqLT@f}NuLTqMvU*k$VsUWxyRqEEoZclmnVGN8r5M);(!~ZDeEZDpS;_?9*$UFvDs3$ zLSKzKE#)&<Kpbu-U@UntSRRt{f$`+w=`=Z@P}~r|--3Hx;AFnOF!bbpNF;*% zjTShtzp!^Vkp7jVHY#BzW^&EZsszc-rsG?`tOdo>kK>nP_Rn~Pe4XmSU{WNn$2!F* zhU7qPc#er)_I}0KqjYwZ$W?enb1+|GOOV`zPn&*?raLAXn#LYn1Us{+9;QlMiY%Z> zxYZi)Bae)BP_UB0xAWM5{@HTT7xQcM?P25IJx#GODkTp)NL|5j>#B+u;)LuvX8@I@ zFTbfP-&Kx279SR?Du`X(*>qW{2B?X_pIg6#5Bv%QW#5BpPEr>OA#4|l z+z^;~XSh7_cc<4p-MY%!ycJnb*aO_oI@5+IN)dZ;^^MBo zjhI5-<}v}pw|BrvJ5Q|uj&N^F$|tDQUoLO~eVW$e7L2p@F^9(VfqgUXscQNZ`NJhr z%5kO}qb-kTR~e;7<2zTQ<#w(Le%LMo+`vxVpNguT|8}?JJVm%VTdv7frP9I;DUaq4 z!&Gnq?+u0sRJqi9NX7NKn^}9V8uCadK=B5(cmCr_r*9@q&Av51o$xExN$0-H>cS84 z>uf3FK(0)GNuRq7(bvcb3@#OKjUp;X^PyN20{y+_l~ja=hvYB!`)sEkw{h5fOQ1`$ zI(hIMC7Qnp#U)X;UnLB2;){ z+OwqF7GU>!JGvDvixm)gC|Cp>i+wYu?)Po-F13KBgNLlYEGEZ|>Fj91QxAR_pJzWk z-vW{e^Fl{bwHT;Xo(OgnCzW87&j%a%bz7X%9*KgHTl!Z`RYdp1r!g}9kVi(0MAw-^ z(4+7YzisQg8E(KeUrF-pi~tK)fSv<|W!ZoY&|mXj!#t#i8RS=xM}l@*Z@N9eT|=NA3>6lCO=N2q6oJu8ecZ^9~1sl>V1s=nH1hwP>Qvz zw6ru?4NH#e339+P!Jhh>V?)>suJS47S0cQQSjib2;vvZ1@4p$Hw9^AS!6v z$$_BzcB}C)??&e5_936wvoP{y`3}ukufs)&;KFSozH)`jfW?&Ex$e>F5Ep$5gOW!B zwMtYfrbWwAeig^*tpf5b&`MBot(i`P)6BorTe@M%tP^s4o%$eIaR%0et-1Y&USfG0)BWdo1boc7&G!6ACMlyNGH~hu_MP zVUKBmWAAe^257Bc#TJf=R&oN5)G5rWN493`AtUnB$~Y0-;3kx=d9KkAQ9kfWQae;IzG zR~u#ffs)-TX50A;vS&kKRn}6uppUHv(C-$UvI>va&RLl`kwY7>fVtd|4A^N{jZ|F< z{Z#+)*z_3W__hP$X~Q$$LZ0NE9e+g?cHhXSaWtc4O>`vCd+QN$c3Tba|0ET@Z)$A2Kb}3i3n8+W!PJl^eOIw1`8(X0xlKPYy16gpM zHgQQskN{1ukbRAokd>TFbfCdvgO%I13sX&b_@;dO=GC1$kT8J(aP&fzKgRzgIvDS&=*)VAQHl>7 zFy-qnXEDM`NS3rNC~}AOJz`A{j79?v>F5FCu`aji4DypMq|Qhg0xC% zB%h#7<`fWfF*w!%+Nt5{k2>+c#lc&bBFSfylvD2)GqnNhM$bJ}%*!{~GvOM(Pp>r( zuB;*OI{`mRWZGA(uDu09N_Mf|O&&+LSu2Q9Rg5{x`VtI^k`EAP{#a~J(zsXCJc_s{ zB$CEV9P6r_UTA&u9?9loMvSM|9*IfJ+8rV~I4RUuC0?cwj!pJqW1O~pJgt9S5a4B_ zS@Ts3P9>GVXc5|`!T`4WnES^wMj6rd9}dZ;302-6QFfGZ+|xfN8U0F*k5061Fr7|O zvMxM!9?ixxQSgliG6@YvOT}B9KwKn`MK6KFN<12LF)u1{TYGJ&U{n+oK(Y!wncODt zs2I?@lvq{qg81muuT|o^gbk=3hmMh6BnEI*`}S>8`Rweh;yVD05x-GBU`NL*An`6N zA9Ra2HeOf{Hk2!W%DDoRzpXRGWW-K%GT@+Y8o&$PDhRS2MjDKE2)s(XR~vT_uqN^L zvV=;!3Zu2}a;RV5Z=tAY1E8`9JOJ!g(Eg%=HQilXs!=P-e!1EI z+mGPX>kTbV66{qG$5kr-{?Lk715OnnEpg;WO~inc?}~Twg=ZB8y-f*r6!T}-6LB7Q zd}>%Xu2ur$;9O$uUM3pArl&(lY@mFIO1_DUr1xC?VGynOI{={(QC6DWu*4Q046wMA zZ_S@?xr@zIKXYI^dSEf|qG2VZsD++iu+DB0bs89P+^yPP?R_*4@}bkO{RpFI+h zA;^6%pR$&wSxjoH78#0wjkg*5Uqt_@{Iyhiu$m)39T1?VTJg$f!S4fWBz#S0r@8Qf z!b10V>X(Wa52<&&K>GXG!pw|K2?AoYDQPp{yz61bO0{kx>cvx+#aUSiFe)X;-d5+* zDYgrZO0C77FVwFyid`U{de-{U(Lle`rHiNBFKtF)Kh>_IEqtYqcP_C5`BKu<9>b=+ zC)0VyAo?cl8F{~H(KEj#X&9O_Zxj!6wh=G>D%1vBM9BQ@MpPQ>?Ehf_Y{S!#y4noY zT_Dz6@U(|e-BHL32hpFLys`Y`lTeLEDv~zd{|`k07EW5upM?zi>y%9jZkflf!JJ?- z9h6Cu>SwB*@ppU-Cj$$&7?ljDrtW?x7BqT5N^SsKWr*?RHLTS7p zoPlSXsh#+bPa5p+-}_x?4FgZyIh%*@Sz;4X=~i7*CcLV>?$*L-9<$r4tG!E$ZFNXM z_093}Rch|Z(X0o}oI9H4Wp#8|h?HIS?Zi2w*|IVEtDVkdKjEd zb3=Wh7Vb!AMSA9-F|Hbb2>F5_w+*ixt~gY$Mg6!qT(+x?=uR&VG(9TlgPsztbk%_f zCDk;JkHr0GZ(_V@A%fV~u&h7!x1DFDDI4+fD1y(Yf!%T=`=9-ZJp0H9ykE?GVnriN z6(Ev`^R|imGHhtU!1k|uosoR!nFf_B8=xcNdov@cu=qz)H*jJDJzbeQdY1DGj5y3h zOr&y1Ztf3A6}a;S%)(D^?A2S!bkV_mBP&9uXtUS+#hHBc-W+d&xtUr|t(Vi0*xATuEd#|P4(}-~sf-0` z68JEyQm2sMbRHLrM#_%o06h>co*ecfreKsb>76iCV=`aB3>Je#bYWY;My&eBGc+hDiU4N z5bKEjZIv*$isdj`>&d`i=yAyS<`c=tqq`i>U;JuSDlD8HZfu%-8`@m=3W|q?aWqys z?u{QV|NZ?L(WZ6gaMeMCN5!L~{+E=h_|_{==56H$m~~`bX~8ZZ%!+7$@-#@|(;Ik_ z5cZ+FOg+8rmoAV^2@9%N?{&b4{b8@4*qKA6KZ^rpDINp!%&46RMgaOn9ywbKgTRw- zAFw}twMA*PWzW42ySF&P_U#ag70}_uMG64fk%fg-kN?0Zyf6X1!sR64x8U#q!eX)G zs=p@;(E&gMtxrsr_gc(9_<5oUQ~PVZubxz9m78- z=uCe?9oHL@l7=n(8KfL^8ZNPSHGOBI^IIR?p_|pADX=tg#8A8{GmmFhGT9g(m+N|x3_3AJw0s$=n9+zu(stEV$jFkNa=2KS%)^hDmob(J{(?j& z1x1QSOg8<=(~m#p*N`t*M!PbK8;%jOksQy?QUxNy>atWp-+HO}#v%%%wa#UG57K%L z1TYX*RxyYJ!$7NB93QRNb zeZ!Lh2l*VSMOwkPt-rV^ilD<0aWG+ZcvG52N4n?JIE+z~861+uyx;G^HDSOUow8v_ z^Ky0+1LM(|!>axIkH)!~YPTav`Az{ygNC{rrZ^TCb2d$2P;XZ?>Sp!cL_+6spm=V> zMm(C0YsyxQLas=0B+|vAKo-6AnRRUd)ZKcaCEvDhMyfr|MQ2uXAz}_MU(V7gJ7ySo z9ECp868Qnu)4!&wErk<+=s2vk`tmj7?ZSiFW}T@07U&=upmcV8vFnJR>}X+HuZuu)I875|!;`gBBHU8~k#{LSrw0OF2hPXWkRN(%Vlw%59D zH{YI7v$Xf{qyG9jL=I^5v>&%u4oSN0I8Ap&L%k0wsn%mRGMiM=#R!&dRVKtRBn&I4 zZ|-a3f2C$TY<1d!XLX+p=0@9w2M=8^4AMVZDq}F0xXm+na(wNjsyd7D-f)MoODQR6 z^|eRvND71Md>j96TpT)WZwb#z)PE&&SD_l2b?r<5j+gh8Jf;wdjqM>|!y@#}>$A18 z1Vlp_*VF6R5`COkv3{>-m?g2sk2qckpRC=4iT57%98Q88+} z`jvu+4Otzv@p##}ikzKz_WjPz(e16aN+*%panrD7em=u`q8gRsgs-){zs% zfVcuOms5#OUA^@CkFx)NW_NTn5z&uKfEldu3X>t}#J$LFlvAZXEJ(>!zJXoMuT72pCs&lttIUozxf&&R20uYv4A z0^rM+9NEc$`Du3)pa9Ix%*?oK~lpE1kBxqydiFchKcZyv#I_xAR3l4lQ)- zm&e3rufGJP6Hps<{ZM}9JoQw>{VXrp1#ls*bt2t|{jTRbXC~R)%b^) zqEa!AySuDAh;LmvlAb1}i14LT-bklY<+F<8)mO{=k$Qq<1mOm_mvcQ z!dTr0OX|njFkIyARh2`DN}CHqA$fzT*b>xd0Lj@hs3SX0ypGk8oJaitAz4#oOYa~t z@>IAp)(~_4J^yN=BC*V(wok|I;SmfbTRATsXu!ILnXk$=rj;2GEM-o7U15v5-*HJq<5t$c8eeRix792TDQf_a5gHuoa1pFjBijL5)!A}z|Iy-o0e7jd zr3_c$pGiyXX}@2O^V%J?}wS z&QwXUJ|xPCK}JmglgP!lS>nOp(?3y@teD*zR)5qFD*UM^0>!3);YvzzplS_3dHJcw zfxv0(Wxe{ClGpfJChc80)F8uqa7L#;onU{PsiR- zgXTn9{AAV6i)?Isz+UIh$(ArJZL;PI?Ec6bgDuY0g0}GrqzUcXp2}RqDhkTc%zLm{ z`AUGZ((fz*CnW!Vbse9d!4JPWx{`kJAL@OUVB<*^?Trg))}@sb+jzVKaJi5RZNKAS z$$HF?Zdo4{!~fOOUK*pECY>c@2UMUNoW?ZJ#~M7Ln{Qp4m4wZWBwjU(d;dq-xsIl{ zG4_aRnLV=%+V7?oaFC!WMO-I&ij+BsUeI6SEokMI%vN$YIs9EyeOZ%+%&kICrdNe^ zW<_yh1Fv-t`ntjQ;LXx!T*)J*^0Us%%-XpD5ieExh))?wq}M%iq>z;_ZNaPC`nvku z9JEAgUK*I#ct@?jfLQv{A6!?OUvlK|^Ut^F76?YIV?F*K z)?5GV{{8ib_~WA*;4~>}7>!2|BuHFPwI-$+G_oBe*X zSjm02p6eVur&EQ6)ZF-!h-5j1-i(JsI*4pXfqRpsrA=&tbNSPYC-iF7i0|dS z4t|-IbX`XWpG`0Aa8^FOmLFFoTNp#fCOUR3io>Oe8r{)DbE&t=+Sp(0owjQp$ACFI zFEb$H3b!sU5fh3&q5SZB0{3&u6LLqQIE3n+4+F&$J&Mm1yb!j5B$Gz{MJ(E;Ciep)i7n z5~BojaN5J2B0Ki335kFa{dAogFoK1x*Km`G46kj(C^$2FJ#s#f$W-VPBJ^K z3y7Gw6JGlDHdf=Qf}f|=*1ex(?!B=sf3GN=wjse_wm{gZpz5=ws z7Kok~FJ3I0nVV;xUg5KtxP#qxdfJi^FQ_Zf;h5@ki|7mn4euwz%Y?D2479kvbg!(3 z+T(FYZT&a=k`dE|hM4qUuVZ!&tE+Rl7vU0Dk1F@vAIAp$#?(J}qwZp?geHnh^jli? zTV(#7`r!qM}Kk>*xl? zGTa~NHV6xT7!RvSoM%G*Kcs@^uX%NR2lVX>>-3l0&wHWO0)#j_&s)b5Wx9{V&rE8o zR{dgP)^3_14gu~GjwifnGnzbeu}JGaq4i4{C;l`9|*EH^P4feYERV2(7ht;qdBB|}Qg842y~=FQiXOTvAaPU-OS>E!+;-z(Ec-1J^lo{bi;3i9t3T%M zpG%#NjCGX3j$>25%nz4=8O2x{(zV#cw$CJV^HoZ{LqmU`$gWTlL4cmF^NPy>S;Xx~ zFer;T>vqU$t%1i!-8mhnyKHIcrmxmIa=49PI(t9ja#REZLy}MK?U9E180>bT^F(D30Oght76uV}Z2rJt%#jE{H@oG^MB? zvnRtS+Z+>n4;L#RZgoO`9ZJO-Gf287R+k)w(_a@vFszl|6`Ae!pcE#2`!lHIN2A5X zH*snYkA$wv5M1k;U&oK7IMquq(xE;&TzH5X+6ue2C@>(n{<*G4@GO>EA3x#l@Nlb6 z@$S?H%q2cQuBRS7cWloc%2wRwfXK)l)LO>MR3auivp#gTK1Kgm3pvpwOBpc>L{j~4 zDTUlG@CerQ=zq~l5v^uEANvn*)SgWp`m$~5y}{U__#}8r+-1Eti45AEw!vrx!<7G$ zsJAK3K6hz3TNSDV9N!7VtDy6>eX}6MTe)69(U;x0ncMk@&87Cv~^vH z<$Q}VY^Bpk0TY`tPd-bku1s~ zx}@Rm0@}0xQrX_Bar}LC>m>C{nR2pxBgFXP%6VNHmUibE+GwO;W%%pfvx_g(ckewC zcg8*K6p#SF|J)Bewt^Tmwuz5m5VeG2P?7!5hi7~nY?rwayhM&Zu5O`(ov5qs>BV+z z6O`r%4GuA}H?6aa%W4NC!}fKY!aB*ysbl;By7qS4Q|_&xrV@Dg+AyKMrppGh3#~cS z)!%kUeb=ZQkpyN6IQwt4wY4L_<>H92?6Yo}EH!ojm>smN+B$VEA~bR=5^64D1RFKi zd~Yd!ybA#JH(#tF3;&Hy_pnJ0E?ZJ3LTrH78&bv=8bo?-**bF>ncyH)XUN%^m8lBe zpe-(%^IeR;^Bvc@` z8-IXS7Stm20nWE(XP$J#N|=#2@UZ7HWUb(dboqPA;Qi6Sh69jZlAF@SG8Cw2%u21G zQW-K>H_XxRWDw002l9ZGRwn1{*II1087oCY`nm-No0=Z2t~r#9+VQrXnZje@!$N|Ke2&z6}#ohZX3| z5Mdud4A#n5RJIg&N=c7#HcRW;rhbF{mO8(uhra!c=DQiHVSM=`<_g`pn8Aj&4^vnX zLy0?4m^V)0^e2_#9mI$3>7IbgO8g4 zl6kbGzNHkv!o1kk0HHh)+Y6d+FZOv3E}sPqL?h07GbZu@FFsVP|5;c8*sqtkBL``H zycTO?k$6Wc<_X^Rv8LLcf)U`eOqvsu_kNFyExRvh3 z32pj|OJUhXcX^4YJR7DmTi}UeU21VP(1Ls=W=h|->vd@!1MgH?cW6ckZOM@cd$Qa@ zv;SjPpG$EyUoWcTBCcf&xzAs4A5hanj;3-||3KU5IQuTY7gaH0mfNEhP^aSc_yM%r z9Fw8v&UD#5;h9}obPoUSAR*WqQIduP)PwzzlhF~Aw_tk@4CPw{Ciu{RBn!Z3Rn0dH zD=jSS0kWUd(lfdMTFXuv(N*ArrCt0V3gr9Ut~zAJnGGeDa}PTpQ1y9Wn8)Z_;i(Y{ zADq|7EYJGqAAFDAVPucWJC@>1QPPv|sh4^=Xi0+s6VxZpw1e2}!31>ic0p@-lNCzY zv3)J;_>J8nK6#Dpm_hewkT|!EA~67`$6SpEQ`})nLKh+!-#1`aL0wN29jR3?YAn`M zPW~M?ye>#w{jGwy{*}N#g_y!@zK5oCdN(iuzE>zmh50*~*2gkT0Cs>8FJX-?zf)OD zD&)0M{6gX9!Y`Ks{a3LYP_Tsk!HCxN{&#bdW75VEE8+yB>0z$jPxK6Yd~TQ_L=`Br zYM4L>PY0Y!vt)rLdJG6Cq#*@jFj}7g2wGsy1uK9}BI`J}va<3ng~>EI<&kwAv;s20 zJ{Z#ZjX>-%Fxt7yqiH1|)_5afx2c>yH6a8E08ufXO1~cj>24Qdqv$$8XVKOoouKGXoZe!1VNb6h-ZH@W) z!L#3YQ#Lxls4SpvHI}pJ$MQk_vZs#xO!X$)scpKE9hi|rI{oWY@_Is`X2_1ws*3p5 z=85yaFrzJ;KsK~WX9IYHOI#-uJ(~?a4nnLLc@Z#H-P_AdW+BpST@2v1%SQOL=khG! zZgt%*gb19ee*78{nqtqO0%^57Q`N1tS^&>} z%LNfUYmKj1%>YcnCxcpVE4RU45}n+nbzax5&WaA3i$^cx}(<6hi!Of zw5F{Fkf-=OE|-pNYP4HJgK4=g1qQ9p8pTu=$>zAQw@MO!y+Brfvg^<}%DRlE1s( z#~=1cX`$t>+<&<@^N0XBz1UBp6Y1_j(Qb;sM*Szza}@)F@dsi+r^TC#h$=H{=MFBC zB@0~yB{LxtkTHla;c zVVCZWNx%46M>n{Y|AN%*Q*x5_oh+1i_cQH=Jt3LBSo7jTa5hIJg(4oD5A%ZJUAV=O zQiY<@ne4zQ2aV}U*^iX*HV3$}VmCJ?Znud=I1~a5UaUvwk*;Ob(Ylq!69>jVozTU^ zlLArFgN*bRtw;1hdzB|@dR|bk-+)T{B8npMDJk&A5@AY-VbTkLJsUD1obk6^S01kF zn%_s!9+_fZm-{r7&sX;#R`1IPrEC%Ogh2r3SQ#NPz}Ck=3vKU%^VOf9sCNb7K530Q z!K3KmY=F|JtlWIaj}6q4>M>+*#*@n$BZ!xmyf%`L+W+AQQPi7TnUWCd17#)9N_@jw zff*{xv&u(;eSr-Sx7!rq!c=A>6pns9kYHNMuH1i@N>_w~UGYt{l$9li>pPuJuppJP zP>x$H$$QsoPWRE!S>My<2RUEAZccsr^l67yow0!H*y0B8=ngD+(EFrg`S5sew(fxA zy(iHes;=-^%bU?tIjDpeaor;ZKns<=c=p8|iyOm^)hxBnlL4v@>r8S!Wygt2b~ZbA zDGweiH+i+BGUfURVL4T4(VN=`kmZWFhf=_t-#AY(4m^vL%F<~Zc@13_#-8b5&-Ux> z9Nv!!Qi-Uk=%RIDOSfvv^^THnb9rZugNFX?~m+ST({>{h~NU(;{ zLYo|U{n28+%>;;OFU`5hG^LROp3*;88xh;?utN7S5AGS8jV!uO+w;j|;Eeym{k0;n@D9k1<z*18Lkh7L@{&MqcTAiESvign?y^T>RD#G4 zg9i*zQmQP%2`~FCV2#qmLCzU@xzOpUvrDcc*KLE(?Y6)-FN)rBtWKm1loa@McyM|p z^f+LZaJDPCwqzxK!WIvRCmuat&cAWip)AkycwUv~h;PmR#7^*ZMOaCPN&j8uWur05 zk)M+pAkX%coW~CRdMLhkRqfhlq~LL!&4?K*3CI8Kc|{+=Ny1k5Q&%z1Fisgd{4^SQKqY&9hlMqEeRU?&4I32;Wo9y%?qLj z8-wA~x-Il|HyRbypRo9Kmqcv=aZEsGYmg2x*0vDi@?+7}R!<#IWh3t{YH@Tm_G!M@ z1eh*!71Vtt7>a+|iQU1X2_71tdh>cV5fnO{iQx#JZD%#Kqs{n3_rfUHhzdtGiw$LK z8S~`1uiuJD?d=lva=m5Yh6Wo}V|2kC%$W|RGB6*yxx!Wz9XA+-@C(1Cr+~{+T_4{| zAoFNv;>&fhT5RP&qf5FBkMYbsD6)wrP!uZHcdkUu07$tHE?UKcs8!|I3>>v0Y^wC~ z3j>P6QTEpMn*ML1dlBuZTld~k$f*)(qvS;~HIVJxpU`pqu^pyHV;diDGug0P>;cf5 zCH|c*-X0WI#nMlWpY~kw=#uYFI_`dOn<2Y`ZM|N(@3DOlP6PWmlA)^{KAjzq7i2oN zK=I+E+#8qc>@^&Sg3XW_-{VMXeW;H5imNdYG~$R4C7f|a@Ok^=95a%j+NC%Y$%KZS$;oL79&eAE-Z@KceyLN@EXKbSDt zmpAflmf8zRoA%6P*LdGhjzZH&f_p2Ac`!Os2H*~kC8c>CbptmAc!;J^SfLL4j#`>7 zh%{fK?S##W3r@I}T}MbX20m-63<1Lh^_l`aIZn<8?&C(AqCsvm*G$eFVAuB2C&yT+ z_2M^q1S&DBvnu(T<(}iHssIsgOm#r8`%6C!9q%||_BdwZoGxLZ49zuFCev*H^iY%JTO-0(>(cpeLV1q-F z{H1Elq{7ynGET`(K}+mR$!K0U3PuveD+U0qAaRKhfc+K)1oUawi$L-YJiOz`N&lF_ z!_t&EqzyPS0!?6~eBIuC+fdkS3k{5~9{hge#m-&-!zsB4+kx!~PkQI}gja zNsa!De3fncVoV{1Q*;Cv{)%F75Z!s@AN*v835Ob6Ltph4PIcs}(G-R2@aGy?ez|n* zZBr>K#t4q%nv%(W_%+4b8~L;oF|+W3YYbbX>6M#MXC<;2$QF8ASDR_V zeio#{iZXxv;cZ9-XPIK=>LajcefSQ6e)3b30fjcZYy8r0dd%rD zGXcW@+z>d#f=jub{JU;qYt5^P^sQv2?rujCGH_cy1qY#M`|a;nU3q~bTy>bOw3=%q z)*#@kYyfrd+2SXJeE=GMt3iw%i%kn3^ z4yWi&+iTx-(4|O~zQ@696#^ItfkIrB(M`%;fTJ{A=X7AI4t)5WE6grIVMmmFM}xfG z;O{RY{zP)#yD%-(?4h+@ath59#sy&7RaA4XbDm_QbZHqiXlL(PH}J;##7H__yb%M8 zT9Q|#v{n{nC&rQo@P9?ZSRQar+Uo&U#WK#Deh7c#zy?-|%l7zGQh8WF%ori64fz&Q08~o;=gSxb>AaTc)@d$2HUF=&y!29Zn445 zS~772v{Kk=9FLO=ga(sY_*v!pZL_;3Fc~vf#x8syau(k5VcH1OOY1~To7sK~;KT!L zoLtK4hP^XR^Z#y>Z@0KhU z8~j&*+qbj^N*xwuYq*^rI%N9t>Xlhwucp?e(s_sjqdgreg z9&lSQ7cw;I2tc!jdrl}RH<*M|)D-a4^Y9c-#m36DBLf{ALmizY=t4)3+APqC5e2PZ zo%cm%Qt5vtq3T|L7zSEp*?-Bh=hJUekO4_rYXRzdkOxJ8p?Ejj*W%*m_aOlGK{LyM z4zaWG=bU<1^P9o0ldfTcI~aX?f~3Q`qwu_f>bG#8GTw#HReC5#_b(i9a_4oe&j-j2 zoh<5&#l#Gm<=2rCo~x9>PdlRp0^Gi8?Ef%f>{6u4vTbMY%8|3{9vihC04ygM^62+0 zm{a*&22Al(z6&?Yz`lTME++sE3-`z`>~l3QzL+ar`LNVVrf^Lr+s^P;ltq?Rjwrr5tsNET?t3ltX6p_^6f9pJ+VzH2&O=6l zg*;&?!{*(FUlN8Z34~h=#e6?SSn*oXAWjRoI^UCjvb%6Hekw40{Q0rlZ07br?`<9; z*pRFAC-?yVyZM93h9SG<3VUD^j9RAL&w%u;4a*a5jJw-Nwg1neNEmm>imt-4)E&1y zDL`w*P_nm6UOQ#Jj=*df_MDqqPaIH-!>d$zeJs>!>5*zGcUO`v=dlHPrQPpEvU=p? ztnsH2720ea#dOt-a2pJ%H z&xw@+&$8sSoG~r1W8Db{#nyD%j#hf(giDRJ>w9t$hX0NMND=K`s3d{$dUuJ2U`$6*z5_&jZ zZsz8&+U>TYr`eW6&ian5KjmN};IfDbpORe+n1X|&w&y)3!wTzl ztljR3`|BM1ETvnb+~C%%ljiqZZ??6~sGq?+7C z=phtKUPjIY+~McH(&VIS+OQBFX18&=!3eAfztzJ7zFXEo;k>q;Ei5w{V`Nl^q&T#f zB^!bg+^E|7f-Q%4H%E#;2)fUin2|}hoY6Ib^Cl{;p@XsngO0NkDM87S#W#Rm)2D=3 ze&B|C14fH>La&Mm_F&dS6gK|*KD_2@fo`1j2M){4KJn3%XdfFs@t==;0(12M+ z!bHK^J{xi-fr$)vq!}zbHDB#-8$Ic>XLp^qeLefbQFP89 zTl$h7yPRQgZ1ud!5_U9`0zODB^=#YXI#`DFb6ngKO`x}L?KSTbJ5ta(jpaIu)zG*) zelAOmC9g;u+K|Tscz5tGO;jN8wU{Kn$VgLvtO49wW_4z{kGX(a4NeFG78>+aVv>GP zSx4S^P}*zsfin$8GNa>WMKk2!?cF7Llx5mnK0$`tTGFAu4C#iDF)o3m1CVX;`i8`> z{)g$o(~!<9ELRU%#7!8V)GYpo1WgX=EtA)SWvG6;=H5{|-9|z&DmOgP(k(`kG&3Hh zl;fy)dDDm6SE+=F!DwSGoMvQrMy@@zpZ-M7B{yTEe$VmHZ28@$VU|+d^o|Od%Z)+M zc?|#Z>zqwKBL9^K-OsBFE1H5YgL*xQi4%d>VQ zH%9#110qteBCv!9Y?1@?TB5)vv5H68F9x7yUELK3*o`Mk4Bk^SqYjRx8pdy#rU~+G zTY1GzfjntRpxFCUV82P z4*D>H6rX16@_I5|SVeF@q*g=$mhXOi7kq8T=R6XC+n{g)%@)=G*i?hSJ7a8?2ps1i z5*2`eN)}EL0^d%@TmPWTR9iZ}Z;muJwL36l-==`fwh^mrgGAEAmFXRcQZ9Tm5d6J@ zlKR$y{IR9O<8qw@ia|+;DryIgFsqTdy9uWcU3%{t9Siz zy>C#MIr30(d7^Q{rF-1WLwWL2`p)lsM4`3h$knC_T>icSv$OPc%NU6oLt*@_qx46g z3%B#KAINps#&$Zgd244yR^Hg|3?G(Rn^A zq9B!rgN0mpa9sdcsrqb#|LwJJ(MwJhd=J8$O;_{Z>`Zj*c#XVDRN}yt?;>mod(G~I zS|(tCtwgzL*SQ?p88bdIuO$vDD-Lm=sfp^sL2Vb8L;}C(-U}!NAa>95s%VQKb@uuXCwF$Wx+p>0nIxE zrul62_N#l!h;6yCt_*dXaO&Kb-KFU_O2^i-puak6p*lz&t~e>W25d z(RN*$t>=q+K z!v~nQ82~S_ZaTi`V?IV;TS)DcAoI^35DxX_eaN>1J)y{hoQufnaT?ggbhF!yUGu|& z%Sy9C><39nk-bm~%Mp5u&IZKATJrLY%RC1-Ga6hRDY=BUi4XR*&n`iQk%M-HSACjC0$o z7Ck%&+^m-Et`GN9VQmgZ;N1z`zS&D8c2U7Dw-qFdh8-$>O7ud%U;I%uYj;<=xR5K=tv@9#iz*hEf8Ey(!Cb;Ibv(raTR?@jitl;^2hX^ir00I{Aoqg_x3kd(a<@K zE*H+++YK|&=+D{Ymb6>mk6Lc$_SYGw!kTwzG-7W>{jus)u{kGiR^WZKMGvh~B-cI* zXg#xzJsX@@3xu7AAtxJKdg@zRl!ULMghg zW&2HjpxpE{&0M}kV;e(p+~aw_hFs%MisK6{^u|o|#$do%*f}`dKy^c?tN3By-?Uo6 zJ+4gC{gtWHxzcsOZ|if^=pVj<(}#3xaqbwQ=tbKAYO`#uv#^E8!_B~-l zGD>YQA(>|>xVX6Z0~lUT=kc1mnuXpiehaDO7nAd`XXj=vV$D-D&emw54_79};7po- zC59c1!~#1#igYKG!zTbG8Nh%d4?THt{IM}e(dlXnYSM|=+!wZp7*<2XHn_c# z;7+BA+*r>c6aPqccPc&JC~8l{VUx&_tN;Da+ocOd4ReNf_LvweSYO<^)P0h!-j)Ex zL}ek7yWxgCgE|g?+X@DuNA48`I~uc~JE@`%UoeVk@CZ}nx_U{Vczq}b+P;RleKEyPHAdDnT)}()Hx{c*5Y>v+rY9=6dfFa4g-8ji;4Tb>iNtrXPl7?04IQdPSTH2W^;oZ#+%Ab%2UH_R* zWfuVd*nn$Ui_+yr*Lp=s}Z4P&pH@C+dB&F7^9xauLw^b4rd z4XqXBz_(+#RA2Aq_#^kpCdy0<)s1$!K(ng!z3(|)ID&3Y*qF8Xfc?jc>eO-Hqw`M3Nzesm zzO)rmeuBNMH|kG%`?uPy-k}nz`NG4)Gfz%Vyi-$BAg1*$pkz|>@$qeffhT0}f115yj&dh{V*EcFg=UwG`|SO zud-nN#K=T1_gsM3^R){fiXKdP1vCt9s&fI1z$tx7Ck!H`;Q3J$m(PLbvv-WVkRZ?V zLLoC7oJdjYfF@#%eyc$77qys9Ih66c&9jD%&oqaz&iA&b@W6-O5qM%{wX7anX?R%D z)YQZ&2KZP)8*L+l5ZFh97qc=fgT9TA`D(^_?)4X2C8E5&WY^`lwlYxtT=Zl>lr`XR zrV-Cfj(jTVO{YTKoV%2L-3PrNjt$GbXN69@b^8?HU;<9X_&$$0&>3(ZWvj>B*B5=n z@sS?B{I_&EhkEFZ_4VtPni>zhvivjfyG({8%ww#H zl>}_Q9QK?KpRtnxGx|PVKOVbp8GF9v(rg2ViU7ngg`-Kfk*|8tELEm<3rA)=9P$`H zvP1%vd9z0=i$Ahy#*(iH)#3n|Z>7zv<$8t5LC8I2rX+tNA00G3EvGC92re=$MU0Tr(-J+Yj^oXpRXI61bzX(! zz#t(clznZ+Pgc}|8msddlBixc-foc|?gVTGj032}Zw;gM){aM2mV6XUav((-bpuYYr1)qCLl zPDUVyx21r_vifv81xelC-rk-STZzoW@{tkZ<6{U@Bst&t>f3_5e`}|N|dMFIPCEYNh`;R1U zIUSZ&-pi@Qb+IK0aKqbZ3MHRom)Y21u-SM$SV0ng_oUPa$R3JVDT4yLFhxPI`}uGZ$h}hb!&$lq0jRO(*Y9&uq5vw;8n%|;o8HEN z`?YdoD`R9gD0*bq7a5||_9?>e!a6#9fyV%c^*AT$8dEVc)ygo-K!_i$|I(1U^j=6C zNBoB@T0)|0Uz3=Wg!sm6uQiIP>QQEPE{j%TG&8ru9Ym z{x_Y~4CNa@k4QIK&!)>tXbLF2(Ay;$<8EB z8v&sHKb+JCWper~a|g3IeggD<#!;ce2*$hO?f!Jnu{shx!{wO?ToFEi9AT7$3e8l% zc}>%Myt*bbeJOGQ;=YX1D`&!(;E1H_ntDB1_*Fnm5VyL{*Qm5*v!J{7b#lwDKM5)Bc z_mb#5Jsx4<1JJ^%05Vj7P9*tZ1o#gTO-%v$QkF4BcD9&26aoAT__1FXRG!zln1Os* zJuqcC+ho%dugJ()8X+GC#kd=+eXf4**A8o~rEiBj9vKYGaoGdl(BLkBgW&w$^5SNa zqk)l_Zecj$yYcliCzg$X=f1udn$KZD2!u(Qidh0|BV&=4INXleSPpDq23xFKZ$m8Q`>BVMw)TKoY-P~MkTNKZppap*e!0EVnYVQscBr%;$|3^b0Uqpw7X!JE30LL6 zCfpQAj_cn=1|eXSki&Z&H*88jfxsI0Rdf?+5$${*imeE{o7=X61TAsw446K6_3}=> zy}L+*^Y`*p%BkGaQ3SLxjot@AgrQ&hVuH8O;GFhQ z2pL__)n47QPt47mhkSlVPkk-H`}2Nd3>q-+J>89rNh*ih?0g^d;C0yz!?$;HH)5TL z!Rx^T6qpKkL%}zHBSvLzWQ4lA$0FP%1sW2jyCP4T7W!V_7hy2{d}d%|C)G$-3cPy<+1t&7NAWvxzdrlz-o#xJWbFQjpP9= zFMCGb$b-wS%$_|f=qAePVE~JF-7-9P*1bUbh8b83^ymMDJLBYHeSQ5b-YY<9aTUnHXmraWIBAI$g@cEX7Xbz%L))KE z_yq)-!bH!XoHiznkB`F~hjVZsTOIyr28yNG-@l*4{t@!$-~+(~@1XCA7Wc7^xw*Oh z;DyCE)_lE1(7fXUd{O6dtFXHaNCM!HW`o5HTLId5KGU7dcdPf`cr2LJHh@33B|17f zvFT!s#!K8Y&dCZjb3|9Alm2b z>0gcg(IoY4^){1kBqm{A+5wy?M3(LMp=~Id zqXNCl1aPt$Bji)+X@VqE%f;=f+Sf6TVF;;jrcwYmHP*pW7pD1`QnDjbI)+}npi7>sT2mG1P3 zgy;eCs-hu8%LWyS3=e{>McV(>zKM0D;ZA$|dLmI=K2TTQPxQaWVC zy1P?AAM_b_xp2FN7bxmaCw)KfN8ONzP>*`E$Gr>N3Jg<0krC;sXa1gsQ{TI zWVBxQkLU57Fh73bXd9=#0nm#SIRShMm3U)w7DGwA5RFr(Z|%U6`f^2EA|62JxkdPa z!s5dM@3D$AHI~hYn~C#3_wk_hY{eoc3H$~bkUnCw*NhYoA6n1x^& zYjAdToh-9K5$&wbnRK4_u<6`Pf?xHc*QZ|$%#?9aaP{HpWwJPZZfaR|!aIapW0KYK zB*erGe9<9nB*j|sN*17PNC#~L*14wasgbpudiGEqDIgSk9=gD3140F5oE+W;im?** zkplt{cEcQOjxPdMaX9D>g5863_-OCiz%N0}qnB}FKWws6w-_XEMlu8aaow+vexmnY z4{WI4Zpndqg7{0{9@=>2uZnW4{y+vC?PDF+%Ro*c4aR2xxFU3&D)wSDDt08%39J+- zz0C|f_i(0&V`H+hqS$wg?nGQoglVY#byFO;A9hrDk2O{X3Sb; za!eAB-TQim5@@MbJwrZyViJ89Ma^kN#a(;J~(6*(UGoYKISOvZ^63J|MgWf>7 ze^u=O9vw~G% zsx?L)%2xmJL(x}^omYqxkZ41f_OhSk#eKBD)dz>zf0ccS8>cP`i{319`dt%LYm6f| zuJOnsKl6_Jpe}j%-E5D8h32++R{5IuKI~)B4`a*A6@8du_fHhVhqf6AYa*{k4pJqY zrrg4IvsgWeG_4XQ2@Z?gU&XT5fF`@Xpy1ERCelYSXOQZl#5YSUNFXGr{fY%OziGy| zsXVUieljls?D}ldrz>Y+_2fci23A!QvUxN3<*D^F>9<^u?Scl`zjAsfI>eBDyS_D_ zN=5IzKW;GKhr3zse@7rti9PFg3H_U!>aqq(&Xaw7@Y48B25N@23h*b3?yuBRa-B@l z6gY%Msd3_gVvOOs>-U=z1lr)<6o!Skp?Tar-ED`KTd89{YtaNaWH=~p66@4r92L@i zC^KVc8j(2&2&&fKfsP*W(EJ$_|kogs~DsNWV#_0(MW!e|%NHBV`4<|fa1>c;@ z`iGSB3x zZ64qH$U&BU>~);G(lg45%@xcL8V&Nr9zH~3y;uT&7ddv7pMiWSE2{<%3ikBw9UQJT z6d5YqI7M&uSH(zRE@YmjXaD5l?WD%F#aHp{TWywyt|O;ml}^-0xV+t~%O@8>RONCx z8E(?x+aBHYL=v>oi2XbF3i6b{S$ z{+2xa=-%5jlSJAF*cy9^wT*SvYF_OpU(M*ME=`K-DXTTTX=q<$g40 zp7Zb;E_(iooBG0gv_KmVL{sH`Lqo3bQ&YXwJUu<@93Cj_(K2$Cd;iWW06SL`Zsfz6 z_7LB`9XfS0EQIzD=xi?1E5FyMT9!aiQ5%8k>}4s>HfiWX@MS0QtCGA9j-BsiTMn%c^{h4~=Id3-vn-k(IW5X6(o|LTDM5E!iwLNgM^+ zFw)ok&mxfheF_jtKX!f~NjhwIxtxK$O-tNdY5*>>K+Os=(VVi;h9U?6g($;=QSt75 z5`0{5H4>X|O;1!&%S2b4w6huRzlh-wJW*VQ=!4O`BgpnCl=!{fz$eLehht4P-!;|0 zp7mkB21mo^t7{Ah=*~toL@1XJJf0A*Ov-398I}|OoT-c>$HyIFz)#gN3pMME&ZY0J zOUp|X_44MRY~?5OG`#{1gCBqN@i#97n+-iaZm{(au4(}i5)#Acg&QMjgq}~%L0@=s zd~%|kpvJjKa?KjLKmfkcq@@#qFN@yL`?AGeVZk;F+hDVPZd7l(cZ*OmhzGJ-kD`e>Wm>ZK^@r@;x+!6gkzj^9&AG6rw zmFaZ$h@k}(tx&xcWrqAYE}mqmU>+mzPq9UcCNpP687!`W{2(V)+Ri+rT7pidWqabS zSI?E9SqRu6&YXxn`J93iL+llo9eM43pB)&)!*dr`c@ zdCX^1kNHmD89_hIUSnVBqK4>|H!>=o(5@GNpv)M=2z}9c4&KWKNd+I`qHDd^$3l+l zA3p|TRperWCw~?xxV{t*wS4U&d6>22>{jkZjMflg=tPw0O2}a}_71>tlA2D7ErFUD z-*?pVdoqre-B7t$3=&b?C~xj5eCN(KSZwfnGbp;Zd^FN{hJD91onGW7~22yhlzcfpr`!Syc!;Z~N+xGUWQ?bNRj4J46<74Y0NY69fyR^_y z^EVek$|>K)i?L(w(2l|5YvtndEN?E4eS@CCKfRbSQb03fnC;k}JKkEDKxub`*fm1+ z*8n?Kz%MCs{U7y79-1FgB00fcyBhgyWDb|qXE{sk<(6KrkTB3?>Px-A0n6Iw6FHE7 zfQK4KKOvZ~;4*2dn9orSUC1~&0hWw_Rz6I`WVA-pW+cMQ`e~qoZ(yC&3MQsj?-Vv> zLPMKnU|^8^cxb0c(>=ZuME;sQ>DX2r5+SsJJ)7w zJ5N`w_ebT`ff^E(;eMusFlWvP~rIY#F5?X4CF)PiOYZi*+qC3 zSv`HeA$d!?5~IyLUNHPmU+}mQC~~sVd2_c*bvY?0?lrxsq>RF%O7?3Gdy)c>+Op`y z*wp3pgqD!tRpGbQ6~#1m9Z7~Q`?(#mpHh$X)ss31oH&sv6Z zpKP{17z_&YWGo(lG=g%FI%N$0KE9G_Ehuf;m9z(7(HCvNWE4P-O*EvnU(S7F;kd}nYF7a)X z$qtMkncN3w@o{k^a5t6I)zt&Rj7jt-Dd$UARJZM0PoWPW%!So<@tqp|9HfEy(kiqE z-~M=IuPy;(D4pX*n$A$*elQ@iAfrVs$2kn$a2TsCyK{Mz6vs!yLcCijJ+)&u^cEI` zy5$sj@j`Ml+bqEW*qpnN;rrw(SUN_waKBy>yQjx7b_vc<5+aBZ)pM zgu#X#ykR7R%|X{mqE=z~N74LSpY`%JQgHw#lf!*~|L!T>4ENXyN&r4i#w6jx6?N6P zCou3U>Ze^iX!Qcm{R*eO=1+#^wxRnMXTGqA$&EdVsff1;>=9?$uYg1okL1Ii%I5EO zWTn$JLty91V8pCUmjn~&3vyh3j>08B?tQIGwH*KDc4?1XVmn#4`cM1HUqz)XYH z$-1lvd?Cp)@`umi>dK38@$ICylA8%h9J}!DjPn%mSeOlt=Z(hJ`Aef4yoG6NoT3VU zGj4#?xI?B+@wlVz%pNu{xHA!968aYtPHA(TAYM!<-g}h|W?W0bqLc;fewWs{4l#^{ zjClY1{7joRiEBS9Xfy=7tde-Iuw+o54A5Y^m{GIbdcrmSC^gEBTo`dJrErZzbbLQR zLRnkWTbAjC%=Y=|xt(JdXe*)mHt@=r77JotLTM#$k~w5xGbR5Sgd_WM z)^iB9F{E5lbfq-s!7=uzj+}KV{jt^=(7G|0#I;;}GBorR3fKWE!T_m<&pF;|v-c&K zHfu?Zea`!Dszw*&{v7nZ$XBHap5y)d@z-6A+d47WE&7EUvvto}$6)gm5%g?HBV>0P zu2)qO!>NvhvoBYpq=_v`vS9(JTc*nX-)D9|LNfsO?zqr<5v|FuPt{Ukdc{ydc_YA(RP@2Lw9AqMa zX82dxT>oT^!ok)Y@Rq^a6K4_KuQR6Bf{PQ?K1&t=n|N!xRVP!F92ZLaJg$A=TnjzhpZfg26xR?H zcyB*=@^Stq5m%HkY?KGr$V_i8)47roF6KYsjoA^B{RROXSn+zsORldguVH)p&Klg)|}PQQEiZ^3;gEp|%B zqd#u5SG;PV(2$(Yt`2I`TtD=yzlAU7`Bj{Ez%7K08x2GeCVqS*5?3rdV}6*y0++hXAX=Tqn(@I{=ZC>>3vYp`BS@^ zL$?>2{X-S~&n3!>i=TAr?a$U#Y;Wc#loLZj<335dWnrGKNWCodH?i34n zADHXfccF_Bwb_vE)_Z!hUmZ5TPPc^(uN2DM?H#6aaM1R% zQU<51?gQ%dZEO8P{Zo#&iD5(Lg}KSt*gf{`w12gC?(c`lPj=PjZ;ux0V!k`i`ne4b zy@;(HgZyMDS;>Iq@%{Xx#r=jCpuN8s_|(|=;^1%Fo!j1;gJssOCH12O=<#ygtaa?3 zS1#E9rBis_alhsAaB*pLa-v-qdp!cG8^HxlIMyF#MpT_0AiRv3(D}X|i8bnrZ^SO{ zP2nH~pMQp+XtGtC{n~on2XSIC*}KMP7aOl0v=+pbo)uH#T`x0897Nf!zFtO-e$RH_ zDc0hydcwqf1MR+SJ(Bn@IBE?7b3A+Y8J>!?H*d0<=46kfD;7H0}N9F z;NJHjtK)RXxEto}j{7I4>`2D_pdKt>m>9k45bYdZrDWngjH^!u>4Vxe?4 z83w~vAVMm#Ve(!xi>qqNj~HfY{5%%1`7iexfYNIT2656QBI$Ahrk5g1*{G>yOv0W` z7?$1p;8~+wjjwy}P|2#ymn)aO=|=izLFXSoi7w6eu!ZL3$5-vLz7I!qKiSEVWI0?y zkq3P%wMb{;JaVNmdX@6;<23z!9O+5H)qD*+U?V_eBOq4bI2j~X-;O?h`ts#Vv&Gri zp-9LQ53V^i_RXct$Le|SBCzDCkHpr>%H6_%nZYP;k7d{McILx@6W1;`U4~hb?u5Of z#(@omiy+n^fVN483Ha6Gj}(56-)~s)_P~3u>M0sZg9-Izr|eRWzp_E$Dp^~kWrpal zdv+o6_yYTX{(L4YNw^*H?-3*6CZI@r_doD)WBI>a_XMyHHz&dQHgMV@0IqwIkn75f zt-KMPr|@Hz?{ zgM>eczy>Hp`A(W0JLf<9ZiN|zsEV|dUQOW`u;lZxeM5S z<7V3T;_;x@)OOc9PAE0@c1ZrKf3-F-^Dm!dqd`uJe*gdgFaS%_FOBcr$bdcbXI-KO zHXlNOo)n%P`Qs4OB$ycV4)_};x&5$b#^N8eoF>TauU9Qqy}2cB1$@&-$3QW8$S|29 zwZiM8sh<#_3G%onZMY~IY``2%7+Ea#q9vmp>k_PO_DGbdfNvCVw8Li({(Ke$`ZPVUnJwe-N`9T6>gUM1}LES zd+!O(vQ?EJ{zxX3A5KCxZ{WcVRxCWuCS$e}+Sz}!)Vg^&UC+=V;nfd#zp??S$i?Nw zy_TQW)?7&_X`Gm7>VJ=TPSb0AWC7-b2vVh~qfUt+p&ymI`$C$*r19X_fkR1SF<|M} zvsGBEXl&KEE+KV24WMKbtN6sgp!OXpVEm1U}}{-`yz|k z%xBsx7FWq~xOTX$SpRiZX_yCcl5N0<)9<-Jke`phSBv74-{OE}H)M?-WDGk?B^v^z zlvs`57m`^hSmkJt_G#VTOg`I09sOXD4Zd4A4!qJ}cBtjWCzieJeRSdFv&V}6mKhJ_ zf-`P{`JM#F7)=DBQ5`tEp;Z=Ps`ozl325Q4?5_t&7;Nmf1CSu0rPL(C$+fR+g}rNe zFiN-cF#SGNo#lzW+oTz|j|yld7f&J{-eps7!(KM6dxjx+MVJ@FD1It+P;vp!z_xNH zKL~)W99&1MQR_^L$+s8M znHd^Qe%lp$gHaj)Jw<~#*xL1vguIDS(TRpBoYCvzGT4;JY#gloPNrA66jsT53ePz4 zX0Q?L1rn7uOkLVGPIBu8X+H+f7yju>i%GiHzfT5ju$&dlu%);m%?zwP4XP*aEku3@ z%-t=3T8&-a|2fH23+f9vbPr96q%SRfjTUD*cdYm)wo zWdBdcODQ8!?YlVqIbLuBryG|GCTfB^83M4R&c*^qVc4)r#pMC!?BLq${UN4T>rsd% zC9sE40Z%1j8Sg0toTiw$@$!Q>CtC90gPIGbk9RBD5k%g}f;55uMPafU9QI&)ex21R zF31pJ_?VR?LG&5atX~Z2%;XXD$v7MXH+omLU8u2nB?v}oSRDhfRzmC|25b^x#a$5$ zF7@BUf7S2%YK9@c>vVJrM4wH*>iRWx9ECo7e>Jg{Y&F)h_+x*MChmcZOf!w|NP64X zbXZlC-6<2ufUxbq_rKC?JQb)j!i|`i9UY^nJe~hG*dbp}+!aPB3Ml$LStePd{$rU;4yWrb4??^!Y?wo^Bd=JXS{bRLQ5T z90yQV@a&eU|I%^4{0LWz7rn0X)@~l1eAKk(6noZyl`UTrL-}?dBpX!z0TcBTZl9n1 zmq&0UTq{NO+>8aGbUAin-xX|$>rZ};=N5{hzu_KJzGoFEVU+!0M||Z&_~CvG%rL!X z7JvsvQk~TL+O9CcnZuQjjT(i34h`WaBVaevNA`01Zk*mzNr0=}g7k^+n?44Hr*9%Y z%rbm;HYL&3WqDSpEwBH)TMqPXTXIiq{!8}-f(~2OP#+UT@dIOmsQjyle_O*sz36iZ zJWjCh+!x@Y6Zd}Ns(`v~5nbvaIttpXgzm-#exaszsq#FWR-(4|FY(}6hX|W#ZIObn z_q%n_U0pZH09K&`0)HX-!oiss^e`Y|)teY=hWGb^auR&*DCgP2x`|5I!RZS$kg~mv zxfw&9GtGs-XB9%Vn*TsasW_)x@tupt1UFu^n$j5-3kPMG9-|1U+1a z8!3|7U#XOu4xBM8LBEfII92lL;vL<)%PAJPa91ABUtN$MsMtg>JbqT6`vDxXL-jfk z8LT$q*$0tnn9cLs`LP`_JDm?#BFH<^HLq_)^PAp=x)`ydhzp!L4tkbqJHu>!V_*Gk zo1pQ#aAZ+EH}TG5LyDBdK}@@>3z)l@vZjAk<;ZNdC;?{iM#cKgYfO&)90kaI2ZU|~ z<(DTU2_f$%kJHbzKsKVNf-IZr#4V^34q^tC>)EP5>WFpENfHy{U&5 zP#y`-{DQ~_NwF!vdFA||t!1mj78DaEyVQX1x+Kv4uKy>}%r7+@4^bJQlp-E_FA3c< z+R3~D)yncNzb4x*MtpfWAbMhA%x&kMG@I<8c(QW%6CtSNj<*Fs+Lm{|>xUl8<_!WF zQS6io2At7f^R;8hA?OsVq4^e-&9iURp$$*VKzUe0{8fEXSKjfoojAgBJJ|=7XIBze zpxR9qYPfMaEanlSBfa@?;4A+XE|Zj<_T}LCGtjl|pZ;#4^v&Ky8zpPeD+8yO9>T_( z{aNyd_J1~Av$5VQ8~yI07EXV7+@hO-U$wn&awjPaSN@!7qceUqu@*vmK;DG4QL-dx zJK#+jnVGS>--_vd`kG9dA3J&)Q*yyQ7zQ7O&~I1N;&y)XAG`qtMmZ}f9;54pTRjH_ zrRGaJ0>ijl)X>vRokv_r@8vvn@A50tEQJaz#LKS6$qXGb-g)Nx_OZ{T34A8dJo+V2 z_Xe~Ra=!nW=T99fx)xS*PzUr56ncMQkd-B?{DP>cn7ZO~uff>M$~$villUhm4+QV; zd`ODSA%RSHK%IZdlT_S*b=^Q`O;YI!ZcMkPuAs(P>H7Rw`U9z7vY^`ZN1 zbuhTS$!0N^eu`^lH{>~gPQumVI>Nw-s1)@a1}Pvx!#%|gs@q@bbs=a5BrbUk>K8gd z`{EfYGa#U&*j~0HB&IwO+3^&2PpH9t6wI@!@o$C3(VX?oAkvYki@PJ2ZG5*Lf?FMQ zwcnXIf$szDx>R(^pY$=#(C2&~@8Q;iocTa@apv-Njga3{%}IoWITtmQxL%jI4hs5& zAuYlkLWmQLc;Tl-p8&aXQ0p;ZXlRIe86Ra(pu^DB-W&D&+REFN?PwTMw4D3rDvOFR z1@4P3wXQ#h0v5x5Z@k;brelNhdKYw;`409&52&UMILeXNbH|c$6_BIks3CF6JNrIZ z%T%u8ibaJZ+39p;QvLqsa%ckxy2v_qrGed0gvwyGqR&3u5K935P9LtgW@m%|oj@nzXVqu$vD4A9Q%}Nq<=7HeeD(-D$&odG;qzcC0=HR4i zg}CAGbi0(g>DhvY@^1UuZvz`rhRyj(w+D!Ae*95!TKC&tTlIWL0(so;Mwa|~l3aKR zWu5eK@JdJu|ye*S4=4cnUGae(+S=}_c+Lf=v9UJbHbApzex0e-SLv=OV1 z{h~>%P6DyKe?5A+Py=KCbY-o?3QcQv4!Xe_pje)!oFf(jc-Eq#V3@~9f(Q-KfM(s z`JVi94mp;CP%5z;6GzlGZ(!N)r=qE7G;m?qDihUx&;8I1xsfFBDEqOQ z#m?-N217!me8FSe(Wo6=)v)R#7%a#FfL|#9zufYFpare|Mw6g-C*1Gt;!~^lZoQS% zy9{Qb)UcWOTn+PEeU1`Wzae?AD%HJpZg{nd)uc{jB4V++>=s9i#Wuq;u0Vro9DtSgx2b)b`YQKK@LITV6=|a#QVcz-p#O~@$mAm?S zQ>tM6erBgS9g8LigyQ(&qIY%2KLW_ zYsnuye-4v-q4FvghMjH=*z=XK2>2Di>UaJ8*Tg%pg;5jYG8`y_TPVRH&2iUP>Zjrq zbKlW;CkFMu!97GY@AZ4)E{cnRv*g?o3`9y&`h7)q`sYEEw7qbrr}o(7kO2EXJ3SWS zVuyNxZ``5%op|%X?V2F2R>E3{PxQp<{=pfUnfs0BIXw0|dlhh16Jh~1w^??c%QLL4 zuW_oJKJxnMJ+jUc6!^hcj;egalvOfSLaXu@XZgk)i%rjO zLZ9i**v<6%wG5*Fb~TQ>>jqj&NT)PKlKBn}cC zsQX>wb~&&?wX51+bmip0xnDlgu}^NWBYB0OE0|2B__<@;q7=aIz=HPpb5KJdb9(^t zc!C=aFZE$_-@@e~`Kn1$5;tQ^drVVCnxo%1AGl9KLus10Ijg~ z=CX`jyhSDH(%TYc6h#eXC(34ucjYE51&>l1XH9Z*=Njy{-rU<8KIy*Il}Y1{*b=n3;t=cGxg{mPNkkMK9q%_e`cc$A4ojN(>}0NzvI(w6Cz)!IlA^Z;wec&?9Y-lz zYd;o1D8S_p=CmaHqtDw1a?-w*z3@^tKTA_BD_}=>n1x*z>!TBJ3QV5Zp2Z2IgpNm; zi0I9DbtzI$gvY5}2JTTdDrI{lRBY-)tQaOjpf<$~(Yf`h{bX1s_BhO;ixeBQnRX`k z2Gi4S)=B#RaT?39w9ne{ezE7Ve0`v{yP1-4u2L@VaLWahn@w@k%0reBG8sCHpy9W^^=5N@0r>xuP;O9{ z(?K-xLIO%-ckdvgC~?3etGffZ3&!6%{BfEu-OyKhu-RLiaik=5+#Q$P>ud6ySvF2$ zh&b<0*A;GAUwoCMYWC}w0glpmly*LMjk+X0+aHUg4j}YY-Hr%fO;EG?_G+9((Q%R+ zAqUbDE1Ij$!UIEJhYkAsLi4kG8S{QcW^9#O61lm^`6b`T?Bc*0>}SdhS{X&VJ9CeCuaA3e|EOJ{zVYU+m1%h=?EIFx@}vcTncNJhDn>_ z7FhT?fBK!-kXvp1AukoMe@@GHNv3ZRm-s?A#P{Opv;zsGjd@Ki6}Qw8Jtbn}@685z zFYOssO%*)$_gFw+KePO_HfLWa6Qo{vP{|!Y_}qDrlPh5>^lqkU0d~j>#wz7^b-@z^ zkZ0N^$Fl!mqY|9Tr#>H8ow@kc|D;(-(U|)!TbQfdJ~BEEOqr|}R_w-xccqITM;LjB zUSelk6_~5_-lAXj#dk*BDroNOK?p4E-zKcuAGE&-!#x=;3Ff6$VA80;ah1lz<*Mjh zT9Q7jnkc`1xtUpuYHwybyGBZJ5VF}mz}ecIDa8>47rdgaSr{apN%DPvZqiz7G49Dq zvrO2CSlhv$c+tzZrB4lRoZYoz9GXhrxyU)XC9OH6l}W$a%RPL>wh{D*(KzNDXQc)~ zUK0Vj#fN)7?4|-Y=F%(AOHD{&^+^ImI8z7T$Zd64_$&R zUp_^3k0OjF3O+D~U%S_N-*(_&Ky@KdauqR|Y(G)Xb+E7`a(y3+jH`eA>>6mDudwKq zIgUTTcn}=lWQqlmW$8_A8d3vTNW%nu;<^XPA!tF#q~aSx1)Wy3Pc3Hy)A|@6F+OpV zgrs*dZriR|`S~Hg!FN6wL-qF`uQNWXiNgHRDWta?T%>{=Y(>C0(5pMNYHI8qpMyR1lZx9WcFvngp%& z#_Qe6z%|fD0=hiQF7NP44Bmb_^PWvC^)|-P$9Un>Sl}*wo78=(<`iSZIi49#??3=9 z?6i+|ef|<9X8ZlJ@!jm<5H=rF9&n{f&yilxIm6N{oheqk zRb93K@sHYB21kI2J2`kEXdnJ~B8e^JF%rZJ?zgV!FR3u&ZdnY@CrZ(RIQLPnj}UFk zqZ6Zdn#ai4J+#H?iQHS)sP(W1o$ft-RnMp{y0TS1s&Rmjwk&!@f{jcw3c7oe!DX8Dh#AG;)lU#TPEz}C5BHYP42Kg6a3Wq+$Cc;Y51`2X<8l(G*QQm}~fct_p4O51>2&hM6e5`EjbmH5{ zZk*f;YKT#ul@xS%Wk2EXp}cNDT)+yAoWhjY(N#e!NvvY)=kdr>u?A_v!A3a$N)M|2 z%~;s$v$O3K6Tpp%PES=m25Fyy_mV=U%3Q%QNob9)3BxT9l~k`*_z$!9X}#|QP7-juali?^Nz*vC&iueS3((^>+Lc}|Z1D9U z32UqG?^nQT!y)v5e8TB0B@kgcQ!2c(1Wm*}T=4@`(4a78ihNL0bvuY{ATfDMHggzl z{-_3w6&A##Bui^05m4Z|C0H76I$vqi|jydrfa3Uf<)`J8UNqe*HIo43_ws)pPH^2tJ6~@%a|rgFIf&<)ibmtKZY7 z449SeJ%@Ki2AG@Nu9DDn&dLB|kpuKftW@Bc)5p#zYB}exr>^ZgA{W~=F25KJ9;n=3 zx+{gg=h@4M4X-GDWOr7$_co7gcVH zCf!A~Hcf{h(R>)%_t9r>Es7rwcLsw^ms>$an(jOsq$9#c7_70RUV3m2cL}t~G2)iT zLyPx&J}v7;VS6$(sE@)F^Z>q;BtI4K@DVt4HeHf`?)+6|LQofwca;jgq4zr9MXjPn zIham2F-eD5z1|vSryOPP1rmc;e6B-IwLyP98f!|gca^y*?`dVjrbJwPAWlDzuCouv zyOqul<_xBLjsZ_Bhf;O(f4LU$Vv%&KAu*3wp^NVQ3%fC=A7n#k4(EkMfgyj`FRS{n zXc(NoKHKQKzuRK<(P#(pT2aZ}Wu#&t$ z-k)0JLV))?dD0M~JM6CU-Y(z1MyVU{^9L{q=)R{;bxieeXS9rfl5~?(^KuMAe)# z*Kqoqc+kN+f2atziLcZtR8ag)(XGVC{ z`a*YaD~2X7m9{~8YCDH%f5~~0_dU@+vk{})WhZ1V=p6eZ0D-c;_TE+z_vdF# zsCVh*%YaXpQyK5R+x`hWNKu;h93=XNdyR#=T-j~XcrY}%9ufEH>Atije(GI{OkGmw->;?LM5t;CUjkBi*0OrWIg#=ZXAH9G`) zJGUbkZ?vwA%bV#*K^u|cu+-PM<1(YL%EBgNRw$7fD~*xP#PVSk_ORazkKciBY3I*) z`~cfsr8s}lrnDHa7LV^jto1U>>;jA+f9`DNuCSW&xCfPJB}@W!f8!p8Pq;0rlI!|$ zdf%@;*DuHpQ?An`iQD@Oszu}#99f-*3yocRg?KGK8=&s&mTg(S?LX3@+F`IG&3QFN z+~=w{5;A_0&;ls4;($h8frFU+Z#m3t&{UHajXdSvb052Ap$Q&dCNNM2atB2m6+zHy zF>333IvOVcb6Qqa_V_W0Fi%}5(jR@^GOo`}X}-e~P&R&Ilk%MjB+dg+gS!J`Ket1V zYvqC#;DAMDY8BflCpqL^6p2Mrxc_f7Yi&jB_w}LH-r1T3*XKova0|jN#)E4a4J)uh zdt2-IR6WfKqJ*lIkU8?g&}E!Jq3;!x!?gga?(aN$G9Ek&BZcR)VUWRYRh4_`lmJ#* zOm;If?>OpEuElq^p8kUPgyk5%_Q|RZ2?vM%wa-^3eHU%#lubaJzJaPPw3=q!@VCwV zhHx(@y?P3=?!md0)%wid5PVq&IBo|V4L^@XkX#f97vjV8SX(>c>+68acqbI7a8k+X z0o;f$Y#~vI4O&PuuUS;Kl8e8CYrgu#1lkBoJ&z8Te)_PcX!wzzgd(U`f`y*l;rwYP zEoV$w0COHLKkfZJ!dlwJ;i1&^FXx*%r&43DMl zG*8B?iJDuJ&d=Dv86jWjT9`uegowcsZm4%MNoTA2E*xJ_`-mt@=#&j@5U8T*tFmIo zn7Q#`7AJJ_j*J0BUlEOkDO{}@t{;+r?IT~04mtf^(S8)1Bt1#)RsjaH!oi5eI`NF1$D8T)i+b7ge*(t2V0Mupl@(X&JFdiR&k2vajR_m+!ECj zJ?TY@S<4!asFa?9c0rjnay@8@&!>>A9uI| zaM!(Nba7HM)1|3bj$Br>A_1Sx6k3B1UPuAhIda|bJxvf_Jfy_|VFGd8iR(RVH!lF> zc@vL3Ps^^^Cd+{w!{$k)^=;uGqPl}#q3)V?D@#-Mi${c%Rso3$&hM>6M4zy7B`QS) zDSSZ?H3#&z9tRRCDv9gtO+RWt#P{SoDZx&_C8~`J{M}0)d(ZD|E!rl~)4PWtKlTJl z{xLepEXI2@!uG1lV0On*5V3L^h#h%~lkQ7j&rDtBBT@!~Qh4kw!@S4`s}D>d7L8G>;gEB%;F41%wR62T<=d ze5k0GB!x3HYmukXJ@M zWYglXcMq)nb+hFrq$&USHj1L0#$ zlWCvKM}t%#MD{v}aX=)RnYih_f}(pnBAy+Y?j8GSm8F78-+d?3g)}p@b6!^weIIpm z<#{nzt;Q>N)0v%k!?RuEyY#WO&BeQA*JNZ5E<%r=TvMuj4}*8iB8ah?h3vq)yY{iR zmkbDD2wCbl7-fhXy%m>K@?>P6n&+L;U&2S@m&Qa7C1E+x$|`4hZJHotCNBGCMD@UZ z)|*v%DJIsnWImY zQYC@iex-j#85sdR3BV+JV)Ne$&ZV@br)*Ii+ZWhA@boUWywbr@KX`lNC4H^I!C!O% zN)#TQk6?omF5K?F3mT04d~n<@U8!~N^UTXa&T<&iNIhJxHp$d}(DpNSeeGvqoEE0v zQTNsp`OUxf7y_$%YzcO9+UDxjb~1djO?oD7-j)D_Mfh31pRYHB{HeAQSVJw&ZfZ3B z&<6d-DdtkFe?5u6+oXfxtJd9a3_eT54){^&;?vmLlA0=yk@^JWlaMorDTvW!GBahG zQ8=L$YyFRZtHAy~V44JfFbv2H&LX5T1vRYixJZ(}e9}w?TA|;`%w$S5In__B%;5`l zYelJ9p;P!#HkfNLf%JoSRmULrUsi=dpAKG+5~}Lp&cC_}r~7le)J0gA=wCBp53T4( za|SYoLx)2exgbtbX!Owzbazj8_{p&_9*(HNm#Tq|;00K)Z++_l;zYTORrAdHC_DySDt(B(H+5!ZGzefsOzB)&ND!wd*jH#I4*k&LnF~6Sc4H7 zP_oC71*%60BePlp3Azn%ED$_-K>aHjXa;)#i5Y?qQ-v~y6U2e`$+E{6{tCm9J%bZ` z0)@1wC-bh22i&hS_HjZ49d%BxhTK4tahdQFaedkDILB0a8@8DMT8K%asY6SjYd#`7 zKR@sdm_SseDzR(%EGdyo`JXWa&d#? z$_mN9L#1bKGpWor%PdDeWXvEANX^M{;%9^III4A9{CA)D`elut5pvZLoDgJ_06u+G zVN3LY&0@WToM*r~{;K<~{9=Sbe+ zZt8+UHMEb2RM09^KjF>;-F<3Bu0Bm2Nk;ESF=AwCc_8}`H!ZxrA znD7P1C&*qhcUFj#gPKmzihL*eyi}4WR$q)2({0#bcw9*T&zteO-SaWN1w!+}X`$Es zK?%fNpJCXVDmmK7ckdZ5VGVW3)W||d#CjCS=T&JKkuu&B{EyEnZk!m{MA8v4SrXA- zS6oeemj8tF!6|wb7A5J=FO57D!(Kg3wO}BlPz>s_lpyEfBoCou5I=BE))`3qZm@kn zV1_pOtD(ZCe9Bl4PtA-lPva$}Fpq?c1K)f1cH)AAgEdzbAs7)I^5~DCJ7kfVW>$L= z=tg6@!x+IFx^`k>B5xqya=*T-tBZY%*qxAo43375|ClyLPa}SJ?XrPNGndqE2d#)t z8cY#?e|vPDo;WNTM?|n>52ae3EQRsWv2X?b=ZY8=@Q^p^5>OL#0qbJrN3NyV zWUlNoLq-zD5FaZy*Q5X3%wz;^o@gOqg>2~H6RfM!|Jw_GSE&vJd^jOxLuyD}Xg@8c zCqO>g2)a~Q&h-CV#{XpV0%}D)sp4sofnqi6QmqPT)2_L2h5zS8@JI=;j~GH`8wlC= z$_Nr_4sYiDA0hBt^<{wXb|Q-jk_l_!NRk_BC;n3P@MS^rfVB348j2>c!V~bzifO|C z&=HPCk`Sw7C75H-gwT6}mTR2ZH2=~W9?1`0epcn3WgOa#P|U2}7h-!RWTf;D>tICR znc%y+24lYld5mFM(QHI5?#|il|9wgt*k|+*EaeLviIjG|{GZ$D9LOUvGikgJIU&6~ zTYjX@4ebl`$8vGBgDkMbcg*3RVYJ;JK3~F_u znfZ!KJ@cNQX#UGQI5&9nZrF=VGu88CU4eQvd$B=~?t;a`|FF;23dqSi4~20Z8TdLF z*H?3b2anK||K-qEJcmv*ts3&RDMChT z%=v$P!2je5mgA~Ky#I-v2q;BPjD8hA{D~+D!CSpFs{fD-t_pa{>QQ8Mw=oxH=*dRb z;@6eU{V&5}^wHBbG1tdFchgt%*II}X!ZyzR_Yu6{Wo+ci#yNFJr12?5h5v5o4-sHk z(OHWt=M!n+lG+@R`W-WQut)PEl4SqVQ2Z=VvWaoqY$p#YqT=BJ>uf`kxy?P~%OOq6 z|32^(5D^xk1o9I8jv%Oy$e5UqNu4YBUqYwT!*@lsv!)y!wCstVIn&x0A>}>Z8UDL! z!8pL)Pm?Ea-=1P8b1+FLBW+|?piuiq27*peg7?T&2$klQmHC(z4JKwd4jBslDrU{8 zw?R^6F+KkGm0ti)L~mJ04CF|g5VN|`**r!5uGe|=@5dk>j2-|vd9nSnq zX=$*r9)vzin5Pl4(j)I20ToFamy_^+^jipAlS$NhbCEcjooYPx!G_X*cJ&nO%JfuH zk2kwH(NN#=(`+_QE>^~~Urzn=5Tq2~ND)1$)5KMD9TcHw5%TXw$XL zmLR+{YLq4a&mz;A!9Klqo(|k2A$-DOVD=vYp8&eqz@e|yMEs19=z{iZ;q*jiFRu%S zJh!()_I`{B{r6Q@!Cut(ugH9&S9cOa2Cb+DjohUae)9j9Am4xoQl7M8pHV^Q5m?2P zll&_&Yj}y7`Ezyq2hp{YbR6f1|Fwij3Lun1sp~Bm@ioN=e=S|%74&q?*c*@9L?4 zxb6W|q4}KFE}5+|8W!}7=z$YBEZz`5tGO*>Q_^)k!vm}qDFQa*$tEQnj=W3I=j4@< zU0j`|@msGhi%*CYMikAY*IL=Pp78B=)^hH}V%7l-p}eJelVE6$-K(Z&b8cD7z;wND znGLBfz9vrz_$*ir=(~DTi;{lj9NmC0jD#p4^Z^GQ60QW}{ote7$Qxj5sIr;+K>FDQ zVI5#AVMJ|0$TjPZ@*XFm0LsP-@KoXYRP?et-dQLoV|`p5^Re@m@uh;0ygs5$zx%f7 zNGlGds>8!2`UE%|`d?@A=t;v!h$&d~*&iZZiF?P?ga=Ki1+Ue+aWInTdJbD=O9Prj zmVz*9GNnq61->CP$FMxdiM;`k@FZ5rkvT6GINg=4gV@pPuotRhR} zi=0NIoldpiFrbi(P=PDLylk>%l|{d$jY(&Tev|zX#X*>!cp}m!bNVWtnjwIi5}cel zA;@mxNiMoS7W}oT$Zmh8#WPQl`{($hs*)O?z7}dtaZ4O``{^@=-FLRKH~npZ>FZ33 z$$NnbFCZcnm6q(%O#Osc8bSPN78I=6L)Y>S$v{Yu6io1*q2;FTn+`iBbtm7NsKVO` z{Ti=cNLz?z*X&CzB+l9MBiRrN#cbeax(CxKRzUDS(K&Q(WFKHZh^+pS53;yRDr~A0 zRDNr}M}4Czq?45{Y<5|J1ooNa!8|Z~K@Y$|#=;jVNf;zf!AeN5SlV1c^g~`SHM!X- zNPJ6rC@};*dS?&r)q3aca=OBKY7#NBapIKQ_*)6^x5C3DVK5>oa=!ArrJH&r#DDIS zH44yqmyRVa2z2&t60}nADZl~6KCb?9 z{K809QC9P(wc0$5_uf7%7;s=+5SSMX2Sm#2cJYJ$zQHKB-$DZ{ZBa6?G&C;O(I!1vrwjJB#R%sl5cI987w3;E1=^OYI#Mk7}yq z$1YsZ<)eDXHF=#5eFB<&7?O!UjG_RG)#}xG@S-kBV9HJ>X4Ys<_V)3)bY~{`YNxlZ z4Fn%vedCS)JuBFPnBT}#pig{!k(-LJFSX1M9?D+s;5NuJPPb%ywc4rbK1u?G8Awek zpTi%yHNc<+ektd|GdT>&OiDz>ZcJ&L_QLi(yGZ7qT4~zn%1ZzF^EaEcd>2awlC8ja z$91-wEh=7#Bv9{6C5Rz9RW4`o0a~DI!=P3~5|BhX!BmsNb!V(ONRKYY_p*>KT0n52rvK9QD>=_4?Gs(yF#itwt*g{7sXOU3%xBlZUS zJUopqAZAkPDfZp;J1F^M#)i65Vzw9{^ow69Y;3k7&=8}`RPR_q(vEu(26f_{=C_RC zWax;F#_QAwxvg(3I>?cWNpI|C3390--hF7MVBRkc$DG~gM5-aga?h|LW6(L!$HC#x z)rf&iKhqMkgABS!i5vz?OG{Gu^{*0?Q~8lSu<*0E2=kORyjVIApE}@nEjD_G7d0rT zm}pH*Fx6SsgfeTnZ+CN&?;;8jbrV)NXCmr(D#z*n-H9TT?Ny%%bvrKNXDp1(Bl^Y@Q% zmU^A(%qk$;r}RtOa==Ynw|%$uWGimp%M$0S?^W`Yw+MGd1wDg*{IMT{*9ASNAbDzF zzX+RBk)|XhhR?oVqn{|Y4QQwjZ+U?;%iB^B`eicA-9pU4uv_-yzO=u;zmJ2|>l8Yr zfR4}X$e!YDDfV#lk2`Oc)Uo^NF$g-Pg*-GxM8!3Rk&Z|48n0Z}x|u z9WoFAluRW9@~aYd$_39valyzk#h{L$pttXt|ELe0(k)M_NWNUps;S`T=hx9Ax%kmL z>E(Py_3L4^cR<>B(hXZb4l$>eM#~R5NM@9e#`Sh?k9c;v2*CG<1(1+(}++U2<__gJ&!0e`u-d>uW9W)^OP?Ok zpfd35tMfO{#A+18tOaw>5Mk10FN>KBY_b`1KLKO#!mW#LedAFIyQnE-;*=m&9(b_l zA17kDmoq>cW#$NfDE)rWd5eN4IadFYgrww#08eA#m+@8=j+bkw{x=Io`fX#z!m<9L zp`qN7`dGWxXns+@yr*OH3nVXoY%iNc&*RB>7kJS<9y8fsPY?kY(GzjO)P#g4DS|kK z?cZdL;$${2FwJXc`(8bt#<^pS%yt)irh_nC;H8C|SIj-{b#Xwb$pg>CDsU<18HbA%C7F z$Mc*e2lE86F)1zzqeiB6=uu5!s!dCs6=I1B(zhox+%`G==?t|gQKl*Y_P)&EkQVp(N zPRKYp;2_B*WnGN_&BtFa%>&i)#8&hoKF#;%?RR~&YfpBwHV1}g=891B#n<1)z}UJ) z(R7S!CGGu>-3}JN-h5RT{qRJn6_wV=gg1$er-8Qr_38P9BDarDp_wcOj}$^^t~=^$ zhdKWfk8`^wu}eSJadMUe5p%P%SKIg8mKL5my|mWYH4*xyq}>zvetz#}(tu%Z-|cf& zL6h`%}!CkA9oWBRbs2 z>-)88+MQNEou$U?O6+TlbG1#jOZMC1*u$Gpr+3*5IzP^7^5L$i1Mz%PL6rQq({&enB^6}Yz19Y|8dkvST`fio``hM)# z{$go#qfFLR*>j6Sq|D|Va!vQ}%%BrYM+T~0mD`rUb zEaX{Yn*<8TBlL@Uk>-9`{=6D{Qe|bOd6mdA$$IsDpxd*|te=4vN>ncD1-HhDM zIiQy1I67FR1;T0FW~y=JQwgdZzIZQ6PZ_uar&MF54r4JQW(R@ydHW#PoRhfX9gyz} zPgP(QL(JK;o^a=2TB)Q!#%4O2p{lIh^4wR}`eZaYeWHK#^mry@I5vI7OmRfd9LUbe zWSHdz8j-!BX_KhtGI1lm_dDL6J;^1$6o9;6Gw94Ux3Tw`obj|4I~P~r+mi$P zKxlAR4yw)aawLe%h6(OkF6*ld4)!^fL}&=1eqeVWRpg_30u~>VGiFfH(jvZNe_jCp zGJ;umD%lfN3VR34hv@V5UewbQCxb{rKX35lAEJ(7I>0TE)zmmd)daJUBe_IrF-<&P zybrEWQM{_}o8i5i%&3sdKj|StyF5E9qthdK!9Ty+c($TC$*2;X2+vLSG1>@oqin-l zz4M#>(SB38nVBJMb(hkMxn^|D9y&<8Y^NTZB!-37IaM`;)K5mxThGqqvy^@O(;i{LbWZ*AFqjA@+ z3E`p4FiWsG#in60Z0dZ4 zYiwW7h#We3fjOwpH~{jCLwey|Fu|g`Illc$PX<4}Suu0aM9OE&uHh zL=4k~=PSZ;(=i?j?eCsX38T|q;GV~Od;pfwQoGr_0Yy z1P*g%ZqlhFCwjh8`y$r0fgD%kDD}D%AdZ$`r!bD6?Cx)8XL2RA<%+u<&WcwL^z^Lz z0=>RfaqD@^8lEgeA)s|uDJ)wltSmkdI}WmgKX3KEDf<5VU<<$1mriJ#vQcl2i{1-A z#!AO7rT%1C)lJG}m2c|lu}kvDUcoAS@W7v{krlw7pW{izTP29VjxJ2%imtC-tp@~e zZJ5{^^%?t6@OiDx>>VP{wCL7nO9mGU(CH_sU8xI>m^pvo{8eRuaf-jKl6c&6%(?p{ z7rSp2`RCeJ_HWJ)ploNqr;GluG|w|mk;*gH z>(n2He!BX-s5@DapNnW-v-2`mPBfRf;Yjod9}e(?2wc>0+))a@MgUMvvNILRt#o2O zQxEqXENqq$dnW9b_>Ed8N_+va&dZoJ#)ORf+1(D9%s9aTrxJydt8#6aY`RGe%YkJN^&fo{!MfQ-1WoYZS+f$e@dS`qSf-|?R`_K+ zW|rq&9rN84fx98}>sei$;sqpKyUC2rb3d|H)Y)DpJS?Rw0uCh^A!7WBL6~G!i|G%_ zioR=fX8krw8|0C6&ZrG z$R(05{c>Tq$V|BUI{kzB`&=oVcI0H&^2YTd?C1nw0f&hcWj=5lJ>M6+fB}I!0ek&x z@J+7$khoZlel}GWLDlEdh{p4h_j8D)Hv~xB^~0%X@mMPcgthzYlTP^P0eHHiHlR|s z;-@0*+|AZ|4tu>6pyoC_nG}!fMF9j#O-Cc>P=v|S1jZvB%#GXRUl$x=vXYfUp{A2A z7x%gEB6$=h?kEhp!XKB%=5lwLfw(OMi2wOX^)rVrp<|^l)cAopJGd-=P4C2eN3)tZ zeysQPLXMcQ{b^K%iaV?l#?f(`%(X2EFNz6(wvKb_&POY}!5~HhGxT3?;9IBtv97mBC()Oi1h?t$eWLU&^ulg9wQ`dzC=V)v0n+zF4m z?f@iqTf2$+h$3*hN02%6R*%X$q`|pn+z_`?9fRkEEk3VxCMuQd=&~5ST2}~N1~rbc zS=`#cf|;U)>EcPxfJeY$UWXpB9gWOH_k5Koh<%z_Q~v73PIH3}-jPM-0AggXt-Bv@ zK##zc5Hj_%yPnMZTWYvqs7W6uulA|}&#ofCQpLtay1T4kiz zQ;Y@a9=%&~LwJV80KBM%>S{*#(fg+@YNz1H+wfGaP{N#t))&k)1_jcc4}LbY*bDB| zPU?g0z?Z?sd?VnSN11|DWr9m2Ff&5f972VIiY$Jm%#TuFFYeFf%Le!$pcDi%mU!p@ zzpOg~eV0laEj45i5rhO+`7j}{c8ur1Xf{A+;B=3{CitHH514f%BVr>e%*ph% zsb%F9%h$Y#2rz67R5LgIouVr6D<01eRp4$)Fp%Z^O=Z>*4+TRx=<4zGdzA*%Og}6g z-=B3CzT$G(y&Tc<^M2VAx0f5aL62?~4a#z-E`;(`)UA4g#igo)O{qLN?Q)d&Ol>6i zZ3HdG2h@;`pHS^wVnRKQWq3~_sMYI82G%(aV0T^G;}mJo2<6?=CBnNdlI%S%Ea)SN zmCJyCY=1E0)f)>#pCyWr#k1@@9zLeh0=O6W3)pD^`dc;bDtVsoCx3=v!bz9SCPJz{ ziC$RXoQbGC;}1Su-p5{>D*L8@539Jz>NmTnW(=vTEU$s-Ao9l((X^`$xq2@_hjUZX$lc0 z@!Rn22*0^Vb3pmJCG_h1KBsl(p?WE_Sm7b&WNQed^j**!v*X9uh|-x%E)Q5xGreNm z=`h`VFPaIyDC5@n=@4!z09*{bsrKzy6w^VsNX_+q|2c_%UrnqLY1>D8Avx)Ayj$@< z-2PB#k#Ij)sIpO2(WOn)z>g{~xZh(3&0-DT>tK;(U{yHL)4IM-=)Xc7708T<$(~)$oTc1 zJk*CE0!lC#BwY2TAM@*<@nX5AfkJlHKW|~F%{yQ%H$>m*fcGQCz|w78NB$<;&r?iq zZhEQ!GKR51VISTZ-~-vnR6y_tKjbNn)og@#Enz43<%#pVfGJg-cZl)Vkb(!vD?gS! zI=J;1g5(9TDmqwz8A_9;@^R}W0L1tcu&13~JAboXU+7!>C&zRm(`8Q0l>{2+!tlvR zF5a0nVPs!gpyh!5?+-4$cywG4?_G8CYV9Akj-rr{CHDsmNNZS|QV~w@~!tv8Bop`&r1N3sLDBbpW%?^}I3ae%DyRM5~WwJzte@-_1 z(uQ(&6AAS;bdZ3oa8E$EcuMjUN2La-*9cr6+g!a>e~v_2<25FiYk}E^1bVV5lX(HK z`aWRwLUS1V5vwE3VP!;xf{zgrxPHr+V90a7JAD#_CrEG=0ZHuXV>hQ1J^a(nsOA(u zmJhe3(BT5XJKgT7?|U|_=*Kc$5 z_Hm~{U)0RU69@ymjQ#FO;P@O01SnpatciL!JvJ8Fu&7P9_+6T`mivqeUeim#zV03$ z(i$RNQ4e=DO4Y}uyts0!4wBkO;=6xh8k%aaw&K>_>;lSWf!)0wk*7K&Ly$B1HGXR= z+G*ZZ?${QMb)UPI1M1k(eej+2CDp@qD+6zc`_U=8qh&`L1a%C~c}sBGbu>=e(_lNuSUWMC07A zZiqvAO^WI6IU{(F)ft~1y?FE%J$=|~+jc2OfmNR+K|b6BLVe@IO30xj^aJQg&&4B} z2xg;Grc+D@yOWgpb3nj=JZN-vL-=D4Q01)9id)l3ydJUuJrvcjI4Uea;`2JY0;}97 zF*Gl*8tLM|jhKtf=cx1JoGH}uG-hGh)ylAT6 zs|s!yIx^x_4{|sSd_g5k#a8>**=XaZMP&i` zYC(3ohWlLpQRPS~oWNLO;CK3r6=S~mX)Mhd0aZV>$my~2!C4nd;@QCJ*ApcOV6IB?9uB)TwHU#n_^| zQRb#Fo6*-*G|;t8&9XWXL<;xe%YnUIOBran>UJcL1Qb8`({iK>9@kWqE=cr>7ze#T zy(}2^apra-plgCZ8C9VANQ9w$I1JRO49Mpv>XRpqd@cAZbl=%5**q(}SC;rY-S9#s zs4uYo{i{dhan{W>w~Ow2TYXGcd&wy<=qCJ28^>+YAw2g~fP9GROFB~ikA|9=eNx$v(hAxwb9UUH*IK{*0=nY~ z2$@J4Id^!#!pDtjZ@--`k@vD3d?66Z|EIJSRWxWAeC~cyhaKW_oo+ z+b5gO=aI;Rv}JlccV#P_??1l6>DO%;dH);#o-cvT?+>%qA%;Ako2AFG; z@;05qn{y4z>0S)+)xCJN_NnB;iORfOhi(#&7v>^s1?>!TNqPf&7&8RjQwLJo&F(Sw zE@;gXt3u)W89qY$T#ib$U`)ABlFc9~)cZ&9vI_|MHo}uG2+SW?fMeJarlt1pa(3%KN0VqLQU-jPYkfGy2|qq=2J6@-&T5ad?3yG^5_B)($W^0% zrvxwK()pCfd=8?rt1cjbJW!5U^8vpQL-an_iNHn_)}T_LvVa!1|W45&j^ z)iq>fBmrffHp`NQkFCb>4w_#oC*UalcDsgfg5YP-(b2&4!okzuzjOP)@<7Q89Jn(G zXSJ+7;vs(tE%%^(4@+R+uz06 z|8<-|&8e2AhdFQ63`Ca;X4&u3BgaRJnd&3Lroi0DQqMlSIhM!e99kd6Ga7_aZSqeM zDR3H06H;;^ETv9B!Qqb^tK`^(GUdPr}zE#1nb7D4jU_!$iA`x=As&0*D(2cz+CMA>kPF@m4|__dQgpR7go+~?T& z_x6!H=AcN=35J~K{ixw)tAIPUub%OVKDJ79kMryk(IXcf)Vd2Y;6VU{Dm6WYH~R~) zn{B4_o7>%_Qis#(*MU&4{GtALkVyi`OFYc0D*wM2_usYxSpX=#yQm)RWD^IQS;W7w zlNXGNbnUw%|NSwTS%BWb*5zJdR~Zpu*LZWeh5>@Gq=>28@3;C_aVM|0z4Lrd@5i3{ z-94OJ(%!TC2{v+AB^qbuvI%&Mo+r1PJ;!9X9-pYUP;)i$-j~1HxbLD7A9wfctc^X& zzTOF7OJ##suF@OV&v4dX`ty{09p$c1-T9`%`3dFUor}%a6G7m2=3gIhn>vf!kcxxopvMtH|13lKRI3+oE#3Z`S-^LcbM%#-aI?lSZPR6HWQ3{3?@*6~U`I z{>Q@+O^>^2#>A?syEbO#?^#n?R?CWIFh}7~70=s`SMNzyU*GSLC*t5BdDNb*W@4uS zYdo8Vv3px?3N)N#`MSt`xZ}`B3Uu}4?Hzh&&n#ZKlq|y%kCN~)4npXKV9#JAF;bM2 zlaqgE=PQ=crjaVcbaUygOWLYz9{3S>Cj4Oy`J2veAEyvJo$1o89Pmcncz68c+n;e= zA1_T9AH@D13Cdt-X2Nor?R0Y*@tyzJvxE%fpP1e_RXeQoXf{? z7IJ;XPe*4dvxd*r&TmcACH4jZ41_I#Zu6uqT%)DNjJpiTL;nHx&66;yyJ^Oc1nz!- zn-w7D9`&6GOL556JF4BXtLpO}_xJlh^qh?oFXex(QZ^jH_}XKQPRPSmuFv$eE01O( zmvPXLenU9iO1F8X`C!7AsiG!&DC2eMx3(4BS<-a~AL?PkKr`lWHg)H|=;-zMKz8Cc z-PqI3aK}Q@&PvXW=@o|I1k}}QB^~J@vq%2ab(|PcAlFTOemAu_)msyxqEa{>Yf=o= zbA%cec_b`pbvcjfRcDp(z$p}q@a;zF0lynd+xy)mAeP9PTvM&lQ7*rMm4hw9T83}Y zKi_~O7>BTH*P24jBdZLS_qRq%mh1uWzngQN7`~x@0P_wJ7SgUe+<`-{MzjpLTc|-d z1<@8&HXI8-R;|i)F1(g(KZv-IP=Rt#pCx@gObUuzu^LpH-P#j|=#WGxAoSmY4S|Z? zGj!s~IMUgKg*2<(8WvchXXB8U7AJ2n80!wrO^aWUAM%m#-xEo=`X$>4HS{6zFkeRU z4dsu38Oa+Q9@N}Hv3XA1(J6`P=thjLigHc~bw6AklQ$E$|LR!BVJYK(`KK1;WIM+d zg)Mxi$wS?7T@)=s=M0DPBu7rDS3%~G_(&Zxze{lxytW5 zx5VX<-+UanT9T_Nx)XC(tNL=zAO4knKGOlSw+A_ggh=vKy4kA>E3s!uN1ObPUc4`@ zvB*7nOJAe4!}J6@U8?|H3Q8DZ<6%VD?;WBmJ=oS8flFIL+y$<-GkEqRuDBpN;i4(M z(m>kcFHYY`II_n+{zwwoRQO}e$BuJGVno8YwqgD3aI9rr3O{o(jofEJX%#Z{?7vn` zCNxxg?dRn67a2nD%lJ^X5nVIMoc8w9dtF9682D5*tGsWJN$Gk}+D66&)@hMe#V41S`(sHyJhA31>v&G*5Uyap}}$K)Wxw#uCdUa8)RU%T(- z^g2rot6TL^4kbInO);BTe^!gn`t z>PFngsvCb@C)5}?;_3{;F3sa{Yy5AID< zXuM**DkjP1w!aEz9XpH~in_8%ER0?)r+d?zz#G@(m_heb2OxMwa!WCsCI@E;o-JN{ zjQ(1q5cIa^Ya5glg7&v}Va1xJuU2@p)YrA z1)_et8y^}k|6$w@Ns*2;n3aMmDoWwpni#brN`GOBx_Yxj9u44V6$7)Qp2GU2&A zN2&GjbwPOez5iTQ!#Ss%x!PWJ3xlBRim=NxY-m}^Fg!U+&_^A$x2L6FFQz#ahR&T< zbnm|R4ql4;; zVrr07DNn|J3BiPd0zIONyH_epdA17HcnPE6brH`lb%Km#1>>z#KQw&#W$kW{G+Aiw z9eP5>?SGdp2KrqO1O?jZ8u;!f$;%SqQNd|{yx3W;EmClMI> z?m+p#(z#z>T2|h0-8#v{$HRpl87i!*5R-kpEKIt{6VA)Yk2`ZW1N4R(?CtQ{@W*@{Io5Ea{cKJ=#FHU_ zb0$gIqrIAP_nl2xw{3%pmq-Bw^iLVMX4Zahk0@QW&}-_?3x9wdf+PZR{S7GgBb19t z(MZ^9-nkH3h_!mElS3AoQJYd>qM8k0vNKd83lr9wh;=HSnLAA?yr8Ri7C0tzee`NllR zqPzg4Ptt-uvsr$RgP$6q@(Cg0Ec&M+i#42QXfS%~(tMP1%bn$tWU(>kb3O!WCYg!pr zF(sE6l~uyPj+}IU{<ba+PVe5E#JjvIPM=hR%xk z^T z$L}lmxXgQh6n-*#3Ceed=@p}C+1g9W?PH{th^aRR(ee|aw=8FWqmzA;{Yif6_k|a{ z2m>cAgW#6bb+{o33AYjp%HOEBkHOOf=<=~L3WqAar!?9~!se6pb=?OuO(WzT#J#pL zVh3g4Nz=toc`khwI1In|d5xDdn@tls^}22d`!Lqk@~d$miHx+pQ24Ok#jv24l^A+; zzr?%+dZNKv1|A&@VSJ4)s923ETIS3dLr#UFkxzQiQyq@eKhEXbLaQXxJw_QJQP{rw z_%9crmcS9%G<0z`LYcuxDQsIu2E*YaLUO_W=EnwCtLRVpGgfX=>q7?j7+;FjKTT)h zbYN0B)y5f;ws^0?-3(N_sEq3!?BXec;~+;43-XkMXROPsoV*UJ%B;q{dhvrjX+h=EY9miR z@&(hkZ9xka_++1!6QwsMRROiTWsu;x>er*xSGuxp#5zYKTqI|n}m^#-_C$^#w=k57N&lQ;b#e;zd*hOf`Oz!@8ICq8F=gO}fWUTE+Q@245 zI*y1tCQ-0^103?~SFP%uhJBmRy){re=St@~vyMnyqczHCQS5}}NR0c1<0ZfD&wT=q zYV8{mHwG_x>&lxTC-cthTH*1wa-qJg*pW9ul)e1O1v%=>OjlNOzR8rZ6+E_;`sB6W zc2=uIDDx<_#sls5PBb3=`ECzJ?*qGmVxb2%L$EK6P*rv57~JcWua^t<-2b)|(-^K# zBfK~<4JOKLsM3U8VB2|Zzo|ik^oJ}b$3D9nxyKVGKU!hKSQixPF(V#k`Ma zgB5;H97Gh8<_|FNDAk}6l%8o*t@z7ax_6CaSbLvlUY<^59c&_4{16?lLgV+o0MuGUm*EK^fTfni7Kj0I#@*eRhF*l!9$? zQ&uf$Xa+R0OD#*2V^GurFtbu;r{fLEx!|{#Ae0@osBGMf8)F8r7?1xe%%@v)7$f}mA{YNJ(@oNgB z)pkeDvP%rRK^}E_V?i_+s%xf@gui|T&qM0*e`VmS=}Ryz zezNv$A5Se;59yG7i0k|1b>nRYPYSsmJ+S(z?Q*>f4Ar>veeNDs%K&pRweFE zy)9(y_Z*V)5blfdem1(dUDvdKnNV7)?>Y9%C^DKra3zF3=L1>F+1MZeNPvV9Tg5J2 zvDNdRU&1{QBia80q!&dpEztPERk=c1JofXCPi--!QDRCOy0qud*F6sBTjJCn3xvF`4s(IB_mEQ7 zzPCBARdRcf%kt~eoJHkj#X=q>KVaCNm8pI6;8X-F;rN*o&XC2Y@JWxNZno|@v2c1twuGj+ zS8W*fnV%J{#&@a-J#D|}vd=c3#+$a@upgOIH0R6$t7z{4kxj=(t5l*LZEYqcwS;;j z;s`(M;2J_g3bUhI`Ivp8 zEq60>0vB$B>M?D_Nqv@Z)7Hh6$FCK5Ku)oTB=ql)8g~NFnx{S7=`a4@B}j}%*a6EIunM1NkVz~-#{Acpo^Nx6QJpAr$wYR{L z_kdGn3Ulh4fHmur$`t!qi4zP;(|xa6_v zA!OpZ+_LFo- zzN6hEkwHJy4ZWw=tR}Zf<$NI7DrUv4eL>IbwD$?~@vss?uQa4dT1jCrK(2;lZ7gy@ zaJ1B?wc~}bK#6}k#K4Ut_aH83Yh?ZTTq3j>BAlf-Rzs3pVdj;lrmKhFF zv@$M5Rc3{vy<1y(VC%MAiHE&Xj@_4w)r6dLtkJqPDcgjE&q(%wfBniQk+8|`8dNLE zpU6n48;AKKL5+~_b?x13Bg-UMON{1n;zh@h`j$RzriVH%BLVOlYYkL&O{YUZK?wR7 zvPqnB>of#n>z5@+9e-{MWVWxiL8J7%WHSC$Y$ufw!Fy)wf27+xgQ99FD=I zry0aDnKljtC^VKmfuFLAnE$C~>g6p{giV7G)vh0qL62BI_sSE*t@k0OZp`Cn%ys;? z7enF(Zmp0qAm%+Mw&nB=CUhe(9huzNTXAP5j&lnv0KF3>GBSF@5SxzK4S+GeY zSoy@o7BgAj5_vaZqM-9z6!9r_*=qThMDXY0lw+%YycMe8K#@Dq?|2Z`=OxBEnt`4} zfGsS`6$KSu5G!99R!=XXA3tA0fQ{p{?TGiYjeg@fmcI34zl(dY&M3luX)G%l8BK5` z)(>QadSJ8y+mif1@TXXU9edcp_8Lbi3pU<`$YrsPQc>j#t}W-Eb3dVEs_FXvt`_oJ z5B+Fo48HDdW8eaP+Sh;*eVSba4xWv+`%#R?r^j5q5#^+Euq~NBHm#^~wiDQ{^PYwk7qsp_x*LF)!|T z92@(Xw70EdReJA|H5n(d{3@4+qf^b}^c<&Y{>|(Nx%XH1e&E1Fh%-90b86t2G(a*7 zsLsCs^IiP5cT!(Bm!;R)=F2yEB^}t>5UovSWXP;~PC(Bh{WMSJ^1bRa2xVi!1?a@( zS(m0#A|kNt#x6wgVq3#fK-p&><2~VsRz^@4ANutH1cPO_s7x3dbrB!R{(OYU@xaDM zZ-gRqfUPY!17LD^v**Ts8327%xN%+4tG|Fzlrwe9=Y6ntTrkupcAzS>gVnmy(9h+< z3VC5(ykKvM*k^Gy9C46`%YoqjJUNtZs2ZLgww@h6oiZ9C38RMG%6 zw>aG%AUODs6}N!X6gbo%AyQJ)<0&z6eddabz1 z`A8&k!c`b^ddCoORJ(Jq-ihNKoyQ1k*Nh2kHV;VkNZXLYPNYM187T4x13rZ*JGylv zF$9H{&-`5JvX2%jO#7pyBe8q48ThL%i1pzh^(A2RdgRf#dy_)(`x@0Un!;k68j|D= zUGW7#u85BKy8+|`zG5zb?M1+z6XxL@RfKx-7w^hd^2Hp_=U9(#*FByF31j}se!9CF z4GwX@%B&O`{;pim;NjTU=}Dj54py%KZX79Y_3utAnKD034%rZ_xH$7s4IUkV5y4LD z-CB40+%(bTs>--rr1Ce~gn~m$1m*y5+2Smzd<7)2DLDlkdYY-Fg{u2|6rL+{YwnPd7Y2Ao<_@2R`J*_GGIdauw-d5Xye- z#%v?pV_9G$pPh~uJhSW)|NZ>> storage = TensorStorage(...) + >>> data = storage[3] + +:class:`~torchrl.data.TensorDictMap` allows us to make more advanced queries in the storage. The typical example is +when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action +pair. In tensor terms, this could be written with the following pseudocode: + + >>> next_state = storage[observation, action] + +(if there is more than one next state associated with this pair one could return a stack of ``next_states`` instead). +This API would make sense but it would be restrictive: allowing observations or actions that are composed of +multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage +know what ``in_keys`` to look at to query the next state: + + >>> td = TensorDict(observation=observation, action=action) + >>> next_td = storage[td] + +Of course, this class also allows us to extend the storage with new data: + + >>> storage[td] = next_state + +This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken +at a given node (ie, for a given observation). All `(observation, action)` pairs that have been observed may lead us to +a (set of) rollout that we can use further. + +MCTSForest +~~~~~~~~~~ + +Building a tree from an initial observation then becomes just a matter of organizing data efficiently. +The :class:`~torchrl.data.MCTSForest` has at its core two storages: a first storage links observations to hashes and +indices of actions encountered in the past in the dataset: + + >>> data = TensorDict(observation=observation) + >>> metadata = forest.node_map[data] + >>> index = metadata["_index"] + +where ``forest`` is a :class:`~torchrl.data.MCTSForest` instance. +Then, a second storage keeps track of the actions and results associated with the observation: + + >>> next_data = forest.data_map[index] + +The ``next_data`` entry can have any shape, but it will usually match the shape of ``index`` (since at each index +corresponds one action). Once ``next_data`` is obtrained, it can be put together with ``data`` to form a set of nodes, +and the tree can be expanded for each of these. The following figure shows how this is done. + +.. figure:: /_static/img/collector-copy.png + + Building a :class:`~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object. + The flowchart represents a tree being built from an initial observation `o`. The :class:`~torchrl.data.MCTSForest.get_tree` + method passed the input data structure (the root node) to the ``node_map`` :class:`~torchrl.data.TensorDictMap` instance + that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of + actions, next observations, rewards etc. that are associated with the root node. + A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked). + The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make + up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be + expanded anymore. .. currentmodule:: torchrl.data diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 5fb9e71cbf2..6c9348e5ae9 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -372,15 +372,16 @@ def make_labels(tree): tree.rollout["next", "observation"], ] ) + a = tree.rollout["action"].tolist() s = s.tolist() - return f"{tree.node_id}: {s}" - return f"{tree.node_id}" + return f"node {tree.node_id}: states {s}, actions {a}" + return f"node {tree.node_id}" def test_forest_build(self): r0, *_ = self.dummy_rollouts() forest = self._make_forest() tree = forest.get_tree(r0[0]) - # tree.plot(make_labels=self.make_labels) + tree.plot(make_labels=self.make_labels) def test_forest_vertices(self): r0, *_ = self.dummy_rollouts() diff --git a/test/test_transforms.py b/test/test_transforms.py index a5a4fad4e40..1c1903a3b1b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7098,7 +7098,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): torch.manual_seed(0) env.set_seed(0) r1 = env.rollout(100, break_when_any_done=break_when_any_done) - tensordict.tensordict.assert_allclose_td(r0, r1) + tensordict.assert_close(r0, r1) def test_callable_default_value(self): def create_tensor(): diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 9a54913ca2a..570214f1cb2 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -53,7 +53,7 @@ def make_labels(tree): x=Xe, y=Ye, mode="lines", - line={"color": "rgb(210,210,210)", "width": 1}, + line={"color": "rgb(210,210,210)", "width": 5}, hoverinfo="none", ) ) @@ -61,16 +61,17 @@ def make_labels(tree): go.Scatter( x=Xn, y=Yn, - mode="markers", + mode="markers+text", name="bla", marker={ "symbol": "circle-dot", - "size": 18, + "size": 40, "color": "#6175c1", # '#DB4551', "line": {"color": "rgb(50,50,50)", "width": 1}, }, text=labels, hoverinfo="text", + textposition="middle right", opacity=0.8, ) ) From 7a44821ef91a3ac51f963ac9da70c4dd34be809d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 14:40:57 +0000 Subject: [PATCH 05/14] Update (base update) [ghstack-poisoned] --- test/test_specs.py | 5 +++++ torchrl/data/tensor_specs.py | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 1a7dd41621e..39b09798ac2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3823,6 +3823,7 @@ def test_discrete(self): spec.enumerate() == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) ).all() + assert spec.is_in(spec.enumerate()) def test_one_hot(self): spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) @@ -3839,15 +3840,18 @@ def test_one_hot(self): dtype=torch.bool, ) ).all() + assert spec.is_in(spec.enumerate()) def test_multi_discrete(self): spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 3]) def test_multi_onehot(self): spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 12]) def test_composite(self): @@ -3859,6 +3863,7 @@ def test_composite(self): shape=[3], ) c_enum = c.enumerate() + assert c.is_in(c_enum) assert c_enum.shape == torch.Size((20, 3)) assert c_enum["b"].shape == torch.Size((20, 3)) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b641b808cf3..3590d76d2ce 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -835,7 +835,7 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool: return self.is_in(item) @abc.abstractmethod - def enumerate(self): + def enumerate(self) -> Any: """Returns all the samples that can be obtained from the TensorSpec. The samples will be stacked along the first dimension. @@ -1281,7 +1281,7 @@ def __eq__(self, other): return False return True - def enumerate(self): + def enumerate(self) -> torch.Tensor | TensorDictBase: return torch.stack( [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 ) @@ -1747,7 +1747,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val - def enumerate(self): + def enumerate(self) -> torch.Tensor: return ( torch.eye(self.n, dtype=self.dtype, device=self.device) .expand(*self.shape, self.n) @@ -2078,7 +2078,7 @@ def __init__( domain=domain, ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError( f"enumerate is not implemented for spec of class {type(self).__name__}." ) @@ -2402,7 +2402,7 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("Cannot enumerate a NonTensorSpec.") def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: @@ -2641,7 +2641,7 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("enumerate cannot be called with continuous specs.") def expand(self, *shape): @@ -2808,7 +2808,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: nvec = self.nvec enum_disc = self.to_categorical_spec().enumerate() enums = torch.cat( @@ -3253,7 +3253,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: arange = torch.arange(self.n, dtype=self.dtype, device=self.device) if self.ndim: arange = arange.view(-1, *(1,) * self.ndim) @@ -3766,7 +3766,7 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton - def enumerate(self): + def enumerate(self) -> torch.Tensor: if self.mask is not None: raise RuntimeError( "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." @@ -4682,7 +4682,7 @@ def clone(self) -> Composite: shape=self.shape, ) - def enumerate(self): + def enumerate(self) -> TensorDictBase: # We are going to use meshgrid to create samples of all the subspecs in here # but first let's get rid of the batch size, we'll put it back later self_without_batch = self @@ -4959,7 +4959,7 @@ def update(self, dict) -> None: self[key] = item return self - def enumerate(self): + def enumerate(self) -> TensorDictBase: dim = self.stack_dim return LazyStackedTensorDict.maybe_dense_stack( [spec.enumerate() for spec in self._specs], dim + 1 From 5a59f00d758ec9b9aa40d7bd9fa2e9761ccca3e4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 18:17:27 +0000 Subject: [PATCH 06/14] Update [ghstack-poisoned] --- torchrl/data/map/tree.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..fd3f84913ee 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -163,9 +163,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: From 441fa9a69b841e30c0ef5686812fbbf35eb8d85d Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Jun 2025 12:27:11 +0530 Subject: [PATCH 07/14] Added EXP3 scoring function --- torchrl/modules/mcts/scores.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index 99b8772fc14..17660277d83 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -91,6 +91,13 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node +class EXP3Score(MCTSScore): + def __init__( + self, + *, + gamma: float = 0.07 + ): + pass class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value From bc32e0ef2811d199b4e2a80fd501c014db6bf939 Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Wed, 18 Jun 2025 16:33:28 +0530 Subject: [PATCH 08/14] Add EXP3 scoring function --- torchrl/modules/mcts/scores.py | 108 ++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index 17660277d83..e97684578ca 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -9,6 +9,8 @@ from abc import abstractmethod from enum import Enum +import torch + from tensordict import NestedKey, TensorDictBase from tensordict.nn import TensorDictModuleBase from torch import nn @@ -91,13 +93,115 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node + class EXP3Score(MCTSScore): def __init__( self, *, - gamma: float = 0.07 + gamma: float = 0.1, + weights_key: NestedKey = "weights", + action_prob_key: NestedKey = "action_prob", + reward_key: NestedKey = "reward", + score_key: NestedKey = "score", + num_actions_key: NestedKey = "num_actions", ): - pass + super().__init__() + if not 0 <= gamma <= 1: + raise ValueError(f"gamma must be between 0 and 1, got {gamma}") + self.gamma = gamma + self.weights_key = weights_key + self.action_prob_key = action_prob_key + self.reward_key = reward_key + self.score_key = score_key + self.num_actions_key = num_actions_key + + self.in_keys = [self.weights_key, self.num_actions_key] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + num_actions = node.get(self.num_actions_key) + + if self.weights_key not in node.keys(include_nested=True): + batch_size = node.batch_size + if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1: + k = int(num_actions.item()) + elif isinstance(num_actions, int): + k = num_actions + else: + raise ValueError( + f"'{self.num_actions_key}' ('num_actions') must be an integer or a scalar tensor." + ) + weights_shape = (*batch_size, k) + weights = torch.ones(weights_shape, device=node.device) + node.set(self.weights_key, weights) + else: + weights = node.get(self.weights_key) + + k = weights.shape[-1] + if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1: + if k != num_actions.item(): + raise ValueError( + f"Shape of weights {weights.shape} implies {k} actions." + f"but num_actions is {num_actions.item()}" + ) + elif isinstance(num_actions, int): + if k != num_actions: + raise ValueError( + f"Shape of weights {weights.shape} implies {k} actions, " + f"but num_actions is {num_actions}." + ) + + sum_weights = torch.sum(weights, dim=-1, keepdim=True) + sum_weights = torch.where( + sum_weights == 0, torch.ones_like(sum_weights), sum_weights + ) + + p_i = (1 - self.gamma) * (weights / sum_weights) + (self.gamma / k) + node.set(self.score_key, p_i) + if self.action_prob_key != self.score_key: + node.set(self.action_prob_key, p_i) + return node + + def update_weights( + self, node: TensorDictBase, action_idx: int, reward: float + ) -> None: + if not (0 <= reward <= 1): + ValueError( + f"Reward {reward} is outside the expected [0, 1] range for EXP3." + ) + + weights = node.get(self.weights_key) + action_probs = node.get(self.score_key) + k = weights.shape[-1] + + if weights.ndim == 1: + current_weight = weights[action_idx] + prob_i = action_probs[action_idx] + elif weights.ndim > 1: + current_weight = weights[..., action_idx] + prob_i = action_probs[..., action_idx] + else: + raise ValueError(f"Invalid weights dimensions: {weights.ndim}") + + if torch.any(prob_i <= 0): + ValueError( + f"Probability p_i(t) for action {action_idx} is {prob_i}, which is <= 0." + "This might lead to issues in weight update." + ) + prob_i = torch.clamp(prob_i, min=1e-9) + + reward_tensor = torch.as_tensor( + reward, device=current_weight.device, dtype=current_weight.dtype + ) + exponent = (self.gamma / k) * (reward_tensor / prob_i) + new_weight = current_weight * torch.exp(exponent) + + if weights.ndim == 1: + weights[action_idx] = new_weight + else: + weights[..., action_idx] = new_weight + node.set(self.weights_key, weights) + class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value From 4de024a64ebadaedbf17cf76857df26295d46cc5 Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Fri, 20 Jun 2025 09:46:15 +0530 Subject: [PATCH 09/14] Update scores.py --- torchrl/modules/mcts/scores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index e97684578ca..24616e8d1b7 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -207,5 +207,5 @@ class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 UCB1_TUNED = "UCB1-Tuned" - EXP3 = "EXP3" + EXP3 = functool.partial(EXP3Score, gamma=0.1) PUCT_VARIANT = "PUCT-Variant" From 73952eec91746cb490fe31e3ce7af74f1dff312e Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Fri, 20 Jun 2025 09:51:08 +0530 Subject: [PATCH 10/14] Update scores.py --- torchrl/modules/mcts/scores.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index 24616e8d1b7..d79fb1426ed 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -21,7 +21,6 @@ class MCTSScore(TensorDictModuleBase): def forward(self, node): pass - class PUCTScore(MCTSScore): c: float @@ -61,7 +60,6 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node - class UCBScore(MCTSScore): c: float @@ -93,7 +91,6 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node - class EXP3Score(MCTSScore): def __init__( self, @@ -202,7 +199,6 @@ def update_weights( weights[..., action_idx] = new_weight node.set(self.weights_key, weights) - class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 From 57e39406dc955661d6524a50b80542a367dd9f65 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 20 Jun 2025 11:59:37 +0530 Subject: [PATCH 11/14] Added tests --- test/test_mcts.py | 681 +++++++++++++++++++++++++++++++ torchrl/modules/mcts/__init__.py | 2 +- torchrl/modules/mcts/scores.py | 5 + 3 files changed, 687 insertions(+), 1 deletion(-) create mode 100644 test/test_mcts.py diff --git a/test/test_mcts.py b/test/test_mcts.py new file mode 100644 index 00000000000..c5e25c91e2e --- /dev/null +++ b/test/test_mcts.py @@ -0,0 +1,681 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tensordict import TensorDict +import math +from torchrl.modules.mcts.scores import UCBScore, PUCTScore, EXP3Score + +# Sample TensorDict for testing +def create_node(num_actions, weights=None, batch_size=None, device="cpu", custom_keys=None): + if custom_keys is None: + custom_keys = { + "num_actions_key": "num_actions", + "weights_key": "weights", + "score_key": "score", + } + + if batch_size: + data = {custom_keys["num_actions_key"]: torch.tensor([num_actions] * batch_size, device=device)} + if weights is not None: + if weights.ndim == 1: + weights = weights.unsqueeze(0).repeat(batch_size, 1) + data[custom_keys["weights_key"]] = weights.to(device) + td = TensorDict(data, batch_size=[batch_size], device=device) + else: + data = {custom_keys["num_actions_key"]: torch.tensor(num_actions, device=device)} + if weights is not None: + data[custom_keys["weights_key"]] = weights.to(device) + td = TensorDict(data, batch_size=[], device=device) + return td + +# Sample TensorDict node for UCBScore +def create_ucb_node(win_count, visits, total_visits, batch_size=None, device="cpu", custom_keys=None): + if custom_keys is None: + custom_keys = { + "win_count_key": "win_count", + "visits_key": "visits", + "total_visits_key": "total_visits", + "score_key": "score", + } + + win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32) + visits = torch.as_tensor(visits, device=device, dtype=torch.float32) + total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32) + + if batch_size: + if win_count.ndim == 0: + win_count = win_count.unsqueeze(0).repeat(batch_size) + elif win_count.shape[0] != batch_size: + raise ValueError("Batch size mismatch for win_count") + if visits.ndim == 0: + visits = visits.unsqueeze(0).repeat(batch_size) + elif visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for visits") + if total_visits.ndim == 0: + total_visits = total_visits.unsqueeze(0).repeat(batch_size) + elif total_visits.shape[0] != batch_size and total_visits.numel() != 1 : + raise ValueError("Batch size mismatch for total_visits") + if total_visits.numel() == 1 and batch_size > 1: + total_visits = total_visits.repeat(batch_size) + + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + } + td = TensorDict(data, batch_size=[batch_s for batch_s in batch_size] if isinstance(batch_size, (list, tuple)) else [batch_size], device=device) + else: + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + } + td = TensorDict(data, batch_size=win_count.shape[:-1] if win_count.ndim > 1 else [], device=device) + + return td + +# Helper function to create a sample TensorDict node for PUCTScore +def create_puct_node(win_count, visits, total_visits, prior_prob, batch_size=None, device="cpu", custom_keys=None): + if custom_keys is None: + custom_keys = { + "win_count_key": "win_count", + "visits_key": "visits", + "total_visits_key": "total_visits", + "prior_prob_key": "prior_prob", + "score_key": "score", + } + + win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32) + visits = torch.as_tensor(visits, device=device, dtype=torch.float32) + total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32) + prior_prob = torch.as_tensor(prior_prob, device=device, dtype=torch.float32) + + if batch_size: + if win_count.ndim == 0: win_count = win_count.unsqueeze(0).repeat(batch_size) + elif win_count.shape[0] != batch_size: raise ValueError("Batch size mismatch for win_count") + if visits.ndim == 0: visits = visits.unsqueeze(0).repeat(batch_size) + elif visits.shape[0] != batch_size: raise ValueError("Batch size mismatch for visits") + if prior_prob.ndim == 0: prior_prob = prior_prob.unsqueeze(0).repeat(batch_size) + elif prior_prob.shape[0] != batch_size: raise ValueError("Batch size mismatch for prior_prob") + + if total_visits.numel() == 1 and batch_size > 1: # scalar total_visits for batch + total_visits = total_visits.repeat(batch_size) + elif total_visits.ndim == 0 : total_visits = total_visits.unsqueeze(0).repeat(batch_size) # make it (batch_size,) + elif total_visits.shape[0] != batch_size : raise ValueError("Batch size mismatch for total_visits") + + + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + custom_keys["prior_prob_key"]: prior_prob, + } + if isinstance(batch_size, (list, tuple)): + td_batch_size = batch_size + else: + td_batch_size = [batch_size] + td = TensorDict(data, batch_size=td_batch_size, device=device) + + else: + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + custom_keys["prior_prob_key"]: prior_prob, + } + td_batch_size = win_count.shape[:-1] if win_count.ndim > 1 else [] + + td = TensorDict(data, batch_size=td_batch_size, device=device) + + return td + +class TestEXP3Score: + @pytest.fixture + def default_scorer(self): + return EXP3Score() + + @pytest.fixture + def custom_key_names(self): + return { + "weights_key": "custom_weights", + "score_key": "custom_scores", + "num_actions_key": "custom_num_actions", + "action_prob_key": "custom_actions_prob", + "reward_key": "custom_reward" + } + + @pytest.mark.parametrize("gamma_val", [0.1, 0.5, 0.9]) + def test_initialization(self, gamma_val): + scorer = EXP3Score(gamma=gamma_val) + assert scorer.gamma == gamma_val + scorer_default = EXP3Score() + assert scorer_default.gamma == 0.1 + + def test_forward_initial_weights(self, default_scorer): + num_actions = 3 + node = create_node(num_actions=num_actions) + + default_scorer.forward(node) + + assert default_scorer.weights_key in node.keys() + expected_weights = torch.ones(num_actions) + torch.testing.assert_close(node.get(default_scorer.weights_key), expected_weights) + + expected_scores = torch.ones(num_actions) / num_actions + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) + torch.testing.assert_close(node.get(default_scorer.score_key).sum(), torch.tensor(1.0)) + + def test_forward_custom_weights(self, default_scorer): + num_actions = 3 + weights = torch.tensor([1.0, 2.0, 3.0]) + node = create_node(num_actions=num_actions, weights=weights) + + default_scorer.forward(node) + + gamma = default_scorer.gamma + sum_w = weights.sum() + expected_scores = (1 - gamma) * (weights / sum_w) + (gamma / num_actions) + + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) + torch.testing.assert_close(node.get(default_scorer.score_key).sum(), torch.tensor(1.0)) + + @pytest.mark.parametrize("batch_s", [2, 4]) + def test_forward_batch(self, default_scorer, batch_s): + num_actions = 3 + node_initial = create_node(num_actions=num_actions, batch_size=batch_s) + default_scorer.forward(node_initial) + + expected_weights_initial = torch.ones(batch_s, num_actions) + torch.testing.assert_close(node_initial.get(default_scorer.weights_key), expected_weights_initial) + + expected_scores_initial = torch.ones(batch_s, num_actions) / num_actions + torch.testing.assert_close(node_initial.get(default_scorer.score_key), expected_scores_initial) + torch.testing.assert_close(node_initial.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s)) + + weights_custom = torch.rand(batch_s, num_actions) + 0.1 + node_custom = create_node(num_actions=num_actions, weights=weights_custom, batch_size=batch_s) + default_scorer.forward(node_custom) + + gamma = default_scorer.gamma + sum_w_custom = weights_custom.sum(dim=-1, keepdim=True) + expected_scores_custom = (1 - gamma) * (weights_custom / sum_w_custom) + (gamma / num_actions) + torch.testing.assert_close(node_custom.get(default_scorer.score_key), expected_scores_custom, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(node_custom.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s)) + + def test_update_weights_single_node(self, default_scorer): + num_actions = 3 + action_idx = 0 + reward = 1.0 + node = create_node(num_actions=num_actions) + + default_scorer.forward(node) + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) + + torch.testing.assert_close(updated_weights[action_idx], torch.tensor(expected_new_weight_val)) + torch.testing.assert_close(updated_weights[action_idx+1:], initial_weights[action_idx+1:]) + + default_scorer.forward(node) + sum_w_updated = updated_weights.sum() + expected_scores_after_update = (1-gamma)*(updated_weights/sum_w_updated) + (gamma/k) + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores_after_update) + + + def test_update_weights_zero_reward(self, default_scorer): + num_actions = 3 + action_idx = 1 + reward = 0.0 + weights = torch.tensor([1.0, 2.0, 1.5]) + node = create_node(num_actions=num_actions, weights=weights) + + default_scorer.forward(node) + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) + torch.testing.assert_close(updated_weights[action_idx], expected_new_weight_val) + torch.testing.assert_close(updated_weights[action_idx], initial_weights[action_idx]) + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_update_weights_batch(self, default_scorer, batch_s): + num_actions = 3 + node = create_node(num_actions=num_actions, batch_size=batch_s) + default_scorer.forward(node) + + initial_weights_batch = node.get(default_scorer.weights_key).clone() + probs_batch = node.get(default_scorer.score_key).clone() + + rewards = torch.rand(batch_s) + action_indices = torch.randint(0, num_actions, (batch_s,)) + + expected_updated_weights_batch = initial_weights_batch.clone() + gamma = default_scorer.gamma + k = num_actions + + for i in range(batch_s): + action_idx = action_indices[i].item() + reward = rewards[i].item() + + single_node_td = node[i] + + current_weight_item = initial_weights_batch[i, action_idx] + prob_i_item = probs_batch[i, action_idx] + + exp_val = math.exp((gamma / k) * (reward / prob_i_item)) + expected_updated_weights_batch[i, action_idx] = current_weight_item * exp_val + + node_item_to_update = node[i:i+1] + default_scorer.update_weights(node_item_to_update, action_idx, reward) + + torch.testing.assert_close(node.get(default_scorer.weights_key), expected_updated_weights_batch, atol=1e-5, rtol=1e-5) + + def test_single_action(self, default_scorer): + num_actions = 1 + node = create_node(num_actions=num_actions) + default_scorer.forward(node) + + assert default_scorer.weights_key in node.keys() + torch.testing.assert_close(node.get(default_scorer.weights_key), torch.ones(num_actions)) + torch.testing.assert_close(node.get(default_scorer.score_key), torch.ones(num_actions)) # p_i = 1.0 + + action_idx = 0 + reward = 0.5 + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) + torch.testing.assert_close(updated_weights[action_idx], torch.tensor(expected_new_weight_val)) + + @pytest.mark.parametrize("gamma_val, expected_behavior", [ + (0.0, "exploitation"), (1.0, "exploration") + ]) + def test_gamma_extremes(self, gamma_val, expected_behavior): + scorer = EXP3Score(gamma=gamma_val) + num_actions = 3 + weights = torch.tensor([1.0, 2.0, 7.0]) + node = create_node(num_actions=num_actions, weights=weights) + + scorer.forward(node) + scores = node.get(scorer.score_key) + + if expected_behavior == "exploitation": + expected_scores = weights / weights.sum() + torch.testing.assert_close(scores, expected_scores) + elif expected_behavior == "exploration": + expected_scores = torch.ones(num_actions) / num_actions + torch.testing.assert_close(scores, expected_scores) + + def test_custom_keys(self, custom_key_names): + gamma = 0.2 + scorer = EXP3Score( + gamma=gamma, + weights_key=custom_key_names["weights_key"], + score_key=custom_key_names["score_key"], + num_actions_key=custom_key_names["num_actions_key"], + action_prob_key=custom_key_names["action_prob_key"], + ) + num_actions = 2 + + node1 = create_node(num_actions=num_actions, custom_keys=custom_key_names) + scorer.forward(node1) + + assert custom_key_names["weights_key"] in node1.keys() + expected_weights1 = torch.ones(num_actions) + torch.testing.assert_close(node1.get(custom_key_names["weights_key"]), expected_weights1) + expected_scores1 = torch.ones(num_actions) / num_actions + torch.testing.assert_close(node1.get(custom_key_names["score_key"]), expected_scores1) + if scorer.action_prob_key != scorer.score_key: # Check if action_prob_key was also populated + torch.testing.assert_close(node1.get(custom_key_names["action_prob_key"]), expected_scores1) + + weights2_val = torch.tensor([1.0, 3.0]) + node2 = create_node(num_actions=num_actions, weights=weights2_val, custom_keys=custom_key_names) + scorer.forward(node2) + + sum_w2 = weights2_val.sum() + expected_scores2 = (1 - gamma) * (weights2_val / sum_w2) + (gamma / num_actions) + torch.testing.assert_close(node2.get(custom_key_names["score_key"]), expected_scores2) + + action_idx = 0 + reward = 1.0 + initial_weights2 = node2.get(custom_key_names["weights_key"]).clone() + prob_i2 = node2.get(custom_key_names["score_key"])[action_idx] + + scorer.update_weights(node2, action_idx, reward) + updated_weights2 = node2.get(custom_key_names["weights_key"]) + k = num_actions + + expected_new_weight_val2 = initial_weights2[action_idx] * math.exp((gamma / k) * (reward / prob_i2)) + torch.testing.assert_close(updated_weights2[action_idx], torch.tensor(expected_new_weight_val2)) + + def test_forward_raises_error_on_mismatched_num_actions(self, default_scorer): + num_actions_prop = 3 + weights = torch.tensor([1.0, 2.0, 3.0, 4.0]) # K=4 from weights + node = create_node(num_actions=num_actions_prop, weights=weights) # num_actions=3 + + with pytest.raises(ValueError, match="Shape of weights .* implies 4 actions, but num_actions is 3"): + default_scorer.forward(node) + + weights_ok = torch.tensor([1.0, 2.0, 3.0]) + node_ok = create_node(num_actions=torch.tensor(4), weights=weights_ok) # num_actions=4 from tensor + + with pytest.raises(ValueError, match="Shape of weights .* implies 3 actions, but num_actions is 4"): + default_scorer.forward(node_ok) + + def test_update_weights_handles_prob_zero(self, default_scorer): + num_actions = 2 + action_idx = 0 + reward = 1.0 + scorer_exploit = EXP3Score(gamma=0.0) + weights = torch.tensor([0.0, 1.0]) + node = create_node(num_actions=num_actions, weights=weights) + + scorer_exploit.forward(node) # p_0 will be 0 + assert node.get(scorer_exploit.score_key)[0] == 0.0 + + with pytest.warns(UserWarning, match="Probability p_i\\(t\\) for action 0 is 0.0"): + scorer_exploit.update_weights(node, action_idx, reward) + torch.testing.assert_close(node.get(scorer_exploit.weights_key)[action_idx], torch.tensor(0.0)) + + def test_init_raises_error_gamma_out_of_range(self): + with pytest.raises(ValueError, match="gamma must be between 0 and 1"): + EXP3Score(gamma=-0.1) + with pytest.raises(ValueError, match="gamma must be between 0 and 1"): + EXP3Score(gamma=1.1) + + def test_update_weights_reward_warning(self, default_scorer): + num_actions = 2 + node = create_node(num_actions=num_actions) + default_scorer.forward(node) + with pytest.warns(UserWarning, match="Reward .* is outside the expected \\[0,1\\] range"): + default_scorer.update_weights(node, 0, 1.5) + with pytest.warns(UserWarning, match="Reward .* is outside the expected \\[0,1\\] range"): + default_scorer.update_weights(node, 0, -0.5) + initial_weight = node.get(default_scorer.weights_key)[0].clone() + default_scorer.update_weights(node, 0, 1.5) + assert node.get(default_scorer.weights_key)[0] != initial_weight # it changed + + +class TestUCBScore: + @pytest.fixture + def default_ucb_scorer(self): + return UCBScore(c=math.sqrt(2)) + + @pytest.fixture + def ucb_custom_key_names(self): + return { + "win_count_key": "custom_wins", + "visits_key": "custom_visits", + "total_visits_key": "custom_total_visits", + "score_key": "custom_ucb_score", + } + + @pytest.mark.parametrize("c_val", [0.5, 1.0, math.sqrt(2), 5.0]) + def test_initialization(self, c_val): + scorer = UCBScore(c=c_val) + assert scorer.c == c_val + + def test_forward_basic(self, default_ucb_scorer): + win_count = torch.tensor([10.0, 5.0, 20.0]) + visits = torch.tensor([15.0, 10.0, 25.0]) + total_visits_parent = torch.tensor(50.0) + + node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + exploitation_term = win_count / visits + exploration_term = c * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + def test_forward_zero_visits(self, default_ucb_scorer): + win_count = torch.tensor([0.0, 0.0]) + visits = torch.tensor([10.0, 0.0]) + total_visits_parent = torch.tensor(10.0) + + node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + scores = node.get(default_ucb_scorer.score_key) + + expected_score_0 = (win_count[0] / visits[0]) + c * total_visits_parent.sqrt() / (1 + visits[0]) + torch.testing.assert_close(scores[0], expected_score_0) + assert torch.isnan(scores[1]), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_forward_batch(self, default_ucb_scorer, batch_s): + win_count = torch.rand(batch_s, 2) * 10 + visits = torch.rand(batch_s, 2) * 5 + 1 + total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) + + node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent, batch_size=batch_s) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + exploitation_term = win_count / visits + exploration_term = c * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + def test_forward_exploration_term(self, default_ucb_scorer): + win_count = torch.tensor([0.0, 0.0, 0.0]) + visits = torch.tensor([10.0, 5.0, 1.0]) + total_visits_parent = torch.tensor(100.0) + + node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + expected_scores = c * total_visits_parent.sqrt() / (1 + visits) + + torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + def test_custom_keys(self, ucb_custom_key_names): + c_val = 1.5 + scorer = UCBScore( + c=c_val, + win_count_key=ucb_custom_key_names["win_count_key"], + visits_key=ucb_custom_key_names["visits_key"], + total_visits_key=ucb_custom_key_names["total_visits_key"], + score_key=ucb_custom_key_names["score_key"], + ) + + win_count = torch.tensor([1.0, 2.0]) + visits = torch.tensor([3.0, 4.0]) + total_visits_parent = torch.tensor(10.0) + + node = create_ucb_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + custom_keys=ucb_custom_key_names + ) + scorer.forward(node) + + exploitation = win_count / visits + exploration = c_val * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation + exploration + + assert ucb_custom_key_names["score_key"] in node.keys() + torch.testing.assert_close(node.get(ucb_custom_key_names["score_key"]), expected_scores) + + assert "score" not in node.keys() + assert "win_count" not in node.keys() + assert "visits" not in node.keys() + assert "total_visits" not in node.keys() + + +class TestPUCTScore: + @pytest.fixture + def default_puct_scorer(self): + return PUCTScore(c=5.0) + + @pytest.fixture + def puct_custom_key_names(self): + return { + "win_count_key": "custom_puct_wins", + "visits_key": "custom_puct_visits", + "total_visits_key": "custom_puct_total_visits", + "prior_prob_key": "custom_puct_priors", + "score_key": "custom_puct_score", + } + + @pytest.mark.parametrize("c_val", [0.5, 1.0, 5.0, 10.0]) + def test_initialization(self, c_val): + scorer = PUCTScore(c=c_val) + assert scorer.c == c_val + + def test_forward_basic(self, default_puct_scorer): + win_count = torch.tensor([10.0, 5.0, 20.0]) + visits = torch.tensor([15.0, 10.0, 25.0]) + prior_prob = torch.tensor([0.4, 0.3, 0.3]) + total_visits_parent = torch.tensor(50.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + exploitation_term = win_count / visits + exploration_term = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores) + + def test_forward_zero_visits(self, default_puct_scorer): + win_count = torch.tensor([0.0, 0.0]) + visits = torch.tensor([10.0, 0.0]) + prior_prob = torch.tensor([0.6, 0.4]) + total_visits_parent = torch.tensor(10.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + scores = node.get(default_puct_scorer.score_key) + + expected_score_0 = (win_count[0] / visits[0]) + \ + c * prior_prob[0] * total_visits_parent.sqrt() / (1 + visits[0]) + torch.testing.assert_close(scores[0], expected_score_0) + + assert torch.isnan(scores[1]), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." + + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_forward_batch(self, default_puct_scorer, batch_s): + num_actions = 2 + win_count = torch.rand(batch_s, num_actions) * 10 + visits = torch.rand(batch_s, num_actions) * 5 + 1 + prior_prob = torch.rand(batch_s, num_actions) + prior_prob = prior_prob / prior_prob.sum(dim=-1, keepdim=True) + total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + batch_size=batch_s + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + exploitation_term = win_count / visits + exploration_term = c * prior_prob * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores, atol=1e-6, rtol=1e-6) + + def test_forward_exploration_term(self, default_puct_scorer): + num_actions = 3 + win_count = torch.zeros(num_actions) + visits = torch.tensor([10.0, 5.0, 1.0]) + prior_prob = torch.tensor([0.3, 0.5, 0.2]) + total_visits_parent = torch.tensor(100.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + # exploitation_term is effectively 0 + expected_scores = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) + + torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores) + + def test_custom_keys(self, puct_custom_key_names): + c_val = 2.5 + scorer = PUCTScore( + c=c_val, + win_count_key=puct_custom_key_names["win_count_key"], + visits_key=puct_custom_key_names["visits_key"], + total_visits_key=puct_custom_key_names["total_visits_key"], + prior_prob_key=puct_custom_key_names["prior_prob_key"], + score_key=puct_custom_key_names["score_key"], + ) + + win_count = torch.tensor([1.0, 2.0]) + visits = torch.tensor([3.0, 4.0]) + prior_prob = torch.tensor([0.5, 0.5]) + total_visits_parent = torch.tensor(10.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + custom_keys=puct_custom_key_names + ) + scorer.forward(node) + + exploitation = win_count / visits + exploration = c_val * prior_prob * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation + exploration + + assert puct_custom_key_names["score_key"] in node.keys() + torch.testing.assert_close(node.get(puct_custom_key_names["score_key"]), expected_scores) + + # Check that default keys are not present + assert "score" not in node.keys() + assert "win_count" not in node.keys() + assert "visits" not in node.keys() + assert "total_visits" not in node.keys() + assert "prior_prob" not in node.keys() \ No newline at end of file diff --git a/torchrl/modules/mcts/__init__.py b/torchrl/modules/mcts/__init__.py index b983d492454..b225f4c0cca 100644 --- a/torchrl/modules/mcts/__init__.py +++ b/torchrl/modules/mcts/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .scores import PUCTScore, UCBScore +from .scores import EXP3Score, PUCTScore, UCBScore diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index d79fb1426ed..f9046f9df2c 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -91,6 +91,7 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node + class EXP3Score(MCTSScore): def __init__( self, @@ -198,6 +199,10 @@ def update_weights( else: weights[..., action_idx] = new_weight node.set(self.weights_key, weights) +<<<<<<< Updated upstream +======= + +>>>>>>> Stashed changes class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value From 168c29306f23d1d4ebf3ed74f0cf3861312e9cdc Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 20 Jun 2025 12:11:04 +0530 Subject: [PATCH 12/14] Updates --- test/test_mcts.py | 478 ++++++++++++++++++++++----------- torchrl/modules/mcts/scores.py | 7 +- 2 files changed, 325 insertions(+), 160 deletions(-) diff --git a/test/test_mcts.py b/test/test_mcts.py index c5e25c91e2e..5ad4467a875 100644 --- a/test/test_mcts.py +++ b/test/test_mcts.py @@ -3,14 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import math + import pytest import torch from tensordict import TensorDict -import math -from torchrl.modules.mcts.scores import UCBScore, PUCTScore, EXP3Score +from torchrl.modules.mcts.scores import EXP3Score, PUCTScore, UCBScore # Sample TensorDict for testing -def create_node(num_actions, weights=None, batch_size=None, device="cpu", custom_keys=None): +def create_node( + num_actions, weights=None, batch_size=None, device="cpu", custom_keys=None +): if custom_keys is None: custom_keys = { "num_actions_key": "num_actions", @@ -19,21 +22,30 @@ def create_node(num_actions, weights=None, batch_size=None, device="cpu", custom } if batch_size: - data = {custom_keys["num_actions_key"]: torch.tensor([num_actions] * batch_size, device=device)} + data = { + custom_keys["num_actions_key"]: torch.tensor( + [num_actions] * batch_size, device=device + ) + } if weights is not None: if weights.ndim == 1: weights = weights.unsqueeze(0).repeat(batch_size, 1) data[custom_keys["weights_key"]] = weights.to(device) td = TensorDict(data, batch_size=[batch_size], device=device) else: - data = {custom_keys["num_actions_key"]: torch.tensor(num_actions, device=device)} + data = { + custom_keys["num_actions_key"]: torch.tensor(num_actions, device=device) + } if weights is not None: data[custom_keys["weights_key"]] = weights.to(device) td = TensorDict(data, batch_size=[], device=device) return td + # Sample TensorDict node for UCBScore -def create_ucb_node(win_count, visits, total_visits, batch_size=None, device="cpu", custom_keys=None): +def create_ucb_node( + win_count, visits, total_visits, batch_size=None, device="cpu", custom_keys=None +): if custom_keys is None: custom_keys = { "win_count_key": "win_count", @@ -41,7 +53,7 @@ def create_ucb_node(win_count, visits, total_visits, batch_size=None, device="cp "total_visits_key": "total_visits", "score_key": "score", } - + win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32) visits = torch.as_tensor(visits, device=device, dtype=torch.float32) total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32) @@ -50,36 +62,55 @@ def create_ucb_node(win_count, visits, total_visits, batch_size=None, device="cp if win_count.ndim == 0: win_count = win_count.unsqueeze(0).repeat(batch_size) elif win_count.shape[0] != batch_size: - raise ValueError("Batch size mismatch for win_count") + raise ValueError("Batch size mismatch for win_count") if visits.ndim == 0: visits = visits.unsqueeze(0).repeat(batch_size) elif visits.shape[0] != batch_size: - raise ValueError("Batch size mismatch for visits") + raise ValueError("Batch size mismatch for visits") if total_visits.ndim == 0: total_visits = total_visits.unsqueeze(0).repeat(batch_size) - elif total_visits.shape[0] != batch_size and total_visits.numel() != 1 : - raise ValueError("Batch size mismatch for total_visits") + elif total_visits.shape[0] != batch_size and total_visits.numel() != 1: + raise ValueError("Batch size mismatch for total_visits") if total_visits.numel() == 1 and batch_size > 1: total_visits = total_visits.repeat(batch_size) - + data = { custom_keys["win_count_key"]: win_count, custom_keys["visits_key"]: visits, custom_keys["total_visits_key"]: total_visits, } - td = TensorDict(data, batch_size=[batch_s for batch_s in batch_size] if isinstance(batch_size, (list, tuple)) else [batch_size], device=device) + td = TensorDict( + data, + batch_size=[batch_s for batch_s in batch_size] + if isinstance(batch_size, (list, tuple)) + else [batch_size], + device=device, + ) else: data = { custom_keys["win_count_key"]: win_count, custom_keys["visits_key"]: visits, custom_keys["total_visits_key"]: total_visits, } - td = TensorDict(data, batch_size=win_count.shape[:-1] if win_count.ndim > 1 else [], device=device) - + td = TensorDict( + data, + batch_size=win_count.shape[:-1] if win_count.ndim > 1 else [], + device=device, + ) + return td + # Helper function to create a sample TensorDict node for PUCTScore -def create_puct_node(win_count, visits, total_visits, prior_prob, batch_size=None, device="cpu", custom_keys=None): +def create_puct_node( + win_count, + visits, + total_visits, + prior_prob, + batch_size=None, + device="cpu", + custom_keys=None, +): if custom_keys is None: custom_keys = { "win_count_key": "win_count", @@ -95,18 +126,29 @@ def create_puct_node(win_count, visits, total_visits, prior_prob, batch_size=Non prior_prob = torch.as_tensor(prior_prob, device=device, dtype=torch.float32) if batch_size: - if win_count.ndim == 0: win_count = win_count.unsqueeze(0).repeat(batch_size) - elif win_count.shape[0] != batch_size: raise ValueError("Batch size mismatch for win_count") - if visits.ndim == 0: visits = visits.unsqueeze(0).repeat(batch_size) - elif visits.shape[0] != batch_size: raise ValueError("Batch size mismatch for visits") - if prior_prob.ndim == 0: prior_prob = prior_prob.unsqueeze(0).repeat(batch_size) - elif prior_prob.shape[0] != batch_size: raise ValueError("Batch size mismatch for prior_prob") - - if total_visits.numel() == 1 and batch_size > 1: # scalar total_visits for batch + if win_count.ndim == 0: + win_count = win_count.unsqueeze(0).repeat(batch_size) + elif win_count.shape[0] != batch_size: + raise ValueError("Batch size mismatch for win_count") + if visits.ndim == 0: + visits = visits.unsqueeze(0).repeat(batch_size) + elif visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for visits") + if prior_prob.ndim == 0: + prior_prob = prior_prob.unsqueeze(0).repeat(batch_size) + elif prior_prob.shape[0] != batch_size: + raise ValueError("Batch size mismatch for prior_prob") + + if ( + total_visits.numel() == 1 and batch_size > 1 + ): # scalar total_visits for batch total_visits = total_visits.repeat(batch_size) - elif total_visits.ndim == 0 : total_visits = total_visits.unsqueeze(0).repeat(batch_size) # make it (batch_size,) - elif total_visits.shape[0] != batch_size : raise ValueError("Batch size mismatch for total_visits") - + elif total_visits.ndim == 0: + total_visits = total_visits.unsqueeze(0).repeat( + batch_size + ) # make it (batch_size,) + elif total_visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for total_visits") data = { custom_keys["win_count_key"]: win_count, @@ -133,11 +175,12 @@ def create_puct_node(win_count, visits, total_visits, prior_prob, batch_size=Non return td + class TestEXP3Score: @pytest.fixture def default_scorer(self): return EXP3Score() - + @pytest.fixture def custom_key_names(self): return { @@ -145,9 +188,9 @@ def custom_key_names(self): "score_key": "custom_scores", "num_actions_key": "custom_num_actions", "action_prob_key": "custom_actions_prob", - "reward_key": "custom_reward" + "reward_key": "custom_reward", } - + @pytest.mark.parametrize("gamma_val", [0.1, 0.5, 0.9]) def test_initialization(self, gamma_val): scorer = EXP3Score(gamma=gamma_val) @@ -158,16 +201,20 @@ def test_initialization(self, gamma_val): def test_forward_initial_weights(self, default_scorer): num_actions = 3 node = create_node(num_actions=num_actions) - + default_scorer.forward(node) assert default_scorer.weights_key in node.keys() expected_weights = torch.ones(num_actions) - torch.testing.assert_close(node.get(default_scorer.weights_key), expected_weights) + torch.testing.assert_close( + node.get(default_scorer.weights_key), expected_weights + ) expected_scores = torch.ones(num_actions) / num_actions torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) - torch.testing.assert_close(node.get(default_scorer.score_key).sum(), torch.tensor(1.0)) + torch.testing.assert_close( + node.get(default_scorer.score_key).sum(), torch.tensor(1.0) + ) def test_forward_custom_weights(self, default_scorer): num_actions = 3 @@ -175,13 +222,15 @@ def test_forward_custom_weights(self, default_scorer): node = create_node(num_actions=num_actions, weights=weights) default_scorer.forward(node) - + gamma = default_scorer.gamma sum_w = weights.sum() expected_scores = (1 - gamma) * (weights / sum_w) + (gamma / num_actions) - + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) - torch.testing.assert_close(node.get(default_scorer.score_key).sum(), torch.tensor(1.0)) + torch.testing.assert_close( + node.get(default_scorer.score_key).sum(), torch.tensor(1.0) + ) @pytest.mark.parametrize("batch_s", [2, 4]) def test_forward_batch(self, default_scorer, batch_s): @@ -190,22 +239,39 @@ def test_forward_batch(self, default_scorer, batch_s): default_scorer.forward(node_initial) expected_weights_initial = torch.ones(batch_s, num_actions) - torch.testing.assert_close(node_initial.get(default_scorer.weights_key), expected_weights_initial) - + torch.testing.assert_close( + node_initial.get(default_scorer.weights_key), expected_weights_initial + ) + expected_scores_initial = torch.ones(batch_s, num_actions) / num_actions - torch.testing.assert_close(node_initial.get(default_scorer.score_key), expected_scores_initial) - torch.testing.assert_close(node_initial.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s)) + torch.testing.assert_close( + node_initial.get(default_scorer.score_key), expected_scores_initial + ) + torch.testing.assert_close( + node_initial.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s) + ) weights_custom = torch.rand(batch_s, num_actions) + 0.1 - node_custom = create_node(num_actions=num_actions, weights=weights_custom, batch_size=batch_s) + node_custom = create_node( + num_actions=num_actions, weights=weights_custom, batch_size=batch_s + ) default_scorer.forward(node_custom) gamma = default_scorer.gamma sum_w_custom = weights_custom.sum(dim=-1, keepdim=True) - expected_scores_custom = (1 - gamma) * (weights_custom / sum_w_custom) + (gamma / num_actions) - torch.testing.assert_close(node_custom.get(default_scorer.score_key), expected_scores_custom, atol=1e-6, rtol=1e-6) - torch.testing.assert_close(node_custom.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s)) - + expected_scores_custom = (1 - gamma) * (weights_custom / sum_w_custom) + ( + gamma / num_actions + ) + torch.testing.assert_close( + node_custom.get(default_scorer.score_key), + expected_scores_custom, + atol=1e-6, + rtol=1e-6, + ) + torch.testing.assert_close( + node_custom.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s) + ) + def test_update_weights_single_node(self, default_scorer): num_actions = 3 action_idx = 0 @@ -217,21 +283,30 @@ def test_update_weights_single_node(self, default_scorer): prob_i = node.get(default_scorer.score_key)[action_idx] default_scorer.update_weights(node, action_idx, reward) - + updated_weights = node.get(default_scorer.weights_key) gamma = default_scorer.gamma k = num_actions - - expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) - - torch.testing.assert_close(updated_weights[action_idx], torch.tensor(expected_new_weight_val)) - torch.testing.assert_close(updated_weights[action_idx+1:], initial_weights[action_idx+1:]) + + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) + + torch.testing.assert_close( + updated_weights[action_idx], torch.tensor(expected_new_weight_val) + ) + torch.testing.assert_close( + updated_weights[action_idx + 1 :], initial_weights[action_idx + 1 :] + ) default_scorer.forward(node) sum_w_updated = updated_weights.sum() - expected_scores_after_update = (1-gamma)*(updated_weights/sum_w_updated) + (gamma/k) - torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores_after_update) - + expected_scores_after_update = (1 - gamma) * ( + updated_weights / sum_w_updated + ) + (gamma / k) + torch.testing.assert_close( + node.get(default_scorer.score_key), expected_scores_after_update + ) def test_update_weights_zero_reward(self, default_scorer): num_actions = 3 @@ -249,9 +324,13 @@ def test_update_weights_zero_reward(self, default_scorer): gamma = default_scorer.gamma k = num_actions - expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) torch.testing.assert_close(updated_weights[action_idx], expected_new_weight_val) - torch.testing.assert_close(updated_weights[action_idx], initial_weights[action_idx]) + torch.testing.assert_close( + updated_weights[action_idx], initial_weights[action_idx] + ) @pytest.mark.parametrize("batch_s", [2, 3]) def test_update_weights_batch(self, default_scorer, batch_s): @@ -261,7 +340,7 @@ def test_update_weights_batch(self, default_scorer, batch_s): initial_weights_batch = node.get(default_scorer.weights_key).clone() probs_batch = node.get(default_scorer.score_key).clone() - + rewards = torch.rand(batch_s) action_indices = torch.randint(0, num_actions, (batch_s,)) @@ -272,19 +351,26 @@ def test_update_weights_batch(self, default_scorer, batch_s): for i in range(batch_s): action_idx = action_indices[i].item() reward = rewards[i].item() - + single_node_td = node[i] - + current_weight_item = initial_weights_batch[i, action_idx] prob_i_item = probs_batch[i, action_idx] - + exp_val = math.exp((gamma / k) * (reward / prob_i_item)) - expected_updated_weights_batch[i, action_idx] = current_weight_item * exp_val + expected_updated_weights_batch[i, action_idx] = ( + current_weight_item * exp_val + ) - node_item_to_update = node[i:i+1] + node_item_to_update = node[i : i + 1] default_scorer.update_weights(node_item_to_update, action_idx, reward) - torch.testing.assert_close(node.get(default_scorer.weights_key), expected_updated_weights_batch, atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + node.get(default_scorer.weights_key), + expected_updated_weights_batch, + atol=1e-5, + rtol=1e-5, + ) def test_single_action(self, default_scorer): num_actions = 1 @@ -292,9 +378,13 @@ def test_single_action(self, default_scorer): default_scorer.forward(node) assert default_scorer.weights_key in node.keys() - torch.testing.assert_close(node.get(default_scorer.weights_key), torch.ones(num_actions)) - torch.testing.assert_close(node.get(default_scorer.score_key), torch.ones(num_actions)) # p_i = 1.0 - + torch.testing.assert_close( + node.get(default_scorer.weights_key), torch.ones(num_actions) + ) + torch.testing.assert_close( + node.get(default_scorer.score_key), torch.ones(num_actions) + ) # p_i = 1.0 + action_idx = 0 reward = 0.5 initial_weights = node.get(default_scorer.weights_key).clone() @@ -304,19 +394,23 @@ def test_single_action(self, default_scorer): updated_weights = node.get(default_scorer.weights_key) gamma = default_scorer.gamma k = num_actions - - expected_new_weight_val = initial_weights[action_idx] * math.exp((gamma / k) * (reward / prob_i)) - torch.testing.assert_close(updated_weights[action_idx], torch.tensor(expected_new_weight_val)) - @pytest.mark.parametrize("gamma_val, expected_behavior", [ - (0.0, "exploitation"), (1.0, "exploration") - ]) + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) + torch.testing.assert_close( + updated_weights[action_idx], torch.tensor(expected_new_weight_val) + ) + + @pytest.mark.parametrize( + "gamma_val, expected_behavior", [(0.0, "exploitation"), (1.0, "exploration")] + ) def test_gamma_extremes(self, gamma_val, expected_behavior): scorer = EXP3Score(gamma=gamma_val) num_actions = 3 weights = torch.tensor([1.0, 2.0, 7.0]) node = create_node(num_actions=num_actions, weights=weights) - + scorer.forward(node) scores = node.get(scorer.score_key) @@ -340,47 +434,73 @@ def test_custom_keys(self, custom_key_names): node1 = create_node(num_actions=num_actions, custom_keys=custom_key_names) scorer.forward(node1) - + assert custom_key_names["weights_key"] in node1.keys() expected_weights1 = torch.ones(num_actions) - torch.testing.assert_close(node1.get(custom_key_names["weights_key"]), expected_weights1) + torch.testing.assert_close( + node1.get(custom_key_names["weights_key"]), expected_weights1 + ) expected_scores1 = torch.ones(num_actions) / num_actions - torch.testing.assert_close(node1.get(custom_key_names["score_key"]), expected_scores1) - if scorer.action_prob_key != scorer.score_key: # Check if action_prob_key was also populated - torch.testing.assert_close(node1.get(custom_key_names["action_prob_key"]), expected_scores1) + torch.testing.assert_close( + node1.get(custom_key_names["score_key"]), expected_scores1 + ) + if ( + scorer.action_prob_key != scorer.score_key + ): # Check if action_prob_key was also populated + torch.testing.assert_close( + node1.get(custom_key_names["action_prob_key"]), expected_scores1 + ) weights2_val = torch.tensor([1.0, 3.0]) - node2 = create_node(num_actions=num_actions, weights=weights2_val, custom_keys=custom_key_names) + node2 = create_node( + num_actions=num_actions, weights=weights2_val, custom_keys=custom_key_names + ) scorer.forward(node2) - + sum_w2 = weights2_val.sum() expected_scores2 = (1 - gamma) * (weights2_val / sum_w2) + (gamma / num_actions) - torch.testing.assert_close(node2.get(custom_key_names["score_key"]), expected_scores2) + torch.testing.assert_close( + node2.get(custom_key_names["score_key"]), expected_scores2 + ) action_idx = 0 reward = 1.0 initial_weights2 = node2.get(custom_key_names["weights_key"]).clone() prob_i2 = node2.get(custom_key_names["score_key"])[action_idx] - + scorer.update_weights(node2, action_idx, reward) updated_weights2 = node2.get(custom_key_names["weights_key"]) k = num_actions - - expected_new_weight_val2 = initial_weights2[action_idx] * math.exp((gamma / k) * (reward / prob_i2)) - torch.testing.assert_close(updated_weights2[action_idx], torch.tensor(expected_new_weight_val2)) + + expected_new_weight_val2 = initial_weights2[action_idx] * math.exp( + (gamma / k) * (reward / prob_i2) + ) + torch.testing.assert_close( + updated_weights2[action_idx], torch.tensor(expected_new_weight_val2) + ) def test_forward_raises_error_on_mismatched_num_actions(self, default_scorer): num_actions_prop = 3 - weights = torch.tensor([1.0, 2.0, 3.0, 4.0]) # K=4 from weights - node = create_node(num_actions=num_actions_prop, weights=weights) # num_actions=3 - - with pytest.raises(ValueError, match="Shape of weights .* implies 4 actions, but num_actions is 3"): + weights = torch.tensor([1.0, 2.0, 3.0, 4.0]) # K=4 from weights + node = create_node( + num_actions=num_actions_prop, weights=weights + ) # num_actions=3 + + with pytest.raises( + ValueError, + match="Shape of weights .* implies 4 actions, but num_actions is 3", + ): default_scorer.forward(node) weights_ok = torch.tensor([1.0, 2.0, 3.0]) - node_ok = create_node(num_actions=torch.tensor(4), weights=weights_ok) # num_actions=4 from tensor - - with pytest.raises(ValueError, match="Shape of weights .* implies 3 actions, but num_actions is 4"): + node_ok = create_node( + num_actions=torch.tensor(4), weights=weights_ok + ) # num_actions=4 from tensor + + with pytest.raises( + ValueError, + match="Shape of weights .* implies 3 actions, but num_actions is 4", + ): default_scorer.forward(node_ok) def test_update_weights_handles_prob_zero(self, default_scorer): @@ -391,12 +511,16 @@ def test_update_weights_handles_prob_zero(self, default_scorer): weights = torch.tensor([0.0, 1.0]) node = create_node(num_actions=num_actions, weights=weights) - scorer_exploit.forward(node) # p_0 will be 0 + scorer_exploit.forward(node) # p_0 will be 0 assert node.get(scorer_exploit.score_key)[0] == 0.0 - with pytest.warns(UserWarning, match="Probability p_i\\(t\\) for action 0 is 0.0"): + with pytest.warns( + UserWarning, match="Probability p_i\\(t\\) for action 0 is 0.0" + ): scorer_exploit.update_weights(node, action_idx, reward) - torch.testing.assert_close(node.get(scorer_exploit.weights_key)[action_idx], torch.tensor(0.0)) + torch.testing.assert_close( + node.get(scorer_exploit.weights_key)[action_idx], torch.tensor(0.0) + ) def test_init_raises_error_gamma_out_of_range(self): with pytest.raises(ValueError, match="gamma must be between 0 and 1"): @@ -408,13 +532,17 @@ def test_update_weights_reward_warning(self, default_scorer): num_actions = 2 node = create_node(num_actions=num_actions) default_scorer.forward(node) - with pytest.warns(UserWarning, match="Reward .* is outside the expected \\[0,1\\] range"): + with pytest.warns( + UserWarning, match="Reward .* is outside the expected \\[0,1\\] range" + ): default_scorer.update_weights(node, 0, 1.5) - with pytest.warns(UserWarning, match="Reward .* is outside the expected \\[0,1\\] range"): + with pytest.warns( + UserWarning, match="Reward .* is outside the expected \\[0,1\\] range" + ): default_scorer.update_weights(node, 0, -0.5) initial_weight = node.get(default_scorer.weights_key)[0].clone() default_scorer.update_weights(node, 0, 1.5) - assert node.get(default_scorer.weights_key)[0] != initial_weight # it changed + assert node.get(default_scorer.weights_key)[0] != initial_weight # it changed class TestUCBScore: @@ -441,59 +569,80 @@ def test_forward_basic(self, default_ucb_scorer): visits = torch.tensor([15.0, 10.0, 25.0]) total_visits_parent = torch.tensor(50.0) - node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) default_ucb_scorer.forward(node) c = default_ucb_scorer.c exploitation_term = win_count / visits exploration_term = c * total_visits_parent.sqrt() / (1 + visits) expected_scores = exploitation_term + exploration_term - - torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) def test_forward_zero_visits(self, default_ucb_scorer): win_count = torch.tensor([0.0, 0.0]) visits = torch.tensor([10.0, 0.0]) total_visits_parent = torch.tensor(10.0) - node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) default_ucb_scorer.forward(node) c = default_ucb_scorer.c scores = node.get(default_ucb_scorer.score_key) - expected_score_0 = (win_count[0] / visits[0]) + c * total_visits_parent.sqrt() / (1 + visits[0]) + expected_score_0 = ( + win_count[0] / visits[0] + ) + c * total_visits_parent.sqrt() / (1 + visits[0]) torch.testing.assert_close(scores[0], expected_score_0) - assert torch.isnan(scores[1]), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." + assert torch.isnan( + scores[1] + ), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." @pytest.mark.parametrize("batch_s", [2, 3]) def test_forward_batch(self, default_ucb_scorer, batch_s): win_count = torch.rand(batch_s, 2) * 10 - visits = torch.rand(batch_s, 2) * 5 + 1 + visits = torch.rand(batch_s, 2) * 5 + 1 total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) - node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent, batch_size=batch_s) + node = create_ucb_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + batch_size=batch_s, + ) default_ucb_scorer.forward(node) c = default_ucb_scorer.c exploitation_term = win_count / visits exploration_term = c * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) expected_scores = exploitation_term + exploration_term - - torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) def test_forward_exploration_term(self, default_ucb_scorer): win_count = torch.tensor([0.0, 0.0, 0.0]) visits = torch.tensor([10.0, 5.0, 1.0]) total_visits_parent = torch.tensor(100.0) - node = create_ucb_node(win_count=win_count, visits=visits, total_visits=total_visits_parent) + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) default_ucb_scorer.forward(node) c = default_ucb_scorer.c expected_scores = c * total_visits_parent.sqrt() / (1 + visits) - - torch.testing.assert_close(node.get(default_ucb_scorer.score_key), expected_scores) + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) def test_custom_keys(self, ucb_custom_key_names): c_val = 1.5 @@ -504,28 +653,30 @@ def test_custom_keys(self, ucb_custom_key_names): total_visits_key=ucb_custom_key_names["total_visits_key"], score_key=ucb_custom_key_names["score_key"], ) - + win_count = torch.tensor([1.0, 2.0]) visits = torch.tensor([3.0, 4.0]) total_visits_parent = torch.tensor(10.0) node = create_ucb_node( - win_count=win_count, - visits=visits, + win_count=win_count, + visits=visits, total_visits=total_visits_parent, - custom_keys=ucb_custom_key_names + custom_keys=ucb_custom_key_names, ) scorer.forward(node) exploitation = win_count / visits exploration = c_val * total_visits_parent.sqrt() / (1 + visits) expected_scores = exploitation + exploration - + assert ucb_custom_key_names["score_key"] in node.keys() - torch.testing.assert_close(node.get(ucb_custom_key_names["score_key"]), expected_scores) - + torch.testing.assert_close( + node.get(ucb_custom_key_names["score_key"]), expected_scores + ) + assert "score" not in node.keys() - assert "win_count" not in node.keys() + assert "win_count" not in node.keys() assert "visits" not in node.keys() assert "total_visits" not in node.keys() @@ -552,15 +703,15 @@ def test_initialization(self, c_val): def test_forward_basic(self, default_puct_scorer): win_count = torch.tensor([10.0, 5.0, 20.0]) - visits = torch.tensor([15.0, 10.0, 25.0]) - prior_prob = torch.tensor([0.4, 0.3, 0.3]) + visits = torch.tensor([15.0, 10.0, 25.0]) + prior_prob = torch.tensor([0.4, 0.3, 0.3]) total_visits_parent = torch.tensor(50.0) node = create_puct_node( - win_count=win_count, - visits=visits, - total_visits=total_visits_parent, - prior_prob=prior_prob + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, ) default_puct_scorer.forward(node) @@ -568,32 +719,36 @@ def test_forward_basic(self, default_puct_scorer): exploitation_term = win_count / visits exploration_term = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) expected_scores = exploitation_term + exploration_term - - torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores) + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), expected_scores + ) def test_forward_zero_visits(self, default_puct_scorer): - win_count = torch.tensor([0.0, 0.0]) - visits = torch.tensor([10.0, 0.0]) + win_count = torch.tensor([0.0, 0.0]) + visits = torch.tensor([10.0, 0.0]) prior_prob = torch.tensor([0.6, 0.4]) total_visits_parent = torch.tensor(10.0) node = create_puct_node( - win_count=win_count, - visits=visits, + win_count=win_count, + visits=visits, total_visits=total_visits_parent, - prior_prob=prior_prob + prior_prob=prior_prob, ) default_puct_scorer.forward(node) c = default_puct_scorer.c scores = node.get(default_puct_scorer.score_key) - expected_score_0 = (win_count[0] / visits[0]) + \ - c * prior_prob[0] * total_visits_parent.sqrt() / (1 + visits[0]) + expected_score_0 = (win_count[0] / visits[0]) + c * prior_prob[ + 0 + ] * total_visits_parent.sqrt() / (1 + visits[0]) torch.testing.assert_close(scores[0], expected_score_0) - assert torch.isnan(scores[1]), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." - + assert torch.isnan( + scores[1] + ), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." @pytest.mark.parametrize("batch_s", [2, 3]) def test_forward_batch(self, default_puct_scorer, batch_s): @@ -605,20 +760,27 @@ def test_forward_batch(self, default_puct_scorer, batch_s): total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) node = create_puct_node( - win_count=win_count, - visits=visits, - total_visits=total_visits_parent, + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, prior_prob=prior_prob, - batch_size=batch_s + batch_size=batch_s, ) default_puct_scorer.forward(node) c = default_puct_scorer.c exploitation_term = win_count / visits - exploration_term = c * prior_prob * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + exploration_term = ( + c * prior_prob * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + ) expected_scores = exploitation_term + exploration_term - - torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores, atol=1e-6, rtol=1e-6) + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), + expected_scores, + atol=1e-6, + rtol=1e-6, + ) def test_forward_exploration_term(self, default_puct_scorer): num_actions = 3 @@ -628,18 +790,20 @@ def test_forward_exploration_term(self, default_puct_scorer): total_visits_parent = torch.tensor(100.0) node = create_puct_node( - win_count=win_count, - visits=visits, + win_count=win_count, + visits=visits, total_visits=total_visits_parent, - prior_prob=prior_prob + prior_prob=prior_prob, ) default_puct_scorer.forward(node) c = default_puct_scorer.c # exploitation_term is effectively 0 expected_scores = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) - - torch.testing.assert_close(node.get(default_puct_scorer.score_key), expected_scores) + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), expected_scores + ) def test_custom_keys(self, puct_custom_key_names): c_val = 2.5 @@ -651,31 +815,33 @@ def test_custom_keys(self, puct_custom_key_names): prior_prob_key=puct_custom_key_names["prior_prob_key"], score_key=puct_custom_key_names["score_key"], ) - + win_count = torch.tensor([1.0, 2.0]) visits = torch.tensor([3.0, 4.0]) prior_prob = torch.tensor([0.5, 0.5]) total_visits_parent = torch.tensor(10.0) node = create_puct_node( - win_count=win_count, - visits=visits, + win_count=win_count, + visits=visits, total_visits=total_visits_parent, prior_prob=prior_prob, - custom_keys=puct_custom_key_names + custom_keys=puct_custom_key_names, ) scorer.forward(node) exploitation = win_count / visits exploration = c_val * prior_prob * total_visits_parent.sqrt() / (1 + visits) expected_scores = exploitation + exploration - + assert puct_custom_key_names["score_key"] in node.keys() - torch.testing.assert_close(node.get(puct_custom_key_names["score_key"]), expected_scores) - + torch.testing.assert_close( + node.get(puct_custom_key_names["score_key"]), expected_scores + ) + # Check that default keys are not present assert "score" not in node.keys() assert "win_count" not in node.keys() assert "visits" not in node.keys() assert "total_visits" not in node.keys() - assert "prior_prob" not in node.keys() \ No newline at end of file + assert "prior_prob" not in node.keys() diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index f9046f9df2c..c6f62377d23 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -21,6 +21,7 @@ class MCTSScore(TensorDictModuleBase): def forward(self, node): pass + class PUCTScore(MCTSScore): c: float @@ -60,6 +61,7 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: ) return node + class UCBScore(MCTSScore): c: float @@ -199,14 +201,11 @@ def update_weights( else: weights[..., action_idx] = new_weight node.set(self.weights_key, weights) -<<<<<<< Updated upstream -======= ->>>>>>> Stashed changes class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 UCB1_TUNED = "UCB1-Tuned" - EXP3 = functool.partial(EXP3Score, gamma=0.1) + EXP3 = functools.partial(EXP3Score, gamma=0.1) PUCT_VARIANT = "PUCT-Variant" From 578ba4b4707677cca17e033ffac5c83d9a96269e Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Mon, 23 Jun 2025 23:36:33 +0530 Subject: [PATCH 13/14] Added UCB-1 Tuned --- torchrl/modules/mcts/scores.py | 81 +++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index c6f62377d23..5e2ad515bd1 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -203,9 +203,88 @@ def update_weights( node.set(self.weights_key, weights) +class UCB1TunedScore(MCTSScore): + def __init__( + self, + *, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + sum_squared_rewards_key: NestedKey = "sum_squared_rewards", + score_key: NestedKey = "score", + exploration_constant: float = 2.0, + ): + super().__init__() + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.sum_squared_rewards_key = sum_squared_rewards_key + self.score_key = score_key + self.exploration_constant = exploration_constant + + self.in_keys = [ + self.win_count_key, + self.visits_key, + self.total_visits_key, + self.sum_squared_rewards_key, + ] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + q_sum_i = node.get(self.win_count_key) + n_i = node.get(self.visits_key) + n_parent = node.get(self.total_visits_key) + sum_sq_rewards_i = node.get(self.sum_squared_rewards_key) + + if n_parent.ndim > 0 and n_parent.ndim < q_sum_i.ndim: + n_parent_expanded = n_parent.unsqueeze(-1) + else: + n_parent_expanded = n_parent + + safe_n_parent_for_log = torch.clamp(n_parent_expanded, min=1.0) + log_n_parent = torch.log(safe_n_parent_for_log) + + scores = torch.zeros_like(q_sum_i, device=q_sum_i.device) + + visited_mask = n_i > 0 + + if torch.any(visited_mask): + q_sum_i_v = q_sum_i[visited_mask] + n_i_v = n_i[visited_mask] + sum_sq_rewards_i_v = sum_sq_rewards_i[visited_mask] + + log_n_parent_v = log_n_parent.expand_as(n_i)[visited_mask] + + avg_reward_i_v = q_sum_i_v / n_i_v + + empirical_variance_v = (sum_sq_rewards_i_v / n_i_v) - avg_reward_i_v.pow(2) + bias_correction_v = ( + self.exploration_constant * log_n_parent_v / n_i_v + ).sqrt() + + v_i_v = empirical_variance_v + bias_correction_v + v_i_v.clamp(min=0) + + min_variance_term_v = torch.min(torch.full_like(v_i_v, 0.25), v_i_v) + exploration_component_v = ( + log_n_parent_v / n_i_v * min_variance_term_v + ).sqrt() + + scores[visited_mask] = avg_reward_i_v + exploration_component_v + + unvisited_mask = ~visited_mask + if torch.any(unvisited_mask): + scores[unvisited_mask] = torch.finfo(scores.dtype).max / 10.0 + + node.set(self.score_key, scores) + return node + + class MCTSScores(Enum): PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 - UCB1_TUNED = "UCB1-Tuned" + UCB1_TUNED = functools.partial( + UCB1TunedScore, exploration_constant=2.0 + ) # Auer et al. (2002) C=2 for rewards in [0,1] EXP3 = functools.partial(EXP3Score, gamma=0.1) PUCT_VARIANT = "PUCT-Variant" From e29f6d8f13978f4933433916b97e4f5823fbdd2b Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Mon, 23 Jun 2025 23:47:14 +0530 Subject: [PATCH 14/14] Added Docstrings --- torchrl/modules/mcts/scores.py | 223 +++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py index 5e2ad515bd1..2b9de222f62 100644 --- a/torchrl/modules/mcts/scores.py +++ b/torchrl/modules/mcts/scores.py @@ -23,6 +23,55 @@ def forward(self, node): class PUCTScore(MCTSScore): + """Computes the PUCT (Polynomial Upper Confidence Trees) score for MCTS. + + PUCT is a widely used score in MCTS algorithms, notably in AlphaGo and AlphaZero, + to balance exploration and exploitation. It incorporates prior probabilities from a + policy network, encouraging exploration of actions deemed promising by the policy, + while also considering visit counts and accumulated rewards. + + The formula used is: + `score = (win_count / visits) + c * prior_prob * sqrt(total_visits) / (1 + visits)` + + Where: + - `win_count`: Sum of rewards (or win counts) for the action. + - `visits`: Visit count for the action. + - `total_visits`: Visit count of the parent node (N). + - `prior_prob`: Prior probability of selecting the action (e.g., from a policy network). + - `c`: The exploration constant, controlling the trade-off between exploitation + (first term) and exploration (second term). + + Args: + c (float): The exploration constant. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards (or win counts) for each action. + Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action. Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). Defaults to "total_visits". + prior_prob_key (NestedKey, optional): Key for the tensor containing the + prior probabilities for each action. Defaults to "prior_prob". + score_key (NestedKey, optional): Key where the calculated PUCT scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions) + or matching `visits_key`. + - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action + has zero visits, its exploitation term (win_count / visits) will result in NaN + if win_count is also zero, or +/-inf if win_count is non-zero. The exploration + term will still be valid due to `(1 + visits)`. + - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs, + representing the parent node's visit count. + - `prior_prob_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + prior probabilities. + + Output Keys: + - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing + the calculated PUCT scores. + """ + c: float def __init__( @@ -63,6 +112,53 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: class UCBScore(MCTSScore): + """Computes the UCB (Upper Confidence Bound) score, specifically UCB1, for MCTS. + + UCB1 is a classic algorithm for the multi-armed bandit problem that balances + exploration and exploitation. In MCTS, it's used to select which action to + explore from a given node. The score encourages trying actions with high + empirical rewards and actions that have been visited less frequently. + + The formula used is: + `score = (win_count / visits) + c * sqrt(log(total_visits) / visits)` + However, the implementation here uses `1 + visits` in the denominator of the + exploration term to handle cases where `visits` might be zero for an action, + preventing division by zero and ensuring unvisited actions get a high exploration score. + The formula implemented is: + `score = (win_count / visits) + c * sqrt(total_visits) / (1 + visits)` + Note: The standard UCB1 formula's exploration term is `c * sqrt(log(N) / N_i)`, + where N is parent visits and N_i is action visits. This implementation uses `sqrt(N)` + instead of `sqrt(log N)`. For the canonical UCB1 `sqrt(log N / N_i)` term, + total_visits would need to be `log(parent_visits)` and then use `c * sqrt(total_visits / visits_i)`. + The current form is simpler and common in some MCTS variants. + + Args: + c (float): The exploration constant. A common value is `sqrt(2)`. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards (or win counts) for each action. + Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action. Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). This is used in the + exploration term. Defaults to "total_visits". + score_key (NestedKey, optional): Key where the calculated UCB scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions). + - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action + has zero visits, its exploitation term (win_count / visits) will result in NaN + if win_count is also zero, or +/-inf if win_count is non-zero. The exploration + term remains well-defined due to `(1 + visits)`. + - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs, + representing the parent node's visit count (N). + + Output Keys: + - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing + the calculated UCB scores. + """ + c: float def __init__( @@ -95,6 +191,68 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: class EXP3Score(MCTSScore): + """Computes action selection probabilities for the EXP3 algorithm in MCTS. + + EXP3 (Exponential-weight algorithm for Exploration and Exploitation) is a bandit + algorithm that performs well in adversarial or non-stationary environments. + It maintains weights for each action and adjusts them based on received rewards. + Actions are chosen probabilistically based on these weights, with a mechanism + to ensure a minimum level of exploration. + + The `forward` method calculates the probability distribution over actions: + `p_i(t) = (1 - gamma) * (w_i(t) / sum_weights) + (gamma / K)` + where `w_i(t)` are the current weights, `sum_weights` is the sum of all weights, + `gamma` is the exploration factor, and `K` is the number of actions. + These probabilities are typically stored in the `score_key` and used for action selection. + + The `update_weights` method updates the weights after an action is chosen and a + reward is observed. This method is typically called after a simulation/rollout + and backpropagation phase in MCTS. The update rule is: + `w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t)))` + where `reward` is the reward for the chosen action (typically normalized to [0,1]) + and `p_i(t)` is the probability with which the action was chosen. + + Reference: "Bandit based Monte-Carlo Planning" (Kocsis & Szepesvari, 2006), though + the specific EXP3 formulation can vary (e.g., "Regret Analysis of Stochastic and + Nonstochastic Multi-armed Bandit Problems", Bubeck & Cesa-Bianchi, 2012 for EXP3 details). + + Args: + gamma (float, optional): Exploration factor, balancing uniform exploration + and exploitation of current weights. Must be in [0, 1]. Defaults to 0.1. + weights_key (NestedKey, optional): Key in the input `TensorDictBase` for + the tensor containing current action weights. If not found during the first + `forward` call, weights are initialized to ones. Defaults to "weights". + action_prob_key (NestedKey, optional): Key to store the calculated action + probabilities `p_i(t)`. If different from `score_key`, it allows storing + these probabilities separately, which might be useful if `score_key` is + used for a different purpose by the selection strategy. Defaults to "action_prob". + The `update_weights` method will look for `p_i(t)` in `score_key`. + score_key (NestedKey, optional): Key where the calculated action probabilities + (scores for MCTS selection) will be stored. Defaults to "score". + num_actions_key (NestedKey, optional): Key for the number of available + actions (K). Used for weight initialization and in formulas. Defaults to "num_actions". + + Input Keys for `forward`: + - `weights_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + current weights. Initialized to ones if not present on first call. + - `num_actions_key` (int or torch.Tensor): Scalar representing K, the number of actions. + + Output Keys for `forward`: + - `score_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + the calculated action probabilities `p_i(t)`. + - `action_prob_key` (torch.Tensor, optional): Same as `score_key` if this key + is set and different from `score_key`. + + `update_weights` Method: + This method is designed to be called externally after an action has been + selected (using probabilities from `forward`) and a reward obtained. + Args for `update_weights(node: TensorDictBase, action_idx: int, reward: float)`: + - `node`: The `TensorDictBase` for the current MCTS node, containing + at least `weights_key` and `score_key` (with `p_i(t)` values). + - `action_idx`: The index of the action that was chosen. + - `reward`: The reward received for the chosen action (assumed to be in [0,1]). + """ + def __init__( self, *, @@ -165,6 +323,11 @@ def forward(self, node: TensorDictBase) -> TensorDictBase: def update_weights( self, node: TensorDictBase, action_idx: int, reward: float ) -> None: + """Updates the weight of the chosen action based on the reward. + + w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t))) + Assumes reward is in [0, 1]. + """ if not (0 <= reward <= 1): ValueError( f"Reward {reward} is outside the expected [0, 1] range for EXP3." @@ -204,6 +367,66 @@ def update_weights( class UCB1TunedScore(MCTSScore): + """Computes the UCB1-Tuned score for MCTS, using variance estimation. + + UCB1-Tuned is an enhancement of the UCB1 algorithm that incorporates an estimate + of the variance of rewards for each action. This allows for a more refined + balance between exploration and exploitation, potentially leading to better + performance, especially when reward variances differ significantly across actions. + + The score for an action `i` is calculated as: + `score_i = avg_reward_i + sqrt(log(N) / N_i * min(0.25, V_i))` + + The variance estimate `V_i` for action `i` is calculated as: + `V_i = (sum_squared_rewards_i / N_i) - avg_reward_i^2 + sqrt(exploration_constant * log(N) / N_i)` + + Where: + - `avg_reward_i`: Average reward obtained from action `i`. + - `N_i`: Number of times action `i` has been visited. + - `N`: Total number of times the parent node has been visited. + - `sum_squared_rewards_i`: Sum of the squares of rewards received from action `i`. + - `exploration_constant`: A constant used in the bias correction term of `V_i`. + Auer et al. (2002) suggest a value of 2.0 for rewards in the range [0,1]. + - The term `min(0.25, V_i)` implies that rewards are scaled to `[0,1]`, as 0.25 is + the maximum variance for a distribution in this range (e.g., Bernoulli(0.5)). + + Reference: "Finite-time Analysis of the Multiarmed Bandit Problem" + (Auer, Cesa-Bianchi, Fischer, 2002). + + Args: + exploration_constant (float, optional): The constant `C` used in the bias + correction term for the variance estimate `V_i`. Defaults to `2.0`, + as suggested for rewards in `[0,1]`. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards for each action (Q_i * N_i). Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action (N_i). Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). Defaults to "total_visits". + sum_squared_rewards_key (NestedKey, optional): Key for the tensor containing + the sum of squared rewards received for each action. This is crucial for + calculating the empirical variance. Defaults to "sum_squared_rewards". + score_key (NestedKey, optional): Key where the calculated UCB1-Tuned scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Sum of rewards for each action. + - `visits_key` (torch.Tensor): Visit counts for each action (N_i). + - `total_visits_key` (torch.Tensor): Parent node's visit count (N). + - `sum_squared_rewards_key` (torch.Tensor): Sum of squared rewards for each action. + + Output Keys: + - `score_key` (torch.Tensor): Calculated UCB1-Tuned scores for each action. + + Important Notes: + - **Unvisited Nodes**: Actions with zero visits (`visits_key` is 0) are assigned a + very large positive score to ensure they are selected for exploration. + - **Reward Range**: The `min(0.25, V_i)` term is theoretically most sound when + rewards are normalized to the range `[0, 1]`. + - **Logarithm of N**: `log(N)` (log of parent visits) is calculated using `torch.log(torch.clamp(N, min=1.0))` + to prevent issues with `N=0` or `N` between 0 and 1. + """ + def __init__( self, *,