From a285470e2f93f3562cf73dcd921cf9a23efa762f Mon Sep 17 00:00:00 2001 From: Eric Hasegawa Date: Tue, 4 Mar 2025 17:31:38 -0800 Subject: [PATCH 1/4] chore: format code with ruff --- .github/workflows/ci.yml | 34 + TEST_get_emission.py | 10 +- docs/stream_tutorial/client.py | 12 +- docs/stream_tutorial/config.py | 4 +- docs/stream_tutorial/miner.py | 20 +- docs/stream_tutorial/protocol.py | 4 +- neurons/miner.py | 105 +-- neurons/test_miner.py | 18 +- neurons/validator.py | 687 +++++++++++------- omega/audio_scoring.py | 116 +-- omega/augment.py | 34 +- omega/base/miner.py | 1 - omega/base/neuron.py | 27 +- omega/base/validator.py | 118 +-- omega/constants.py | 16 +- omega/diarization_metric.py | 29 +- omega/diarization_pipeline.py | 124 ++-- omega/imagebind_wrapper.py | 111 +-- omega/miner_utils.py | 126 ++-- omega/mock.py | 13 +- omega/protocol.py | 41 +- omega/test_audio.py | 7 +- omega/text_similarity.py | 13 +- omega/unstuff.py | 13 +- omega/utils/config.py | 10 +- omega/utils/logging.py | 2 +- omega/utils/uids.py | 11 +- omega/video_utils.py | 107 ++- purchase_focus_video.py | 348 +++++---- requirements.txt | 3 +- test_audio_dataset.py | 63 +- validator-api/_generate_api_key.py | 4 +- validator-api/app.py | 386 ++++++---- validator-api/check_vali_api.py | 19 +- validator-api/validator_api/check_blocking.py | 4 +- .../validator_api/communex/_common.py | 8 +- .../validator_api/communex/client.py | 74 +- .../validator_api/communex/errors.py | 2 +- validator-api/validator_api/communex/key.py | 9 +- validator-api/validator_api/communex/types.py | 6 +- validator-api/validator_api/config.py | 50 +- .../validator_api/cron/confirm_purchase.py | 132 +++- .../validator_api/database/__init__.py | 10 +- .../validator_api/database/crud/focusvideo.py | 206 ++++-- .../validator_api/database/encrypted_json.py | 34 +- .../database/models/boosted_task.py | 3 +- .../database/models/focus_video_record.py | 33 +- .../database/models/miner_bans.py | 41 +- .../validator_api/database/models/scoring.py | 58 +- .../validator_api/database/models/task.py | 3 +- .../validator_api/database/models/user.py | 2 +- .../validator_api/database/schemas.py | 55 +- validator-api/validator_api/dataset_upload.py | 171 +++-- .../validator_api/imagebind_loader.py | 5 +- validator-api/validator_api/limiter.py | 28 +- validator-api/validator_api/score.py | 104 ++- .../validator_api/scoring/deepseek_chat.py | 21 +- .../scoring/legitimacy_checks.py | 65 +- .../validator_api/scoring/query_llm.py | 79 +- .../validator_api/scoring/scoring_service.py | 310 +++++--- .../scoring/video_description.py | 47 +- validator-api/validator_api/utils/__init__.py | 8 +- .../validator_api/utils/marketplace.py | 37 +- validator-api/validator_api/utils/wallet.py | 59 +- 64 files changed, 2695 insertions(+), 1605 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..7e4def3d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: ci + +on: + workflow_dispatch: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + ci: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11.4' + + - name: Set up UV + uses: astral-sh/setup-uv@v5 + + - name: Install dependencies + run: | + python -m venv env + source env/bin/activate + uv pip install --pre -r requirements.txt + uv pip install --pre -r requirements_api.txt + - name: Run Black + run: | + source env/bin/activate + ruff format --check \ No newline at end of file diff --git a/TEST_get_emission.py b/TEST_get_emission.py index 21f71ac1..3c969933 100644 --- a/TEST_get_emission.py +++ b/TEST_get_emission.py @@ -4,15 +4,19 @@ import time import random + def main(): subtensor = bt.subtensor(network="test") while True: subnet = subtensor.subnet(netuid=96) - print(f"Tempo: {subnet.tempo} Block: {subtensor.block} alpha_out_emission: {subnet.alpha_out_emission.tao} alpha_out: {subnet.alpha_out.tao} ") - + print( + f"Tempo: {subnet.tempo} Block: {subtensor.block} alpha_out_emission: {subnet.alpha_out_emission.tao} alpha_out: {subnet.alpha_out.tao} " + ) + sleep_time = 60 + random.uniform(-30, 30) time.sleep(sleep_time) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/stream_tutorial/client.py b/docs/stream_tutorial/client.py index 67e6f05c..d6cfa2d6 100644 --- a/docs/stream_tutorial/client.py +++ b/docs/stream_tutorial/client.py @@ -29,9 +29,7 @@ async def query_synapse(my_uid, wallet_name, hotkey, network, netuid): wallet = bt.wallet(name=wallet_name, hotkey=hotkey) # instantiate the metagraph with provided network and netuid - metagraph = bt.metagraph( - netuid=netuid, network=network, sync=True, lite=False - ) + metagraph = bt.metagraph(netuid=netuid, network=network, sync=True, lite=False) # Grab the axon you're serving axon = metagraph.axons[my_uid] @@ -40,9 +38,7 @@ async def query_synapse(my_uid, wallet_name, hotkey, network, netuid): dendrite = bt.dendrite(wallet=wallet) async def main(): - responses = await dendrite( - [axon], syn, deserialize=False, streaming=True - ) + responses = await dendrite([axon], syn, deserialize=False, streaming=True) for resp in responses: i = 0 @@ -73,9 +69,7 @@ async def main(): required=True, help="Your unique miner ID on the chain", ) - parser.add_argument( - "--netuid", type=int, required=True, help="Network Unique ID" - ) + parser.add_argument("--netuid", type=int, required=True, help="Network Unique ID") parser.add_argument( "--wallet_name", type=str, default="default", help="Name of the wallet" ) diff --git a/docs/stream_tutorial/config.py b/docs/stream_tutorial/config.py index 7cbe82ca..7507076a 100644 --- a/docs/stream_tutorial/config.py +++ b/docs/stream_tutorial/config.py @@ -37,9 +37,7 @@ def get_config() -> "bt.Config": help="Chain endpoint to connect to.", ) # Adds override arguments for network and netuid. - parser.add_argument( - "--netuid", type=int, default=1, help="The chain subnet uid." - ) + parser.add_argument("--netuid", type=int, default=1, help="The chain subnet uid.") parser.add_argument( "--miner.root", diff --git a/docs/stream_tutorial/miner.py b/docs/stream_tutorial/miner.py index 5625cf50..36d6297f 100644 --- a/docs/stream_tutorial/miner.py +++ b/docs/stream_tutorial/miner.py @@ -60,9 +60,7 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None): bt.logging.info(f"Running miner on uid: {self.my_subnet_uid}") # The axon handles request processing, allowing validators to send this process requests. - self.axon = axon or bt.axon( - wallet=self.wallet, port=self.config.axon.port - ) + self.axon = axon or bt.axon(wallet=self.wallet, port=self.config.axon.port) # Attach determiners which functions are called when servicing a request. bt.logging.info(f"Attaching forward function to axon.") print(f"Attaching forward function to axon. {self._prompt}") @@ -79,13 +77,11 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None): self.request_timestamps: Dict = {} @abstractmethod - def config(self) -> "bt.Config": - ... + def config(self) -> "bt.Config": ... @classmethod @abstractmethod - def add_args(cls, parser: argparse.ArgumentParser): - ... + def add_args(cls, parser: argparse.ArgumentParser): ... def _prompt(self, synapse: StreamPrompting) -> StreamPrompting: """ @@ -162,9 +158,7 @@ def run(self): self.axon.serve(netuid=self.config.netuid, subtensor=self.subtensor) # Start starts the miner's axon, making it active on the network. - bt.logging.info( - f"Starting axon server on port: {self.config.axon.port}" - ) + bt.logging.info(f"Starting axon server on port: {self.config.axon.port}") self.axon.start() # --- Run until should_exit = True. @@ -206,7 +200,7 @@ def run(self): f"Stake:{metagraph.S[self.my_subnet_uid]} | " f"Rank:{metagraph.R[self.my_subnet_uid]} | " f"Trust:{metagraph.T[self.my_subnet_uid]} | " - f"Consensus:{metagraph.C[self.my_subnet_uid] } | " + f"Consensus:{metagraph.C[self.my_subnet_uid]} | " f"Incentive:{metagraph.I[self.my_subnet_uid]} | " f"Emission:{metagraph.E[self.my_subnet_uid]}" ) @@ -347,9 +341,7 @@ async def _prompt(text: str, send: Send): processing steps or modify how tokens are sent back to the client. """ bt.logging.trace("HI. _PROMPT()") - input_ids = tokenizer( - text, return_tensors="pt" - ).input_ids.squeeze() + input_ids = tokenizer(text, return_tensors="pt").input_ids.squeeze() buffer = [] bt.logging.debug(f"Input text: {text}") bt.logging.debug(f"Input ids: {input_ids}") diff --git a/docs/stream_tutorial/protocol.py b/docs/stream_tutorial/protocol.py index 26e91fdc..25c4e92b 100644 --- a/docs/stream_tutorial/protocol.py +++ b/docs/stream_tutorial/protocol.py @@ -85,9 +85,7 @@ async def process_streaming_response(self, response: StreamingResponse): """ if self.completion is None: self.completion = "" - bt.logging.debug( - "Processing streaming response (StreamingSynapse base class)." - ) + bt.logging.debug("Processing streaming response (StreamingSynapse base class).") async for chunk in response.content.iter_any(): bt.logging.debug(f"Processing chunk: {chunk}") tokens = chunk.decode("utf-8").split("\n") diff --git a/neurons/miner.py b/neurons/miner.py index 2b3e4edd..47c41843 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -17,6 +17,7 @@ # DEALINGS IN THE SOFTWARE. import os + # Set USE_TORCH=1 environment variable to use torch instead of numpy os.environ["USE_TORCH"] = "1" @@ -32,16 +33,21 @@ from omega.base.miner import BaseMinerNeuron from omega.imagebind_wrapper import ImageBind -from omega.miner_utils import search_and_diarize_youtube_videos, search_and_embed_youtube_videos +from omega.miner_utils import ( + search_and_diarize_youtube_videos, + search_and_embed_youtube_videos, +) from omega.augment import LocalLLMAugment, OpenAIAugment, NoAugment from omega.utils.config import QueryAugment from omega.constants import VALIDATOR_TIMEOUT, VALIDATOR_TIMEOUT_AUDIO from omega.diarization_pipeline import CustomDiarizationPipeline + class Miner(BaseMinerNeuron): """ Your miner neuron class. You should use this class to define your miner's behavior. In particular, you should replace the forward function with your own logic. You may also want to override the blacklist and priority functions according to your needs. """ + def __init__(self, config=None): super(Miner, self).__init__(config=config) query_augment_type = QueryAugment(self.config.neuron.query_augment) @@ -53,59 +59,74 @@ def __init__(self, config=None): self.augment = OpenAIAugment(device=self.config.neuron.device) else: raise ValueError("Invalid query augment") - - + self.diarization_pipeline = CustomDiarizationPipeline( - overlap_detection_model_id = "tezuesh/overlapped-speech-detection", + overlap_detection_model_id="tezuesh/overlapped-speech-detection", diarization_model_id="tezuesh/diarization", # device="cuda" ) self.imagebind = ImageBind(v2=True) - async def forward_videos( - self, synapse: omega.protocol.Videos - ) : + async def forward_videos(self, synapse: omega.protocol.Videos): # Scrape Youtube videos - bt.logging.info(f"Received scraping request: {synapse.num_videos} videos for query '{synapse.query}'") - + bt.logging.info( + f"Received scraping request: {synapse.num_videos} videos for query '{synapse.query}'" + ) + start = time.time() - + synapse.video_metadata = search_and_embed_youtube_videos( self.augment(synapse.query), synapse.num_videos, self.imagebind ) - + time_elapsed = time.time() - start - - if len(synapse.video_metadata) == synapse.num_videos and time_elapsed < VALIDATOR_TIMEOUT: - bt.logging.info(f"–––––– SCRAPING SUCCEEDED: Scraped {len(synapse.video_metadata)}/{synapse.num_videos} videos in {time_elapsed} seconds.") - else: - bt.logging.error(f"–––––– SCRAPING FAILED: Scraped {len(synapse.video_metadata)}/{synapse.num_videos} videos in {time_elapsed} seconds.") + if ( + len(synapse.video_metadata) == synapse.num_videos + and time_elapsed < VALIDATOR_TIMEOUT + ): + bt.logging.info( + f"–––––– SCRAPING SUCCEEDED: Scraped {len(synapse.video_metadata)}/{synapse.num_videos} videos in {time_elapsed} seconds." + ) + else: + bt.logging.error( + f"–––––– SCRAPING FAILED: Scraped {len(synapse.video_metadata)}/{synapse.num_videos} videos in {time_elapsed} seconds." + ) return synapse - + async def forward_audios( self, synapse: omega.protocol.Audios ) -> omega.protocol.Audios: - bt.logging.info(f"Received youtube audio scraping and diarization request: {synapse.num_audios} audios for query '{synapse.query}'") - + bt.logging.info( + f"Received youtube audio scraping and diarization request: {synapse.num_audios} audios for query '{synapse.query}'" + ) + start = time.time() - + synapse.audio_metadata = search_and_diarize_youtube_videos( - self.augment(synapse.query), synapse.num_audios, self.diarization_pipeline, self.imagebind + self.augment(synapse.query), + synapse.num_audios, + self.diarization_pipeline, + self.imagebind, ) - + time_elapsed = time.time() - start - - if len(synapse.audio_metadata) == synapse.num_audios and time_elapsed < VALIDATOR_TIMEOUT_AUDIO: - bt.logging.info(f"–––––– SCRAPING SUCCEEDED: Scraped {len(synapse.audio_metadata)}/{synapse.num_audios} audios in {time_elapsed} seconds.") + + if ( + len(synapse.audio_metadata) == synapse.num_audios + and time_elapsed < VALIDATOR_TIMEOUT_AUDIO + ): + bt.logging.info( + f"–––––– SCRAPING SUCCEEDED: Scraped {len(synapse.audio_metadata)}/{synapse.num_audios} audios in {time_elapsed} seconds." + ) else: - bt.logging.error(f"–––––– SCRAPING FAILED: Scraped {len(synapse.audio_metadata)}/{synapse.num_audios} audios in {time_elapsed} seconds.") + bt.logging.error( + f"–––––– SCRAPING FAILED: Scraped {len(synapse.audio_metadata)}/{synapse.num_audios} audios in {time_elapsed} seconds." + ) return synapse - async def blacklist( - self, synapse: bt.Synapse - ) -> typing.Tuple[bool, str]: + async def blacklist(self, synapse: bt.Synapse) -> typing.Tuple[bool, str]: """ Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should define the logic for blacklisting requests based on your needs and desired security parameters. @@ -156,20 +177,25 @@ async def blacklist( return True, "Non-validator hotkey" stake = self.metagraph.S[uid].item() - if self.config.blacklist.validator_min_stake and stake < self.config.blacklist.validator_min_stake: - bt.logging.warning(f"Blacklisting request from {synapse.dendrite.hotkey} [uid={uid}], not enough stake -- {stake}") + if ( + self.config.blacklist.validator_min_stake + and stake < self.config.blacklist.validator_min_stake + ): + bt.logging.warning( + f"Blacklisting request from {synapse.dendrite.hotkey} [uid={uid}], not enough stake -- {stake}" + ) return True, "Stake below minimum" bt.logging.trace( f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}" ) return False, "Hotkey recognized!" - + async def blacklist_videos( self, synapse: omega.protocol.Videos ) -> typing.Tuple[bool, str]: return await self.blacklist(synapse) - + async def blacklist_audios( self, synapse: omega.protocol.Audios ) -> typing.Tuple[bool, str]: @@ -206,16 +232,12 @@ async def priority(self, synapse: bt) -> float: ) return prirority - async def priority_videos( - self, synapse: omega.protocol.Videos - ) -> float: + async def priority_videos(self, synapse: omega.protocol.Videos) -> float: return await self.priority(synapse) - - async def priority_audios( - self, synapse: omega.protocol.Audios - ) -> float: + + async def priority_audios(self, synapse: omega.protocol.Audios) -> float: return await self.priority(synapse) - + def save_state(self): """ We define this function to avoid printing out the log message in the BaseNeuron class @@ -223,6 +245,7 @@ def save_state(self): """ pass + # This is the main function, which runs the miner. if __name__ == "__main__": with Miner() as miner: diff --git a/neurons/test_miner.py b/neurons/test_miner.py index a53f73f2..244f5ed3 100644 --- a/neurons/test_miner.py +++ b/neurons/test_miner.py @@ -13,20 +13,28 @@ if time_elapsed > VALIDATOR_TIMEOUT or len(video_metadata_list) < num_videos: if time_elapsed > VALIDATOR_TIMEOUT: - print(f"Searching took {time_elapsed} seconds, which is longer than the validator timeout of {VALIDATOR_TIMEOUT} seconds") + print( + f"Searching took {time_elapsed} seconds, which is longer than the validator timeout of {VALIDATOR_TIMEOUT} seconds" + ) if len(video_metadata_list) < num_videos: - print(f"Only got {len(video_metadata_list)} videos, which is less than the requested {num_videos} videos") + print( + f"Only got {len(video_metadata_list)} videos, which is less than the requested {num_videos} videos" + ) else: - print(f"SUCCESS! Search and embed took {time_elapsed} seconds and got {len(video_metadata_list)} videos") + print( + f"SUCCESS! Search and embed took {time_elapsed} seconds and got {len(video_metadata_list)} videos" + ) if len(video_metadata_list) == 0: print("No videos found") else: - videos = Videos(query=query, num_videos=num_videos, video_metadata=video_metadata_list) + videos = Videos( + query=query, num_videos=num_videos, video_metadata=video_metadata_list + ) response = requests.get( "https://dev-sn24-api.omegatron.ai/api/count_unique", - json=videos.to_serializable_dict(videos) + json=videos.to_serializable_dict(videos), ) print(response.json()) diff --git a/neurons/validator.py b/neurons/validator.py index 38d19407..1e232f86 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -17,6 +17,7 @@ # DEALINGS IN THE SOFTWARE. import os + os.environ["USE_TORCH"] = "1" @@ -81,6 +82,7 @@ import asyncio from aiohttp import ClientSession, BasicAuth import os + # Set USE_TORCH=1 environment variable to use torch instead of numpy os.environ["USE_TORCH"] = "1" @@ -97,6 +99,7 @@ CLIENT_TIMEOUT_SECONDS = VALIDATOR_TIMEOUT + VALIDATOR_TIMEOUT_MARGIN CLIENT_TIMEOUT_SECONDS_AUDIO = VALIDATOR_TIMEOUT_AUDIO + VALIDATOR_TIMEOUT_MARGIN + class Validator(BaseValidatorNeuron): """ Your validator neuron class. You should use this class to define your validator's behavior. In particular, you should replace the forward function with your own logic. @@ -119,17 +122,19 @@ def __init__(self, config=None): self.successfully_started_wandb = True else: bt.logging.exception( - "WANDB_API_KEY not found. Set it with `export WANDB_API_KEY=`. Alternatively, you can disable W&B with --wandb.off, but it is strongly recommended to run with W&B enabled.") + "WANDB_API_KEY not found. Set it with `export WANDB_API_KEY=`. Alternatively, you can disable W&B with --wandb.off, but it is strongly recommended to run with W&B enabled." + ) self.successfully_started_wandb = False else: bt.logging.warning( - "Running with --wandb.off. It is strongly recommended to run with W&B enabled.") + "Running with --wandb.off. It is strongly recommended to run with W&B enabled." + ) self.successfully_started_wandb = False self.api_root = ( "https://dev-sn24-api.omegatron.ai" - if self.config.subtensor.network == "test" else - "https://sn24-api.omegatron.ai" + if self.config.subtensor.network == "test" + else "https://sn24-api.omegatron.ai" ) # load topics from topics URL (CSV) or fallback to local topics file self.load_topics_start = dt.datetime.now() @@ -139,7 +144,9 @@ def __init__(self, config=None): self.load_focus_rewards_start = dt.datetime.now() self.FOCUS_REWARDS_PERCENT = self.load_focus_rewards_percent() self.AUDIO_REWARDS_PERCENT = AUDIO_REWARDS_PERCENT - self.YOUTUBE_REWARDS_PERCENT = 1.0 - self.FOCUS_REWARDS_PERCENT - self.AUDIO_REWARDS_PERCENT + self.YOUTUBE_REWARDS_PERCENT = ( + 1.0 - self.FOCUS_REWARDS_PERCENT - self.AUDIO_REWARDS_PERCENT + ) def new_wandb_run(self): # Shoutout SN13 for the wandb snippet! @@ -173,31 +180,42 @@ def load_topics(self): # split the response text into a list of topics and trim any whitespace all_topics = [line.strip() for line in response.text.split("\n")] bt.logging.info( - f"Loaded {len(all_topics)} topics from {self.config.topics_url}") + f"Loaded {len(all_topics)} topics from {self.config.topics_url}" + ) except Exception as e: bt.logging.error( - f"Error loading topics from URL {self.config.topics_url}: {e}") + f"Error loading topics from URL {self.config.topics_url}: {e}" + ) traceback.print_exc() + bt.logging.info(f"Using fallback topics from {self.config.topics_path}") + all_topics = [ + line.strip() for line in open(self.config.topics_path) if line.strip() + ] bt.logging.info( - f"Using fallback topics from {self.config.topics_path}") - all_topics = [line.strip() for line in open( - self.config.topics_path) if line.strip()] - bt.logging.info( - f"Loaded {len(all_topics)} topics from {self.config.topics_path}") + f"Loaded {len(all_topics)} topics from {self.config.topics_path}" + ) return all_topics def load_focus_rewards_percent(self): # get focus rewards percent from API endpoint or fallback to default try: - focus_rewards_percent_endpoint = f"{self.api_root}/api/focus/get_rewards_percent" + focus_rewards_percent_endpoint = ( + f"{self.api_root}/api/focus/get_rewards_percent" + ) response = requests.get(focus_rewards_percent_endpoint) response.raise_for_status() rewards_percent = float(response.text) - bt.logging.info(f"Loaded focus rewards percent of {rewards_percent} from {focus_rewards_percent_endpoint}") + bt.logging.info( + f"Loaded focus rewards percent of {rewards_percent} from {focus_rewards_percent_endpoint}" + ) except Exception as e: - bt.logging.error(f"Error loading rewards percent from {focus_rewards_percent_endpoint}: {e}") + bt.logging.error( + f"Error loading rewards percent from {focus_rewards_percent_endpoint}: {e}" + ) traceback.print_exc() - bt.logging.info(f"Using fallback focus rewards percent of {FOCUS_REWARDS_PERCENT}") + bt.logging.info( + f"Using fallback focus rewards percent of {FOCUS_REWARDS_PERCENT}" + ) rewards_percent = FOCUS_REWARDS_PERCENT return rewards_percent @@ -209,7 +227,7 @@ async def forward(self): - Getting the responses - Rewarding the miners - Updating the scores - + The forward function is called by the validator every time step. It is responsible for querying the network and scoring the responses. @@ -226,8 +244,7 @@ async def forward(self): return """ START YOUTUBE AUDIO PROCESSING AND SCORING """ - bt.logging.info( - "===== YOUTUBE REQUESTS, AUDIO PROCESSING, AND SCORING =====") + bt.logging.info("===== YOUTUBE REQUESTS, AUDIO PROCESSING, AND SCORING =====") # The dendrite client queries the network. query = random.choice(self.all_topics) + " podcast" bt.logging.info(f"Sending query '{query}' to miners {miner_uids}") @@ -246,11 +263,18 @@ async def forward(self): audio_finished_responses = [] for response in audio_responses: - if response.audio_metadata is None or not response.axon or not response.axon.hotkey: + if ( + response.audio_metadata is None + or not response.axon + or not response.axon.hotkey + ): continue - uid = [uid for uid, axon in zip( - miner_uids, axons) if axon.hotkey == response.axon.hotkey][0] + uid = [ + uid + for uid, axon in zip(miner_uids, axons) + if axon.hotkey == response.axon.hotkey + ][0] audio_working_miner_uids.append(uid) audio_finished_responses.append(response) @@ -261,7 +285,9 @@ async def forward(self): bt.logging.info(f"Received audio responses: {audio_responses}") # Adjust the scores based on responses from miners. try: - audio_rewards_list = await self.handle_checks_and_reward_audio(input_synapse=audio_input_synapse, responses=audio_finished_responses) + audio_rewards_list = await self.handle_checks_and_reward_audio( + input_synapse=audio_input_synapse, responses=audio_finished_responses + ) except Exception as e: bt.logging.error(f"Error in handle_checks_and_rewards_audio: {e}") traceback.print_exc() @@ -278,24 +304,25 @@ async def forward(self): # give min reward to miners who didn't respond bad_miner_uids = [ - uid for uid in miner_uids if uid not in audio_working_miner_uids] + uid for uid in miner_uids if uid not in audio_working_miner_uids + ] penalty_tensor = torch.FloatTensor( - [NO_RESPONSE_MINIMUM] * len(bad_miner_uids)).to(self.device) + [NO_RESPONSE_MINIMUM] * len(bad_miner_uids) + ).to(self.device) self.update_audio_scores(penalty_tensor, bad_miner_uids) for reward, miner_uid in zip(audio_rewards, audio_reward_uids): bt.logging.info( - f"Rewarding miner={miner_uid} with reward={reward} for audio dataset") + f"Rewarding miner={miner_uid} with reward={reward} for audio dataset" + ) for penalty, miner_uid in zip(penalty_tensor, bad_miner_uids): - bt.logging.info( - f"Penalizing miner={miner_uid} with penalty={penalty}") + bt.logging.info(f"Penalizing miner={miner_uid} with penalty={penalty}") """ END YOUTUBE AUDIO PROCESSING AND SCORING """ """ START YOUTUBE SYNAPSE REQUESTS, PROCESSING, AND SCORING """ - bt.logging.info( - "===== YOUTUBE REQUESTS, PROCESSING, AND SCORING =====") + bt.logging.info("===== YOUTUBE REQUESTS, PROCESSING, AND SCORING =====") # The dendrite client queries the network. query = random.choice(self.all_topics) bt.logging.info(f"Sending query '{query}' to miners {miner_uids}") @@ -313,11 +340,18 @@ async def forward(self): finished_responses = [] for response in responses: - if response.video_metadata is None or not response.axon or not response.axon.hotkey: + if ( + response.video_metadata is None + or not response.axon + or not response.axon.hotkey + ): continue - uid = [uid for uid, axon in zip( - miner_uids, axons) if axon.hotkey == response.axon.hotkey][0] + uid = [ + uid + for uid, axon in zip(miner_uids, axons) + if axon.hotkey == response.axon.hotkey + ][0] working_miner_uids.append(uid) finished_responses.append(response) @@ -329,10 +363,11 @@ async def forward(self): # Adjust the scores based on responses from miners. try: - rewards_list = await self.handle_checks_and_rewards_youtube(input_synapse=input_synapse, responses=finished_responses) + rewards_list = await self.handle_checks_and_rewards_youtube( + input_synapse=input_synapse, responses=finished_responses + ) except Exception as e: - bt.logging.error( - f"Error in handle_checks_and_rewards_youtube: {e}") + bt.logging.error(f"Error in handle_checks_and_rewards_youtube: {e}") traceback.print_exc() return @@ -347,19 +382,17 @@ async def forward(self): self.update_scores(rewards, reward_uids) # give min reward to miners who didn't respond - bad_miner_uids = [ - uid for uid in miner_uids if uid not in working_miner_uids] + bad_miner_uids = [uid for uid in miner_uids if uid not in working_miner_uids] penalty_tensor = torch.FloatTensor( - [NO_RESPONSE_MINIMUM] * len(bad_miner_uids)).to(self.device) + [NO_RESPONSE_MINIMUM] * len(bad_miner_uids) + ).to(self.device) self.update_scores(penalty_tensor, bad_miner_uids) for reward, miner_uid in zip(rewards, reward_uids): - bt.logging.info( - f"Rewarding miner={miner_uid} with reward={reward}") + bt.logging.info(f"Rewarding miner={miner_uid} with reward={reward}") for penalty, miner_uid in zip(penalty_tensor, bad_miner_uids): - bt.logging.info( - f"Penalizing miner={miner_uid} with penalty={penalty}") + bt.logging.info(f"Penalizing miner={miner_uid} with penalty={penalty}") """ END YOUTUBE SYNAPSE REQUESTS, PROCESSING, AND SCORING """ """ START FOCUS VIDEOS PROCESSING AND SCORING """ @@ -383,29 +416,39 @@ async def forward(self): self.update_focus_scores(focus_rewards, focus_reward_uids) for reward, uid in zip(focus_rewards, focus_reward_uids): - bt.logging.info(f"Scoring miner={uid} with reward={reward} for focus videos") + bt.logging.info( + f"Scoring miner={uid} with reward={reward} for focus videos" + ) """ END FOCUS VIDEOS PROCESSING AND SCORING """ def metadata_check(self, metadata: List[VideoMetadata]) -> List[VideoMetadata]: return [ - video_metadata for video_metadata in metadata + video_metadata + for video_metadata in metadata if ( - video_metadata.end_time - video_metadata.start_time <= MAX_VIDEO_LENGTH and - video_metadata.end_time - video_metadata.start_time >= MIN_VIDEO_LENGTH + video_metadata.end_time - video_metadata.start_time <= MAX_VIDEO_LENGTH + and video_metadata.end_time - video_metadata.start_time + >= MIN_VIDEO_LENGTH ) ] - def audio_metadata_check(self, metadata: List[AudioMetadata]) -> List[AudioMetadata]: + def audio_metadata_check( + self, metadata: List[AudioMetadata] + ) -> List[AudioMetadata]: return [ - audio_metadata for audio_metadata in metadata + audio_metadata + for audio_metadata in metadata if ( - audio_metadata.end_time - audio_metadata.start_time <= MAX_VIDEO_LENGTH and - audio_metadata.end_time - audio_metadata.start_time >= MIN_VIDEO_LENGTH + audio_metadata.end_time - audio_metadata.start_time <= MAX_VIDEO_LENGTH + and audio_metadata.end_time - audio_metadata.start_time + >= MIN_VIDEO_LENGTH ) ] - def filter_embeddings(self, embeddings: Embeddings, is_too_similar: List[bool]) -> Embeddings: + def filter_embeddings( + self, embeddings: Embeddings, is_too_similar: List[bool] + ) -> Embeddings: """Filter the embeddings based on whether they are too similar to the query.""" is_too_similar = torch.tensor(is_too_similar) if embeddings.video is not None: @@ -416,7 +459,9 @@ def filter_embeddings(self, embeddings: Embeddings, is_too_similar: List[bool]) embeddings.description = embeddings.description[~is_too_similar] return embeddings - def filter_stuffed_embeddings(self, embeddings: Embeddings, stuffed: List[Tuple[bool, float]]) -> Embeddings: + def filter_stuffed_embeddings( + self, embeddings: Embeddings, stuffed: List[Tuple[bool, float]] + ) -> Embeddings: """Filter the embeddings based on whether they are too similar to the query.""" stuffed = torch.tensor([s for s, _ in stuffed]) if embeddings.video is not None: @@ -434,7 +479,7 @@ async def deduplicate_videos(self, embeddings: Embeddings) -> Videos: cossim = CosineSimilarity(dim=1) is_similar = [] for i in range(num_videos): - similarity_score = cossim(video_tensor[[i]], video_tensor[i + 1:]) + similarity_score = cossim(video_tensor[[i]], video_tensor[i + 1 :]) has_duplicates = (similarity_score > SIMILARITY_THRESHOLD).any() is_similar.append(has_duplicates.item()) @@ -447,26 +492,26 @@ async def deduplicate_audios(self, embeddings: Embeddings) -> Audios: cossim = CosineSimilarity(dim=1) is_similar = [] for i in range(num_audios): - similarity_score = cossim(audio_tensor[[i]], audio_tensor[i + 1:]) + similarity_score = cossim(audio_tensor[[i]], audio_tensor[i + 1 :]) has_duplicates = (similarity_score > SIMILARITY_THRESHOLD).any() is_similar.append(has_duplicates.item()) return is_similar def is_similar(self, emb_1: torch.Tensor, emb_2: List[float]) -> bool: - return F.cosine_similarity( - emb_1, - torch.tensor(emb_2, device=emb_1.device).unsqueeze(0) - ) > SIMILARITY_THRESHOLD + return ( + F.cosine_similarity( + emb_1, torch.tensor(emb_2, device=emb_1.device).unsqueeze(0) + ) + > SIMILARITY_THRESHOLD + ) def strict_is_similar(self, emb_1: torch.Tensor, emb_2: List[float]) -> bool: - return torch.allclose(emb_1, torch.tensor(emb_2, device=emb_1.device), atol=1e-4) + return torch.allclose( + emb_1, torch.tensor(emb_2, device=emb_1.device), atol=1e-4 + ) - async def get_random_youtube_video( - self, - metadata, - check_video: bool - ): + async def get_random_youtube_video(self, metadata, check_video: bool): if not check_video and len(metadata) > 0: random_metadata = random.choice(metadata) return random_metadata, None @@ -479,27 +524,34 @@ async def get_random_youtube_video( proxy_url = await self.get_proxy_url() if proxy_url is None: bt.logging.info( - "Issue getting proxy_url from API, not using proxy. Attempting download for random_video check") + "Issue getting proxy_url from API, not using proxy. Attempting download for random_video check" + ) else: bt.logging.info( - "Got proxy_url from API. Attempting download for random_video check") + "Got proxy_url from API. Attempting download for random_video check" + ) try: async with DOWNLOAD_SEMAPHORE: - random_video = await asyncio.wait_for(run_async( - video_utils.download_youtube_video, - random_metadata.video_id, - random_metadata.start_time, - random_metadata.end_time, - proxy=proxy_url - ), timeout=VIDEO_DOWNLOAD_TIMEOUT) + random_video = await asyncio.wait_for( + run_async( + video_utils.download_youtube_video, + random_metadata.video_id, + random_metadata.start_time, + random_metadata.end_time, + proxy=proxy_url, + ), + timeout=VIDEO_DOWNLOAD_TIMEOUT, + ) except video_utils.IPBlockedException: # IP is blocked, cannot download video, check description only bt.logging.warning( - "WARNING: IP is blocked, cannot download video, checking description only") + "WARNING: IP is blocked, cannot download video, checking description only" + ) return random_metadata, None except video_utils.FakeVideoException: bt.logging.warning( - f"WARNING: Video {random_metadata.video_id} is fake, punishing miner") + f"WARNING: Video {random_metadata.video_id} is fake, punishing miner" + ) return None except asyncio.TimeoutError: continue @@ -511,55 +563,64 @@ async def get_random_youtube_video( return random_metadata, random_video - async def random_youtube_check(self, random_meta_and_vid: List[VideoMetadata]) -> bool: + async def random_youtube_check( + self, random_meta_and_vid: List[VideoMetadata] + ) -> bool: random_metadata, random_video = random_meta_and_vid if random_video is None: - desc_embeddings = self.imagebind.embed_text( - [random_metadata.description]) + desc_embeddings = self.imagebind.embed_text([random_metadata.description]) is_similar_ = self.is_similar( - desc_embeddings, random_metadata.description_emb) + desc_embeddings, random_metadata.description_emb + ) strict_is_similar_ = self.strict_is_similar( - desc_embeddings, random_metadata.description_emb) + desc_embeddings, random_metadata.description_emb + ) bt.logging.info( - f"Description similarity: {is_similar_}, strict description similarity: {strict_is_similar_}") + f"Description similarity: {is_similar_}, strict description similarity: {strict_is_similar_}" + ) return is_similar_ # Video downloaded, check all embeddings - embeddings = self.imagebind.embed( - [random_metadata.description], [random_video]) + embeddings = self.imagebind.embed([random_metadata.description], [random_video]) is_similar_ = ( - self.is_similar(embeddings.video, random_metadata.video_emb) and - self.is_similar(embeddings.audio, random_metadata.audio_emb) and - self.is_similar(embeddings.description, - random_metadata.description_emb) + self.is_similar(embeddings.video, random_metadata.video_emb) + and self.is_similar(embeddings.audio, random_metadata.audio_emb) + and self.is_similar(embeddings.description, random_metadata.description_emb) ) strict_is_similar_ = ( - self.strict_is_similar(embeddings.video, random_metadata.video_emb) and - self.strict_is_similar(embeddings.audio, random_metadata.audio_emb) and - self.strict_is_similar( - embeddings.description, random_metadata.description_emb) + self.strict_is_similar(embeddings.video, random_metadata.video_emb) + and self.strict_is_similar(embeddings.audio, random_metadata.audio_emb) + and self.strict_is_similar( + embeddings.description, random_metadata.description_emb + ) ) bt.logging.debug( - f"Total similarity: {is_similar_}, strict total similarity: {strict_is_similar_}") + f"Total similarity: {is_similar_}, strict total similarity: {strict_is_similar_}" + ) return is_similar_ - async def random_audio_check(self, random_meta_and_audio: List[AudioMetadata]) -> bool: + async def random_audio_check( + self, random_meta_and_audio: List[AudioMetadata] + ) -> bool: random_metadata, random_video = random_meta_and_audio bt.logging.info( - f"inside random_audio_check, random_metadata: {random_metadata}, random_video: {random_video}") + f"inside random_audio_check, random_metadata: {random_metadata}, random_video: {random_video}" + ) if random_video is None: return True - audio_bytes_from_youtube = video_utils.get_audio_bytes( - random_video.name) - audio_bytes_from_youtube = base64.b64encode( - audio_bytes_from_youtube).decode('utf-8') + audio_bytes_from_youtube = video_utils.get_audio_bytes(random_video.name) + audio_bytes_from_youtube = base64.b64encode(audio_bytes_from_youtube).decode( + "utf-8" + ) audio_array_youtube, _ = sf.read( - BytesIO(base64.b64decode(audio_bytes_from_youtube))) + BytesIO(base64.b64decode(audio_bytes_from_youtube)) + ) submitted_audio_bytes = random_metadata.audio_bytes audio_array_submitted, _ = sf.read( - BytesIO(base64.b64decode(submitted_audio_bytes))) + BytesIO(base64.b64decode(submitted_audio_bytes)) + ) if np.array_equal(audio_array_youtube, audio_array_submitted) is False: bt.logging.warning("WARNING: Audio bytes do not match") @@ -572,7 +633,8 @@ def compute_novelty_score_among_batch(self, emb: Embeddings) -> List[float]: novelty_scores = [] for i in range(num_videos - 1): similarity_score = F.cosine_similarity( - video_tensor[[i]], video_tensor[i + 1:]).max() + video_tensor[[i]], video_tensor[i + 1 :] + ).max() novelty_scores.append(1 - similarity_score.item()) novelty_scores.append(1.0) # last video is 100% novel return novelty_scores @@ -583,7 +645,8 @@ def compute_novelty_score_among_batch_audio(self, emb: Embeddings) -> List[float novelty_scores = [] for i in range(num_audios - 1): similarity_score = F.cosine_similarity( - audio_tensor[[i]], audio_tensor[i + 1:]).max() + audio_tensor[[i]], audio_tensor[i + 1 :] + ).max() novelty_scores.append(1 - similarity_score.item()) novelty_scores.append(1.0) # last video is 100% novel return novelty_scores @@ -593,18 +656,18 @@ async def async_zero() -> None: # algorithm for computing final novelty score def compute_final_novelty_score(self, base_novelty_scores: List[float]) -> float: - is_too_similar = [ - score < DIFFERENCE_THRESHOLD for score in base_novelty_scores] - novelty_score = sum([ - score for score, is_too_similar - in zip(base_novelty_scores, is_too_similar) if not is_too_similar - ]) + is_too_similar = [score < DIFFERENCE_THRESHOLD for score in base_novelty_scores] + novelty_score = sum( + [ + score + for score, is_too_similar in zip(base_novelty_scores, is_too_similar) + if not is_too_similar + ] + ) return novelty_score async def check_videos_and_calculate_rewards_youtube( - self, - input_synapse: Videos, - videos: Videos + self, input_synapse: Videos, videos: Videos ) -> Optional[float]: try: # return minimum score if no videos were found in video_metadata @@ -612,21 +675,28 @@ async def check_videos_and_calculate_rewards_youtube( return MIN_SCORE # check video_ids for fake videos - if any(not video_utils.is_valid_youtube_id(video.video_id) for video in videos.video_metadata): + if any( + not video_utils.is_valid_youtube_id(video.video_id) + for video in videos.video_metadata + ): return FAKE_VIDEO_PUNISHMENT # check and filter duplicate metadata metadata = self.metadata_check(videos.video_metadata)[ - :input_synapse.num_videos] + : input_synapse.num_videos + ] if len(metadata) < len(videos.video_metadata): bt.logging.info( - f"Filtered {len(videos.video_metadata)} videos down to {len(metadata)} videos") + f"Filtered {len(videos.video_metadata)} videos down to {len(metadata)} videos" + ) # if randomly tripped, flag our random check to pull a video from miner's submissions check_video = CHECK_PROBABILITY > random.random() # pull a random video and/or description only - random_meta_and_vid = await self.get_random_youtube_video(metadata, check_video) + random_meta_and_vid = await self.get_random_youtube_video( + metadata, check_video + ) if random_meta_and_vid is None: return FAKE_VIDEO_PUNISHMENT @@ -640,42 +710,48 @@ async def check_videos_and_calculate_rewards_youtube( query_emb = await self.imagebind.embed_text_async([videos.query]) embeddings = Embeddings( - video=torch.stack([torch.tensor(v.video_emb) - for v in metadata]).to(self.imagebind.device), - audio=torch.stack([torch.tensor(v.audio_emb) - for v in metadata]).to(self.imagebind.device), - description=torch.stack([torch.tensor(v.description_emb) for v in metadata]).to( - self.imagebind.device), + video=torch.stack([torch.tensor(v.video_emb) for v in metadata]).to( + self.imagebind.device + ), + audio=torch.stack([torch.tensor(v.audio_emb) for v in metadata]).to( + self.imagebind.device + ), + description=torch.stack( + [torch.tensor(v.description_emb) for v in metadata] + ).to(self.imagebind.device), ) # check and deduplicate videos based on embedding similarity checks. We do this because we're not uploading to pinecone first. metadata_is_similar = await self.deduplicate_videos(embeddings) - metadata = [metadata for metadata, too_similar in zip( - metadata, metadata_is_similar) if not too_similar] - embeddings = self.filter_embeddings( - embeddings, metadata_is_similar) + metadata = [ + metadata + for metadata, too_similar in zip(metadata, metadata_is_similar) + if not too_similar + ] + embeddings = self.filter_embeddings(embeddings, metadata_is_similar) if len(metadata) < len(videos.video_metadata): bt.logging.info( - f"Deduplicated {len(videos.video_metadata)} videos down to {len(metadata)} videos") + f"Deduplicated {len(videos.video_metadata)} videos down to {len(metadata)} videos" + ) # return minimum score if no unique videos were found if len(metadata) == 0: return MIN_SCORE # first get local novelty scores - local_novelty_scores = self.compute_novelty_score_among_batch( - embeddings) + local_novelty_scores = self.compute_novelty_score_among_batch(embeddings) # bt.logging.debug(f"local_novelty_scores: {local_novelty_scores}") # second get the novelty scores from the validator api if not already too similar embeddings_to_check = [ (embedding, metadata) - for embedding, local_score, metadata in zip(embeddings.video, local_novelty_scores, metadata) + for embedding, local_score, metadata in zip( + embeddings.video, local_novelty_scores, metadata + ) if local_score >= DIFFERENCE_THRESHOLD ] # If there are embeddings to check, call get_novelty_scores once if embeddings_to_check: - embeddings_to_check, metadata_to_check = zip( - *embeddings_to_check) + embeddings_to_check, metadata_to_check = zip(*embeddings_to_check) global_novelty_scores = await self.get_novelty_scores(metadata_to_check) else: # If no embeddings to check, return an empty list or appropriate default value @@ -683,29 +759,37 @@ async def check_videos_and_calculate_rewards_youtube( if global_novelty_scores is None or len(global_novelty_scores) == 0: bt.logging.error( - "Issue retrieving global novelty scores, returning None.") + "Issue retrieving global novelty scores, returning None." + ) return None # #bt.logging.debug(f"global_novelty_scores: {global_novelty_scores}") # calculate true novelty scores between local and global true_novelty_scores = [ - min(local_score, global_score) for local_score, global_score - in zip(local_novelty_scores, global_novelty_scores) + min(local_score, global_score) + for local_score, global_score in zip( + local_novelty_scores, global_novelty_scores + ) ] # bt.logging.debug(f"true_novelty_scores: {true_novelty_scores}") pre_filter_metadata_length = len(metadata) # check scores from index for being too similar is_too_similar = [ - score < DIFFERENCE_THRESHOLD for score in true_novelty_scores] + score < DIFFERENCE_THRESHOLD for score in true_novelty_scores + ] # filter out metadata too similar - metadata = [metadata for metadata, too_similar in zip( - metadata, is_too_similar) if not too_similar] + metadata = [ + metadata + for metadata, too_similar in zip(metadata, is_too_similar) + if not too_similar + ] # filter out embeddings too similar embeddings = self.filter_embeddings(embeddings, is_too_similar) if len(metadata) < pre_filter_metadata_length: bt.logging.info( - f"Filtering {pre_filter_metadata_length} videos down to {len(metadata)} videos that are too similar to videos in our index.") + f"Filtering {pre_filter_metadata_length} videos down to {len(metadata)} videos that are too similar to videos in our index." + ) # return minimum score if no unique videos were found if len(metadata) == 0: @@ -713,25 +797,25 @@ async def check_videos_and_calculate_rewards_youtube( # Filter out "stuffed" descriptions. pre_filter_metadata_length = len(metadata) - stuffed = [ - unstuff.is_stuffed(meta.description) - for meta in metadata - ] + stuffed = [unstuff.is_stuffed(meta.description) for meta in metadata] if any([garbage and confidence > 0.75 for garbage, confidence in stuffed]): bt.logging.warning( - "Stuffed description found with high confidence, penalizing the miner.") + "Stuffed description found with high confidence, penalizing the miner." + ) return STUFFED_DESCRIPTION_PUNISHMENT # More stuffing. extraneous = [ unstuff.check_extraneous_chunks( - meta.description, meta.video_emb, meta.audio_emb, self.imagebind) + meta.description, meta.video_emb, meta.audio_emb, self.imagebind + ) for meta in metadata ] for really_bad, low_quality, total in extraneous: if really_bad > 5 or low_quality >= 16: bt.logging.info( - f"Extraneous garbage found in text check {really_bad=} {low_quality=} {total=}") + f"Extraneous garbage found in text check {really_bad=} {low_quality=} {total=}" + ) return STUFFED_DESCRIPTION_PUNISHMENT metadata = [ @@ -743,7 +827,8 @@ async def check_videos_and_calculate_rewards_youtube( ] if len(metadata) < pre_filter_metadata_length: bt.logging.info( - f"Filtering {pre_filter_metadata_length} videos down to {len(metadata)} videos to remove token-stuffed descriptions.") + f"Filtering {pre_filter_metadata_length} videos down to {len(metadata)} videos to remove token-stuffed descriptions." + ) if len(metadata) == 0: return MIN_SCORE embeddings = self.filter_stuffed_embeddings(embeddings, stuffed) @@ -764,23 +849,29 @@ async def check_videos_and_calculate_rewards_youtube( # Query relevance score now includes video cosim, audio cosim, and text cosim using higher quality text-only model. query_relevance_scores = [ - sum([ - video_query_relevance_scores[idx], - audio_query_relevance_scores[idx], - get_text_similarity_score( - metadata[idx].description, videos.query), - ]) / 3 + sum( + [ + video_query_relevance_scores[idx], + audio_query_relevance_scores[idx], + get_text_similarity_score( + metadata[idx].description, videos.query + ), + ] + ) + / 3 for idx in range(len(video_query_relevance_scores)) ] # Combine audio & visual description scores, weighted towards visual. description_relevance_scores = [ - sum([ - video_description_relevance_scores[idx] * - VIDEO_RELEVANCE_WEIGHT, - audio_description_relevance_scores[idx] * - (1.0 - VIDEO_RELEVANCE_WEIGHT), - ]) + sum( + [ + video_description_relevance_scores[idx] + * VIDEO_RELEVANCE_WEIGHT, + audio_description_relevance_scores[idx] + * (1.0 - VIDEO_RELEVANCE_WEIGHT), + ] + ) for idx in range(len(video_description_relevance_scores)) ] @@ -788,35 +879,46 @@ async def check_videos_and_calculate_rewards_youtube( length_scalers = [] for idx in range(len(description_relevance_scores)): unique_tokens = LENGTH_TOKENIZER(metadata[idx].description) - unique_tokens = set( - unique_tokens[unique_tokens != 0][1:-1].tolist()) + unique_tokens = set(unique_tokens[unique_tokens != 0][1:-1].tolist()) unique_token_count = len(unique_tokens) if unique_token_count <= MIN_LENGTH_BOOST_TOKEN_COUNT: bt.logging.debug( - f"Very few tokens, applying {DESCRIPTION_LENGTH_WEIGHT} penalty.") - description_relevance_scores[idx] *= ( - 1.0 - DESCRIPTION_LENGTH_WEIGHT) + f"Very few tokens, applying {DESCRIPTION_LENGTH_WEIGHT} penalty." + ) + description_relevance_scores[idx] *= 1.0 - DESCRIPTION_LENGTH_WEIGHT length_scalers.append(0) continue - length_scaler = min(math.log(MAX_LENGTH_BOOST_TOKEN_COUNT, 2), math.log( - unique_token_count, 2)) - math.log(MIN_LENGTH_BOOST_TOKEN_COUNT, 2) - length_scaler /= (math.log(MAX_LENGTH_BOOST_TOKEN_COUNT, - 2) - math.log(MIN_LENGTH_BOOST_TOKEN_COUNT, 2)) + length_scaler = min( + math.log(MAX_LENGTH_BOOST_TOKEN_COUNT, 2), + math.log(unique_token_count, 2), + ) - math.log(MIN_LENGTH_BOOST_TOKEN_COUNT, 2) + length_scaler /= math.log(MAX_LENGTH_BOOST_TOKEN_COUNT, 2) - math.log( + MIN_LENGTH_BOOST_TOKEN_COUNT, 2 + ) length_scalers.append(length_scaler) - bt.logging.debug( - f"Description length scaling factor = {length_scaler}") - description_relevance_scores[idx] -= description_relevance_scores[idx] * \ - DESCRIPTION_LENGTH_WEIGHT * (1.0 - length_scaler) + bt.logging.debug(f"Description length scaling factor = {length_scaler}") + description_relevance_scores[idx] -= ( + description_relevance_scores[idx] + * DESCRIPTION_LENGTH_WEIGHT + * (1.0 - length_scaler) + ) # Aggregate scores score = ( - (sum(description_relevance_scores) * DESCRIPTION_RELEVANCE_SCALING_FACTOR) + - (sum(query_relevance_scores) * QUERY_RELEVANCE_SCALING_FACTOR) - ) / 2 / videos.num_videos + ( + ( + sum(description_relevance_scores) + * DESCRIPTION_RELEVANCE_SCALING_FACTOR + ) + + (sum(query_relevance_scores) * QUERY_RELEVANCE_SCALING_FACTOR) + ) + / 2 + / videos.num_videos + ) score = max(score, MIN_SCORE) # Log all our scores - bt.logging.info(f''' + bt.logging.info(f""" is_unique: {[not is_sim for is_sim in is_too_similar]}, video cosine sim: {video_description_relevance_scores}, audio cosine sim: {audio_description_relevance_scores}, @@ -824,11 +926,19 @@ async def check_videos_and_calculate_rewards_youtube( query relevance scores: {query_relevance_scores}, length scalers: {length_scalers}, total score: {score} - ''') + """) # Upload our final results to API endpoint for index and dataset insertion. Include leaderboard statistics miner_hotkey = videos.axon.hotkey - upload_result = await self.upload_video_metadata(metadata, description_relevance_scores, query_relevance_scores, videos.query, None, score, miner_hotkey) + upload_result = await self.upload_video_metadata( + metadata, + description_relevance_scores, + query_relevance_scores, + videos.query, + None, + score, + miner_hotkey, + ) if upload_result: bt.logging.info("Uploading of video metadata successful.") else: @@ -838,7 +948,8 @@ async def check_videos_and_calculate_rewards_youtube( except Exception as e: bt.logging.error( - f"Error in check_videos_and_calculate_rewards_youtube: {e}") + f"Error in check_videos_and_calculate_rewards_youtube: {e}" + ) traceback.print_exc() return None @@ -848,15 +959,16 @@ async def handle_checks_and_rewards_youtube( input_synapse: Videos, responses: List[Videos], ) -> torch.FloatTensor: - - rewards = await asyncio.gather(*[ - self.check_videos_and_calculate_rewards_youtube( - input_synapse, - # replace with input properties from input_synapse - response.replace_with_input(input_synapse), - ) - for response in responses - ]) + rewards = await asyncio.gather( + *[ + self.check_videos_and_calculate_rewards_youtube( + input_synapse, + # replace with input properties from input_synapse + response.replace_with_input(input_synapse), + ) + for response in responses + ] + ) return rewards async def handle_checks_and_reward_audio( @@ -864,13 +976,15 @@ async def handle_checks_and_reward_audio( input_synapse: Audios, responses: List[Audios], ) -> torch.FloatTensor: - rewards = await asyncio.gather(*[ - self.check_audios_and_calculate_rewards( - input_synapse, - response, - ) - for response in responses - ]) + rewards = await asyncio.gather( + *[ + self.check_audios_and_calculate_rewards( + input_synapse, + response, + ) + for response in responses + ] + ) return rewards async def upload_video_metadata( @@ -881,10 +995,10 @@ async def upload_video_metadata( query: str, novelty_score: float, score: float, - miner_hotkey: str + miner_hotkey: str, ) -> bool: """ - Queries the validator api to get novelty scores for supplied videos. + Queries the validator api to get novelty scores for supplied videos. Returns a list of float novelty scores for each video after deduplicating. Returns: @@ -897,8 +1011,9 @@ async def upload_video_metadata( async with ClientSession() as session: # Serialize the list of VideoMetadata # serialized_metadata = [item.dict() for item in metadata] - serialized_metadata = [json.loads( - item.model_dump_json()) for item in metadata] + serialized_metadata = [ + json.loads(item.model_dump_json()) for item in metadata + ] # Construct the JSON payload payload = { "metadata": serialized_metadata, @@ -907,7 +1022,7 @@ async def upload_video_metadata( "topic_query": query, "novelty_score": novelty_score, "total_score": score, - "miner_hotkey": miner_hotkey + "miner_hotkey": miner_hotkey, } async with session.post( @@ -919,8 +1034,7 @@ async def upload_video_metadata( result = await response.json() return True except Exception as e: - bt.logging.debug( - f"Error trying upload_video_metadata_endpoint: {e}") + bt.logging.debug(f"Error trying upload_video_metadata_endpoint: {e}") traceback.print_exc() return False @@ -933,10 +1047,10 @@ async def upload_audio_metadata( audio_query_score: float, query: str, total_score: float, - miner_hotkey: str + miner_hotkey: str, ) -> bool: """ - Queries the validator api to get novelty scores for supplied audios. + Queries the validator api to get novelty scores for supplied audios. Returns a list of float novelty scores for each audio after deduplicating. Returns: @@ -949,8 +1063,9 @@ async def upload_audio_metadata( async with ClientSession() as session: # Serialize the list of AudioMetadata # serialized_metadata = [item.dict() for item in metadata] - serialized_metadata = [json.loads( - item.model_dump_json()) for item in metadata] + serialized_metadata = [ + json.loads(item.model_dump_json()) for item in metadata + ] # Construct the JSON payload payload = { "metadata": serialized_metadata, @@ -960,7 +1075,7 @@ async def upload_audio_metadata( "audio_query_score": audio_query_score, "topic_query": query, "total_score": total_score, - "miner_hotkey": miner_hotkey + "miner_hotkey": miner_hotkey, } async with session.post( @@ -972,14 +1087,13 @@ async def upload_audio_metadata( result = await response.json() return True except Exception as e: - bt.logging.debug( - f"Error trying upload_audio_metadata_endpoint: {e}") + bt.logging.debug(f"Error trying upload_audio_metadata_endpoint: {e}") traceback.print_exc() return False async def get_novelty_scores(self, metadata: List[VideoMetadata]) -> List[float]: """ - Queries the validator api to get novelty scores for supplied videos. + Queries the validator api to get novelty scores for supplied videos. Returns a list of float novelty scores for each video after deduplicating. Returns: @@ -1034,31 +1148,36 @@ async def get_proxy_url(self) -> str: return None async def check_audios_and_calculate_rewards( - self, - input_synapse: Audios, - audios: Audios + self, input_synapse: Audios, audios: Audios ) -> Optional[float]: try: # return minimum score if no videos were found in video_metadata if len(audios.audio_metadata) == 0: return MIN_SCORE # check video_ids for fake videos - if any(not video_utils.is_valid_youtube_id(audio.video_id) for audio in audios.audio_metadata): + if any( + not video_utils.is_valid_youtube_id(audio.video_id) + for audio in audios.audio_metadata + ): return FAKE_VIDEO_PUNISHMENT # check and filter duplicate metadata metadata = self.audio_metadata_check(audios.audio_metadata)[ - :input_synapse.num_audios] + : input_synapse.num_audios + ] if len(metadata) < len(audios.audio_metadata): bt.logging.info( - f"Filtered {len(audios.audio_metadata)} audios down to {len(metadata)} audios") + f"Filtered {len(audios.audio_metadata)} audios down to {len(metadata)} audios" + ) # if randomly tripped, flag our random check to pull a video from miner's submissions check_video = CHECK_PROBABILITY > random.random() # pull a random video and/or description only - random_meta_and_vid = await self.get_random_youtube_video(metadata, check_video) + random_meta_and_vid = await self.get_random_youtube_video( + metadata, check_video + ) if random_meta_and_vid is None: return FAKE_VIDEO_PUNISHMENT @@ -1074,21 +1193,25 @@ async def check_audios_and_calculate_rewards( embeddings = Embeddings( video=None, - audio=torch.stack([torch.tensor(a.audio_emb) - for a in metadata]).to(self.imagebind.device), - description=None + audio=torch.stack([torch.tensor(a.audio_emb) for a in metadata]).to( + self.imagebind.device + ), + description=None, ) # check and deduplicate videos based on embedding similarity checks. We do this because we're not uploading to pinecone first. metadata_is_similar = await self.deduplicate_audios(embeddings) - metadata = [metadata for metadata, too_similar in zip( - metadata, metadata_is_similar) if not too_similar] - embeddings = self.filter_embeddings( - embeddings, metadata_is_similar) + metadata = [ + metadata + for metadata, too_similar in zip(metadata, metadata_is_similar) + if not too_similar + ] + embeddings = self.filter_embeddings(embeddings, metadata_is_similar) if len(metadata) < len(audios.audio_metadata): bt.logging.info( - f"Deduplicated {len(audios.audio_metadata)} audios down to {len(metadata)} audios") + f"Deduplicated {len(audios.audio_metadata)} audios down to {len(metadata)} audios" + ) # return minimum score if no unique videos were found if len(metadata) == 0: @@ -1096,20 +1219,26 @@ async def check_audios_and_calculate_rewards( # first get local novelty scores local_novelty_scores = self.compute_novelty_score_among_batch_audio( - embeddings) + embeddings + ) pre_filter_metadata_length = len(metadata) # check scores from index for being too similar is_too_similar = [ - score < DIFFERENCE_THRESHOLD for score in local_novelty_scores] + score < DIFFERENCE_THRESHOLD for score in local_novelty_scores + ] # filter out metadata too similar - metadata = [metadata for metadata, too_similar in zip( - metadata, is_too_similar) if not too_similar] + metadata = [ + metadata + for metadata, too_similar in zip(metadata, is_too_similar) + if not too_similar + ] # filter out embeddings too similar embeddings = self.filter_embeddings(embeddings, is_too_similar) if len(metadata) < pre_filter_metadata_length: bt.logging.info( - f"Filtering {pre_filter_metadata_length} audios down to {len(metadata)} audios that are too similar to audios in our index.") + f"Filtering {pre_filter_metadata_length} audios down to {len(metadata)} audios that are too similar to audios in our index." + ) # return minimum score if no unique videos were found if len(metadata) == 0: @@ -1119,52 +1248,62 @@ async def check_audios_and_calculate_rewards( # Filter audios based on length constraints pre_filter_metadata_length = len(metadata) metadata = [ - meta for meta in metadata + meta + for meta in metadata if (meta.end_time - meta.start_time) >= MIN_AUDIO_LENGTH_SECONDS and (meta.end_time - meta.start_time) <= MAX_AUDIO_LENGTH_SECONDS ] if len(metadata) < pre_filter_metadata_length: bt.logging.info( - f"Filtered {pre_filter_metadata_length} audios down to {len(metadata)} audios based on length constraints") + f"Filtered {pre_filter_metadata_length} audios down to {len(metadata)} audios based on length constraints" + ) # Return minimum score if no audios remain after filtering if len(metadata) == 0: return MIN_SCORE total_audio_length = sum( - (meta.end_time - meta.start_time) for meta in metadata) + (meta.end_time - meta.start_time) for meta in metadata + ) bt.logging.info( - f"Average audio length: {total_audio_length/len(metadata):.2f} seconds") - audio_length_score = total_audio_length / \ - (NUM_AUDIOS*MAX_AUDIO_LENGTH_SECONDS) + f"Average audio length: {total_audio_length / len(metadata):.2f} seconds" + ) + audio_length_score = total_audio_length / ( + NUM_AUDIOS * MAX_AUDIO_LENGTH_SECONDS + ) - audio_query_score = sum(F.cosine_similarity( - embeddings.audio, query_emb - ).tolist())/len(metadata) + audio_query_score = sum( + F.cosine_similarity(embeddings.audio, query_emb).tolist() + ) / len(metadata) bt.logging.info(f"Audio query score: {audio_query_score}") # Randomly sample one audio for duration check selected_random_meta = random.choice(metadata) audio_array, sr = sf.read( - BytesIO(base64.b64decode(selected_random_meta.audio_bytes))) + BytesIO(base64.b64decode(selected_random_meta.audio_bytes)) + ) audio_duration = len(audio_array) / sr bt.logging.info( - f"Selected Youtube Video: {selected_random_meta.video_id}, Duration: {audio_duration:.2f} seconds") + f"Selected Youtube Video: {selected_random_meta.video_id}, Duration: {audio_duration:.2f} seconds" + ) audio_quality_scores = self.audio_score.total_score( audio_array, sr, selected_random_meta.diar_timestamps_start, selected_random_meta.diar_timestamps_end, - selected_random_meta.diar_speakers + selected_random_meta.diar_speakers, ) audio_quality_total_score = ( - audio_quality_scores["speech_content_score"] * SPEECH_CONTENT_SCALING_FACTOR + - audio_quality_scores["speaker_dominance_score"] * SPEAKER_DOMINANCE_SCALING_FACTOR + - audio_quality_scores["background_noise_score"] * BACKGROUND_NOISE_SCALING_FACTOR + - audio_quality_scores["unique_speakers_error"] * - UNIQUE_SPEAKERS_ERROR_SCALING_FACTOR + audio_quality_scores["speech_content_score"] + * SPEECH_CONTENT_SCALING_FACTOR + + audio_quality_scores["speaker_dominance_score"] + * SPEAKER_DOMINANCE_SCALING_FACTOR + + audio_quality_scores["background_noise_score"] + * BACKGROUND_NOISE_SCALING_FACTOR + + audio_quality_scores["unique_speakers_error"] + * UNIQUE_SPEAKERS_ERROR_SCALING_FACTOR ) # query score @@ -1172,20 +1311,18 @@ async def check_audios_and_calculate_rewards( miner_diar_segment = { "start": selected_random_meta.diar_timestamps_start, "end": selected_random_meta.diar_timestamps_end, - "speakers": selected_random_meta.diar_speakers + "speakers": selected_random_meta.diar_speakers, } diarization_score = calculate_diarization_metrics( - audio_array, - sr, - miner_diar_segment + audio_array, sr, miner_diar_segment ) inverse_der = diarization_score["inverse_der"] total_score = ( - DIARIZATION_SCALING_FACTOR * inverse_der + - AUDIO_LENGTH_SCALING_FACTOR * audio_length_score + - AUDIO_QUALITY_SCALING_FACTOR * audio_quality_total_score + - AUDIO_QUERY_RELEVANCE_SCALING_FACTOR * audio_query_score + DIARIZATION_SCALING_FACTOR * inverse_der + + AUDIO_LENGTH_SCALING_FACTOR * audio_length_score + + AUDIO_QUALITY_SCALING_FACTOR * audio_quality_total_score + + AUDIO_QUERY_RELEVANCE_SCALING_FACTOR * audio_query_score ) bt.logging.info( @@ -1197,9 +1334,17 @@ async def check_audios_and_calculate_rewards( ) # Upload our final results to API endpoint for index and dataset insertion. Include leaderboard statistics miner_hotkey = audios.axon.hotkey - bt.logging.info( - f"Uploading audio metadata for miner: {miner_hotkey}") - upload_result = await self.upload_audio_metadata(metadata, inverse_der, audio_length_score, audio_quality_total_score, audio_query_score, audios.query, total_score, miner_hotkey) + bt.logging.info(f"Uploading audio metadata for miner: {miner_hotkey}") + upload_result = await self.upload_audio_metadata( + metadata, + inverse_der, + audio_length_score, + audio_quality_total_score, + audio_query_score, + audios.query, + total_score, + miner_hotkey, + ) if upload_result: bt.logging.info("Uploading of audio metadata successful.") else: @@ -1207,8 +1352,7 @@ async def check_audios_and_calculate_rewards( return total_score except Exception as e: - bt.logging.error( - f"Error in check_audios_and_calculate_rewards: {e}") + bt.logging.error(f"Error in check_audios_and_calculate_rewards: {e}") traceback.print_exc() return None @@ -1248,13 +1392,15 @@ async def get_rewards( Returns a tensor of rewards for the given query and responses. """ # Get all the reward results by iteratively calling your reward() function. - rewards = await asyncio.gather(*[ - self.reward( - input_synapse, - response, - ) - for response in responses - ]) + rewards = await asyncio.gather( + *[ + self.reward( + input_synapse, + response, + ) + for response in responses + ] + ) return rewards """ @@ -1291,13 +1437,16 @@ async def get_rewards( async def get_focus_videos(self) -> Dict[str, Dict]: async with ClientSession() as session: try: - async with session.get(f"{self.api_root}/api/focus/miner_purchase_scores", timeout=10) as response: + async with session.get( + f"{self.api_root}/api/focus/miner_purchase_scores", timeout=10 + ) as response: if response.status == 200: return await response.json() else: error_message = await response.text() bt.logging.warning( - f"Retrieving miner focus videos failed. Status: {response.status}, Message: {error_message}") + f"Retrieving miner focus videos failed. Status: {response.status}, Message: {error_message}" + ) return {} except asyncio.TimeoutError: bt.logging.error("Request timed out in get_focus_videos") diff --git a/omega/audio_scoring.py b/omega/audio_scoring.py index 8c295bdc..0629f97f 100644 --- a/omega/audio_scoring.py +++ b/omega/audio_scoring.py @@ -1,5 +1,6 @@ import numpy as np -if hasattr(np, 'nan'): + +if hasattr(np, "nan"): np.NaN = np.nan np.NAN = np.nan from pyannote.audio import Pipeline @@ -12,42 +13,45 @@ dotenv.load_dotenv() + class AudioScore: def __init__(self, device="cuda"): - self.device = torch.device(device) - # Load the audio file + # Load the audio file self.pipeline = Pipeline.from_pretrained("salmanshahid/vad").to(self.device) - self.steepness = 5 self.midpoint = 0.3 - - def speech_content_score(self, audio_arr, sr): self.total_duration = librosa.get_duration(y=audio_arr, sr=sr) - output = self.pipeline({"waveform": torch.from_numpy(audio_arr.astype(np.float32)).unsqueeze(0).to(self.device), "sample_rate": sr}) - - self.total_speech_duration = 0 + output = self.pipeline( + { + "waveform": torch.from_numpy(audio_arr.astype(np.float32)) + .unsqueeze(0) + .to(self.device), + "sample_rate": sr, + } + ) + + self.total_speech_duration = 0 for speech in output.get_timeline().support(): self.total_speech_duration += speech.end - speech.start - ratio = self.total_speech_duration / self.total_duration - + ratio = self.total_speech_duration / self.total_duration return ratio - - def speaker_dominance_score(self, timestamps_start, timestamps_end, speakers, dominance_threshold=0.7): + + def speaker_dominance_score( + self, timestamps_start, timestamps_end, speakers, dominance_threshold=0.7 + ): if timestamps_start is None: self.rttm_data = None return 0 - self.rttm_data = pd.DataFrame({ - 'start': timestamps_start, - 'end': timestamps_end, - 'speaker': speakers - }) + self.rttm_data = pd.DataFrame( + {"start": timestamps_start, "end": timestamps_end, "speaker": speakers} + ) # If there's only one speaker, return 0 since dominance is expected if len(set(speakers)) == 1: @@ -56,8 +60,8 @@ def speaker_dominance_score(self, timestamps_start, timestamps_end, speakers, do # Calculate total duration for each speaker speaker_durations = {} for _, row in self.rttm_data.iterrows(): - speaker = row['speaker'] - duration = row['end'] - row['start'] + speaker = row["speaker"] + duration = row["end"] - row["start"] if speaker in speaker_durations: speaker_durations[speaker] += duration else: @@ -66,31 +70,34 @@ def speaker_dominance_score(self, timestamps_start, timestamps_end, speakers, do min_time = min(speaker_durations.values()) return 1 - (max_time - min_time) / self.total_duration - def background_noise_score(self, audio_arr, sr, noise_threshold=0.1): # Load audio and calculate SNR self.audio = audio_arr self.sr = sr - + # Calculate signal power signal_power = np.mean(self.audio**2) - + # Estimate noise power (using the lowest 10% of frame energies as noise estimate) frame_length = int(0.025 * self.sr) # 25ms frames - frames = librosa.util.frame(self.audio, frame_length=frame_length, hop_length=frame_length) + frames = librosa.util.frame( + self.audio, frame_length=frame_length, hop_length=frame_length + ) frame_energies = np.mean(frames**2, axis=0) noise_power = np.mean(np.percentile(frame_energies, 10)) - + # Calculate SNR in dB if noise_power == 0: snr = 100 # High SNR for very clean signal else: snr = 10 * np.log10(signal_power / noise_power) - + # Convert SNR to penalty score (higher SNR = lower penalty) - return 1 - max(0, 1 - (snr / 50)) # Normalize to 0-1 range, assuming 50dB as reference - + return 1 - max( + 0, 1 - (snr / 50) + ) # Normalize to 0-1 range, assuming 50dB as reference + def unique_speakers_error(self, speakers): unique_speakers = len(set(speakers)) if unique_speakers == 2: @@ -98,7 +105,7 @@ def unique_speakers_error(self, speakers): elif unique_speakers == 1 or unique_speakers == 0 or unique_speakers > 4: return 0 else: - return 1/(unique_speakers-1) + return 1 / (unique_speakers - 1) def total_score(self, audio_arr, sr, timestamps_start, timestamps_end, speakers): audio_arr = np.array(audio_arr) @@ -106,54 +113,55 @@ def total_score(self, audio_arr, sr, timestamps_start, timestamps_end, speakers) timestamps_end = np.array(timestamps_end) # speakers = torch.tensor(speakers) speech_content_score = self.speech_content_score(audio_arr, sr) - speaker_dominance_score = self.speaker_dominance_score(timestamps_start, timestamps_end, speakers) + speaker_dominance_score = self.speaker_dominance_score( + timestamps_start, timestamps_end, speakers + ) background_noise_score = self.background_noise_score(audio_arr, sr) return { - "speech_content_score": speech_content_score, - "speaker_dominance_score": speaker_dominance_score, + "speech_content_score": speech_content_score, + "speaker_dominance_score": speaker_dominance_score, "background_noise_score": background_noise_score, "unique_speakers_error": self.unique_speakers_error(speakers), } -if __name__ == "__main__": - +if __name__ == "__main__": from datasets import load_dataset import huggingface_hub - repo_id = "diarizers-community/voxconverse" - ds = load_dataset(repo_id, split="test", cache_dir="/workspace/tezuesh/voxconverse/data_cache") + ds = load_dataset( + repo_id, split="test", cache_dir="/workspace/tezuesh/voxconverse/data_cache" + ) ds = next(ds.shuffle().iter(batch_size=64)) - audio_arr = ds['audio'][0]['array'] - sr = ds['audio'][0]['sampling_rate'] - timestamps_start = ds['timestamps_start'][0] - timestamps_end = ds['timestamps_end'][0] - speakers = ds['speakers'][0] - + audio_arr = ds["audio"][0]["array"] + sr = ds["audio"][0]["sampling_rate"] + timestamps_start = ds["timestamps_start"][0] + timestamps_end = ds["timestamps_end"][0] + speakers = ds["speakers"][0] # # Save test audio to WAV file import soundfile as sf - - output_audio_path = 'test_audio.wav' + + output_audio_path = "test_audio.wav" sf.write(output_audio_path, audio_arr, sr) print(f"Saved test audio to {output_audio_path}") # Create a DataFrame with timestamps and speakers import pandas as pd - - df = pd.DataFrame({ - 'start': timestamps_start, - 'end': timestamps_end, - 'speaker': speakers - }) - + + df = pd.DataFrame( + {"start": timestamps_start, "end": timestamps_end, "speaker": speakers} + ) + # Save to CSV file - output_path = 'speaker_timestamps.csv' + output_path = "speaker_timestamps.csv" df.to_csv(output_path, index=False) print(f"Saved speaker timestamps to {output_path}") audio_score = AudioScore() - - score = audio_score.total_score(audio_arr, sr, timestamps_start, timestamps_end, speakers) + + score = audio_score.total_score( + audio_arr, sr, timestamps_start, timestamps_end, speakers + ) print(score) diff --git a/omega/augment.py b/omega/augment.py index 5888f623..d1a26be7 100644 --- a/omega/augment.py +++ b/omega/augment.py @@ -21,7 +21,7 @@ def __call__(self, query: str) -> str: except Exception as e: print(f"Error augmenting query: {e}") return query - + def augment_query(self, query: str) -> str: raise NotImplementedError @@ -38,10 +38,20 @@ class LocalLLMAugment(AbstractAugment): def __init__(self, **kwargs): self.device = kwargs.get("device") if self.device == "cpu": - raise ValueError("Cannot run Local LLM on CPU. Please move to a GPU instance or restart miner with `--neuron.query_augment OpenAIAugment` to use the GPT-4 API for augmenting instead of a local LLM.") + raise ValueError( + "Cannot run Local LLM on CPU. Please move to a GPU instance or restart miner with `--neuron.query_augment OpenAIAugment` to use the GPT-4 API for augmenting instead of a local LLM." + ) model_name = "teknium/OpenHermes-2.5-Mistral-7B" - self.pipe = pipeline("text-generation", model=model_name, device=self.device, torch_dtype=torch.float16, pad_token_id=32000) - bt.logging.info(f"Running query augmentation with local LLM {model_name} (thanks Nous!)") + self.pipe = pipeline( + "text-generation", + model=model_name, + device=self.device, + torch_dtype=torch.float16, + pad_token_id=32000, + ) + bt.logging.info( + f"Running query augmentation with local LLM {model_name} (thanks Nous!)" + ) def augment_query(self, query: str) -> str: prompt = f"""<|im_start|>system @@ -50,7 +60,12 @@ def augment_query(self, query: str) -> str: {get_llm_prompt(query)}<|im_end|> <|im_start|>assistant Detailed query: """ - new_query = self.pipe(prompt, max_new_tokens=64)[0]["generated_text"][len(prompt):].strip().strip("\"").strip("'") + new_query = ( + self.pipe(prompt, max_new_tokens=64)[0]["generated_text"][len(prompt) :] + .strip() + .strip('"') + .strip("'") + ) return new_query @@ -62,14 +77,9 @@ def __init__(self, **kwargs): def augment_query(self, query: str) -> str: response = self.client.chat.completions.create( model="gpt-4-turbo-preview", - messages=[ - { - "role": "user", - "content": get_llm_prompt(query) - } - ], + messages=[{"role": "user", "content": get_llm_prompt(query)}], temperature=0.9, max_tokens=64, top_p=1, ) - return response.choices[0].message.content.strip("\"").strip("'") + return response.choices[0].message.content.strip('"').strip("'") diff --git a/omega/base/miner.py b/omega/base/miner.py index 43176f8c..406f935f 100644 --- a/omega/base/miner.py +++ b/omega/base/miner.py @@ -119,7 +119,6 @@ def run(self): while ( dt.datetime.now() - self.last_sync_check ).total_seconds() < self.sync_check_interval: - # Wait before checking again. time.sleep(1) diff --git a/omega/base/neuron.py b/omega/base/neuron.py index bdf9a850..00b94d8a 100644 --- a/omega/base/neuron.py +++ b/omega/base/neuron.py @@ -65,7 +65,7 @@ def __init__(self, config=None): self.config = self.config() self.config.merge(base_config) self.check_config(self.config) - + # Set up logging with the provided configuration. bt.logging.set_config(config=self.config.logging) @@ -82,12 +82,8 @@ def __init__(self, config=None): # The wallet holds the cryptographic key pairs for the miner. if self.config.mock: self.wallet = bt.MockWallet(config=self.config) - self.subtensor = MockSubtensor( - self.config.netuid, wallet=self.wallet - ) - self.metagraph = MockMetagraph( - self.config.netuid, subtensor=self.subtensor - ) + self.subtensor = MockSubtensor(self.config.netuid, wallet=self.wallet) + self.metagraph = MockMetagraph(self.config.netuid, subtensor=self.subtensor) else: self.wallet = bt.wallet(config=self.config) self.subtensor = bt.subtensor(config=self.config) @@ -101,9 +97,7 @@ def __init__(self, config=None): self.check_registered() # Each miner gets a unique identity (UID) in the network for differentiation. - self.uid = self.metagraph.hotkeys.index( - self.wallet.hotkey.ss58_address - ) + self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) bt.logging.info( f"Running neuron on subnet: {self.config.netuid} with uid {self.uid} using network: {self.subtensor.chain_endpoint}" ) @@ -117,8 +111,7 @@ def __init__(self, config=None): # ... @abstractmethod - def run(self): - ... + def run(self): ... def sync(self): """ @@ -128,7 +121,9 @@ def sync(self): try: self.check_registered() except Exception as e: - bt.logging.error(f"Error checking registration status: {e}. Continuing incase it is a temporary subtensor connection issue.") + bt.logging.error( + f"Error checking registration status: {e}. Continuing incase it is a temporary subtensor connection issue." + ) if self.should_sync_metagraph(): self.resync_metagraph() @@ -173,10 +168,8 @@ def should_set_weights(self) -> bool: # Define appropriate logic for when set weights. return ( - (self.block - self.metagraph.last_update[self.uid]) - > self.config.neuron.epoch_length - and self.neuron_type != "MinerNeuron" - ) + self.block - self.metagraph.last_update[self.uid] + ) > self.config.neuron.epoch_length and self.neuron_type != "MinerNeuron" def save_state(self): bt.logging.warning( diff --git a/omega/base/validator.py b/omega/base/validator.py index 37b21192..5981e96e 100644 --- a/omega/base/validator.py +++ b/omega/base/validator.py @@ -74,7 +74,7 @@ def __init__(self, config=None): self.audio_score_arr = torch.zeros( self.metagraph.n, dtype=torch.float32, device=self.device ) - + # Serve axon to enable external connections. if not self.config.neuron.axon_off: self.serve_axon() @@ -117,37 +117,38 @@ def serve_axon(self): pass except Exception as e: - bt.logging.error( - f"Failed to create Axon initialize with exception: {e}" - ) + bt.logging.error(f"Failed to create Axon initialize with exception: {e}") pass async def concurrent_forward(self): coroutines = [ - self.forward() - for _ in range(self.config.neuron.num_concurrent_forwards) + self.forward() for _ in range(self.config.neuron.num_concurrent_forwards) ] await asyncio.gather(*coroutines) def is_git_latest(self) -> bool: - p = Popen(['git', 'rev-parse', 'HEAD'], stdout=PIPE, stderr=PIPE) + p = Popen(["git", "rev-parse", "HEAD"], stdout=PIPE, stderr=PIPE) out, err = p.communicate() if err: return False current_commit = out.decode().strip() - p = Popen(['git', 'ls-remote', 'origin', 'HEAD'], stdout=PIPE, stderr=PIPE) + p = Popen(["git", "ls-remote", "origin", "HEAD"], stdout=PIPE, stderr=PIPE) out, err = p.communicate() if err: return False latest_commit = out.decode().split()[0] - bt.logging.info(f'Current commit: {current_commit}, Latest commit: {latest_commit}') + bt.logging.info( + f"Current commit: {current_commit}, Latest commit: {latest_commit}" + ) return current_commit == latest_commit def should_restart(self) -> bool: # Check if enough time has elapsed since the last update check, if not assume we are up to date. - if (datetime.now() - self.last_update_check).seconds < self.update_check_interval: + if ( + datetime.now() - self.last_update_check + ).seconds < self.update_check_interval: return False - + self.last_update_check = datetime.now() return not self.is_git_latest() @@ -190,7 +191,7 @@ def run(self): break if self.config.neuron.auto_update and self.should_restart(): - bt.logging.info(f'Validator is out of date, quitting to restart.') + bt.logging.info(f"Validator is out of date, quitting to restart.") raise KeyboardInterrupt # Sync metagraph and potentially set weights. @@ -221,10 +222,14 @@ def run(self): if (dt.datetime.now() - self.load_focus_rewards_start) >= dt.timedelta( hours=1 ): - bt.logging.info("Reloading focus videos rewards percent after 1 hour.") + bt.logging.info( + "Reloading focus videos rewards percent after 1 hour." + ) self.FOCUS_REWARDS_PERCENT = self.load_focus_rewards_percent() self.AUDIO_REWARDS_PERCENT = AUDIO_REWARDS_PERCENT - self.YOUTUBE_REWARDS_PERCENT = 1.0 - self.FOCUS_REWARDS_PERCENT - self.AUDIO_REWARDS_PERCENT + self.YOUTUBE_REWARDS_PERCENT = ( + 1.0 - self.FOCUS_REWARDS_PERCENT - self.AUDIO_REWARDS_PERCENT + ) self.load_focus_rewards_start = dt.datetime.now() # If someone intentionally stops the validator, it'll safely terminate operations. @@ -236,9 +241,7 @@ def run(self): # In case of unforeseen errors, the validator will log the error and continue operations. except Exception as err: bt.logging.error("Error during validation", str(err)) - bt.logging.debug( - print_exception(type(err), err, err.__traceback__) - ) + bt.logging.debug(print_exception(type(err), err, err.__traceback__)) def run_in_background_thread(self): """ @@ -320,17 +323,30 @@ def set_weights(self): f"Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions." ) - self.scores, self.focus_scores, self.audio_score_arr = self.pad_tensors(self.scores, self.focus_scores, self.audio_score_arr) + self.scores, self.focus_scores, self.audio_score_arr = self.pad_tensors( + self.scores, self.focus_scores, self.audio_score_arr + ) - bt.logging.debug(f"Normalizing scores with YOUTUBE_REWARDS_PERCENT: {self.YOUTUBE_REWARDS_PERCENT}, FOCUS_REWARDS_PERCENT: {self.FOCUS_REWARDS_PERCENT}, AUDIO_REWARDS_PERCENT: {self.AUDIO_REWARDS_PERCENT}") + bt.logging.debug( + f"Normalizing scores with YOUTUBE_REWARDS_PERCENT: {self.YOUTUBE_REWARDS_PERCENT}, FOCUS_REWARDS_PERCENT: {self.FOCUS_REWARDS_PERCENT}, AUDIO_REWARDS_PERCENT: {self.AUDIO_REWARDS_PERCENT}" + ) # Calculate the average reward for each uid across non-zero values. # Replace any NaN values with 0. # Normalize the youtube rewards and scale by the percentage. - raw_weights_youtube = torch.nn.functional.normalize(self.scores, p=1, dim=0) * self.YOUTUBE_REWARDS_PERCENT + raw_weights_youtube = ( + torch.nn.functional.normalize(self.scores, p=1, dim=0) + * self.YOUTUBE_REWARDS_PERCENT + ) # Normalize the focus rewards and scale by the percentage. - raw_weights_focus = torch.nn.functional.normalize(self.focus_scores, p=1, dim=0) * self.FOCUS_REWARDS_PERCENT + raw_weights_focus = ( + torch.nn.functional.normalize(self.focus_scores, p=1, dim=0) + * self.FOCUS_REWARDS_PERCENT + ) # Normalize the audio rewards and scale by the percentage. - raw_weights_audio = torch.nn.functional.normalize(self.audio_score_arr, p=1, dim=0) * self.AUDIO_REWARDS_PERCENT + raw_weights_audio = ( + torch.nn.functional.normalize(self.audio_score_arr, p=1, dim=0) + * self.AUDIO_REWARDS_PERCENT + ) # Combine the youtube and focus rewards. raw_weights = raw_weights_youtube + raw_weights_focus + raw_weights_audio @@ -341,8 +357,10 @@ def set_weights(self): bt.logging.debug("raw_weights", raw_weights) bt.logging.debug("raw_weight_uids", self.metagraph.uids.to("cpu")) if raw_weights.shape[0] > self.metagraph.uids.shape[0]: - bt.logging.warning("More raw_weights than metagraph uids, truncating raw_weights.") - raw_weights = raw_weights[:self.metagraph.uids.shape[0]] + bt.logging.warning( + "More raw_weights than metagraph uids, truncating raw_weights." + ) + raw_weights = raw_weights[: self.metagraph.uids.shape[0]] # Process the raw weights to final_weights via subtensor limitations. try: ( @@ -358,7 +376,9 @@ def set_weights(self): bt.logging.debug("processed_weights", processed_weights) bt.logging.debug("processed_weight_uids", processed_weight_uids) except Exception as e: - bt.logging.error(f"Failed to process weights with exception: {e}, skipping set_weights this time") + bt.logging.error( + f"Failed to process weights with exception: {e}, skipping set_weights this time" + ) return # Convert to uint16 weights and uids. @@ -413,9 +433,7 @@ def resync_metagraph(self): # If so, we need to add new hotkeys and moving averages. if len(self.hotkeys) < len(self.metagraph.hotkeys): # Update the size of the moving average scores. - new_moving_average = torch.zeros((self.metagraph.n)).to( - self.device - ) + new_moving_average = torch.zeros((self.metagraph.n)).to(self.device) min_len = min(len(self.hotkeys), len(self.scores)) new_moving_average[:min_len] = self.scores[:min_len] self.scores = new_moving_average @@ -433,11 +451,15 @@ def update_scores(self, rewards: torch.FloatTensor, uids: List[int]): return if len(uids) == 0: - bt.logging.debug("self.update_scores: Miner UIDs list is empty, returning early") + bt.logging.debug( + "self.update_scores: Miner UIDs list is empty, returning early" + ) return if len(rewards) != len(uids): - bt.logging.exception("self.update_scores: Rewards are not the same size as UIDs list (THIS SHOULD NEVER HAPPEN!)") + bt.logging.exception( + "self.update_scores: Rewards are not the same size as UIDs list (THIS SHOULD NEVER HAPPEN!)" + ) return # Check if rewards contains NaN values. @@ -454,9 +476,11 @@ def update_scores(self, rewards: torch.FloatTensor, uids: List[int]): # Compute forward pass rewards, assumes uids are mutually exclusive. # shape: [ metagraph.n ] - scattered_rewards: torch.FloatTensor = self.scores.to(self.device).scatter( - 0, uids_tensor.to(self.device), rewards.to(self.device) - ).to(self.device) + scattered_rewards: torch.FloatTensor = ( + self.scores.to(self.device) + .scatter(0, uids_tensor.to(self.device), rewards.to(self.device)) + .to(self.device) + ) bt.logging.debug(f"Scattered rewards: {rewards}") # Update scores with rewards produced by this step. @@ -468,7 +492,7 @@ def update_scores(self, rewards: torch.FloatTensor, uids: List[int]): bt.logging.debug(f"Updated moving avg scores: {self.scores}") def update_focus_scores(self, rewards: torch.FloatTensor, uids: List[int]): - """ Unlike other update_*_scores functions, this function does not perform an exponential moving average. """ + """Unlike other update_*_scores functions, this function does not perform an exponential moving average.""" # Check if rewards contains NaN values. if torch.isnan(rewards).any(): bt.logging.warning(f"NaN values detected in rewards: {rewards}") @@ -483,9 +507,11 @@ def update_focus_scores(self, rewards: torch.FloatTensor, uids: List[int]): # Compute forward pass rewards, assumes uids are mutually exclusive. # shape: [ metagraph.n ] - self.focus_scores: torch.FloatTensor = self.focus_scores.to(self.device).scatter( - 0, uids_tensor.to(self.device), rewards.to(self.device) - ).to(self.device) + self.focus_scores: torch.FloatTensor = ( + self.focus_scores.to(self.device) + .scatter(0, uids_tensor.to(self.device), rewards.to(self.device)) + .to(self.device) + ) bt.logging.debug(f"Scattered rewards: {self.focus_scores}") def update_audio_scores(self, rewards: torch.FloatTensor, uids: List[int]): @@ -496,18 +522,20 @@ def update_audio_scores(self, rewards: torch.FloatTensor, uids: List[int]): bt.logging.warning(f"NaN values detected in rewards: {rewards}") # Replace any NaN values in rewards with 0. rewards = torch.nan_to_num(rewards, 0) - + # check if `uids` is already a tensor and clone it to avoid the warning. if isinstance(uids, torch.Tensor): uids_tensor = uids.clone().detach() else: uids_tensor = torch.tensor(uids).to(self.device) - + # compute forward pass rewards, assumes uids are mutually exclusive. # shape: [metagraph.n] - scattered_rewards: torch.FloatTensor = self.audio_score_arr.to(self.device).scatter( - 0, uids_tensor.to(self.device), rewards.to(self.device) - ).to(self.device) + scattered_rewards: torch.FloatTensor = ( + self.audio_score_arr.to(self.device) + .scatter(0, uids_tensor.to(self.device), rewards.to(self.device)) + .to(self.device) + ) bt.logging.debug(f"Scattered rewards: {rewards}") # update scores with rewards produced by this step. @@ -543,7 +571,9 @@ def load_state(self): return # Load the state of the validator from file. - state = torch.load(self.config.neuron.full_path + "/state.pt", map_location=self.device) + state = torch.load( + self.config.neuron.full_path + "/state.pt", map_location=self.device + ) self.step = state["step"] self.scores = state["scores"] if "focus_scores" in state: @@ -552,7 +582,7 @@ def load_state(self): state["focus_scores"] = torch.zeros( self.metagraph.n, dtype=torch.float32, device=self.device ) - + if "audio_score_arr" in state: self.audio_score_arr = state["audio_score_arr"] else: diff --git a/omega/constants.py b/omega/constants.py index b95d6079..e3d05ba1 100644 --- a/omega/constants.py +++ b/omega/constants.py @@ -37,10 +37,18 @@ SPEAKER_DOMINANCE_SCALING_FACTOR = 0.2 BACKGROUND_NOISE_SCALING_FACTOR = 0.1 UNIQUE_SPEAKERS_ERROR_SCALING_FACTOR = 0.5 -SPEECH_CONTENT_SCALING_FACTOR = 1.0 - BACKGROUND_NOISE_SCALING_FACTOR - \ - SPEAKER_DOMINANCE_SCALING_FACTOR - UNIQUE_SPEAKERS_ERROR_SCALING_FACTOR +SPEECH_CONTENT_SCALING_FACTOR = ( + 1.0 + - BACKGROUND_NOISE_SCALING_FACTOR + - SPEAKER_DOMINANCE_SCALING_FACTOR + - UNIQUE_SPEAKERS_ERROR_SCALING_FACTOR +) AUDIO_LENGTH_SCALING_FACTOR = 0.1 # max 1 AUDIO_QUALITY_SCALING_FACTOR = 0.2 # max 1 DIARIZATION_SCALING_FACTOR = 0.6 # max 1 -AUDIO_QUERY_RELEVANCE_SCALING_FACTOR = 1.0 - DIARIZATION_SCALING_FACTOR - \ - AUDIO_LENGTH_SCALING_FACTOR - AUDIO_QUALITY_SCALING_FACTOR +AUDIO_QUERY_RELEVANCE_SCALING_FACTOR = ( + 1.0 + - DIARIZATION_SCALING_FACTOR + - AUDIO_LENGTH_SCALING_FACTOR + - AUDIO_QUALITY_SCALING_FACTOR +) diff --git a/omega/diarization_metric.py b/omega/diarization_metric.py index c1a2b207..66032596 100644 --- a/omega/diarization_metric.py +++ b/omega/diarization_metric.py @@ -4,19 +4,17 @@ import numpy as np - - def calculate_diarization_metrics(audio_arr, sr, true_segments): """Calculate Diarization Error Rate (DER) and related metrics using pyannote metrics""" audio_arr = np.asarray(audio_arr).astype(np.float32) pred_segments = pipeline.process(audio_arr, sr) - + # Convert dictionary segments to pyannote Annotation format def segments_to_annotation(segments): annotation = Annotation() - for i in range(len(segments['start'])): - segment = Segment(segments['start'][i], segments['end'][i]) - annotation[segment] = segments['speakers'][i] + for i in range(len(segments["start"])): + segment = Segment(segments["start"][i], segments["end"][i]) + annotation[segment] = segments["speakers"][i] return annotation # Convert both predictions and ground truth @@ -27,23 +25,24 @@ def segments_to_annotation(segments): metric = DiarizationErrorRate(skip_overlap=True) der = metric(reference, hypothesis) # optimal_mapping = metric.optimal_mapping(reference, hypothesis) - + # Get detailed components components = metric(reference, hypothesis, detailed=True) - miss_rate = components['missed detection'] / components['total'] - false_alarm_rate = components['false alarm'] / components['total'] - speaker_error_rate = components['confusion'] / components['total'] + miss_rate = components["missed detection"] / components["total"] + false_alarm_rate = components["false alarm"] / components["total"] + speaker_error_rate = components["confusion"] / components["total"] return { "inverse_der": 1 - max(0, min(1, der)), "miss_rate": 1 - miss_rate, "false_alarm_rate": 1 - false_alarm_rate, - "speaker_error_rate": 1 - speaker_error_rate + "speaker_error_rate": 1 - speaker_error_rate, } diarization_model_id = "tezuesh/diarization" -overlap_detection_model_id = "tezuesh/overlapped-speech-detection" -pipeline = CustomDiarizationPipeline(overlap_detection_model_id=overlap_detection_model_id, - diarization_model_id=diarization_model_id) - +overlap_detection_model_id = "tezuesh/overlapped-speech-detection" +pipeline = CustomDiarizationPipeline( + overlap_detection_model_id=overlap_detection_model_id, + diarization_model_id=diarization_model_id, +) diff --git a/omega/diarization_pipeline.py b/omega/diarization_pipeline.py index 99b2e7ae..4bd174ba 100644 --- a/omega/diarization_pipeline.py +++ b/omega/diarization_pipeline.py @@ -2,7 +2,8 @@ import torch import torchaudio import numpy as np -if hasattr(np, 'nan'): + +if hasattr(np, "nan"): np.NaN = np.nan np.NAN = np.nan from pyannote.audio import Pipeline @@ -12,55 +13,73 @@ class CustomDiarizationPipeline: def __init__(self, overlap_detection_model_id, diarization_model_id, device="cuda"): self.device = torch.device(device) - self.overlapped_speech_detection_pipeline = Pipeline.from_pretrained(overlap_detection_model_id).to(self.device) - self.diarization_pipeline = Pipeline.from_pretrained(diarization_model_id).to(self.device) + self.overlapped_speech_detection_pipeline = Pipeline.from_pretrained( + overlap_detection_model_id + ).to(self.device) + self.diarization_pipeline = Pipeline.from_pretrained(diarization_model_id).to( + self.device + ) - def preprocess_audio(self, audio_arr, sr): waveform, sample_rate = torch.from_numpy(audio_arr), sr # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) - + # Apply high-pass filter to remove low frequency noise - waveform = torchaudio.functional.highpass_biquad(waveform, sample_rate, cutoff_freq=100) - + waveform = torchaudio.functional.highpass_biquad( + waveform, sample_rate, cutoff_freq=100 + ) + # Apply noise reduction using spectral subtraction - spec = torch.stft(waveform[0], - n_fft=2048, - hop_length=512, - win_length=2048, - window=torch.hann_window(2048).to(waveform.device), - return_complex=True) - + spec = torch.stft( + waveform[0], + n_fft=2048, + hop_length=512, + win_length=2048, + window=torch.hann_window(2048).to(waveform.device), + return_complex=True, + ) + # Estimate noise from first few frames noise_estimate = torch.mean(torch.abs(spec[:, :50]), dim=1, keepdim=True) - + # Subtract noise estimate and apply soft thresholding spec_mag = torch.abs(spec) spec_phase = torch.angle(spec) - spec_mag = torch.maximum(spec_mag - 2 * noise_estimate, torch.zeros_like(spec_mag)) - + spec_mag = torch.maximum( + spec_mag - 2 * noise_estimate, torch.zeros_like(spec_mag) + ) + # Reconstruct signal spec = spec_mag * torch.exp(1j * spec_phase) - waveform = torch.istft(spec, - n_fft=2048, - hop_length=512, - win_length=2048, - window=torch.hann_window(2048).to(waveform.device)) + waveform = torch.istft( + spec, + n_fft=2048, + hop_length=512, + win_length=2048, + window=torch.hann_window(2048).to(waveform.device), + ) waveform = waveform.unsqueeze(0) - + # Normalize audio waveform = waveform / torch.max(torch.abs(waveform)) return waveform, sample_rate - + def detect_overlapping_speech_and_run_diarization(self, audio_arr, sr): # waveform, sample_rate = self.preprocess_audio(audio_arr, sr) - waveform, sample_rate = torch.from_numpy(audio_arr).unsqueeze(0).to(torch.float32), sr - - overlapping_segments = self.overlapped_speech_detection_pipeline({"waveform": waveform, "sample_rate": sample_rate}) - diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) + waveform, sample_rate = ( + torch.from_numpy(audio_arr).unsqueeze(0).to(torch.float32), + sr, + ) + + overlapping_segments = self.overlapped_speech_detection_pipeline( + {"waveform": waveform, "sample_rate": sample_rate} + ) + diarization = self.diarization_pipeline( + {"waveform": waveform, "sample_rate": sample_rate} + ) diar_segments = [] overlap_segments = [] @@ -71,7 +90,7 @@ def detect_overlapping_speech_and_run_diarization(self, audio_arr, sr): overlap_segments.append((speech.start, speech.end, None)) return overlap_segments, diar_segments - + def remove_overlapping_segments(self, overlap_segments, diar_segments): for overlap_segment in overlap_segments: overlap_start = overlap_segment[0] @@ -95,8 +114,6 @@ def remove_overlapping_segments(self, overlap_segments, diar_segments): diar_segments = [seg for seg in diar_segments if seg is not None] return diar_segments - - def write_segments_to_csv(self, segments, output_file, min_duration=0.5): """ Write the start, end, and duration times of diarization segments to a CSV file using pandas. @@ -115,15 +132,26 @@ def write_segments_to_csv(self, segments, output_file, min_duration=0.5): speaker = None duration = end - start if duration >= min_duration: - data.append({'Start': start, 'End': end, 'Duration': duration, 'Speaker': speaker}) + data.append( + { + "Start": start, + "End": end, + "Duration": duration, + "Speaker": speaker, + } + ) df = pd.DataFrame(data) df.to_csv(output_file, index=False) def filter_segments_by_duration(self, segments, min_duration=0.7): - return [segment for segment in segments if segment[1] - segment[0] >= min_duration] - - def generate_audio_patches(self, audio_arr, sr, segments, output_dir, min_duration=0.5): + return [ + segment for segment in segments if segment[1] - segment[0] >= min_duration + ] + + def generate_audio_patches( + self, audio_arr, sr, segments, output_dir, min_duration=0.5 + ): # Load the audio file using pydub audio, sr = self.preprocess_audio(audio_arr, sr) @@ -155,25 +183,23 @@ def generate_audio_patches(self, audio_arr, sr, segments, output_dir, min_durati audio_segment.export(output_path, format="wav") print(f"Audio patches generated and saved in {output_dir}") - + def segments_to_dict(self, segments): start_timestamps = [segment[0] for segment in segments] end_timestamps = [segment[1] for segment in segments] speakers = [segment[2] for segment in segments] - return { - "start": start_timestamps, - "end": end_timestamps, - "speakers": speakers - } - + return {"start": start_timestamps, "end": end_timestamps, "speakers": speakers} def process(self, audio_arr, sr, output_path=None): - overlapping_segments, diar_segments = self.detect_overlapping_speech_and_run_diarization(audio_arr, sr) - - filtered_overlapping_segments = self.filter_segments_by_duration(overlapping_segments) - diar_segments = self.remove_overlapping_segments(filtered_overlapping_segments, diar_segments) + overlapping_segments, diar_segments = ( + self.detect_overlapping_speech_and_run_diarization(audio_arr, sr) + ) + + filtered_overlapping_segments = self.filter_segments_by_duration( + overlapping_segments + ) + diar_segments = self.remove_overlapping_segments( + filtered_overlapping_segments, diar_segments + ) dataframe = self.segments_to_dict(diar_segments) return dataframe - - - diff --git a/omega/imagebind_wrapper.py b/omega/imagebind_wrapper.py index dccb1535..4e51f9c7 100644 --- a/omega/imagebind_wrapper.py +++ b/omega/imagebind_wrapper.py @@ -13,12 +13,17 @@ from omega import video_utils -BPE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe", "bpe_simple_vocab_16e6.txt.gz") -V2_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".checkpoints", "videobind-v0.2.pth") +BPE_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe", "bpe_simple_vocab_16e6.txt.gz" +) +V2_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".checkpoints", "videobind-v0.2.pth" +) TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH) LENGTH_TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=1024) TOKEN_CHUNK_SIZE = 74 + class Embeddings(BaseModel): class Config: arbitrary_types_allowed = True @@ -52,7 +57,9 @@ def recursive_split(text, delimiters): result = [] current_segment = "" for part in parts: - candidate_segment = current_segment + (delimiter if current_segment else '') + part + candidate_segment = ( + current_segment + (delimiter if current_segment else "") + part + ) if fits_in_token_limit(candidate_segment): current_segment = candidate_segment else: @@ -73,12 +80,10 @@ def split_by_tokens(text): tokens = tokenizer(text) tokens = tokens[tokens != 0][1:-1].tolist() chunks = np.array_split(tokens, int(len(tokens) / max_tokens) or 1) - return [ - tokenizer.decode(segment_tokens) - for segment_tokens in chunks - ] + return [tokenizer.decode(segment_tokens) for segment_tokens in chunks] + + return recursive_split(text, ["\n", ".", "!", "?", ",", " "]) - return recursive_split(text, ['\n', '.', '!', '?', ',', ' ']) def load_and_transform_text_chunks(text, device): if not text: @@ -91,6 +96,7 @@ def load_and_transform_text_chunks(text, device): for segment in split_text_by_token_limit(text, LENGTH_TOKENIZER) ] + def run_async(func, *args, **kwargs): loop = asyncio.get_event_loop() return loop.run_in_executor(None, functools.partial(func, *args, **kwargs)) @@ -116,9 +122,9 @@ def __init__(self, device="cuda:0", v2=False): def generate_text_embeddings(self, text: str): if not self.v2: - return self.imagebind({ - ModalityType.TEXT: load_and_transform_text([text], self.device) - })[ModalityType.TEXT] + return self.imagebind( + {ModalityType.TEXT: load_and_transform_text([text], self.device)} + )[ModalityType.TEXT] chunks = load_and_transform_text_chunks(text, self.device) embeddings = [ self.imagebind({ModalityType.TEXT: chunk})[ModalityType.TEXT] @@ -160,45 +166,56 @@ def embed(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddi description=text_embeddings, ) else: - return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION])) - return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO])) - return_value.description = torch.cat((return_value.description, text_embeddings)) + return_value.video = torch.cat( + (return_value.video, embeddings[ModalityType.VISION]) + ) + return_value.audio = torch.cat( + (return_value.audio, embeddings[ModalityType.AUDIO]) + ) + return_value.description = torch.cat( + (return_value.description, text_embeddings) + ) return return_value @torch.no_grad() def embed_only_video(self, video_files: List[BinaryIO]) -> Embeddings: video_filepaths = [video_file.name for video_file in video_files] durations = [video_utils.get_video_duration(f.name) for f in video_files] - embeddings = self.imagebind({ - ModalityType.VISION: [ - data.load_and_transform_video_data( - [video_filepaths[idx]], - self.device, - )[0] - for idx in range(len(video_filepaths)) - ] - }) + embeddings = self.imagebind( + { + ModalityType.VISION: [ + data.load_and_transform_video_data( + [video_filepaths[idx]], + self.device, + )[0] + for idx in range(len(video_filepaths)) + ] + } + ) return Embeddings( video=embeddings[ModalityType.VISION], ) @torch.no_grad() - def embed_video_and_text(self, video_files: List[BinaryIO], descriptions: List[str]) -> Embeddings: + def embed_video_and_text( + self, video_files: List[BinaryIO], descriptions: List[str] + ) -> Embeddings: video_filepaths = [video_file.name for video_file in video_files] durations = [video_utils.get_video_duration(f.name) for f in video_files] - embeddings = self.imagebind({ - ModalityType.VISION: [ - data.load_and_transform_video_data( - [video_filepaths[idx]], - self.device, - )[0] - for idx in range(len(video_filepaths)) - ], - }) - description_embeddings = torch.stack([ - self.generate_text_embeddings(description) - for description in descriptions - ]) + embeddings = self.imagebind( + { + ModalityType.VISION: [ + data.load_and_transform_video_data( + [video_filepaths[idx]], + self.device, + )[0] + for idx in range(len(video_filepaths)) + ], + } + ) + description_embeddings = torch.stack( + [self.generate_text_embeddings(description) for description in descriptions] + ) return Embeddings( video=embeddings[ModalityType.VISION], description=description_embeddings, @@ -216,12 +233,16 @@ def embed_text(self, texts: List[str]) -> torch.Tensor: return return_value @torch.no_grad() - async def embed_async(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddings: + async def embed_async( + self, descriptions: List[str], video_files: List[BinaryIO] + ) -> Embeddings: return_value = None for idx in range(len(descriptions)): inputs = self.get_inputs(video_files[idx]) # cannot be async embeddings = await run_async(self.imagebind, inputs) - text_embeddings = await run_async(self.generate_text_embeddings, descriptions[idx]) + text_embeddings = await run_async( + self.generate_text_embeddings, descriptions[idx] + ) if not return_value: return_value = Embeddings( video=embeddings[ModalityType.VISION], @@ -229,9 +250,15 @@ async def embed_async(self, descriptions: List[str], video_files: List[BinaryIO] description=text_embeddings, ) else: - return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION])) - return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO])) - return_value.description = torch.cat((return_value.description, text_embeddings)) + return_value.video = torch.cat( + (return_value.video, embeddings[ModalityType.VISION]) + ) + return_value.audio = torch.cat( + (return_value.audio, embeddings[ModalityType.AUDIO]) + ) + return_value.description = torch.cat( + (return_value.description, text_embeddings) + ) return return_value async def embed_text_async(self, texts: List[str]) -> torch.Tensor: diff --git a/omega/miner_utils.py b/omega/miner_utils.py index 657fc89b..fc31a2a1 100644 --- a/omega/miner_utils.py +++ b/omega/miner_utils.py @@ -8,12 +8,18 @@ from omega.protocol import VideoMetadata, AudioMetadata from omega.imagebind_wrapper import ImageBind -from omega.constants import MAX_VIDEO_LENGTH, FIVE_MINUTES, MAX_AUDIO_LENGTH_SECONDS, MIN_AUDIO_LENGTH_SECONDS +from omega.constants import ( + MAX_VIDEO_LENGTH, + FIVE_MINUTES, + MAX_AUDIO_LENGTH_SECONDS, + MIN_AUDIO_LENGTH_SECONDS, +) from omega import video_utils from omega.diarization_pipeline import CustomDiarizationPipeline if os.getenv("OPENAI_API_KEY"): from openai import OpenAI + OPENAI_CLIENT = OpenAI() else: OPENAI_CLIENT = None @@ -22,7 +28,7 @@ def get_description(yt: video_utils.YoutubeDL, video_path: str) -> str: """ Get / generate the description of a video from the YouTube API. - + Miner TODO: Implement logic to get / generate the most relevant and information-rich description of a video from the YouTube API. """ @@ -32,7 +38,9 @@ def get_description(yt: video_utils.YoutubeDL, video_path: str) -> str: return description -def get_relevant_timestamps(query: str, yt: video_utils.YoutubeDL, video_path: str, max_length: int) -> Tuple[int, int]: +def get_relevant_timestamps( + query: str, yt: video_utils.YoutubeDL, video_path: str, max_length: int +) -> Tuple[int, int]: """ Get the optimal start and end timestamps (in seconds) of a video for ensuring relevance to the query. @@ -45,7 +53,9 @@ def get_relevant_timestamps(query: str, yt: video_utils.YoutubeDL, video_path: s return start_time, end_time -def search_and_embed_youtube_videos(query: str, num_videos: int, imagebind: ImageBind) -> List[VideoMetadata]: +def search_and_embed_youtube_videos( + query: str, num_videos: int, imagebind: ImageBind +) -> List[VideoMetadata]: """ Search YouTube for videos matching the given query and return a list of VideoMetadata objects. @@ -66,28 +76,38 @@ def search_and_embed_youtube_videos(query: str, num_videos: int, imagebind: Imag download_path = video_utils.download_youtube_video( result.video_id, start=0, - end=min(result.length, FIVE_MINUTES) # download the first 5 minutes at most + end=min( + result.length, FIVE_MINUTES + ), # download the first 5 minutes at most ) if download_path: clip_path = None try: - result.length = video_utils.get_video_duration(download_path.name) # correct the length - bt.logging.info(f"Downloaded video {result.video_id} ({min(result.length, FIVE_MINUTES)}) in {time.time() - start} seconds") - start, end = get_relevant_timestamps(query, result, download_path, max_length=MAX_VIDEO_LENGTH) + result.length = video_utils.get_video_duration( + download_path.name + ) # correct the length + bt.logging.info( + f"Downloaded video {result.video_id} ({min(result.length, FIVE_MINUTES)}) in {time.time() - start} seconds" + ) + start, end = get_relevant_timestamps( + query, result, download_path, max_length=MAX_VIDEO_LENGTH + ) description = get_description(result, download_path) clip_path = video_utils.clip_video(download_path.name, start, end) bt.logging.info(f"Clip video path: {clip_path}") embeddings = imagebind.embed([description], [clip_path]) - video_metas.append(VideoMetadata( - video_id=result.video_id, - description=description, - views=result.views, - start_time=start, - end_time=end, - video_emb=embeddings.video[0].tolist(), - audio_emb=embeddings.audio[0].tolist(), - description_emb=embeddings.description[0].tolist(), - )) + video_metas.append( + VideoMetadata( + video_id=result.video_id, + description=description, + views=result.views, + start_time=start, + end_time=end, + video_emb=embeddings.video[0].tolist(), + audio_emb=embeddings.audio[0].tolist(), + description_emb=embeddings.description[0].tolist(), + ) + ) finally: download_path.close() if clip_path: @@ -101,9 +121,12 @@ def search_and_embed_youtube_videos(query: str, num_videos: int, imagebind: Imag return video_metas - - -def search_and_diarize_youtube_videos(query: str, num_videos: int, diarization_pipeline: CustomDiarizationPipeline, imagebind: ImageBind) -> List[AudioMetadata]: +def search_and_diarize_youtube_videos( + query: str, + num_videos: int, + diarization_pipeline: CustomDiarizationPipeline, + imagebind: ImageBind, +) -> List[AudioMetadata]: """ Search YouTube for videos matching the given query and return a list of AudioMetadata objects. @@ -124,14 +147,25 @@ def search_and_diarize_youtube_videos(query: str, num_videos: int, diarization_p download_path = video_utils.download_youtube_video( result.video_id, start=0, - end=min(result.length, MAX_AUDIO_LENGTH_SECONDS) # download the first 5 minutes at most + end=min( + result.length, MAX_AUDIO_LENGTH_SECONDS + ), # download the first 5 minutes at most ) if download_path: clip_path = None try: - result.length = video_utils.get_video_duration(download_path.name) # correct the length - bt.logging.info(f"Downloaded audio {result.video_id} ({min(result.length, MAX_AUDIO_LENGTH_SECONDS)}) in {time.time() - start_time_loop} seconds") - start, end = get_relevant_timestamps(query, result, download_path, max_length=MAX_AUDIO_LENGTH_SECONDS) + result.length = video_utils.get_video_duration( + download_path.name + ) # correct the length + bt.logging.info( + f"Downloaded audio {result.video_id} ({min(result.length, MAX_AUDIO_LENGTH_SECONDS)}) in {time.time() - start_time_loop} seconds" + ) + start, end = get_relevant_timestamps( + query, + result, + download_path, + max_length=MAX_AUDIO_LENGTH_SECONDS, + ) # bt.logging.info(f"Audio Start: {start}, End: {end}") description = get_description(result, download_path) audio_bytes = video_utils.get_audio_bytes(download_path.name) @@ -143,24 +177,31 @@ def search_and_diarize_youtube_videos(query: str, num_videos: int, diarization_p clip_path = video_utils.clip_video(download_path.name, start, end) bt.logging.info(f"Clip video path: {clip_path}") embeddings = imagebind.embed([description], [clip_path]) - bt.logging.info(f"Embeddings: {type(embeddings)}, audio_emb: {type(embeddings.audio[0])}, audio_array: {type(audio_array)} {audio_array.shape}, audio_bytes: {type(audio_bytes)}, sr: {sr}, diar_timestamps_start: {type(diar_timestamps_start)}, diar_timestamps_end: {type(diar_timestamps_end)}, diar_speakers: {type(diar_speakers)}") - bt.logging.info(f"Audio duration: {end - start}, actual length: {result.length}") + bt.logging.info( + f"Embeddings: {type(embeddings)}, audio_emb: {type(embeddings.audio[0])}, audio_array: {type(audio_array)} {audio_array.shape}, audio_bytes: {type(audio_bytes)}, sr: {sr}, diar_timestamps_start: {type(diar_timestamps_start)}, diar_timestamps_end: {type(diar_timestamps_end)}, diar_speakers: {type(diar_speakers)}" + ) + bt.logging.info( + f"Audio duration: {end - start}, actual length: {result.length}" + ) bt.logging.info("Diarization Dataframe: ", dataframe) # Convert audio_bytes to base64 string for serialization import base64 - audio_bytes_b64 = base64.b64encode(audio_bytes).decode('utf-8') - - audio_metas.append(AudioMetadata( - video_id=result.video_id, - views=result.views, - start_time=start, - end_time=end, - audio_emb=embeddings.audio[0].tolist(), - audio_bytes=audio_bytes_b64, # Store base64 encoded string instead of raw bytes - diar_timestamps_start=diar_timestamps_start, - diar_timestamps_end=diar_timestamps_end, - diar_speakers=diar_speakers, - )) + + audio_bytes_b64 = base64.b64encode(audio_bytes).decode("utf-8") + + audio_metas.append( + AudioMetadata( + video_id=result.video_id, + views=result.views, + start_time=start, + end_time=end, + audio_emb=embeddings.audio[0].tolist(), + audio_bytes=audio_bytes_b64, # Store base64 encoded string instead of raw bytes + diar_timestamps_start=diar_timestamps_start, + diar_timestamps_end=diar_timestamps_end, + diar_speakers=diar_speakers, + ) + ) finally: download_path.close() if clip_path: @@ -168,10 +209,11 @@ def search_and_diarize_youtube_videos(query: str, num_videos: int, diarization_p if len(audio_metas) == num_videos: break end_time_loop = time.time() - bt.logging.info(f"Audio Time taken for loop: {end_time_loop - start_time_loop}") + bt.logging.info( + f"Audio Time taken for loop: {end_time_loop - start_time_loop}" + ) except Exception as e: bt.logging.error(f"Error searching for videos: {e}") return audio_metas - diff --git a/omega/mock.py b/omega/mock.py index 69eb78df..2b027ffb 100644 --- a/omega/mock.py +++ b/omega/mock.py @@ -37,9 +37,7 @@ def __init__(self, netuid, n=16, wallet=None, network="mock"): class MockMetagraph(bt.metagraph): def __init__(self, netuid=1, network="mock", subtensor=None): - super().__init__( - netuid=netuid, network=network, sync=False - ) + super().__init__(netuid=netuid, network=network, sync=False) if subtensor is not None: self.subtensor = subtensor @@ -57,6 +55,7 @@ class MockDendrite(bt.dendrite): """ Replaces a real bittensor network request with a mock request that just returns some static response for all axons that are passed and adds some random delay. """ + def __init__(self, wallet): super().__init__(wallet) @@ -69,7 +68,6 @@ async def forward( run_async: bool = True, streaming: bool = False, ): - if streaming: raise NotImplementedError("Streaming not implemented yet.") @@ -106,7 +104,10 @@ async def single_axon_response(i, axon): return s return await asyncio.gather( - *(single_axon_response(i, target_axon) for i, target_axon in enumerate(axons)) + *( + single_axon_response(i, target_axon) + for i, target_axon in enumerate(axons) + ) ) return await query_all_axons(streaming) @@ -118,4 +119,4 @@ def __str__(self) -> str: Returns: str: The string representation of the Dendrite object in the format "dendrite()". """ - return "MockDendrite({})".format(self.keypair.ss58_address) \ No newline at end of file + return "MockDendrite({})".format(self.keypair.ss58_address) diff --git a/omega/protocol.py b/omega/protocol.py index 26f74f82..bfd060f7 100644 --- a/omega/protocol.py +++ b/omega/protocol.py @@ -27,6 +27,7 @@ class VideoMetadata(BaseModel): """ A model class representing YouTube video metadata. """ + video_id: str description: str views: int @@ -38,11 +39,10 @@ class VideoMetadata(BaseModel): def __repr_args__(self): parent_args = super().__repr_args__() - exclude_args = ['video_emb', 'audio_emb', 'description_emb'] - return ( - [(a, v) for a, v in parent_args if a not in exclude_args] + - [(a, ["..."]) for a in exclude_args] - ) + exclude_args = ["video_emb", "audio_emb", "description_emb"] + return [(a, v) for a, v in parent_args if a not in exclude_args] + [ + (a, ["..."]) for a in exclude_args + ] class Videos(bt.Synapse): @@ -70,7 +70,8 @@ def to_serializable_dict(self, input_synapse: "Videos") -> dict: response (self). """ json_str = self.replace_with_input(input_synapse).json( - include={"query", "num_videos", "video_metadata"}) + include={"query", "num_videos", "video_metadata"} + ) return json.loads(json_str) def replace_with_input(self, input_synapse: "Videos") -> "Videos": @@ -80,13 +81,11 @@ def replace_with_input(self, input_synapse: "Videos") -> "Videos": return Videos( query=input_synapse.query, num_videos=input_synapse.num_videos, - video_metadata=self.video_metadata[:input_synapse.num_videos], - axon=self.axon + video_metadata=self.video_metadata[: input_synapse.num_videos], + axon=self.axon, ) - - class AudioMetadata(BaseModel): video_id: str views: int @@ -100,12 +99,17 @@ class AudioMetadata(BaseModel): def __repr_args__(self): parent_args = super().__repr_args__() - exclude_args = ['audio_emb', 'audio_bytes', 'diar_timestamps_start', 'diar_timestamps_end', 'diar_speakers'] - return ( - [(a, v) for a, v in parent_args if a not in exclude_args] + - [(a, ["..."]) for a in exclude_args] - ) - + exclude_args = [ + "audio_emb", + "audio_bytes", + "diar_timestamps_start", + "diar_timestamps_end", + "diar_speakers", + ] + return [(a, v) for a, v in parent_args if a not in exclude_args] + [ + (a, ["..."]) for a in exclude_args + ] + class Audios(bt.Synapse): """ @@ -132,7 +136,8 @@ def to_serializable_dict(self, input_synapse: "Audios") -> dict: response (self). """ json_str = self.replace_with_input(input_synapse).json( - include={"query", "num_audios", "audio_metadata"}) + include={"query", "num_audios", "audio_metadata"} + ) return json.loads(json_str) def replace_with_input(self, input_synapse: "Audios") -> "Audios": @@ -143,5 +148,5 @@ def replace_with_input(self, input_synapse: "Audios") -> "Audios": query=input_synapse.query, num_audios=input_synapse.num_audios, audio_metadata=self.audio_metadata, - axon=self.axon + axon=self.axon, ) diff --git a/omega/test_audio.py b/omega/test_audio.py index d531b125..e88f5e56 100644 --- a/omega/test_audio.py +++ b/omega/test_audio.py @@ -1,14 +1,15 @@ from omega.video_utils import get_audio_bytes import base64 + audio_bytes = get_audio_bytes("test_video.mp4") print(audio_bytes) # Save audio bytes to a WAV file -with open('output_audio.wav', 'wb') as f: +with open("output_audio.wav", "wb") as f: f.write(audio_bytes) -audio_bytes_b64 = base64.b64encode(audio_bytes).decode('utf-8') +audio_bytes_b64 = base64.b64encode(audio_bytes).decode("utf-8") print(audio_bytes_b64) # Save base64 encoded audio to file -with open('output_audio_b64.txt', 'w') as f: +with open("output_audio_b64.txt", "w") as f: f.write(audio_bytes_b64) diff --git a/omega/text_similarity.py b/omega/text_similarity.py index 645ac1bb..2d946ba3 100644 --- a/omega/text_similarity.py +++ b/omega/text_similarity.py @@ -6,11 +6,20 @@ revision = "104333d6af6f97649377c2afbde10a7704870c7b" TOKENIZER = AutoTokenizer.from_pretrained(model_path, revision=revision) DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" -MODEL = AutoModel.from_pretrained(model_path, trust_remote_code=True, revision=revision).to(DEVICE) +MODEL = AutoModel.from_pretrained( + model_path, trust_remote_code=True, revision=revision +).to(DEVICE) MODEL.eval() + def get_text_similarity_score(text_0, text_1): - tokens = TOKENIZER([text_0, text_1], max_length=1024, padding=True, truncation=True, return_tensors='pt').to(DEVICE) + tokens = TOKENIZER( + [text_0, text_1], + max_length=1024, + padding=True, + truncation=True, + return_tensors="pt", + ).to(DEVICE) outputs = MODEL(**tokens) embeddings = outputs.last_hidden_state[:, 0] embeddings = F.normalize(embeddings, p=2, dim=1) diff --git a/omega/unstuff.py b/omega/unstuff.py index b840ec0b..a3d67dec 100644 --- a/omega/unstuff.py +++ b/omega/unstuff.py @@ -10,10 +10,16 @@ BPE_PATH, split_text_by_token_limit, ) + CHUNK_SIZE = 60 TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=10000) -UNSTUFF = pipeline("text-classification", "jondurbin/unstuffer-v0.2", device="cuda" if torch.cuda.is_available() else "cpu") +UNSTUFF = pipeline( + "text-classification", + "jondurbin/unstuffer-v0.2", + device="cuda" if torch.cuda.is_available() else "cpu", +) + def is_stuffed(description: str) -> Tuple[bool, float]: result = UNSTUFF(description, truncation=True, max_length=512) @@ -22,9 +28,12 @@ def is_stuffed(description: str) -> Tuple[bool, float]: if stuffed and confidence > 0.75: print(f"Detected stuffed description [{confidence=}]: {description}") elif not stuffed and random.random() <= 0.01: - print(f"Description does not appear to be stuffed [{confidence=}]: {description}") + print( + f"Description does not appear to be stuffed [{confidence=}]: {description}" + ) return stuffed, confidence + def check_extraneous_chunks(description, video_emb, audio_emb, imagebind): bt.logging.info(f"Length of description: {len(description)}") bt.logging.info(f"Length of video_emb: {len(video_emb)}") diff --git a/omega/utils/config.py b/omega/utils/config.py index c6c8cb28..c6646e90 100644 --- a/omega/utils/config.py +++ b/omega/utils/config.py @@ -39,6 +39,7 @@ def is_cuda_available(): pass return "cpu" + def check_config(cls, config: "bt.Config"): r"""Checks/validates the config namespace object.""" bt.logging.check_config(config) @@ -173,7 +174,7 @@ def add_miner_args(cls, parser): help="If set, miners will accept queries from non registered entities. (Dangerous!)", default=False, ) - + parser.add_argument( "--blacklist.validator_min_stake", help="Minimum stake a validator must have to allow queries", @@ -276,16 +277,19 @@ def add_validator_args(cls, parser): "--topics_url", type=str, help="URL to fetch topics from.", - default="https://docs.google.com/spreadsheets/d/e/2PACX-1vR3jKfd4qkxXt5rTvXTTSsz_RYGkxcxh6-jvB9H0Mljiz-nai7xG-E63qEQ9jQhQabBrIAeJWtgKg5j/pub?gid=0&single=true&output=csv" + default="https://docs.google.com/spreadsheets/d/e/2PACX-1vR3jKfd4qkxXt5rTvXTTSsz_RYGkxcxh6-jvB9H0Mljiz-nai7xG-E63qEQ9jQhQabBrIAeJWtgKg5j/pub?gid=0&single=true&output=csv", ) parser.add_argument( "--topics_path", type=str, help="Path to text file containing a list of random topics to collect data for.", - default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "topics.txt") + default=os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "..", "topics.txt" + ), ) + def config(cls): """ Returns the configuration object specific to this miner or validator after adding relevant arguments. diff --git a/omega/utils/logging.py b/omega/utils/logging.py index 7834ba10..422cb661 100644 --- a/omega/utils/logging.py +++ b/omega/utils/logging.py @@ -32,4 +32,4 @@ def event(self, message, *args, **kws): file_handler.setLevel(EVENTS_LEVEL_NUM) logger.addHandler(file_handler) - return logger \ No newline at end of file + return logger diff --git a/omega/utils/uids.py b/omega/utils/uids.py index fc8c4fe8..6d7248c9 100644 --- a/omega/utils/uids.py +++ b/omega/utils/uids.py @@ -26,9 +26,7 @@ def check_uid_availability( return True -def get_random_uids( - self, k: int, exclude: List[int] = None -) -> torch.LongTensor: +def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.LongTensor: """Returns k available random uids from the metagraph. Args: k (int): Number of uids to return. @@ -60,8 +58,7 @@ def get_random_uids( new_avail_uids, min(len(new_avail_uids), k - len(candidate_uids)), ) - uids = torch.tensor(random.sample( - available_uids, - min(k, len(available_uids)) - )).to(self.device) + uids = torch.tensor(random.sample(available_uids, min(k, len(available_uids)))).to( + self.device + ) return uids diff --git a/omega/video_utils.py b/omega/video_utils.py index d03e06d0..d17b900a 100644 --- a/omega/video_utils.py +++ b/omega/video_utils.py @@ -1,4 +1,4 @@ -import re +import re import json import os import tempfile @@ -24,9 +24,10 @@ def seconds_to_str(seconds): def clip_video(video_path: str, start: int, end: int) -> Optional[BinaryIO]: temp_fileobj = tempfile.NamedTemporaryFile(suffix=".mp4") ( - ffmpeg - .input(video_path, ss=seconds_to_str(start), to=seconds_to_str(end)) - .output(temp_fileobj.name, c="copy") # copy flag prevents decoding and re-encoding + ffmpeg.input(video_path, ss=seconds_to_str(start), to=seconds_to_str(end)) + .output( + temp_fileobj.name, c="copy" + ) # copy flag prevents decoding and re-encoding .overwrite_output() .run(quiet=True) ) @@ -35,7 +36,7 @@ def clip_video(video_path: str, start: int, end: int) -> Optional[BinaryIO]: def skip_live(info_dict): """ - function to skip downloading if it's a live video (yt_dlp doesn't respect the 20 minute + function to skip downloading if it's a live video (yt_dlp doesn't respect the 20 minute download limit for live videos), and we don't want to hang on an hour long stream """ if info_dict.get("is_live"): @@ -71,9 +72,16 @@ def search_videos(query, max_results=8): video_id=entry["id"], title=entry["title"], description=entry.get("description"), - length=(int(entry.get("duration")) if entry.get("duration") else FIVE_MINUTES), - views=(entry.get("view_count") if entry.get("view_count") else 0), - ) for entry in result["entries"] + length=( + int(entry.get("duration")) + if entry.get("duration") + else FIVE_MINUTES + ), + views=( + entry.get("view_count") if entry.get("view_count") else 0 + ), + ) + for entry in result["entries"] ] except Exception as e: bt.logging.warning(f"Error searching for videos: {e}") @@ -83,8 +91,11 @@ def search_videos(query, max_results=8): def get_video_duration(filename: str) -> int: metadata = ffmpeg.probe(filename) - video_stream = next((stream for stream in metadata['streams'] if stream['codec_type'] == 'video'), None) - duration = int(float(video_stream['duration'])) + video_stream = next( + (stream for stream in metadata["streams"] if stream["codec_type"] == "video"), + None, + ) + duration = int(float(video_stream["duration"])) return duration @@ -101,16 +112,20 @@ def __init__(self, message: str): def is_valid_youtube_id(youtube_id: str) -> bool: return youtube_id is not None and len(youtube_id) == 11 + def download_youtube_video( - video_id: str, start: Optional[int]=None, end: Optional[int]=None, proxy: Optional[str]=None + video_id: str, + start: Optional[int] = None, + end: Optional[int] = None, + proxy: Optional[str] = None, ) -> Optional[BinaryIO]: if not is_valid_youtube_id(video_id): raise FakeVideoException(f"Invalid Youtube video ID: {video_id}") video_url = f"https://www.youtube.com/watch?v={video_id}" - + temp_fileobj = tempfile.NamedTemporaryFile(suffix=".mp4") - + ydl_opts = { "format": "worst", # Download the worst quality "outtmpl": temp_fileobj.name, # Set the output template to the temporary file"s name @@ -121,7 +136,9 @@ def download_youtube_video( } if start is not None and end is not None: - ydl_opts["download_ranges"] = lambda _, __: [{"start_time": start, "end_time": end}] + ydl_opts["download_ranges"] = lambda _, __: [ + {"start_time": start, "end_time": end} + ] if proxy is not None: ydl_opts["proxy"] = proxy @@ -139,31 +156,51 @@ def download_youtube_video( return temp_fileobj except Exception as e: temp_fileobj.close() - if ( - "Your IP is likely being blocked by Youtube" in str(e) or - "Requested format is not available" in str(e) - ): + if "Your IP is likely being blocked by Youtube" in str( + e + ) or "Requested format is not available" in str(e): raise IPBlockedException(e) # Quick check to see if miner passed an "unplayable" (sign-in required, paid video, etc.). fake_video = False try: result = requests.get(video_url, proxies={"https": proxy}) - json_match = re.search(r"ytInitialPlayerResponse\s*=\s*(\{(?:.*?)\})\s*;\s*<", result.text) + json_match = re.search( + r"ytInitialPlayerResponse\s*=\s*(\{(?:.*?)\})\s*;\s*<", result.text + ) if json_match: player_info = json.loads(json_match.group(1)) - status = player_info.get('playabilityStatus', {}).get('status', 'ok') - unacceptable_statuses = ('UNPLAYABLE',) - if status in unacceptable_statuses or (status == 'ERROR' and player_info['playabilityStatus'].get('reason', '').lower() == 'video unavailable'): + status = player_info.get("playabilityStatus", {}).get("status", "ok") + unacceptable_statuses = ("UNPLAYABLE",) + if status in unacceptable_statuses or ( + status == "ERROR" + and player_info["playabilityStatus"].get("reason", "").lower() + == "video unavailable" + ): if "sign in to confirm you’re not a bot" not in result.text.lower(): - if player_info['playabilityStatus']['errorScreen']['playerErrorMessageRenderer']['subreason']['simpleText'] != "This content isn’t available.": + if ( + player_info["playabilityStatus"]["errorScreen"][ + "playerErrorMessageRenderer" + ]["subreason"]["simpleText"] + != "This content isn’t available." + ): fake_video = True - print(f"Fake video submitted, youtube player status [{status}]: {player_info['playabilityStatus']}") + print( + f"Fake video submitted, youtube player status [{status}]: {player_info['playabilityStatus']}" + ) except Exception as fake_check_exc: print(f"Error sanity checking playability: {fake_check_exc}") if fake_video: raise FakeVideoException("Unplayable video provided") - if any(fake_vid_msg in str(e) for fake_vid_msg in ["Video unavailable", "is not a valid URL", "Incomplete YouTube ID", "Unsupported URL"]): + if any( + fake_vid_msg in str(e) + for fake_vid_msg in [ + "Video unavailable", + "is not a valid URL", + "Incomplete YouTube ID", + "Unsupported URL", + ] + ): if "Video unavailable. This content isn’t available." not in str(e): raise FakeVideoException(e) print(f"Error downloading video: {e}") @@ -173,14 +210,14 @@ def download_youtube_video( def copy_audio(video_path: str) -> BinaryIO: temp_audiofile = tempfile.NamedTemporaryFile(suffix=".aac") ( - ffmpeg - .input(video_path) - .output(temp_audiofile.name, vn=None, acodec='copy') + ffmpeg.input(video_path) + .output(temp_audiofile.name, vn=None, acodec="copy") .overwrite_output() .run(quiet=True) ) return temp_audiofile + def copy_audio_wav(video_path: str) -> BinaryIO: """ Extract audio from video file to 16-bit PCM WAV format. @@ -194,14 +231,13 @@ def copy_audio_wav(video_path: str) -> BinaryIO: temp_audiofile = tempfile.NamedTemporaryFile(suffix=".wav") ( - ffmpeg - .input(video_path) + ffmpeg.input(video_path) .output( temp_audiofile.name, - acodec='pcm_s16le', # 16-bit PCM - ac=1, # mono - ar=16000, # 16kHz sample rate - vn=None # no video + acodec="pcm_s16le", # 16-bit PCM + ac=1, # mono + ar=16000, # 16kHz sample rate + vn=None, # no video ) .overwrite_output() .run(quiet=True) @@ -209,9 +245,10 @@ def copy_audio_wav(video_path: str) -> BinaryIO: return temp_audiofile + def get_audio_bytes(video_path: str) -> bytes: audio_file = copy_audio_wav(video_path) - with open(audio_file.name, 'rb') as f: + with open(audio_file.name, "rb") as f: wav_bytes = f.read() # Clean up temp file diff --git a/purchase_focus_video.py b/purchase_focus_video.py index f118dfc2..a5993968 100644 --- a/purchase_focus_video.py +++ b/purchase_focus_video.py @@ -30,11 +30,11 @@ - Allows you to purchase a video by entering its ID. - You'll need to provide your wallet information (name, hotkey, path). - The script will initiate a transfer of TAO tokens to the OMEGA Focus App user who created the video. This secures the purchase of the video. - - After the transfer is complete, the script will attempt to verify the purchase. + - After the transfer is complete, the script will attempt to verify the purchase. - Once successful, you're all set! SN24 validators will automatically detect your purchase and reward your expected TAO emissions. Option 3: Verify Purchase - - This option is used when there are issues with the purchase verification during the purchase process. + - This option is used when there are issues with the purchase verification during the purchase process. - If you've successfully transferred the TAO tokens but the purchase wasn't verified, you can use this option to verify the purchase. - You'll need to provide the Video ID, Miner Hotkey, and Block Hash. @@ -71,15 +71,17 @@ from bittensor import wallet as btcli_wallet from tabulate import tabulate -parser = argparse.ArgumentParser(description='Interact with the OMEGA Focus Videos API.') +parser = argparse.ArgumentParser( + description="Interact with the OMEGA Focus Videos API." +) args = parser.parse_args() -SUBTENSOR_NETWORK = None # "test" or None +SUBTENSOR_NETWORK = None # "test" or None API_BASE = ( "https://dev-sn24-api.omegatron.ai" - if SUBTENSOR_NETWORK == "test" else - "https://sn24-api.omegatron.ai" + if SUBTENSOR_NETWORK == "test" + else "https://sn24-api.omegatron.ai" ) # API_BASE = "http://localhost:8000" @@ -88,69 +90,76 @@ RED = "\033[91m" RESET = "\033[0m" + def initialize_subtensor(): try: subtensor = bt.subtensor(network=SUBTENSOR_NETWORK) - #print(f"{GREEN}Subtensor initialized successfully.{RESET}") + # print(f"{GREEN}Subtensor initialized successfully.{RESET}") return subtensor except Exception as e: print(f"{RED}Error initializing subtensor: {str(e)}{RESET}") raise + def list_videos(): videos_response = requests.get( API_BASE + "/api/focus/get_list", headers={"Content-Type": "application/json"}, - timeout=30 + timeout=30, ) if videos_response.status_code != 200: print(f"{RED}Error fetching focus videos: {videos_response.status_code}{RESET}") return None - + videos_data = videos_response.json() return videos_data + def display_videos(videos_data): if not videos_data or len(videos_data) == 0: print(f"\n{RED}No videos available.{RESET}") return print(f"\n{CYAN}Available Focus Videos:{RESET}") - + # Prepare the data for tabulate table_data = [] for idx, video in enumerate(videos_data, 1): # Convert created_at to a more readable format - created_at = datetime.fromisoformat(video['created_at'].replace('Z', '+00:00')) + created_at = datetime.fromisoformat(video["created_at"].replace("Z", "+00:00")) formatted_date = created_at.strftime("%Y-%m-%d %H:%M:%S") - - table_data.append([ - idx, - video['video_id'], - f"{video['video_score']:.3f}", - f"{video['expected_reward_tao']:.5f}", - f"{float(video['expected_reward_tao']) / 0.9:.5f}", - #formatted_date - ]) - + + table_data.append( + [ + idx, + video["video_id"], + f"{video['video_score']:.3f}", + f"{video['expected_reward_tao']:.5f}", + f"{float(video['expected_reward_tao']) / 0.9:.5f}", + # formatted_date + ] + ) + # Create the table headers = ["#", "Video ID", "Score", "Cost (TAO)", "Expected Reward (TAO)"] table = tabulate(table_data, headers=headers, tablefmt="pretty") - + print(table) class TransferTimeout(Exception): pass + def reset_terminal(): # Try multiple methods to reset the terminal - os.system('stty sane') - os.system('reset') - sys.stdout.write('\033[0m') + os.system("stty sane") + os.system("reset") + sys.stdout.write("\033[0m") sys.stdout.flush() + def transfer_operation(wallet, transfer_address_to, transfer_balance, result_queue): try: subtensor = initialize_subtensor() @@ -165,42 +174,52 @@ def transfer_operation(wallet, transfer_address_to, transfer_balance, result_que except Exception as e: result_queue.put((False, None, str(e))) + def transfer_with_timeout(wallet, transfer_address_to, transfer_balance): result_queue = multiprocessing.Queue() - + transfer_process = multiprocessing.Process( target=transfer_operation, - args=(wallet, transfer_address_to, transfer_balance, result_queue) + args=(wallet, transfer_address_to, transfer_balance, result_queue), ) - + transfer_process.start() transfer_process.join(timeout=150) # 2m 30s = 150 seconds - + if transfer_process.is_alive(): transfer_process.terminate() transfer_process.join() reset_terminal() print("\nTransfer operation timed out after 2 minutes 30 seconds. Exiting.") - + if not result_queue.empty(): return result_queue.get() else: return False, None, "Transfer process exited without result" + def get_wallet(wallet_name=None, wallet_hotkey=None, wallet_path=None): if wallet_name is not None: name = wallet_name else: - name = input(f"{CYAN}Enter wallet name (default: Coldkey): {RESET}") or "Coldkey" + name = ( + input(f"{CYAN}Enter wallet name (default: Coldkey): {RESET}") or "Coldkey" + ) if wallet_hotkey is not None: hotkey_name = wallet_hotkey else: - hotkey_name = input(f"{CYAN}Enter wallet hotkey name (default: Hotkey): {RESET}") or "Hotkey" + hotkey_name = ( + input(f"{CYAN}Enter wallet hotkey name (default: Hotkey): {RESET}") + or "Hotkey" + ) if wallet_path is not None: path = wallet_path else: - path = input(f"{CYAN}Enter wallet path (default: ~/.bittensor/wallets/): {RESET}") or "~/.bittensor/wallets/" - + path = ( + input(f"{CYAN}Enter wallet path (default: ~/.bittensor/wallets/): {RESET}") + or "~/.bittensor/wallets/" + ) + wallet = btcli_wallet(name=name, hotkey=hotkey_name, path=path) try: hotkey = wallet.get_hotkey() @@ -209,63 +228,73 @@ def get_wallet(wallet_name=None, wallet_hotkey=None, wallet_path=None): return return wallet, name, hotkey_name, path + def get_auth_headers(wallet): hotkey = wallet.get_hotkey() miner_hotkey = hotkey.ss58_address miner_hotkey_signature = f"0x{hotkey.sign(miner_hotkey).hex()}" return miner_hotkey, miner_hotkey_signature + def purchase_video( - video_id=None, - wallet_name=None, - wallet_hotkey=None, - wallet_path=None + video_id=None, wallet_name=None, wallet_hotkey=None, wallet_path=None ): if not video_id: video_id = input(f"{CYAN}Enter focus video id: {RESET}") - wallet, name, hotkey_name, path = get_wallet(wallet_name, wallet_hotkey, wallet_path) + wallet, name, hotkey_name, path = get_wallet( + wallet_name, wallet_hotkey, wallet_path + ) miner_hotkey, miner_hotkey_signature = get_auth_headers(wallet) - + print(f"Purchasing video {video_id}...") - print(f"{RED}You will only have 2 minutes and 30 seconds to complete the transfer of TAO tokens, otherwise the purchase will be reverted.{RESET}") + print( + f"{RED}You will only have 2 minutes and 30 seconds to complete the transfer of TAO tokens, otherwise the purchase will be reverted.{RESET}" + ) purchase_response = requests.post( API_BASE + "/api/focus/purchase", auth=(miner_hotkey, miner_hotkey_signature), json={"video_id": video_id}, headers={"Content-Type": "application/json"}, - timeout=60 + timeout=60, ) purchase_data = purchase_response.json() if purchase_response.status_code != 200: - print(f"{RED}Error purchasing video {video_id}: {purchase_response.status_code}{RESET}") + print( + f"{RED}Error purchasing video {video_id}: {purchase_response.status_code}{RESET}" + ) if "detail" in purchase_data: print(f"{RED}Details: {purchase_data['detail']}{RESET}") return - + if "status" in purchase_data and purchase_data["status"] == "error": - print(f"{RED}Error purchasing video {video_id}: {purchase_data['message']}{RESET}") + print( + f"{RED}Error purchasing video {video_id}: {purchase_data['message']}{RESET}" + ) return - + try: transfer_address_to = purchase_data["address"] transfer_amount = purchase_data["amount"] print(f"Initiating transfer of {transfer_amount} TAO for video {video_id}...") - + transfer_balance = bt.Balance.from_tao(transfer_amount) - try: - success, block_hash, err_msg = transfer_with_timeout(wallet, transfer_address_to, transfer_balance) + success, block_hash, err_msg = transfer_with_timeout( + wallet, transfer_address_to, transfer_balance + ) except TransferTimeout: - print(f"\n{RED}Transfer operation timed out after 2 minutes and 30 seconds. Aborting purchase.{RESET}") + print( + f"\n{RED}Transfer operation timed out after 2 minutes and 30 seconds. Aborting purchase.{RESET}" + ) reset_terminal() revert_pending_purchase(video_id, miner_hotkey, miner_hotkey_signature) repurchase_input(video_id, name, hotkey_name, path) return - + """ success, block_hash, err_msg = subtensor._do_transfer( wallet, @@ -278,10 +307,16 @@ def purchase_video( if success: print(f"{GREEN}Transfer finalized. Block Hash: {block_hash}{RESET}") - save_purchase_info(video_id, miner_hotkey, block_hash, "purchased", transfer_amount) - verify_result = verify_purchase(video_id, miner_hotkey, block_hash, miner_hotkey_signature) + save_purchase_info( + video_id, miner_hotkey, block_hash, "purchased", transfer_amount + ) + verify_result = verify_purchase( + video_id, miner_hotkey, block_hash, miner_hotkey_signature + ) if not verify_result: - print(f"{RED}There was an error verifying your purchase after successfully transferring TAO. Please try the 'Verify Purchase' option immediately and contact an admin if you are unable to successfully verify.{RESET}") + print( + f"{RED}There was an error verifying your purchase after successfully transferring TAO. Please try the 'Verify Purchase' option immediately and contact an admin if you are unable to successfully verify.{RESET}" + ) else: print(f"{RED}Failed to complete transfer for video {video_id}.{RESET}") revert_pending_purchase(video_id, miner_hotkey, miner_hotkey_signature) @@ -290,11 +325,14 @@ def purchase_video( except Exception as e: print(f"{RED}Error transferring TAO tokens: {str(e)}{RESET}") if "EOF occurred in violation of protocol" in str(e): - print(f"{RED}Subtensor connection error detected. Re-initializing subtensor.{RESET}") + print( + f"{RED}Subtensor connection error detected. Re-initializing subtensor.{RESET}" + ) initialize_subtensor() revert_pending_purchase(video_id, miner_hotkey, miner_hotkey_signature) repurchase_input(video_id, name, hotkey_name, path) + def revert_pending_purchase(video_id, miner_hotkey, miner_hotkey_signature): print(f"Reverting Pending Purchasing of video {video_id}...") revert_response = requests.post( @@ -302,76 +340,104 @@ def revert_pending_purchase(video_id, miner_hotkey, miner_hotkey_signature): auth=(miner_hotkey, miner_hotkey_signature), json={"video_id": video_id}, headers={"Content-Type": "application/json"}, - timeout=60 + timeout=60, ) if revert_response.status_code != 200: - print(f"{RED}Error reverting pending purchase of video {video_id}: {revert_response.status_code}{RESET}") + print( + f"{RED}Error reverting pending purchase of video {video_id}: {revert_response.status_code}{RESET}" + ) return if revert_response.status_code == 200: - print(f"{GREEN}Pending purchase of video {video_id} reverted successfully.{RESET}") + print( + f"{GREEN}Pending purchase of video {video_id} reverted successfully.{RESET}" + ) return + def repurchase_input(video_id, wallet_name=None, wallet_hotkey=None, wallet_path=None): - repurchase = input(f"{CYAN}Do you want to repurchase video {video_id}? (y/n): {RESET}").lower() - if repurchase == 'y': + repurchase = input( + f"{CYAN}Do you want to repurchase video {video_id}? (y/n): {RESET}" + ).lower() + if repurchase == "y": purchase_video(video_id, wallet_name, wallet_hotkey, wallet_path) - elif repurchase != 'n': + elif repurchase != "n": print(f"{RED}Invalid input. Please enter 'y' or 'n'.{RESET}") repurchase_input(video_id, wallet_name, wallet_hotkey, wallet_path) + def display_saved_orders(for_verification=False): purchases_file = os.path.expanduser("~/.omega/focus_videos.json") if not os.path.exists(purchases_file): print(f"{RED}No saved orders found.{RESET}") return None - with open(purchases_file, 'r') as f: + with open(purchases_file, "r") as f: purchases = json.load(f) if not purchases: print(f"{RED}No saved orders found.{RESET}") return None - purchases.sort(key=lambda x: x.get('created_at', ''), reverse=True) + purchases.sort(key=lambda x: x.get("created_at", ""), reverse=True) print(f"\n{CYAN}Saved Orders:{RESET}") - + table_data = [] for idx, purchase in enumerate(purchases, 1): - created_at = purchase.get('created_at', 'N/A') - if created_at != 'N/A': - created_at = datetime.fromisoformat(created_at.replace('Z', '+00:00')).strftime("%Y-%m-%d %H:%M:%S") - - table_data.append([ - idx, - purchase['video_id'], - purchase['state'], - purchase.get('amount', 'N/A'), - f"{float(purchase.get('amount', 0)) / 0.9:.5f}", - purchase.get('miner_hotkey', 'N/A')[:5] + '...' + purchase.get('miner_hotkey', 'N/A')[-5:], - purchase['block_hash'][:5] + '...' + purchase['block_hash'][-5:], - created_at - ]) - - headers = ["#", "Video ID", "Purchase State", "Cost (TAO)", "Estimated Reward (TAO)", "Purchasing Hotkey", "Block Hash", "Purchase Date"] + created_at = purchase.get("created_at", "N/A") + if created_at != "N/A": + created_at = datetime.fromisoformat( + created_at.replace("Z", "+00:00") + ).strftime("%Y-%m-%d %H:%M:%S") + + table_data.append( + [ + idx, + purchase["video_id"], + purchase["state"], + purchase.get("amount", "N/A"), + f"{float(purchase.get('amount', 0)) / 0.9:.5f}", + purchase.get("miner_hotkey", "N/A")[:5] + + "..." + + purchase.get("miner_hotkey", "N/A")[-5:], + purchase["block_hash"][:5] + "..." + purchase["block_hash"][-5:], + created_at, + ] + ) + + headers = [ + "#", + "Video ID", + "Purchase State", + "Cost (TAO)", + "Estimated Reward (TAO)", + "Purchasing Hotkey", + "Block Hash", + "Purchase Date", + ] table = tabulate(table_data, headers=headers, tablefmt="pretty") - + print(table) return purchases + def select_order_for_verification(): purchases = display_saved_orders() while True: if purchases: - print(f"*** NOTE: A purchase is finalized when the purchase state is 'verified'. ***") - choice = input(f"{CYAN}Enter the number of the order to verify, 'm' for manual input, or 'n' to cancel: {RESET}").lower() + print( + f"*** NOTE: A purchase is finalized when the purchase state is 'verified'. ***" + ) + choice = input( + f"{CYAN}Enter the number of the order to verify, 'm' for manual input, or 'n' to cancel: {RESET}" + ).lower() else: - choice = 'm' + choice = "m" - if choice == 'n': + if choice == "n": return None, None, None - elif choice == 'm': + elif choice == "m": video_id = input(f"{CYAN}Enter video ID: {RESET}") miner_hotkey = input(f"{CYAN}Enter miner hotkey: {RESET}") block_hash = input(f"{CYAN}Enter block hash: {RESET}") @@ -380,17 +446,24 @@ def select_order_for_verification(): idx = int(choice) - 1 if 0 <= idx < len(purchases): selected = purchases[idx] - return selected['video_id'], selected.get('miner_hotkey', ''), selected['block_hash'] + return ( + selected["video_id"], + selected.get("miner_hotkey", ""), + selected["block_hash"], + ) else: print(f"{RED}Invalid selection. Please try again.{RESET}") else: print(f"{RED}Invalid input. Please try again.{RESET}") + def select_order_for_full_display(purchases): while True: - choice = input(f"{CYAN}Enter the number of the order to see full details, or 'n' to return to menu: {RESET}").lower() - - if choice == 'n': + choice = input( + f"{CYAN}Enter the number of the order to see full details, or 'n' to return to menu: {RESET}" + ).lower() + + if choice == "n": return elif choice.isdigit(): idx = int(choice) - 1 @@ -401,7 +474,9 @@ def select_order_for_full_display(purchases): print(f"Video ID: {selected['video_id']}") print(f"Purchase State: {selected['state']}") print(f"Cost (TAO): {selected.get('amount', 'N/A')}") - print(f"Estimated Reward (TAO): {float(selected.get('amount', 0)) / 0.9:.5f}") + print( + f"Estimated Reward (TAO): {float(selected.get('amount', 0)) / 0.9:.5f}" + ) print(f"Purchasing Hotkey: {selected.get('miner_hotkey', 'N/A')}") print(f"Block Hash: {selected['block_hash']}") print(f"Purchase Date: {selected.get('created_at', 'N/A')}") @@ -411,7 +486,10 @@ def select_order_for_full_display(purchases): else: print(f"{RED}Invalid input. Please try again.{RESET}") -def verify_purchase(video_id=None, miner_hotkey=None, block_hash=None, miner_hotkey_signature=None): + +def verify_purchase( + video_id=None, miner_hotkey=None, block_hash=None, miner_hotkey_signature=None +): if miner_hotkey_signature is None: wallet, name, hotkey_name, path = get_wallet() miner_hotkey, miner_hotkey_signature = get_auth_headers(wallet) @@ -430,45 +508,59 @@ def verify_purchase(video_id=None, miner_hotkey=None, block_hash=None, miner_hot verify_response = requests.post( API_BASE + "/api/focus/verify-purchase", auth=(miner_hotkey, miner_hotkey_signature), - json={"miner_hotkey": miner_hotkey, "video_id": video_id, "block_hash": block_hash}, + json={ + "miner_hotkey": miner_hotkey, + "video_id": video_id, + "block_hash": block_hash, + }, headers={"Content-Type": "application/json"}, - timeout=90 + timeout=90, + ) + print( + f"Purchase verification response for video {video_id}:", + verify_response.text, ) - print(f"Purchase verification response for video {video_id}:", verify_response.text) if verify_response.status_code == 200: print(f"{GREEN}Purchase verified successfully!{RESET}") save_purchase_info(video_id, miner_hotkey, block_hash, "verified") return True - + if attempt < retries - 1: - print(f"{CYAN}Attempt #{attempt + 1} to verify purchase failed. Retrying in 2 seconds...{RESET}") + print( + f"{CYAN}Attempt #{attempt + 1} to verify purchase failed. Retrying in 2 seconds...{RESET}" + ) time.sleep(2) except Exception as e: if attempt < retries - 1: - print(f"{CYAN}Attempt #{attempt + 1} to verify purchase failed. Retrying in 2 seconds...{RESET}") + print( + f"{CYAN}Attempt #{attempt + 1} to verify purchase failed. Retrying in 2 seconds...{RESET}" + ) print(f"{RED}Error: {str(e)}{RESET}") time.sleep(2) else: - print(f"{RED}All {retries} attempts failed. Unable to verify purchase.{RESET}") + print( + f"{RED}All {retries} attempts failed. Unable to verify purchase.{RESET}" + ) return False + def save_purchase_info(video_id, hotkey, block_hash, state, amount=None): purchases_file = os.path.expanduser("~/.omega/focus_videos.json") os.makedirs(os.path.dirname(purchases_file), exist_ok=True) - + purchases = [] if os.path.exists(purchases_file): - with open(purchases_file, 'r') as f: + with open(purchases_file, "r") as f: purchases = json.load(f) - + # Check if the video_id already exists for purchase in purchases: - if purchase['video_id'] == video_id: - purchase['state'] = state - purchase['miner_hotkey'] = hotkey - purchase['block_hash'] = block_hash + if purchase["video_id"] == video_id: + purchase["state"] = state + purchase["miner_hotkey"] = hotkey + purchase["block_hash"] = block_hash if amount is not None: - purchase['amount'] = amount + purchase["amount"] = amount break else: # If the video_id doesn't exist, create a new entry @@ -477,16 +569,19 @@ def save_purchase_info(video_id, hotkey, block_hash, state, amount=None): "miner_hotkey": hotkey, "block_hash": block_hash, "state": state, - "created_at": datetime.now().isoformat() # Add creation timestamp + "created_at": datetime.now().isoformat(), # Add creation timestamp } if amount is not None: - new_purchase['amount'] = amount + new_purchase["amount"] = amount purchases.append(new_purchase) - - with open(purchases_file, 'w') as f: + + with open(purchases_file, "w") as f: json.dump(purchases, f, indent=2) - - print(f"{GREEN}Purchase information {'updated' if state == 'verified' else 'saved'} to {purchases_file}{RESET}") + + print( + f"{GREEN}Purchase information {'updated' if state == 'verified' else 'saved'} to {purchases_file}{RESET}" + ) + def main(): while True: @@ -496,37 +591,42 @@ def main(): print("3. Verify Purchase") print("4. Display Order History") print("5. Exit") - + choice = input(f"{CYAN}Enter your choice (1-5): {RESET}") - - if choice == '1': + + if choice == "1": videos_data = list_videos() if videos_data: display_videos(videos_data) - purchase_option = input(f"\n{CYAN}Enter the number of the video you want to purchase or press 'n' to return to menu: {RESET}").lower() + purchase_option = input( + f"\n{CYAN}Enter the number of the video you want to purchase or press 'n' to return to menu: {RESET}" + ).lower() if purchase_option.isdigit(): video_index = int(purchase_option) - 1 if 0 <= video_index < len(videos_data): - purchase_video(videos_data[video_index]['video_id']) + purchase_video(videos_data[video_index]["video_id"]) else: print(f"{RED}Invalid video number.{RESET}") - elif purchase_option != 'n': + elif purchase_option != "n": print(f"{RED}Invalid input. Returning to main menu.{RESET}") else: print(f"\n{RED}No videos available for purchase at this time.{RESET}") - elif choice == '2': + elif choice == "2": purchase_video() - elif choice == '3': + elif choice == "3": verify_purchase() - elif choice == '4': + elif choice == "4": purchases = display_saved_orders() select_order_for_full_display(purchases) - elif choice == '5': - print(f"{GREEN}Thank you for using the OMEGA Focus Videos Purchase System. Goodbye!{RESET}") + elif choice == "5": + print( + f"{GREEN}Thank you for using the OMEGA Focus Videos Purchase System. Goodbye!{RESET}" + ) break else: print(f"{RED}Invalid choice. Please try again.{RESET}") + if __name__ == "__main__": try: multiprocessing.freeze_support() diff --git a/requirements.txt b/requirements.txt index e5c2599d..e8661637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,5 @@ pyannote-audio==3.1.1 librosa==0.10.2.post1 substrate-interface==1.7.11 asyncpg==0.30.0 -greenlet==3.1.1 \ No newline at end of file +greenlet==3.1.1 +ruff=0.9.9 diff --git a/test_audio_dataset.py b/test_audio_dataset.py index 9a9fc870..9918e1de 100644 --- a/test_audio_dataset.py +++ b/test_audio_dataset.py @@ -5,7 +5,7 @@ import soundfile as sf # Set HF_TOKEN environment variable or pass directly -HF_TOKEN = os.getenv('HF_TOKEN') +HF_TOKEN = os.getenv("HF_TOKEN") # Login to Hugging Face # login(token=HF_TOKEN) @@ -15,7 +15,7 @@ print(f"Dataset loaded successfully with {len(dataset)} examples") # Get first row from the dataset -first_row = dataset['train'][0] +first_row = dataset["train"][0] print("\nFirst row of dataset:") # print(first_row) print("\nKeys in dataset:") @@ -27,54 +27,63 @@ print(f"{key}: {first_row[key]}") - import librosa import numpy as np + breakpoint() -audio_bytes = first_row['audio_bytes'] +audio_bytes = first_row["audio_bytes"] audio_arr, sr = sf.read(BytesIO(audio_bytes)) print(len(audio_arr), type(audio_arr)) audio = np.array(audio_arr) # exit() print(audio.shape) -youtube_id = first_row['youtube_id'] -os.makedirs('Dataset_audios/Original', exist_ok=True) -sf.write(f'Dataset_audios/Original/{youtube_id}.wav', audio, sr) +youtube_id = first_row["youtube_id"] +os.makedirs("Dataset_audios/Original", exist_ok=True) +sf.write(f"Dataset_audios/Original/{youtube_id}.wav", audio, sr) -diar_timestamps_start = first_row['diar_timestamps_start'] -diar_timestamps_end = first_row['diar_timestamps_end'] -diar_speakers = first_row['diar_speakers'] +diar_timestamps_start = first_row["diar_timestamps_start"] +diar_timestamps_end = first_row["diar_timestamps_end"] +diar_speakers = first_row["diar_speakers"] -for start, end, speaker in zip(diar_timestamps_start, diar_timestamps_end, diar_speakers): +for start, end, speaker in zip( + diar_timestamps_start, diar_timestamps_end, diar_speakers +): # Calculate start and end samples start_sample = int(start * sr) end_sample = int(end * sr) - + # Extract the clip clip = audio[start_sample:end_sample] - + # Create output directory if it doesn't exist - os.makedirs(f'Dataset_audios/Clips/{youtube_id}', exist_ok=True) - + os.makedirs(f"Dataset_audios/Clips/{youtube_id}", exist_ok=True) + # Save the clip with speaker and timestamp info in filename - clip_filename = f'Dataset_audios/Clips/{youtube_id}/speaker_{speaker}_{start:.2f}-{end:.2f}.wav' + clip_filename = ( + f"Dataset_audios/Clips/{youtube_id}/speaker_{speaker}_{start:.2f}-{end:.2f}.wav" + ) sf.write(clip_filename, clip, sr) - + # Create a list to store the diarization data diarization_data = [] -for start, end, speaker in zip(diar_timestamps_start, diar_timestamps_end, diar_speakers): - diarization_data.append({ - 'youtube_id': youtube_id, - 'start_time': start, - 'end_time': end, - 'speaker': speaker, - "duration": end - start - }) +for start, end, speaker in zip( + diar_timestamps_start, diar_timestamps_end, diar_speakers +): + diarization_data.append( + { + "youtube_id": youtube_id, + "start_time": start, + "end_time": end, + "speaker": speaker, + "duration": end - start, + } + ) # Convert to pandas DataFrame and save as CSV import pandas as pd + df = pd.DataFrame(diarization_data) -os.makedirs('Dataset_audios/Metadata', exist_ok=True) -df.to_csv(f'Dataset_audios/Metadata/{youtube_id}_diarization.csv', index=False) +os.makedirs("Dataset_audios/Metadata", exist_ok=True) +df.to_csv(f"Dataset_audios/Metadata/{youtube_id}_diarization.csv", index=False) diff --git a/validator-api/_generate_api_key.py b/validator-api/_generate_api_key.py index 7e5e68b8..c92930ec 100644 --- a/validator-api/_generate_api_key.py +++ b/validator-api/_generate_api_key.py @@ -1,7 +1,9 @@ import secrets + def generate_api_key(): return secrets.token_urlsafe(32) # Generates a 32-byte (256-bit) key + new_api_key = generate_api_key() -print(new_api_key) \ No newline at end of file +print(new_api_key) diff --git a/validator-api/app.py b/validator-api/app.py index 6cd374bf..43a98a21 100644 --- a/validator-api/app.py +++ b/validator-api/app.py @@ -18,8 +18,16 @@ import ulid import uvicorn from datasets import load_dataset -from fastapi import (BackgroundTasks, Body, Depends, FastAPI, HTTPException, - Path, Request, Security) +from fastapi import ( + BackgroundTasks, + Body, + Depends, + FastAPI, + HTTPException, + Path, + Request, + Security, +) from fastapi.responses import FileResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security.api_key import APIKeyHeader @@ -31,36 +39,64 @@ from validator_api.check_blocking import detect_blocking from validator_api.communex._common import get_node_url from validator_api.communex.client import CommuneClient -from validator_api.config import (API_KEY_NAME, API_KEYS, COMMUNE_NETUID, - COMMUNE_NETWORK, DB_CONFIG, ENABLE_COMMUNE, - FIXED_ALPHA_TAO_ESTIMATE, FOCUS_API_KEYS, - FOCUS_API_URL, FOCUS_REWARDS_PERCENT, - IMPORT_SCORE, IS_PROD, NETUID, NETWORK, PORT, - PROXY_LIST, SENTRY_DSN) -from validator_api.cron.confirm_purchase import (confirm_transfer, - confirm_video_purchased) +from validator_api.config import ( + API_KEY_NAME, + API_KEYS, + COMMUNE_NETUID, + COMMUNE_NETWORK, + DB_CONFIG, + ENABLE_COMMUNE, + FIXED_ALPHA_TAO_ESTIMATE, + FOCUS_API_KEYS, + FOCUS_API_URL, + FOCUS_REWARDS_PERCENT, + IMPORT_SCORE, + IS_PROD, + NETUID, + NETWORK, + PORT, + PROXY_LIST, + SENTRY_DSN, +) +from validator_api.cron.confirm_purchase import ( + confirm_transfer, + confirm_video_purchased, +) from validator_api.database import get_db, get_db_context from validator_api.database.crud.focusvideo import ( - MinerPurchaseStats, TaskType, FocusVideoCache, - check_availability, get_video_owner_coldkey, - mark_video_rejected, mark_video_submitted, set_focus_video_score) -from validator_api.database.models.focus_video_record import FocusVideoRecord, FocusVideoStateExternal -from validator_api.dataset_upload import (audio_dataset_uploader, - video_dataset_uploader) + MinerPurchaseStats, + TaskType, + FocusVideoCache, + check_availability, + get_video_owner_coldkey, + mark_video_rejected, + mark_video_submitted, + set_focus_video_score, +) +from validator_api.database.models.focus_video_record import ( + FocusVideoRecord, + FocusVideoStateExternal, +) +from validator_api.dataset_upload import audio_dataset_uploader, video_dataset_uploader from validator_api.limiter import limiter -from validator_api.scoring.scoring_service import (FocusScoringService, - LegitimacyCheckError, - VideoTooLongError, - VideoTooShortError, - VideoUniquenessError) -from validator_api.utils.marketplace import (TASK_TYPE_MAP, - get_max_focus_alpha_per_day, - get_variable_reward_pool_alpha, - get_fixed_reward_pool_alpha) +from validator_api.scoring.scoring_service import ( + FocusScoringService, + LegitimacyCheckError, + VideoTooLongError, + VideoTooShortError, + VideoUniquenessError, +) +from validator_api.utils.marketplace import ( + TASK_TYPE_MAP, + get_max_focus_alpha_per_day, + get_variable_reward_pool_alpha, + get_fixed_reward_pool_alpha, +) from validator_api.database.models.miner_bans import miner_banned_until from omega.protocol import AudioMetadata, VideoMetadata from sqlalchemy import select, update + print("IMPORT_SCORE:", IMPORT_SCORE) if IMPORT_SCORE is not False: @@ -105,31 +141,36 @@ def connect_to_db(): def get_timestamp_from_filename(filename: str): - return ulid.from_str(os.path.splitext(filename.split("/")[-1])[0]).timestamp().timestamp + return ( + ulid.from_str(os.path.splitext(filename.split("/")[-1])[0]) + .timestamp() + .timestamp + ) def pull_and_cache_dataset() -> List[str]: # Get the list of files in the dataset repository omega_ds_files = huggingface_hub.repo_info( - repo_id=HF_DATASET, repo_type="dataset").siblings + repo_id=HF_DATASET, repo_type="dataset" + ).siblings # Filter files that match the DATA_FILES_PREFIX recent_files = [ f.rfilename - for f in omega_ds_files if - f.rfilename.startswith(DATA_FILES_PREFIX) and - time.time() - get_timestamp_from_filename(f.rfilename) < MIN_AGE + for f in omega_ds_files + if f.rfilename.startswith(DATA_FILES_PREFIX) + and time.time() - get_timestamp_from_filename(f.rfilename) < MIN_AGE ][:MAX_FILES] # Randomly sample up to MAX_FILES from the matching files - sampled_files = random.sample( - recent_files, min(MAX_FILES, len(recent_files))) + sampled_files = random.sample(recent_files, min(MAX_FILES, len(recent_files))) # Load the dataset using the sampled files video_metadata = [] with TemporaryDirectory() as temp_dir: omega_dataset = load_dataset( - HF_DATASET, data_files=sampled_files, cache_dir=temp_dir)["train"] + HF_DATASET, data_files=sampled_files, cache_dir=temp_dir + )["train"] for i, entry in enumerate(omega_dataset): metadata = [] if "description" in entry and "description_embed" in entry: @@ -149,6 +190,8 @@ def pull_and_cache_dataset() -> List[str]: json.dump(video_metadata, f) return True + + # endregion Utility functions for OMEGA Metadata Dashboard @@ -156,20 +199,14 @@ async def get_api_key(api_key_header: str = Security(api_key_header)): if api_key_header in API_KEYS: return api_key_header else: - raise HTTPException( - status_code=401, - detail="Invalid API Key" - ) + raise HTTPException(status_code=401, detail="Invalid API Key") async def get_focus_api_key(focus_api_key_header: str = Security(focus_api_key_header)): if focus_api_key_header in FOCUS_API_KEYS: return focus_api_key_header else: - raise HTTPException( - status_code=401, - detail="Invalid API Key" - ) + raise HTTPException(status_code=401, detail="Invalid API Key") class VideoMetadataUpload(BaseModel): @@ -203,7 +240,7 @@ def get_hotkey(credentials: Annotated[HTTPBasicCredentials, Depends(security)]) print(f"Error verifying keypair: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail=f"Error verifying keypair: {e}, make sure Basic Auth username is your hotkey SS58 address and the password is your hotkey's signature hex string (not private key!)." + detail=f"Error verifying keypair: {e}, make sure Basic Auth username is your hotkey SS58 address and the password is your hotkey's signature hex string (not private key!).", ) @@ -247,16 +284,15 @@ def update_commune_keys(commune_client, commune_keys): async def run_focus_scoring( video_id: Annotated[str, Body()], focusing_task: Annotated[str, Body()], - focusing_description: Annotated[str, Body()] + focusing_description: Annotated[str, Body()], ) -> Dict[str, Any]: - score_details = None embeddings = None try: async with get_db_context() as db: query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.deleted_at.is_(None), ) result = await db.execute(query) video_record = result.scalar_one_or_none() @@ -266,15 +302,18 @@ async def run_focus_scoring( update_stmt = ( update(FocusVideoRecord) .where(FocusVideoRecord.video_id == video_id) - .values(processing_state=FocusVideoStateExternal.PENDING_HUMAN_REVIEW.value) + .values( + processing_state=FocusVideoStateExternal.PENDING_HUMAN_REVIEW.value + ) ) await db.execute(update_stmt) await db.commit() return {"success": True} - score_details, embeddings = await focus_scoring_service.score_video(video_id, focusing_task, focusing_description) - print( - f"Score for focus video <{video_id}>: {score_details.final_score}") + score_details, embeddings = await focus_scoring_service.score_video( + video_id, focusing_task, focusing_description + ) + print(f"Score for focus video <{video_id}>: {score_details.final_score}") MIN_FINAL_SCORE = 0.1 # todo: measure and tune these MIN_TASK_UNIQUENESS_SCORE = 0 @@ -289,7 +328,7 @@ async def run_focus_scoring( video_id, rejection_reason, score_details=score_details, - embeddings=embeddings + embeddings=embeddings, ) else: await set_focus_video_score(db, video_id, score_details, embeddings) @@ -332,7 +371,7 @@ async def shutdown_event(): print("Shutdown event fired, attempting dataset upload of current batch.") video_dataset_uploader.submit() audio_dataset_uploader.submit() - + async def lifespan(app: FastAPI): await startup_event() yield @@ -340,8 +379,7 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) # Mount the static directory to serve static files - app.mount( - "/static", StaticFiles(directory="validator-api/static"), name="static") + app.mount("/static", StaticFiles(directory="validator-api/static"), name="static") subtensor = bittensor.subtensor(network=NETWORK) metagraph: bittensor.metagraph = subtensor.metagraph(NETUID) @@ -350,8 +388,9 @@ async def lifespan(app: FastAPI): commune_client = None commune_keys = None if ENABLE_COMMUNE: - commune_client = CommuneClient(get_node_url( - use_testnet=True if COMMUNE_NETWORK == "test" else False)) + commune_client = CommuneClient( + get_node_url(use_testnet=True if COMMUNE_NETWORK == "test" else False) + ) commune_keys = update_commune_keys(commune_client, commune_keys) async def resync_metagraph(): @@ -367,8 +406,7 @@ async def resync_metagraph(): # Sync latest commune keys if ENABLE_COMMUNE: - commune_keys = update_commune_keys( - commune_client, commune_keys) + commune_keys = update_commune_keys(commune_client, commune_keys) print("commune keys synced") # In case of unforeseen errors, the api will log the error and continue operations. @@ -389,6 +427,7 @@ async def detect_blocking_middleware(request: Request, call_next): if auth and auth.startswith("Basic "): try: import base64 + decoded = base64.b64decode(auth.split()[1]).decode() username = decoded.split(":")[0] except: @@ -399,7 +438,9 @@ async def detect_blocking_middleware(request: Request, call_next): mem_after = process.memory_info().rss mem_diff = mem_after - mem_before - print(f"Memory change for {request.url.path}: {mem_diff / 1024 / 1024:.2f} MB, now at {mem_after / 1024 / 1024:.2f} MB") + print( + f"Memory change for {request.url.path}: {mem_diff / 1024 / 1024:.2f} MB, now at {mem_after / 1024 / 1024:.2f} MB" + ) return response @@ -414,7 +455,9 @@ async def get_pinecone_novelty( ) -> List[float]: print("get_pinecone_novelty()") - if not authenticate_with_bittensor(hotkey, metagraph) and not authenticate_with_commune(hotkey, commune_keys): + if not authenticate_with_bittensor( + hotkey, metagraph + ) and not authenticate_with_commune(hotkey, commune_keys): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Valid hotkey required.", @@ -437,7 +480,8 @@ async def get_pinecone_novelty( # query the pinecone index to get novelty scores novelty_scores = await score.get_pinecone_novelty(metadata) print( - f"Returning novelty scores={novelty_scores} for {validator_chain} validator={uid} in {time.time() - start_time:.2f}s") + f"Returning novelty scores={novelty_scores} for {validator_chain} validator={uid} in {time.time() - start_time:.2f}s" + ) return novelty_scores @app.post("/api/upload_video_metadata") @@ -446,7 +490,9 @@ async def upload_video_metadata( hotkey: Annotated[str, Depends(get_hotkey)], ) -> bool: print("upload_video_metadata()") - if not authenticate_with_bittensor(hotkey, metagraph) and not authenticate_with_commune(hotkey, commune_keys): + if not authenticate_with_bittensor( + hotkey, metagraph + ) and not authenticate_with_commune(hotkey, commune_keys): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Valid hotkey required.", @@ -475,17 +521,22 @@ async def upload_video_metadata( topic_query = upload_data.topic_query start_time = time.time() - video_ids = await score.upload_video_metadata(metadata, description_relevance_scores, query_relevance_scores, topic_query) + video_ids = await score.upload_video_metadata( + metadata, description_relevance_scores, query_relevance_scores, topic_query + ) print( - f"Uploaded {len(video_ids)} video metadata from {validator_chain} validator={uid} in {time.time() - start_time:.2f}s") + f"Uploaded {len(video_ids)} video metadata from {validator_chain} validator={uid} in {time.time() - start_time:.2f}s" + ) if upload_data.miner_hotkey is not None: # Calculate and upsert leaderboard data datapoints = len(video_ids) - avg_desc_relevance = sum( - description_relevance_scores) / len(description_relevance_scores) - avg_query_relevance = sum( - query_relevance_scores) / len(query_relevance_scores) + avg_desc_relevance = sum(description_relevance_scores) / len( + description_relevance_scores + ) + avg_query_relevance = sum(query_relevance_scores) / len( + query_relevance_scores + ) novelty_score = upload_data.novelty_score total_score = upload_data.total_score miner_hotkey = upload_data.miner_hotkey @@ -519,28 +570,36 @@ async def upload_video_metadata( last_updated = NOW(); """ cursor = connection.cursor() - cursor.execute(query, ( - miner_hotkey, - is_bittensor, - is_commune, - datapoints, - avg_desc_relevance, - avg_query_relevance, - novelty_score, - total_score - )) + cursor.execute( + query, + ( + miner_hotkey, + is_bittensor, + is_commune, + datapoints, + avg_desc_relevance, + avg_query_relevance, + novelty_score, + total_score, + ), + ) connection.commit() print( - f"Upserted leaderboard data for {miner_hotkey} from {validator_chain} validator={uid} in {time.time() - start_time:.2f}s") + f"Upserted leaderboard data for {miner_hotkey} from {validator_chain} validator={uid} in {time.time() - start_time:.2f}s" + ) except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching data from MySQL database: {err}", + ) finally: if connection: connection.close() else: - print("Skipping leaderboard update because either non-production environment or vali running outdated code.") + print( + "Skipping leaderboard update because either non-production environment or vali running outdated code." + ) return True @@ -552,7 +611,9 @@ async def upload_audio_metadata( ) -> bool: print("upload_audio_metadata()") - if not authenticate_with_bittensor(hotkey, metagraph) and not authenticate_with_commune(hotkey, commune_keys): + if not authenticate_with_bittensor( + hotkey, metagraph + ) and not authenticate_with_commune(hotkey, commune_keys): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Valid hotkey required.", @@ -650,11 +711,10 @@ async def upload_audio_metadata( return True @app.post("/api/get_proxy") - async def get_proxy( - hotkey: Annotated[str, Depends(get_hotkey)] - ) -> str: - - if not authenticate_with_bittensor(hotkey, metagraph) and not authenticate_with_commune(hotkey, commune_keys): + async def get_proxy(hotkey: Annotated[str, Depends(get_hotkey)]) -> str: + if not authenticate_with_bittensor( + hotkey, metagraph + ) and not authenticate_with_commune(hotkey, commune_keys): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Valid hotkey required.", @@ -681,11 +741,14 @@ async def get_focus_score( focusing_description: Annotated[str, Body()] = None, background_tasks: BackgroundTasks = BackgroundTasks(), ) -> Dict[str, bool]: - async def run_focus_scoring_task(video_id: str, focusing_task: str, focusing_description: str): + async def run_focus_scoring_task( + video_id: str, focusing_task: str, focusing_description: str + ): await run_focus_scoring(video_id, focusing_task, focusing_description) - + background_tasks.add_task( - run_focus_scoring_task, video_id, focusing_task, focusing_description) + run_focus_scoring_task, video_id, focusing_task, focusing_description + ) return {"success": True} @app.get("/api/focus/get_list") @@ -713,27 +776,33 @@ async def purchase_video( ) if focus_video_cache.already_purchased_max_focus_tao(): - print("Purchases in the last 24 hours have reached the max focus tao limit.") + print( + "Purchases in the last 24 hours have reached the max focus tao limit." + ) raise HTTPException( - 400, "Purchases in the last 24 hours have reached the max focus tao limit, please try again later.") + 400, + "Purchases in the last 24 hours have reached the max focus tao limit, please try again later.", + ) print(f"purchase_video | video_id <{video_id}> hotkey <{hotkey}>") # run with_lock True availability = await check_availability(db, video_id, hotkey, True) - print('availability', availability) - if availability['status'] == 'success': - amount = availability['price'] - video_owner_coldkey = await get_video_owner_coldkey(db, video_id) # run with_lock True - + print("availability", availability) + if availability["status"] == "success": + amount = availability["price"] + video_owner_coldkey = await get_video_owner_coldkey( + db, video_id + ) # run with_lock True + # Create a standalone async function for the background task async def run_confirm_video_purchased(video_id: str): await confirm_video_purchased(video_id, True) - + background_tasks.add_task(run_confirm_video_purchased, video_id) - + return { - 'status': 'success', - 'address': video_owner_coldkey, - 'amount': amount, + "status": "success", + "address": video_owner_coldkey, + "amount": amount, } else: return availability @@ -763,57 +832,61 @@ async def run_stake(video_id): async with session.post( f"{FOCUS_API_URL}/auth/stake", json={"video_id": video_id}, - headers={"FOCUS_API_KEY": FOCUS_API_KEYS[0]} + headers={"FOCUS_API_KEY": FOCUS_API_KEYS[0]}, ) as response: res = await response.json() print(f"Got res={res} from {FOCUS_API_URL}/auth/stake") return res video_owner_coldkey = await get_video_owner_coldkey(db, video_id) - result = await confirm_transfer(db, video_owner_coldkey, video_id, miner_hotkey, block_hash) + result = await confirm_transfer( + db, video_owner_coldkey, video_id, miner_hotkey, block_hash + ) if result: background_tasks.add_task(run_stake, video_id) return { - 'status': 'success', - 'message': 'Video purchase verification was successful' + "status": "success", + "message": "Video purchase verification was successful", } else: return { - 'status': 'error', - 'message': f'Video purchase verification failed for video_id {video_id} on block_hash {block_hash} by miner_hotkey {miner_hotkey}' + "status": "error", + "message": f"Video purchase verification failed for video_id {video_id} on block_hash {block_hash} by miner_hotkey {miner_hotkey}", } - @app.get('/api/focus/miner_purchase_scores/{miner_hotkeys}') - async def miner_purchase_scores(miner_hotkeys: str) -> Dict[str, MinerPurchaseStats]: + @app.get("/api/focus/miner_purchase_scores/{miner_hotkeys}") + async def miner_purchase_scores( + miner_hotkeys: str, + ) -> Dict[str, MinerPurchaseStats]: return focus_video_cache.miner_purchase_stats() - @app.get('/api/focus/miner_purchase_scores') + @app.get("/api/focus/miner_purchase_scores") async def miner_purchase_scores() -> Dict[str, MinerPurchaseStats]: return focus_video_cache.miner_purchase_stats() class TaskTypeMap(BaseModel): task_type_map: Dict[TaskType, float] - @app.get('/api/focus/get_task_percentage_map') + @app.get("/api/focus/get_task_percentage_map") async def get_task_percentage_map(): return TaskTypeMap(task_type_map=TASK_TYPE_MAP) - @app.get('/api/focus/get_rewards_percent') + @app.get("/api/focus/get_rewards_percent") async def get_rewards_percent(): return FOCUS_REWARDS_PERCENT - @app.get('/api/focus/get_max_focus_alpha') + @app.get("/api/focus/get_max_focus_alpha") async def _get_max_focus_alpha() -> float: return await get_max_focus_alpha_per_day() - @app.get('/api/focus/get_variable_reward_pool_alpha') + @app.get("/api/focus/get_variable_reward_pool_alpha") async def _get_variable_reward_pool_alpha() -> float: return await get_variable_reward_pool_alpha() - - @app.get('/api/focus/get_fixed_reward_pool_alpha') + + @app.get("/api/focus/get_fixed_reward_pool_alpha") async def _get_fixed_reward_pool_alpha() -> float: return await get_fixed_reward_pool_alpha() - + async def cache_max_focus_alpha(): while True: """Re-caches the value of max_focus_tao.""" @@ -831,17 +904,17 @@ async def cache_max_focus_alpha(): except Exception as err: attempt += 1 print( - f"Error during recaching of max_focus_alpha (Attempt {attempt}/{max_attempts}):", str(err)) + f"Error during recaching of max_focus_alpha (Attempt {attempt}/{max_attempts}):", + str(err), + ) if attempt >= max_attempts: - print( - "Max attempts reached. Skipping this caching this cycle.") + print("Max attempts reached. Skipping this caching this cycle.") break # Sleep in seconds await asyncio.sleep(1800) # 30 minutes - - + ################ END OMEGA FOCUS ENDPOINTS ################ @app.get("/") @@ -866,7 +939,9 @@ async def get_mm_topics(api_key: str = Security(get_api_key)): return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching data from MySQL database: {err}", + ) @app.get("/api/mm/topic_video_count") async def get_mm_topic_video_count(api_key: str = Security(get_api_key)): @@ -882,10 +957,14 @@ async def get_mm_topic_video_count(api_key: str = Security(get_api_key)): return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching data from MySQL database: {err}", + ) @app.get("/api/mm/topic_relevant/{topic}") - async def get_mm_topic_relevant(api_key: str = Security(get_api_key), topic: str = Path(...)): + async def get_mm_topic_relevant( + api_key: str = Security(get_api_key), topic: str = Path(...) + ): try: connection = connect_to_db() query = f"SELECT video_id, youtube_id, description, start_time, end_time FROM omega_multimodal where query = '{topic}' ORDER BY query_relevance_score DESC LIMIT 100" @@ -898,12 +977,19 @@ async def get_mm_topic_relevant(api_key: str = Security(get_api_key), topic: str return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching data from MySQL database: {err}", + ) + ################ END MULTI-MODAL API / OPENTENSOR CONNECTOR ################ ################ START LEADERBOARD ################ @app.get("/api/leaderboard") - async def get_leaderboard_data(hotkey: Optional[str] = None, sort_by: Optional[str] = None, sort_order: Optional[str] = None): + async def get_leaderboard_data( + hotkey: Optional[str] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + ): try: leaderboard_table_name = "miner_leaderboard" if not IS_PROD: @@ -928,17 +1014,13 @@ async def get_leaderboard_data(hotkey: Optional[str] = None, sort_by: Optional[s "avg_query_relevance": "avg_query_relevance", "avg_novelty": "avg_novelty", "avg_score": "avg_score", - "last_updated": "last_updated" + "last_updated": "last_updated", } sort_column = valid_sort_columns.get(sort_by, sort_column) if sort_order: # Validate and map sort_order to actual values if necessary - valid_sort_orders = { - "asc": "ASC", - "desc": "DESC" - } - sort_order = valid_sort_orders.get( - sort_order.lower(), sort_order) + valid_sort_orders = {"asc": "ASC", "desc": "DESC"} + sort_order = valid_sort_orders.get(sort_order.lower(), sort_order) query += f" ORDER BY {sort_column} {sort_order}" @@ -951,11 +1033,13 @@ async def get_leaderboard_data(hotkey: Optional[str] = None, sort_by: Optional[s return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching data from MySQL database: {err}", + ) @app.get("/leaderboard") async def leaderboard(): - return FileResponse('./validator-api/static/leaderboard.html') + return FileResponse("./validator-api/static/leaderboard.html") @app.get("/api/leaderboard-dataset-data") async def get_leaderboard_dataset_data(): @@ -971,7 +1055,9 @@ async def get_leaderboard_dataset_data(): return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching leaderboard dataset data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching leaderboard dataset data from MySQL database: {err}", + ) @app.get("/api/leaderboard-miner-data") async def get_leaderboard_miner_data(hotkey: Optional[str] = None): @@ -997,7 +1083,9 @@ async def get_leaderboard_miner_data(hotkey: Optional[str] = None): return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching leaderboard miner data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching leaderboard miner data from MySQL database: {err}", + ) @app.get("/api/leaderboard-focus-data") async def get_leaderboard_focus_data(): @@ -1013,7 +1101,10 @@ async def get_leaderboard_focus_data(): return data except mysql.connector.Error as err: raise HTTPException( - status_code=500, detail=f"Error fetching focus kpi data from MySQL database: {err}") + status_code=500, + detail=f"Error fetching focus kpi data from MySQL database: {err}", + ) + ################ END LEADERBOARD ################ ################ START DASHBOARD ################ @@ -1034,7 +1125,9 @@ async def resync_dataset(): except Exception as err: attempt += 1 print( - f"Error during dataset sync (Attempt {attempt}/{max_attempts}):", str(err)) + f"Error during dataset sync (Attempt {attempt}/{max_attempts}):", + str(err), + ) # print_exception(type(err), err, err.__traceback__) if attempt >= max_attempts: @@ -1049,7 +1142,7 @@ async def get_video_metadata( sort_by: Optional[str] = "submitted_at", sort_order: Optional[str] = "desc", page: Optional[int] = 1, - items_per_page: Optional[int] = 50 + items_per_page: Optional[int] = 50, ): print("get_video_metadata()") if os.path.exists(CACHE_FILE): @@ -1066,7 +1159,7 @@ async def get_video_metadata( "description_relevance_score": 5, "query_relevance_score": 6, "query": 7, - "submitted_at": 8 + "submitted_at": 8, } if sort_by and sort_by in sort_index_mapping: @@ -1087,13 +1180,14 @@ async def get_video_metadata( video[6] = round(video[6], 4) # Round query_relevance_score date_time = datetime.fromtimestamp(video[8]) video[8] = date_time.strftime( - '%Y-%m-%d %H:%M:%S') # Format submitted_at + "%Y-%m-%d %H:%M:%S" + ) # Format submitted_at return { "total_items": total_items, "page": page, "items_per_page": items_per_page, - "data": paginated_descriptions + "data": paginated_descriptions, } else: return {"error": "Cache file not found"} @@ -1101,7 +1195,8 @@ async def get_video_metadata( @app.get("/dashboard") async def dashboard(): print("dashboard()") - return FileResponse('validator-api/static/dashboard.html') + return FileResponse("validator-api/static/dashboard.html") + ################ END DASHBOARD ################ async def run_server(): @@ -1126,5 +1221,6 @@ async def run_server(): server_task.cancel() await server_task + if __name__ == "__main__": asyncio.run(main()) diff --git a/validator-api/check_vali_api.py b/validator-api/check_vali_api.py index f00cbd0a..0e397b8a 100644 --- a/validator-api/check_vali_api.py +++ b/validator-api/check_vali_api.py @@ -16,6 +16,7 @@ if len(sys.argv) > 1: API_URL = sys.argv[1] + async def check_validator_api(idx: int): await asyncio.sleep(idx * SECONDS_DELAY) start = asyncio.get_event_loop().time() @@ -30,6 +31,7 @@ async def check_validator_api(idx: int): print(f"Request {idx} timed out after {TIMEOUT_SECONDS} seconds") return "Timeout", TIMEOUT_SECONDS, None + async def main(): timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs(SAVE_DIR, exist_ok=True) @@ -41,19 +43,23 @@ async def main(): durations = [result[1] for result in results] total_time = sum(durations) - timeout_count = sum(1 for message, duration, status in results if message == "Timeout") + timeout_count = sum( + 1 for message, duration, status in results if message == "Timeout" + ) error_521_count = sum(1 for message, duration, status in results if status == 521) # Create histogram plt.figure(figsize=(10, 6)) - plt.hist(durations, bins=50, edgecolor='black') - plt.title(f'Distribution of API Request Durations ({API_URL}, {timeout_count} timeouts, {error_521_count} 521 errors)') - plt.xlabel('Duration (seconds)') - plt.ylabel('Frequency') + plt.hist(durations, bins=50, edgecolor="black") + plt.title( + f"Distribution of API Request Durations ({API_URL}, {timeout_count} timeouts, {error_521_count} 521 errors)" + ) + plt.xlabel("Duration (seconds)") + plt.ylabel("Frequency") plt.savefig(f"{SAVE_DIR}/duration_histogram_{timestamp}.png") plt.close() - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(f"API URL: {API_URL}\n") f.write(f"Number of requests: {NUM_REQUESTS}\n") f.write(f"Total time taken: {total_time:.2f} seconds\n") @@ -73,4 +79,5 @@ async def main(): print(f"Total number of timeouts: {timeout_count}") print(f"Total number of 521 errors: {error_521_count}") + asyncio.run(main()) diff --git a/validator-api/validator_api/check_blocking.py b/validator-api/validator_api/check_blocking.py index ced69763..bcc46687 100644 --- a/validator-api/validator_api/check_blocking.py +++ b/validator-api/validator_api/check_blocking.py @@ -15,7 +15,9 @@ def check_yield(): nonlocal last_yield, yielded current = loop.time() if current - last_yield > 0.1: # Blocked for >100ms - print(f"Blocking operation detected in {request_name}! Blocked for {current - last_yield:.2f}s (username: {username})") + print( + f"Blocking operation detected in {request_name}! Blocked for {current - last_yield:.2f}s (username: {username})" + ) last_yield = current yielded = True diff --git a/validator-api/validator_api/communex/_common.py b/validator-api/validator_api/communex/_common.py index 12fd9c50..637f3950 100644 --- a/validator-api/validator_api/communex/_common.py +++ b/validator-api/validator_api/communex/_common.py @@ -1,6 +1,7 @@ import random -class ComxSettings(): + +class ComxSettings: # TODO: improve node lists NODE_URLS: list[str] = [ "wss://commune-api-node-0.communeai.net", @@ -36,8 +37,7 @@ class ComxSettings(): "wss://commune-api-node-30.communeai.net", "wss://commune-api-node-31.communeai.net", ] - TESTNET_NODE_URLS: list[str] = [ - "wss://testnet-commune-api-node-0.communeai.net"] + TESTNET_NODE_URLS: list[str] = ["wss://testnet-commune-api-node-0.communeai.net"] def get_node_url( @@ -49,4 +49,4 @@ def get_node_url( node_url = random.choice(comx_settings.TESTNET_NODE_URLS) case False: node_url = random.choice(comx_settings.NODE_URLS) - return node_url \ No newline at end of file + return node_url diff --git a/validator-api/validator_api/communex/client.py b/validator-api/validator_api/communex/client.py index e8a32ec9..10fa4b60 100644 --- a/validator-api/validator_api/communex/client.py +++ b/validator-api/validator_api/communex/client.py @@ -113,7 +113,6 @@ def _get_storage_keys( queries: list[tuple[str, list[Any]]], block_hash: str | None, ): - send: list[tuple[str, list[Any]]] = [] prefix_list: list[Any] = [] @@ -121,7 +120,11 @@ def _get_storage_keys( with self.get_conn(init=True) as substrate: for function, params in queries: storage_key = StorageKey.create_from_storage_function( # type: ignore - storage, function, params, runtime_config=substrate.runtime_config, metadata=substrate.metadata # type: ignore + storage, + function, + params, + runtime_config=substrate.runtime_config, + metadata=substrate.metadata, # type: ignore ) prefix = storage_key.to_hex() @@ -157,16 +160,17 @@ def _get_lists( function_parameters: list[tuple[Any, Any, Any, Any, str]] = [] metadata_pallet = substrate.metadata.get_metadata_pallet( # type: ignore - storage_module) # type: ignore + storage_module + ) # type: ignore for storage_function, params in queries: storage_item = metadata_pallet.get_storage_function( # type: ignore - storage_function) # type: ignore + storage_function + ) # type: ignore value_type = storage_item.get_value_type_string() # type: ignore param_types = storage_item.get_params_type_string() # type: ignore key_hashers = storage_item.get_param_hashers() # type: ignore function_parameters.append( - (value_type, param_types, key_hashers, - params, storage_function) # type: ignore + (value_type, param_types, key_hashers, params, storage_function) # type: ignore ) return function_parameters @@ -196,15 +200,14 @@ def _send_batch( with self.get_conn(init=True) as substrate: try: substrate.websocket.send( # type: ignore - json.dumps(batch_payload)) # type: ignore + json.dumps(batch_payload) + ) # type: ignore except NetworkQueryError: pass while len(results) < len(request_ids): - received_messages = json.loads( - substrate.websocket.recv()) # type: ignore + received_messages = json.loads(substrate.websocket.recv()) # type: ignore if isinstance(received_messages, dict): - received_messages: list[dict[Any, Any]] = [ - received_messages] + received_messages: list[dict[Any, Any]] = [received_messages] for message in received_messages: if message.get("id") in request_ids: @@ -289,8 +292,7 @@ def estimate_size(request: tuple[T1, T2]): # Add the last batch if it's not empty if current_batch: result.append(current_batch) - chunk = Chunk(current_batch, current_prefix_batch, - current_params_batch) + chunk = Chunk(current_batch, current_prefix_batch, current_params_batch) chunk_list.append(chunk) return result, chunk_list @@ -384,7 +386,7 @@ def split_chunks(chunk: Chunk, chunk_info: list[Chunk], chunk_info_idx: int): mutaded_chunk_info.pop(chunk_info_idx) for i in range(0, keys_amount, max_n_keys): new_chunk = deepcopy(chunk) - splitted_keys = result_keys[i: i + max_n_keys] + splitted_keys = result_keys[i : i + max_n_keys] splitted_query = deepcopy(query) splitted_query[1][0] = splitted_keys new_chunk.batch_requests = [splitted_query] @@ -403,8 +405,7 @@ def split_chunks(chunk: Chunk, chunk_info: list[Chunk], chunk_info_idx: int): with ThreadPoolExecutor() as executor: futures: list[Future[list[str | dict[Any, Any]]]] = [] for idx, macro_chunk in enumerate(chunk_requests): - _, mutated_chunk_info = split_chunks( - macro_chunk, chunk_requests, idx) + _, mutated_chunk_info = split_chunks(macro_chunk, chunk_requests, idx) for chunk in mutated_chunk_info: request_ids: list[int] = [] batch_payload: list[Any] = [] @@ -521,7 +522,7 @@ def concat_hash_len(key_hasher: str) -> int: item_key_obj = substrate.decode_scale( # type: ignore type_string=f"({', '.join(key_type_string)})", - scale_bytes="0x" + item[0][len(prefix):], + scale_bytes="0x" + item[0][len(prefix) :], return_scale_obj=True, block_hash=block_hash, ) @@ -630,11 +631,9 @@ def recursive_update( return d # type: ignore def get_page(): - send, prefix_list = self._get_storage_keys( - storage, queries, block_hash) + send, prefix_list = self._get_storage_keys(storage, queries, block_hash) with self.get_conn(init=True) as substrate: - function_parameters = self._get_lists( - storage, queries, substrate) + function_parameters = self._get_lists(storage, queries, substrate) responses = self._rpc_request_batch(send) # assumption because send is just the storage_function keys # so it should always be really small regardless of the amount of queries @@ -648,8 +647,7 @@ def get_page(): _, chunks_info = self._make_request_smaller( built_payload, prefix_list, function_parameters ) - chunks_response, chunks_info = self._rpc_request_batch_chunked( - chunks_info) + chunks_response, chunks_info = self._rpc_request_batch_chunked(chunks_info) return chunks_response, chunks_info if not block_hash: @@ -779,7 +777,8 @@ def compose_call( ) extrinsic = substrate.create_signed_extrinsic( # type: ignore - call=call, keypair=key # type: ignore + call=call, + keypair=key, # type: ignore ) # type: ignore response = substrate.submit_extrinsic( extrinsic=extrinsic, @@ -789,7 +788,8 @@ def compose_call( if wait_for_inclusion: if not response.is_success: raise ChainTransactionError( - response.error_message, response # type: ignore + response.error_message, + response, # type: ignore ) return response @@ -891,7 +891,8 @@ def compose_call_multisig( if wait_for_inclusion: if not response.is_success: raise ChainTransactionError( - response.error_message, response # type: ignore + response.error_message, + response, # type: ignore ) return response @@ -1267,8 +1268,7 @@ def multiunstake( params = {"netuid": netuid, "module_keys": keys, "amounts": amounts} - response = self.compose_call( - "remove_stake_multiple", params=params, key=key) + response = self.compose_call("remove_stake_multiple", params=params, key=key) return response @@ -1309,8 +1309,7 @@ def multistake( "netuid": netuid, } - response = self.compose_call( - "add_stake_multiple", params=params, key=key) + response = self.compose_call("add_stake_multiple", params=params, key=key) return response @@ -1346,8 +1345,7 @@ def add_profit_shares( params = {"keys": keys, "shares": shares} - response = self.compose_call( - "add_profit_shares", params=params, key=key) + response = self.compose_call("add_profit_shares", params=params, key=key) return response @@ -1391,11 +1389,9 @@ def add_custom_proposal( key: Keypair, cid: str, ) -> ExtrinsicReceipt: - params = {"data": cid} - response = self.compose_call( - fn="add_custom_proposal", params=params, key=key) + response = self.compose_call(fn="add_custom_proposal", params=params, key=key) return response def add_custom_subnet_proposal( @@ -1548,14 +1544,14 @@ def add_dao_application( params = {"application_key": application_key, "data": data} - response = self.compose_call( - "add_dao_application", key=key, params=params) + response = self.compose_call("add_dao_application", key=key, params=params) return response def query_map_curator_applications(self) -> dict[str, dict[str, str]]: query_result = self.query_map( - "CuratorApplications", params=[], extract_value=False) + "CuratorApplications", params=[], extract_value=False + ) applications = query_result.get("CuratorApplications", {}) return applications @@ -2859,4 +2855,4 @@ def get_existential_deposit(self, block_hash: str | None = None) -> int: "Balances", "ExistentialDeposit", block_hash ).value # type: ignore - return result \ No newline at end of file + return result diff --git a/validator-api/validator_api/communex/errors.py b/validator-api/validator_api/communex/errors.py index 39420047..fbe994e9 100644 --- a/validator-api/validator_api/communex/errors.py +++ b/validator-api/validator_api/communex/errors.py @@ -11,4 +11,4 @@ class NetworkQueryError(NetworkError): class NetworkTimeoutError(NetworkError): - """Timeout error""" \ No newline at end of file + """Timeout error""" diff --git a/validator-api/validator_api/communex/key.py b/validator-api/validator_api/communex/key.py index d955b9af..5d98ab4d 100644 --- a/validator-api/validator_api/communex/key.py +++ b/validator-api/validator_api/communex/key.py @@ -21,7 +21,9 @@ def is_ss58_address(address: str, ss58_format: int = 42) -> TypeGuard[Ss58Addres return ss58.is_valid_ss58_address(address, valid_ss58_format=ss58_format) -def check_ss58_address(address: str | Ss58Address, ss58_format: int = 42) -> Ss58Address: +def check_ss58_address( + address: str | Ss58Address, ss58_format: int = 42 +) -> Ss58Address: """ Validates whether the given string is a valid SS58 address. @@ -36,8 +38,7 @@ def check_ss58_address(address: str | Ss58Address, ss58_format: int = 42) -> Ss5 AssertionError: If the address is invalid. """ - assert is_ss58_address( - address, ss58_format), f"Invalid SS58 address '{address}'" + assert is_ss58_address(address, ss58_format), f"Invalid SS58 address '{address}'" return Ss58Address(address) @@ -47,4 +48,4 @@ def generate_keypair() -> Keypair: """ mnemonic = Keypair.generate_mnemonic() keypair = Keypair.create_from_mnemonic(mnemonic) - return keypair \ No newline at end of file + return keypair diff --git a/validator-api/validator_api/communex/types.py b/validator-api/validator_api/communex/types.py index b1322479..9ab3604c 100644 --- a/validator-api/validator_api/communex/types.py +++ b/validator-api/validator_api/communex/types.py @@ -65,8 +65,8 @@ class SubnetParams(TypedDict): # redundant "TypedDict" inheritance because of pdoc warns. # see https://github.com/mitmproxy/pdoc/blob/26d40827ddbe1658e8ac46cd092f17a44cf0287b/pdoc/doc.py#L691-L692 class SubnetParamsWithEmission(SubnetParams, TypedDict): - """SubnetParams with emission field. - """ + """SubnetParams with emission field.""" + emission: int """Subnet emission percentage (0-100). """ @@ -93,4 +93,4 @@ class ModuleInfoWithBalance(ModuleInfo): class ModuleInfoWithOptionalBalance(ModuleInfo): - balance: int | None \ No newline at end of file + balance: int | None diff --git a/validator-api/validator_api/config.py b/validator-api/validator_api/config.py index 3429fa3a..c5c777c0 100644 --- a/validator-api/validator_api/config.py +++ b/validator-api/validator_api/config.py @@ -7,35 +7,39 @@ load_dotenv(override=True) + def get_secret(secret_name, region_name): # Create a Secrets Manager client session = boto3.session.Session() client = session.client( - service_name='secretsmanager', + service_name="secretsmanager", region_name=region_name, ) - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) + get_secret_value_response = client.get_secret_value(SecretId=secret_name) # For a list of exceptions thrown, see # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html # Decrypts secret using the associated KMS key. - secret = get_secret_value_response['SecretString'] + secret = get_secret_value_response["SecretString"] return secret + def parse_proxies(proxy_list: List[str]) -> List[str]: transformed_proxies = [] for proxy in proxy_list: - proxy_ip, proxy_port, proxy_user, proxy_pass = proxy.split(':') - transformed_proxies.append(f"http://{proxy_user}:{proxy_pass}@{proxy_ip}:{proxy_port}") + proxy_ip, proxy_port, proxy_user, proxy_pass = proxy.split(":") + transformed_proxies.append( + f"http://{proxy_user}:{proxy_pass}@{proxy_ip}:{proxy_port}" + ) return transformed_proxies + def robust_json_loads(json_str: str) -> List[str]: - return json.loads(json_str.replace("\\\"", '"')) + return json.loads(json_str.replace('\\"', '"')) + PORT = int(os.environ.get("PORT", 8002)) NETWORK = os.environ["NETWORK"] @@ -67,10 +71,10 @@ def robust_json_loads(json_str: str) -> List[str]: UPLOAD_AUDIO_BATCH_SIZE = int(os.environ.get("UPLOAD_AUDIO_BATCH_SIZE", 256)) DB_CONFIG = { - 'user': os.environ["DBUSER"], - 'password': os.environ["DBPASS"], - 'host': os.environ["DBHOST"], - 'database': os.environ["DBNAME"] + "user": os.environ["DBUSER"], + "password": os.environ["DBPASS"], + "host": os.environ["DBHOST"], + "database": os.environ["DBNAME"], } # Omega Focus Constants @@ -87,10 +91,14 @@ def robust_json_loads(json_str: str) -> List[str]: BT_TESTNET = "test" BT_MAINNET = "finney" -assert NETWORK in [BT_TESTNET, BT_MAINNET], "SUBTENSOR_NETWORK must be either test or finney" -TAO_REFRESH_INTERVAL_MINUTES = int(os.getenv('TAO_REFRESH_INTERVAL_MINUTES', 10)) - -FOCUS_REWARDS_PERCENT = float(os.getenv('FOCUS_REWARDS_PERCENT', constants.FOCUS_REWARDS_PERCENT)) +assert NETWORK in [BT_TESTNET, BT_MAINNET], ( + "SUBTENSOR_NETWORK must be either test or finney" +) +TAO_REFRESH_INTERVAL_MINUTES = int(os.getenv("TAO_REFRESH_INTERVAL_MINUTES", 10)) + +FOCUS_REWARDS_PERCENT = float( + os.getenv("FOCUS_REWARDS_PERCENT", constants.FOCUS_REWARDS_PERCENT) +) FOCUS_API_KEYS = robust_json_loads(os.environ["FOCUS_API_KEYS"]) FOCUS_API_URL = os.environ["FOCUS_API_URL"] GOOGLE_AI_API_KEY = os.environ["GOOGLE_AI_API_KEY"] @@ -100,9 +108,13 @@ def robust_json_loads(json_str: str) -> List[str]: AWS_S3_REGION = os.environ["AWS_S3_REGION"] AWS_S3_BUCKET_NAME = os.environ["AWS_S3_BUCKET_NAME"] -MAX_FOCUS_POINTS_PER_HOUR = int(os.getenv("MAX_FOCUS_POINTS_PER_HOUR", 80)) # $80 / hour +MAX_FOCUS_POINTS_PER_HOUR = int( + os.getenv("MAX_FOCUS_POINTS_PER_HOUR", 80) +) # $80 / hour FIXED_TAO_USD_ESTIMATE = float(os.getenv("FIXED_TAO_USD_ESTIMATE", 300.0)) -FIXED_ALPHA_TAO_ESTIMATE = float(os.getenv("FIXED_ALPHA_TAO_ESTIMATE", 0.0208)) # 1 alpha to tao, changes over time, you can find this with `btcli subnet list` +FIXED_ALPHA_TAO_ESTIMATE = float( + os.getenv("FIXED_ALPHA_TAO_ESTIMATE", 0.0208) +) # 1 alpha to tao, changes over time, you can find this with `btcli subnet list` FIXED_TAO_ALPHA_ESTIMATE = 1 / FIXED_ALPHA_TAO_ESTIMATE FIXED_ALPHA_USD_ESTIMATE = FIXED_ALPHA_TAO_ESTIMATE * FIXED_TAO_USD_ESTIMATE BOOSTED_TASKS_PERCENTAGE = float(os.getenv("BOOSTED_TASKS_PERCENTAGE", 0.7)) @@ -116,4 +128,4 @@ def robust_json_loads(json_str: str) -> List[str]: f.write(get_secret("prod/gcp_service_user", region_name=AWS_S3_REGION)) SENTRY_DSN = os.getenv("SENTRY_DSN") -IMPORT_SCORE = os.getenv("IMPORT_SCORE", "true").lower() == "true" \ No newline at end of file +IMPORT_SCORE = os.getenv("IMPORT_SCORE", "true").lower() == "true" diff --git a/validator-api/validator_api/cron/confirm_purchase.py b/validator-api/validator_api/cron/confirm_purchase.py index 02109769..6632bff4 100644 --- a/validator-api/validator_api/cron/confirm_purchase.py +++ b/validator-api/validator_api/cron/confirm_purchase.py @@ -8,20 +8,35 @@ from sqlalchemy import select from validator_api.database import get_db_context from validator_api.database.models.focus_video_record import ( - FocusVideoRecord, FocusVideoStateInternal) + FocusVideoRecord, + FocusVideoStateInternal, +) from validator_api.database.models.miner_bans import ( - increment_failed_purchases, reset_failed_purchases) + increment_failed_purchases, + reset_failed_purchases, +) from validator_api.utils.wallet import get_transaction_from_block_hash async def extrinsic_already_confirmed(db: AsyncSession, extrinsic_id: str) -> bool: - query = select(FocusVideoRecord).filter(FocusVideoRecord.extrinsic_id == extrinsic_id) + query = select(FocusVideoRecord).filter( + FocusVideoRecord.extrinsic_id == extrinsic_id + ) result = await db.execute(query) return result.scalar_one_or_none() is not None -async def check_payment(db: AsyncSession, recipient_address: str, sender_address: str, amount: float, block_hash: str = None): + +async def check_payment( + db: AsyncSession, + recipient_address: str, + sender_address: str, + amount: float, + block_hash: str = None, +): try: - print(f"Checking payment of {amount} from {sender_address} to {recipient_address}") + print( + f"Checking payment of {amount} from {sender_address} to {recipient_address}" + ) # Get all transfers associated with the recipient address transfers = await get_transaction_from_block_hash(recipient_address, block_hash) @@ -29,65 +44,80 @@ async def check_payment(db: AsyncSession, recipient_address: str, sender_address # Filter transfers to find the specific payment for transfer in transfers: if ( - transfer["from"] == sender_address and - transfer["to"] == recipient_address and - round(float(transfer["amount"]), 5) == round(amount, 5) + transfer["from"] == sender_address + and transfer["to"] == recipient_address + and round(float(transfer["amount"]), 5) == round(amount, 5) ): if await extrinsic_already_confirmed(db, transfer["extrinsicId"]): continue - print(f"Payment of {amount} found from {sender_address} to {recipient_address}") + print( + f"Payment of {amount} found from {sender_address} to {recipient_address}" + ) return transfer["extrinsicId"] - print(f"Payment of {amount} not found from {sender_address} to {recipient_address}") + print( + f"Payment of {amount} not found from {sender_address} to {recipient_address}" + ) return None except Exception as e: - print(f'Error in checking payment: {e}') + print(f"Error in checking payment: {e}") return None # finally: # sub.close() + SUBTENSOR_RETRIES = 5 SUBTENSOR_DELAY_SECS = 2 + async def confirm_transfer( db: AsyncSession, video_owner_coldkey: str, video_id: str, miner_hotkey: str, block_hash: str = None, - with_lock: bool = False + with_lock: bool = False, ): subtensor = bt.subtensor(network=config.NETWORK) query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, - FocusVideoRecord.processing_state == FocusVideoStateInternal.PURCHASE_PENDING.value, + FocusVideoRecord.processing_state + == FocusVideoStateInternal.PURCHASE_PENDING.value, FocusVideoRecord.miner_hotkey == miner_hotkey, FocusVideoRecord.deleted_at.is_(None), ) if with_lock: query = query.with_for_update() - + result = await db.execute(query) video = result.scalar_one_or_none() if not video: - print(f"confirm_transfer | video <{video_id}> not found or not in PURCHASE_PENDING state") + print( + f"confirm_transfer | video <{video_id}> not found or not in PURCHASE_PENDING state" + ) return False - + tao_amount = video.expected_reward_tao current_time = datetime.utcnow() - print(f"[{current_time}] | Scanning block hash <{block_hash}> for address <{video_owner_coldkey}> payment transaction from ...") + print( + f"[{current_time}] | Scanning block hash <{block_hash}> for address <{video_owner_coldkey}> payment transaction from ..." + ) for attempt in range(SUBTENSOR_RETRIES): try: miner_coldkey = subtensor.get_hotkey_owner(miner_hotkey) print(f"Miner coldkey: {miner_coldkey}") - - extrinsic_id = await check_payment(db, video_owner_coldkey, miner_coldkey, tao_amount, block_hash) + + extrinsic_id = await check_payment( + db, video_owner_coldkey, miner_coldkey, tao_amount, block_hash + ) if extrinsic_id is not None: - print(f"Miner <{miner_hotkey}> successfully purchased focus recording <{video_id}>!") + print( + f"Miner <{miner_hotkey}> successfully purchased focus recording <{video_id}>!" + ) video.miner_hotkey = miner_hotkey video.processing_state = FocusVideoStateInternal.PURCHASED.value video.updated_at = datetime.utcnow() @@ -101,16 +131,26 @@ async def confirm_transfer( except Exception as e: if attempt < SUBTENSOR_RETRIES - 1: # if it's not the last attempt - if "Broken pipe" in str(e) or "EOF occurred in violation of protocol" in str(e) or "[SSL: BAD_LENGTH]" in str(e): - print(f"Connection to subtensor was lost. Re-initializing subtensor and retrying in {SUBTENSOR_DELAY_SECS} seconds...") + if ( + "Broken pipe" in str(e) + or "EOF occurred in violation of protocol" in str(e) + or "[SSL: BAD_LENGTH]" in str(e) + ): + print( + f"Connection to subtensor was lost. Re-initializing subtensor and retrying in {SUBTENSOR_DELAY_SECS} seconds..." + ) subtensor = bt.subtensor(network=config.NETWORK) await asyncio.sleep(SUBTENSOR_DELAY_SECS) else: - print(f"Attempt #{attempt + 1} to sub.get_hotkey_owner() and check_payment() failed. Retrying in {SUBTENSOR_DELAY_SECS} seconds...") + print( + f"Attempt #{attempt + 1} to sub.get_hotkey_owner() and check_payment() failed. Retrying in {SUBTENSOR_DELAY_SECS} seconds..." + ) print(f"Error: {str(e)}") await asyncio.sleep(SUBTENSOR_DELAY_SECS) else: - print(f"All {SUBTENSOR_RETRIES} attempts failed. Unable to retrieve miner coldkey and confirm payment.") + print( + f"All {SUBTENSOR_RETRIES} attempts failed. Unable to retrieve miner coldkey and confirm payment." + ) print(f"Final error: {str(e)}") return False # we got here because we could not confirm the payment. Let's return false to let the miner know @@ -120,17 +160,17 @@ async def confirm_transfer( DELAY_SECS = 30 # 30s RETRIES = 6 # 30s x 10 retries = 180s = 3 mins -async def confirm_video_purchased( - video_id: str, - with_lock: bool = False -): + +async def confirm_video_purchased(video_id: str, with_lock: bool = False): """ - The purpose of this function is to set the video back to the SUBMITTED state + The purpose of this function is to set the video back to the SUBMITTED state if the miner has not confirmed the purchase in time. """ current_time = datetime.utcnow() - print(f"BACKGROUND TASK | {current_time} | Checking if video_id <{video_id}> has been marked as purchased or reverted back to SUBMITTED ...") + print( + f"BACKGROUND TASK | {current_time} | Checking if video_id <{video_id}> has been marked as purchased or reverted back to SUBMITTED ..." + ) try: for i in range(0, RETRIES): await asyncio.sleep(DELAY_SECS) @@ -142,23 +182,37 @@ async def confirm_video_purchased( ) if with_lock: query = query.with_for_update() - + result = await db.execute(query) video = result.scalar_one_or_none() if not video: print(f"Video <{video_id}> not found") return False - - if video is not None and video.processing_state == FocusVideoStateInternal.PURCHASED.value: - print(f"Video <{video_id}> has been marked as PURCHASED. Stopping background task.") + + if ( + video is not None + and video.processing_state + == FocusVideoStateInternal.PURCHASED.value + ): + print( + f"Video <{video_id}> has been marked as PURCHASED. Stopping background task." + ) await reset_failed_purchases(db, video.miner_hotkey) return True - elif video is not None and video.processing_state == FocusVideoStateInternal.SUBMITTED.value: - print(f"Video <{video_id}> has been marked as SUBMITTED. Stopping background task.") + elif ( + video is not None + and video.processing_state + == FocusVideoStateInternal.SUBMITTED.value + ): + print( + f"Video <{video_id}> has been marked as SUBMITTED. Stopping background task." + ) return True - print(f"Video <{video_id}> has NOT been marked as PURCHASED. Retrying in {DELAY_SECS} seconds...") + print( + f"Video <{video_id}> has NOT been marked as PURCHASED. Retrying in {DELAY_SECS} seconds..." + ) # close the db connection until next retry await db.close() @@ -167,7 +221,9 @@ async def confirm_video_purchased( # we got here because we could not confirm the payment in time, so we need to revert # the video back to the SUBMITTED state (i.e. mark available for purchase) - print(f"Video <{video_id}> has NOT been marked as PURCHASED. Reverting to SUBMITTED state...") + print( + f"Video <{video_id}> has NOT been marked as PURCHASED. Reverting to SUBMITTED state..." + ) await increment_failed_purchases(db, video.miner_hotkey) video.processing_state = FocusVideoStateInternal.SUBMITTED.value video.updated_at = datetime.utcnow() diff --git a/validator-api/validator_api/database/__init__.py b/validator-api/validator_api/database/__init__.py index b21eb907..6052bd93 100644 --- a/validator-api/validator_api/database/__init__.py +++ b/validator-api/validator_api/database/__init__.py @@ -13,7 +13,9 @@ DB_POOL_SIZE = config.FOCUS_DB_POOL_SIZE DB_MAX_OVERFLOW = config.FOCUS_DB_MAX_OVERFLOW -DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" +DATABASE_URL = ( + f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" +) engine = create_async_engine( DATABASE_URL, @@ -23,10 +25,13 @@ pool_pre_ping=True, # Good practice for most scenarios pool_recycle=300, # Recycle connections after 5 minutes ) -SessionLocal = async_sessionmaker(class_=AsyncSession, autocommit=False, autoflush=False, bind=engine) +SessionLocal = async_sessionmaker( + class_=AsyncSession, autocommit=False, autoflush=False, bind=engine +) Base = declarative_base() metadata = MetaData() + async def get_db(): db = SessionLocal() try: @@ -34,6 +39,7 @@ async def get_db(): finally: await db.close() + @asynccontextmanager async def get_db_context(): async for db in get_db(): diff --git a/validator-api/validator_api/database/crud/focusvideo.py b/validator-api/validator_api/database/crud/focusvideo.py index 19143163..1710b3e0 100644 --- a/validator-api/validator_api/database/crud/focusvideo.py +++ b/validator-api/validator_api/database/crud/focusvideo.py @@ -11,15 +11,24 @@ from validator_api.config import NETWORK, NETUID from validator_api.database import get_db_context -from validator_api.database.models.focus_video_record import FocusVideoRecord, FocusVideoInternal, FocusVideoStateInternal, TaskType +from validator_api.database.models.focus_video_record import ( + FocusVideoRecord, + FocusVideoInternal, + FocusVideoStateInternal, + TaskType, +) from validator_api.database.models.user import UserRecord -from validator_api.utils.marketplace import get_max_focus_alpha_per_day, get_variable_reward_pool_alpha, get_max_focus_points_available_today +from validator_api.utils.marketplace import ( + get_max_focus_alpha_per_day, + get_variable_reward_pool_alpha, + get_max_focus_points_available_today, +) from pydantic import BaseModel from validator_api.scoring.scoring_service import VideoScore, FocusVideoEmbeddings MIN_REWARD_TAO = 0.001 -MIN_REWARD_ALPHA = .5 +MIN_REWARD_ALPHA = 0.5 class CachedValue: @@ -44,7 +53,9 @@ async def _background_update(self): self.is_initialized = True print(f"Cache {self._fetch_func.__name__} initialized") else: - print(f"Cache {self._fetch_func.__name__} updated at {datetime.utcnow()}") + print( + f"Cache {self._fetch_func.__name__} updated at {datetime.utcnow()}" + ) except Exception as e: # Log error or handle as needed; do not crash the loop print(f"Background cache update failed: {e}\n{traceback.format_exc()}") @@ -56,33 +67,43 @@ def get(self): raise Exception("Cache is not initialized yet") return self._value + async def _fetch_available_focus(): async with get_db_context() as db: # Show oldest videos first so they get rewarded fastest - query = select(FocusVideoRecord).filter( - FocusVideoRecord.processing_state == FocusVideoStateInternal.SUBMITTED.value, - FocusVideoRecord.deleted_at.is_(None), - FocusVideoRecord.expected_reward_tao > MIN_REWARD_TAO, - # FocusVideoRecord.expected_reward_alpha > MIN_REWARD_ALPHA, - ).order_by(FocusVideoRecord.updated_at.asc()).limit(10) - + query = ( + select(FocusVideoRecord) + .filter( + FocusVideoRecord.processing_state + == FocusVideoStateInternal.SUBMITTED.value, + FocusVideoRecord.deleted_at.is_(None), + FocusVideoRecord.expected_reward_tao > MIN_REWARD_TAO, + # FocusVideoRecord.expected_reward_alpha > MIN_REWARD_ALPHA, + ) + .order_by(FocusVideoRecord.updated_at.asc()) + .limit(10) + ) + result = await db.execute(query) items = result.scalars().all() return [FocusVideoInternal.model_validate(record) for record in items] + async def _alpha_to_tao_rate() -> float: async with bittensor.AsyncSubtensor(network=NETWORK) as subtensor: subnet = await subtensor.subnet(NETUID) balance = subnet.alpha_to_tao(1) return balance.tao + async def _already_purchased_max_focus_tao() -> bool: async with get_db_context() as db: query = select(func.sum(FocusVideoRecord.earned_reward_tao)).filter( - FocusVideoRecord.processing_state == FocusVideoStateInternal.PURCHASED.value, - FocusVideoRecord.updated_at >= datetime.utcnow() - timedelta(hours=24) + FocusVideoRecord.processing_state + == FocusVideoStateInternal.PURCHASED.value, + FocusVideoRecord.updated_at >= datetime.utcnow() - timedelta(hours=24), ) - + result = await db.execute(query) total_earned_tao = result.scalar() or 0 effective_max_focus_alpha = await get_variable_reward_pool_alpha() @@ -90,22 +111,27 @@ async def _already_purchased_max_focus_tao() -> bool: return total_earned_tao >= effective_max_focus_tao + class MinerPurchaseStats(BaseModel): total_focus_points: float max_focus_points: float focus_points_percentage: float + async def _get_miner_purchase_stats() -> Dict[str, MinerPurchaseStats]: async with get_db_context() as db: query = select(FocusVideoRecord).filter( - FocusVideoRecord.processing_state == FocusVideoStateInternal.PURCHASED.value, - FocusVideoRecord.updated_at >= datetime.utcnow() - timedelta(hours=24) + FocusVideoRecord.processing_state + == FocusVideoStateInternal.PURCHASED.value, + FocusVideoRecord.updated_at >= datetime.utcnow() - timedelta(hours=24), ) result = await db.execute(query) purchased_videos_records = result.scalars().all() # Calculate total earned tao - total_earned_tao = sum(record.earned_reward_tao or 0 for record in purchased_videos_records) + total_earned_tao = sum( + record.earned_reward_tao or 0 for record in purchased_videos_records + ) # Group records by miner hotkey videos_by_miner = {} @@ -117,22 +143,33 @@ async def _get_miner_purchase_stats() -> Dict[str, MinerPurchaseStats]: # Process stats for each miner stats = {} for miner_hotkey, miner_videos in videos_by_miner.items(): - miner_earned_tao = sum(video_record.earned_reward_tao for video_record in miner_videos) - tao_percentage = miner_earned_tao / total_earned_tao if total_earned_tao > 0 else 0 + miner_earned_tao = sum( + video_record.earned_reward_tao for video_record in miner_videos + ) + tao_percentage = ( + miner_earned_tao / total_earned_tao if total_earned_tao > 0 else 0 + ) stats[miner_hotkey] = MinerPurchaseStats( total_focus_points=miner_earned_tao, max_focus_points=total_earned_tao, - focus_points_percentage=tao_percentage + focus_points_percentage=tao_percentage, ) return stats + class FocusVideoCache: def __init__(self): - self._available_focus_cache = CachedValue(fetch_func=_fetch_available_focus, update_interval=180) + self._available_focus_cache = CachedValue( + fetch_func=_fetch_available_focus, update_interval=180 + ) self._alpha_to_tao_cache = CachedValue(fetch_func=_alpha_to_tao_rate) - self._already_purchased_cache = CachedValue(fetch_func=_already_purchased_max_focus_tao) - self._miner_purchase_stats_cache = CachedValue(fetch_func=_get_miner_purchase_stats, update_interval=180) + self._already_purchased_cache = CachedValue( + fetch_func=_already_purchased_max_focus_tao + ) + self._miner_purchase_stats_cache = CachedValue( + fetch_func=_get_miner_purchase_stats, update_interval=180 + ) def get_all_available_focus(self): try: @@ -149,11 +186,11 @@ def alpha_to_tao_rate(self) -> float: def miner_purchase_stats(self) -> Dict[str, MinerPurchaseStats]: return self._miner_purchase_stats_cache.get() + async def get_video_owner_coldkey(db: AsyncSession, video_id: str) -> str: try: query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() @@ -167,7 +204,7 @@ async def get_video_owner_coldkey(db: AsyncSession, video_id: str) -> str: query = select(UserRecord).filter(UserRecord.email == user_email) result = await db.execute(query) user_record = result.scalar_one_or_none() - + if user_record is None: raise HTTPException(404, detail="User not found") @@ -177,37 +214,39 @@ async def get_video_owner_coldkey(db: AsyncSession, video_id: str) -> str: print(f"Error in get_video_owner_coldkey: {str(e)}") raise HTTPException(500, detail=f"Error retrieving video owner: {str(e)}") + async def check_availability( - db: AsyncSession, - video_id: str, - miner_hotkey: str, - with_lock: bool = False + db: AsyncSession, video_id: str, miner_hotkey: str, with_lock: bool = False ): try: # Use explicit loading strategy to avoid lazy loading issues query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None), - FocusVideoRecord.processing_state == FocusVideoStateInternal.SUBMITTED.value, # is available for purchase + FocusVideoRecord.processing_state + == FocusVideoStateInternal.SUBMITTED.value, # is available for purchase FocusVideoRecord.expected_reward_tao > MIN_REWARD_TAO, # FocusVideoRecord.expected_reward_alpha > MIN_REWARD_ALPHA, ) - + if with_lock: query = query.with_for_update() - + result = await db.execute(query) video_record = result.scalar_one_or_none() if video_record is None: return { - 'status': 'error', - 'message': f'video {video_id} not found or not available for purchase' + "status": "error", + "message": f"video {video_id} not found or not available for purchase", } if video_record.expected_reward_tao is None: - raise HTTPException(500, detail="The video record is missing the expected reward tao, investigate this bug") - + raise HTTPException( + 500, + detail="The video record is missing the expected reward tao, investigate this bug", + ) + # TODO: This is commented out because expected_reward_alpha is not filled in for all videos yet, need to migrate # if video_record.expected_reward_alpha is None: # raise HTTPException(500, detail="The video record is missing the expected reward alpha, investigate this bug") @@ -215,7 +254,7 @@ async def check_availability( # Create a copy of the values we need to avoid lazy loading issues expected_reward_tao = video_record.expected_reward_tao expected_reward_alpha = video_record.expected_reward_alpha - + # mark the purchase as pending i.e. a miner has claimed the video for purchase and now just needs to pay video_record.processing_state = FocusVideoStateInternal.PURCHASE_PENDING.value video_record.miner_hotkey = miner_hotkey @@ -228,9 +267,9 @@ async def check_availability( await db.commit() return { - 'status': 'success', - 'price': expected_reward_tao, - 'price_alpha': expected_reward_alpha, + "status": "success", + "price": expected_reward_tao, + "price_alpha": expected_reward_alpha, } except Exception as e: @@ -239,23 +278,24 @@ async def check_availability( await db.rollback() raise HTTPException(500, detail="Internal error") + async def check_video_metadata( - db: AsyncSession, - video_id: str, - user_email: str, - miner_hotkey: str + db: AsyncSession, video_id: str, user_email: str, miner_hotkey: str ): try: query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, FocusVideoRecord.user_email == user_email, FocusVideoRecord.miner_hotkey == miner_hotkey, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.deleted_at.is_(None), ) result = await db.execute(query) video_info = result.scalar_one_or_none() - if video_info is not None and video_info.processing_state == FocusVideoStateInternal.PURCHASED.value: + if ( + video_info is not None + and video_info.processing_state == FocusVideoStateInternal.PURCHASED.value + ): # # FV TODO: why do we need the task info? # task_info = db.query(models.Task).filter_by(id=video_info.task_id).first() @@ -269,7 +309,7 @@ async def check_video_metadata( # 'success': True, # 'score': video_score # } - + # return { # 'success': False, # 'message': 'No task found.' @@ -283,31 +323,27 @@ async def check_video_metadata( # print(f"Video score: {video_score}") video_score = video_info.video_score - return { - 'success': True, - 'score': video_score - } + return {"success": True, "score": video_score} - return { - 'success': False, - 'message': 'No video found.' - } + return {"success": False, "message": "No video found."} except Exception as e: print(e) - return { - 'success': False, - 'message': 'Internal Server Errror' - } + return {"success": False, "message": "Internal Server Errror"} + -async def set_focus_video_score(db: AsyncSession, video_id: str, score_details: VideoScore, embeddings: FocusVideoEmbeddings): +async def set_focus_video_score( + db: AsyncSession, + video_id: str, + score_details: VideoScore, + embeddings: FocusVideoEmbeddings, +): query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: raise HTTPException(404, detail="Focus video not found") @@ -319,29 +355,31 @@ async def set_focus_video_score(db: AsyncSession, video_id: str, score_details: video_record.embeddings = json.loads(embeddings.model_dump_json()) video_record.processing_state = FocusVideoStateInternal.READY.value video_record.updated_at = datetime.utcnow() - video_record.task_type = TaskType.BOOSTED if score_details.boosted_multiplier > 1.0 else TaskType.USER + video_record.task_type = ( + TaskType.BOOSTED if score_details.boosted_multiplier > 1.0 else TaskType.USER + ) db.add(video_record) await db.commit() + async def mark_video_rejected( db: AsyncSession, video_id: str, rejection_reason: str, - score_details: Optional[VideoScore]=None, - embeddings: Optional[FocusVideoEmbeddings]=None, - exception_string: Optional[str]=None, + score_details: Optional[VideoScore] = None, + embeddings: Optional[FocusVideoEmbeddings] = None, + exception_string: Optional[str] = None, ): query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: raise HTTPException(404, detail="Focus video not found") - video_details = { **video_record.video_details } + video_details = {**video_record.video_details} if score_details: video_details = { @@ -363,22 +401,32 @@ async def mark_video_rejected( db.add(video_record) await db.commit() -async def mark_video_submitted(db: AsyncSession, video_id: str, miner_hotkey: str, with_lock: bool = False): + +async def mark_video_submitted( + db: AsyncSession, video_id: str, miner_hotkey: str, with_lock: bool = False +): # Mark video as "SUBMITTED" if in the "PURCHASE_PENDING" state. - query = select(FocusVideoRecord).filter( + query = select( + FocusVideoRecord + ).filter( FocusVideoRecord.video_id == video_id, - FocusVideoRecord.processing_state == FocusVideoStateInternal.PURCHASE_PENDING.value, + FocusVideoRecord.processing_state + == FocusVideoStateInternal.PURCHASE_PENDING.value, FocusVideoRecord.deleted_at.is_(None), - FocusVideoRecord.miner_hotkey == miner_hotkey # make sure the miner requesting the cancellation is the one who was trying to buy it! + FocusVideoRecord.miner_hotkey + == miner_hotkey, # make sure the miner requesting the cancellation is the one who was trying to buy it! ) if with_lock: query = query.with_for_update() - + result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: - raise HTTPException(404, detail="Focus video not found or not in the correct state: PURCHASE_PENDING") + raise HTTPException( + 404, + detail="Focus video not found or not in the correct state: PURCHASE_PENDING", + ) video_record.processing_state = FocusVideoStateInternal.SUBMITTED.value video_record.updated_at = datetime.utcnow() diff --git a/validator-api/validator_api/database/encrypted_json.py b/validator-api/validator_api/database/encrypted_json.py index 44fb26f5..9d5ae02d 100644 --- a/validator-api/validator_api/database/encrypted_json.py +++ b/validator-api/validator_api/database/encrypted_json.py @@ -19,7 +19,9 @@ class EncryptedJSON(TypeDecorator): # For MySQL, the default limit here is 64 kb. In the prod DB, I (Salman) set it to 4GB. impl = LargeBinary - def process_bind_param(self, value: Optional[JSONType], dialect: Dialect) -> Optional[bytes]: + def process_bind_param( + self, value: Optional[JSONType], dialect: Dialect + ) -> Optional[bytes]: if value is not None: try: return encrypt_data(value) @@ -27,7 +29,9 @@ def process_bind_param(self, value: Optional[JSONType], dialect: Dialect) -> Opt raise ValueError(f"Error encrypting data: {str(e)}") return None - def process_result_value(self, value: Optional[bytes], dialect: Dialect) -> Optional[JSONType]: + def process_result_value( + self, value: Optional[bytes], dialect: Dialect + ) -> Optional[JSONType]: if value is not None: try: return decrypt_data(value) @@ -56,14 +60,20 @@ def decrypt_data(encrypted_data: bytes) -> JSONType: class LargeEncryptedJSON(EncryptedJSON): - impl = LargeBinary(length=4 * 1024 * 1024 * 1024 - 1) # 4 GB - 1 byte because thats the MySQL max + impl = LargeBinary( + length=4 * 1024 * 1024 * 1024 - 1 + ) # 4 GB - 1 byte because thats the MySQL max + class MediumEncryptedJSON(EncryptedJSON): - impl = LargeBinary(length=16 * 1024 * 1024 - 1) # 16 MB - 1 byte (MySQL MEDIUMBLOB max size) + impl = LargeBinary( + length=16 * 1024 * 1024 - 1 + ) # 16 MB - 1 byte (MySQL MEDIUMBLOB max size) + def test_encrypted_json(): encrypted_json_type = EncryptedJSON() - + class FakeModel(BaseModel): name: str value: int @@ -80,18 +90,22 @@ class NestedFakeModel(BaseModel): 3.14, # float True, # bool None, # null - {"nested": {"list": [1, 2, 3], "dict": {"a": 1, "b": 2}}}, # complex nested structure + { + "nested": {"list": [1, 2, 3], "dict": {"a": 1, "b": 2}} + }, # complex nested structure FakeModel(name="Test", value=123), # Pydantic BaseModel - NestedFakeModel(nested=FakeModel(name="Nested", value=456)), # Nested Pydantic BaseModel + NestedFakeModel( + nested=FakeModel(name="Nested", value=456) + ), # Nested Pydantic BaseModel ] - + for case in test_cases: # Simulate database write encrypted = encrypted_json_type.process_bind_param(case, None) - + # Simulate database read decrypted = encrypted_json_type.process_result_value(encrypted, None) - + if isinstance(case, BaseModel): assert type(case)(**decrypted) == case, f"Failed for case: {case}" else: diff --git a/validator-api/validator_api/database/models/boosted_task.py b/validator-api/validator_api/database/models/boosted_task.py index 0d2bbceb..1077bd7e 100644 --- a/validator-api/validator_api/database/models/boosted_task.py +++ b/validator-api/validator_api/database/models/boosted_task.py @@ -2,8 +2,9 @@ from validator_api.database import Base from datetime import datetime + class BoostedTask(Base): - __tablename__ = 'boosted_tasks' + __tablename__ = "boosted_tasks" id = Column(Integer, primary_key=True) created_at = Column(DateTime, nullable=False, default=datetime.utcnow) diff --git a/validator-api/validator_api/database/models/focus_video_record.py b/validator-api/validator_api/database/models/focus_video_record.py index 6a05cc2c..57f31a87 100644 --- a/validator-api/validator_api/database/models/focus_video_record.py +++ b/validator-api/validator_api/database/models/focus_video_record.py @@ -11,11 +11,13 @@ import enum + class TaskType(enum.Enum): USER = "USER" BOOSTED = "BOOSTED" MARKETPLACE = "MARKETPLACE" + class FocusVideoStateExternal(enum.Enum): PROCESSING = "PROCESSING" PENDING_HUMAN_REVIEW = "PENDING_HUMAN_REVIEW" @@ -24,6 +26,7 @@ class FocusVideoStateExternal(enum.Enum): SUBMITTED = "SUBMITTED" REWARDED = "REWARDED" + class FocusVideoStateInternal(enum.Enum): # OMEGA Focus user facing states IN_PROGRESS = "IN_PROGRESS" @@ -32,7 +35,7 @@ class FocusVideoStateInternal(enum.Enum): READY = "READY" # Score has been calculated and task is eligible for submission REJECTED = "REJECTED" # Turns out that the task was NOT eligible for submission, lifecycle ended here SUBMITTED = "SUBMITTED" # User has pressed "Submit" and the task is now listed on the marketplace, for SN24 miners to buy - + # Miner purchase states PURCHASE_PENDING = "PURCHASE_PENDING" # a miner has request to buy the video, and we have sent them the amount of tao that they need to send the focus user PURCHASED = "PURCHASED" # our background cron has confirmed that the miner has bought the focus video @@ -58,15 +61,33 @@ def map_focus_video_state(state: FocusVideoStateInternal) -> FocusVideoStateExte else: raise ValueError(f"Invalid focus video state: {state}") + class FocusVideoRecord(Base): - __tablename__ = 'focus_videos' + __tablename__ = "focus_videos" - video_id = Column(String(DB_STRING_LENGTH), primary_key=True, default=lambda: str(uuid.uuid4()), nullable=False) + video_id = Column( + String(DB_STRING_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), + nullable=False, + ) task_id = Column(String(DB_STRING_LENGTH), nullable=False) user_id = Column(String, nullable=False) user_email = Column(String, nullable=False) - processing_state = Column(Enum(*FocusVideoStateInternal.__members__, name='focus_videos_processing_state', schema='public'), nullable=False, default=FocusVideoStateInternal.PROCESSING) - task_type = Column(Enum(*TaskType.__members__, name='focus_videos_task_type', schema='public'), nullable=False, default=TaskType.USER) + processing_state = Column( + Enum( + *FocusVideoStateInternal.__members__, + name="focus_videos_processing_state", + schema="public", + ), + nullable=False, + default=FocusVideoStateInternal.PROCESSING, + ) + task_type = Column( + Enum(*TaskType.__members__, name="focus_videos_task_type", schema="public"), + nullable=False, + default=TaskType.USER, + ) video_score = Column(Float, nullable=True) video_details = Column(JSONB, nullable=True) embeddings = Column(JSONB, nullable=True) @@ -84,6 +105,7 @@ class FocusVideoRecord(Base): def get_duration(self) -> float: return float(self.video_details.get("duration", 0.0)) + class FocusVideoBase(BaseModel): video_id: str task_id: str @@ -99,6 +121,7 @@ class FocusVideoBase(BaseModel): updated_at: datetime deleted_at: Optional[datetime] + class FocusVideoInternal(FocusVideoBase): model_config = ConfigDict(from_attributes=True) diff --git a/validator-api/validator_api/database/models/miner_bans.py b/validator-api/validator_api/database/models/miner_bans.py index adad23a6..3d9fbc0b 100644 --- a/validator-api/validator_api/database/models/miner_bans.py +++ b/validator-api/validator_api/database/models/miner_bans.py @@ -9,74 +9,76 @@ from sqlalchemy.ext.asyncio import AsyncSession from datetime import timedelta + class MinerBan(Base): - __tablename__ = 'miner_bans' + __tablename__ = "miner_bans" miner_hotkey = Column(String(DB_STRING_LENGTH), primary_key=True, nullable=False) purchases_failed_in_a_row = Column(Integer, nullable=False) banned_until = Column(DateTime(timezone=True), nullable=True) + async def miner_banned_until(db: AsyncSession, miner_hotkey: str) -> Optional[datetime]: """ Check if a miner is currently banned and return their ban expiration time if they are. - + Args: db: Database session miner_hotkey: The miner's hotkey to check - + Returns: datetime: The banned_until time if the miner is currently banned None: If the miner is not currently banned """ query = select(MinerBan).filter( - MinerBan.miner_hotkey == miner_hotkey, - MinerBan.banned_until > datetime.utcnow() + MinerBan.miner_hotkey == miner_hotkey, MinerBan.banned_until > datetime.utcnow() ) result = await db.execute(query) ban = result.scalar_one_or_none() - + return ban.banned_until if ban else None + async def get_or_create_miner(db: AsyncSession, miner_hotkey: str) -> MinerBan: """ Get a miner's ban record or create it if it doesn't exist. - + Args: db: Database session miner_hotkey: The miner's hotkey - + Returns: MinerBan: The miner's ban record """ - query = select(MinerBan).filter( - MinerBan.miner_hotkey == miner_hotkey - ) + query = select(MinerBan).filter(MinerBan.miner_hotkey == miner_hotkey) result = await db.execute(query) miner = result.scalar_one_or_none() - + if not miner: miner = MinerBan( - miner_hotkey=miner_hotkey, - purchases_failed_in_a_row=0, - banned_until=None + miner_hotkey=miner_hotkey, purchases_failed_in_a_row=0, banned_until=None ) db.add(miner) await db.commit() - + return miner + async def increment_failed_purchases(db: AsyncSession, miner_hotkey: str): """ Increment the number of purchases failed in a row for a miner. Creates the miner record if it doesn't exist. - + """ miner = await get_or_create_miner(db, miner_hotkey) miner.purchases_failed_in_a_row += 1 - print(f"increment_failed_purchases | miner_hotkey <{miner_hotkey}> purchases_failed_in_a_row <{miner.purchases_failed_in_a_row}>") + print( + f"increment_failed_purchases | miner_hotkey <{miner_hotkey}> purchases_failed_in_a_row <{miner.purchases_failed_in_a_row}>" + ) check_and_ban_miner(miner) await db.commit() + async def reset_failed_purchases(db: AsyncSession, miner_hotkey: str): """ In the case of a successful purchase, reset the number of purchases failed in a row for a miner. @@ -87,7 +89,10 @@ async def reset_failed_purchases(db: AsyncSession, miner_hotkey: str): miner.banned_until = None await db.commit() + BAN_PURCHASES_FAILED_IN_A_ROW = 5 + + def check_and_ban_miner(miner: MinerBan): """ If a miner fails more than BAN_PURCHASES_FAILED_IN_A_ROW purchases in a row, ban them for 24 hours. diff --git a/validator-api/validator_api/database/models/scoring.py b/validator-api/validator_api/database/models/scoring.py index 195fc988..d66d0cfa 100644 --- a/validator-api/validator_api/database/models/scoring.py +++ b/validator-api/validator_api/database/models/scoring.py @@ -1,36 +1,67 @@ from typing import List, Optional from pydantic import BaseModel, Field + class VideoTooShortError(Exception): pass + class VideoTooLongError(Exception): pass + class VideoUniquenessError(Exception): pass + class LegitimacyCheckError(Exception): pass + class TaskScoreBreakdown(BaseModel): - reasoning_steps: List[str] = Field(description="Steps of reasoning used to arrive at the final score. Before each step, write the text 'Step X: '") - final_score: float = Field(ge=0, le=1, description="Final score for the task, between 0.0 and 1.0") - rationale: str = Field(description="Compendious user-facing explanation for the given score") + reasoning_steps: List[str] = Field( + description="Steps of reasoning used to arrive at the final score. Before each step, write the text 'Step X: '" + ) + final_score: float = Field( + ge=0, le=1, description="Final score for the task, between 0.0 and 1.0" + ) + rationale: str = Field( + description="Compendious user-facing explanation for the given score" + ) + class DetailedVideoDescription(BaseModel): - applications_used: List[str] = Field(description="List of applications used in the video for completing the task") - completion_sequence_steps: List[str] = Field(description="Highly detailed step-by-step breakdown of the sequence of steps taken to complete the task") - user_feedback: str = Field(description="Feedback for the user to improve their task completion skills in the future") - description: str = Field(description="High-level summary description of the video content") + applications_used: List[str] = Field( + description="List of applications used in the video for completing the task" + ) + completion_sequence_steps: List[str] = Field( + description="Highly detailed step-by-step breakdown of the sequence of steps taken to complete the task" + ) + user_feedback: str = Field( + description="Feedback for the user to improve their task completion skills in the future" + ) + description: str = Field( + description="High-level summary description of the video content" + ) + class CompletionScore(BaseModel): - rationale: str = Field(description="Concise description of how well the user completed the task") - completion_score: float = Field(ge=0, le=1, description="Final completion score, between 0.0 and 1.0") + rationale: str = Field( + description="Concise description of how well the user completed the task" + ) + completion_score: float = Field( + ge=0, le=1, description="Final completion score, between 0.0 and 1.0" + ) + class CompletionScoreWithoutRange(BaseModel): - rationale: str = Field(description="Concise description of how well the user completed the task") - completion_score: float = Field(description="Final completion score, between 0.0 and 1.0") + rationale: str = Field( + description="Concise description of how well the user completed the task" + ) + completion_score: float = Field( + description="Final completion score, between 0.0 and 1.0" + ) + class VideoScore(BaseModel): # task and video scores @@ -48,15 +79,18 @@ class VideoScore(BaseModel): completion_score_breakdown: CompletionScore detailed_video_description: DetailedVideoDescription + class FocusVideoEmbeddings(BaseModel): # embeddings task_overview_embedding: Optional[List[float]] detailed_video_description_embedding: Optional[List[float]] video_embedding: Optional[List[float]] + class BoostedTaskIndex(BaseModel): index: int - + + class BoostedTaskData(BaseModel): title: str description: str diff --git a/validator-api/validator_api/database/models/task.py b/validator-api/validator_api/database/models/task.py index 68c6057a..0d25d65f 100644 --- a/validator-api/validator_api/database/models/task.py +++ b/validator-api/validator_api/database/models/task.py @@ -5,8 +5,9 @@ from pydantic import BaseModel, ConfigDict from typing import Optional + class TaskRecordPG(Base): - __tablename__ = 'tasks' + __tablename__ = "tasks" id = Column(String(DB_STRING_LENGTH), primary_key=True, nullable=False) info = Column(String(DB_STRING_LENGTH)) description = Column(String(DB_STRING_LENGTH)) diff --git a/validator-api/validator_api/database/models/user.py b/validator-api/validator_api/database/models/user.py index c25bc3bb..3f9cca11 100644 --- a/validator-api/validator_api/database/models/user.py +++ b/validator-api/validator_api/database/models/user.py @@ -8,7 +8,7 @@ class UserRecord(Base): - __tablename__ = 'users' + __tablename__ = "users" id = Column(String, primary_key=True, nullable=False) email = Column(String(DB_STRING_LENGTH), primary_key=True, nullable=False) diff --git a/validator-api/validator_api/database/schemas.py b/validator-api/validator_api/database/schemas.py index f3931d60..fcb13620 100644 --- a/validator-api/validator_api/database/schemas.py +++ b/validator-api/validator_api/database/schemas.py @@ -3,19 +3,22 @@ from typing import List, Optional from pydantic import BaseModel, Field + class TaskStatusEnum(enum.Enum): - Ready = 'Ready' - Running = 'Running' - Stopped = 'Stopped' - Completed = 'Completed' + Ready = "Ready" + Running = "Running" + Stopped = "Stopped" + Completed = "Completed" + class FocusVideoEnum(enum.Enum): - Uploaded = 'Uploaded' - Available = 'Available' - Pending = 'Pending' - Purchased = 'Purchased' - Submitted = 'Submitted' - Consumed = 'Consumed' + Uploaded = "Uploaded" + Available = "Available" + Pending = "Pending" + Purchased = "Purchased" + Submitted = "Submitted" + Consumed = "Consumed" + class TaskSchema(BaseModel): focusing_task: str = Field(...) @@ -28,23 +31,28 @@ class TaskSchema(BaseModel): score: float | None = None event: dict | None = None + class UserSchema(BaseModel): email: str = Field(...) password: str = Field(...) nick_name: str = Field(...) + class UserLoginSchema(BaseModel): email: str = Field(...) password: str = Field(...) - + + class IpfsUrlSchema(BaseModel): url: str = Field(...) miner_hotkey: str = Field(...) + class TimeSlot(BaseModel): start: str end: str + class FocusTask(BaseModel): id: str name: str @@ -59,11 +67,13 @@ class FocusTask(BaseModel): totalDuration: str category: Optional[str] = None + class Metadata(BaseModel): date: str day: str lastUpdated: datetime + class DailySchedule(BaseModel): metadata: Metadata tasks: List[FocusTask] @@ -74,20 +84,28 @@ class Link(BaseModel): url: str = Field(..., description="URL of the website") name: str = Field(..., description="Name of the website") + class Step(BaseModel): title: str = Field(..., description="Title of the step") content: List[str] = Field(..., description="Content of the step in paragraphs") links: Optional[List[Link]] = Field(None, description="Relevant links for the step") + class KeyPoint(BaseModel): title: str = Field(..., description="Title of the key point") details: List[str] = Field(..., description="Details of the key point") - links: Optional[List[Link]] = Field(None, description="Relevant links for the key point") + links: Optional[List[Link]] = Field( + None, description="Relevant links for the key point" + ) + class Analysis(BaseModel): summary: str = Field(..., description="Summary of the analysis") points: List[str] = Field(..., description="Key points or recommendations") - links: Optional[List[Link]] = Field(None, description="Relevant links for the analysis") + links: Optional[List[Link]] = Field( + None, description="Relevant links for the analysis" + ) + class TextAnalysisReport(BaseModel): title: str = Field(..., description="Title of the report") @@ -96,8 +114,14 @@ class TextAnalysisReport(BaseModel): keypoints: List[KeyPoint] = Field(..., description="Key points or findings") analysis: Analysis = Field(..., description="Overall analysis or conclusion") metadata: List[str] = Field(..., description="Additional metadata about the report") - timestamp: str = Field(..., description="Timestamp of the report generation (ISO 8601 date string YYYY-MM-DDTHH:MM:SS-UTC)") - links: Optional[List[Link]] = Field(None, description="General links for the entire report") + timestamp: str = Field( + ..., + description="Timestamp of the report generation (ISO 8601 date string YYYY-MM-DDTHH:MM:SS-UTC)", + ) + links: Optional[List[Link]] = Field( + None, description="General links for the entire report" + ) + class FocusTask(BaseModel): id: str @@ -112,4 +136,3 @@ class FocusTask(BaseModel): isCompleted: bool totalDuration: str category: Optional[str] = None - \ No newline at end of file diff --git a/validator-api/validator_api/dataset_upload.py b/validator-api/validator_api/dataset_upload.py index 6754141e..2bafb2b1 100644 --- a/validator-api/validator_api/dataset_upload.py +++ b/validator-api/validator_api/dataset_upload.py @@ -27,10 +27,13 @@ def get_data_path(batch_ulid_str: str) -> str: def get_random_batch_size(batch_size: int) -> int: - return random.choice([ - batch_size // 2, - batch_size, - ]) + return random.choice( + [ + batch_size // 2, + batch_size, + ] + ) + def create_repo(name: str) -> None: try: @@ -38,12 +41,13 @@ def create_repo(name: str) -> None: repo_id=name, repo_type=config.REPO_TYPE, exist_ok=True, - token=config.HF_TOKEN + token=config.HF_TOKEN, ) print("Successfully created/verified repository") except Exception as e: print(f"Error creating repository: {e}") + class DatasetUploader: def __init__(self): self.current_batch = [] @@ -51,39 +55,52 @@ def __init__(self): self.min_batch_size = 32 def add_videos( - self, metadata: List[VideoMetadata], video_ids: List[str], - description_relevance_scores: List[float], query_relevance_scores: List[float], + self, + metadata: List[VideoMetadata], + video_ids: List[str], + description_relevance_scores: List[float], + query_relevance_scores: List[float], query: str, ) -> None: curr_time = datetime.now() - self.current_batch.extend([ - { - "video_id": vid_uuid, - "youtube_id": video.video_id, - "description": video.description, - "views": video.views, - "start_time": video.start_time, - "end_time": video.end_time, - "video_embed": video.video_emb, - "audio_embed": video.audio_emb, - "description_embed": video.description_emb, - "description_relevance_score": desc_score, - "query_relevance_score": query_score, - "query": query, - "submitted_at": int(curr_time.timestamp()), - } - for vid_uuid, video, desc_score, query_score - in zip(video_ids, metadata, description_relevance_scores, query_relevance_scores) - ]) - print(f"Added {len(metadata)} videos to batch, now have {len(self.current_batch)}") + self.current_batch.extend( + [ + { + "video_id": vid_uuid, + "youtube_id": video.video_id, + "description": video.description, + "views": video.views, + "start_time": video.start_time, + "end_time": video.end_time, + "video_embed": video.video_emb, + "audio_embed": video.audio_emb, + "description_embed": video.description_emb, + "description_relevance_score": desc_score, + "query_relevance_score": query_score, + "query": query, + "submitted_at": int(curr_time.timestamp()), + } + for vid_uuid, video, desc_score, query_score in zip( + video_ids, + metadata, + description_relevance_scores, + query_relevance_scores, + ) + ] + ) + print( + f"Added {len(metadata)} videos to batch, now have {len(self.current_batch)}" + ) if len(self.current_batch) >= self.desired_batch_size: self.submit() def submit(self) -> None: if len(self.current_batch) < self.min_batch_size: - print(f"Need at least {self.min_batch_size} videos to submit, but have {len(self.current_batch)}") + print( + f"Need at least {self.min_batch_size} videos to submit, but have {len(self.current_batch)}" + ) return - data = self.current_batch[:self.desired_batch_size] + data = self.current_batch[: self.desired_batch_size] print(f"Uploading batch of {len(self.current_batch)} videos") with BytesIO() as f: dataset = Dataset.from_list(data) @@ -99,9 +116,10 @@ def submit(self) -> None: print(f"Uploaded {num_bytes} bytes to Hugging Face") except Exception as e: print(f"Error uploading to Hugging Face: {e}") - self.current_batch = self.current_batch[self.desired_batch_size:] + self.current_batch = self.current_batch[self.desired_batch_size :] self.desired_batch_size = get_random_batch_size(config.UPLOAD_BATCH_SIZE) + class AudioDatasetUploader: def __init__(self): self.current_batch = [] @@ -117,46 +135,67 @@ def convert_audio_to_wav(self, audio_bytes: str) -> bytes: return temp_audiofile.read() def add_audios( - self, metadata: List[AudioMetadata], audio_ids: List[str], - inverse_der: float, audio_length_score: float, - audio_quality_total_score: float, audio_query_score: float, - query: str, total_score: float + self, + metadata: List[AudioMetadata], + audio_ids: List[str], + inverse_der: float, + audio_length_score: float, + audio_quality_total_score: float, + audio_query_score: float, + query: str, + total_score: float, ) -> None: curr_time = datetime.now() - audio_files = [self.convert_audio_to_wav(audio.audio_bytes) for audio in metadata] - - self.current_batch.extend([ - { - "audio_id": audio_uuid, - "youtube_id": audio.video_id, - # "audio_bytes": audio.audio_bytes, - "audio": {"path": audio_file, "array": sf.read(BytesIO(base64.b64decode(audio.audio_bytes)))[0], "sampling_rate": 16000}, - "start_time": audio.start_time, - "end_time": audio.end_time, - "audio_embed": audio.audio_emb, - "diar_timestamps_start": audio.diar_timestamps_start, - "diar_timestamps_end": audio.diar_timestamps_end, - "diar_speakers": audio.diar_speakers, - "inverse_der": inverse_der, - "audio_length_score": audio_length_score, - "audio_quality_score": audio_quality_total_score, - "query_relevance_score": audio_query_score, - "total_score": total_score, - "query": query, - "submitted_at": int(curr_time.timestamp()), - } - for audio_uuid, audio_file, audio in zip(audio_ids, audio_files, metadata) - ]) - print(f"Added {len(metadata)} audios to batch, now have {len(self.current_batch)}") + audio_files = [ + self.convert_audio_to_wav(audio.audio_bytes) for audio in metadata + ] + + self.current_batch.extend( + [ + { + "audio_id": audio_uuid, + "youtube_id": audio.video_id, + # "audio_bytes": audio.audio_bytes, + "audio": { + "path": audio_file, + "array": sf.read(BytesIO(base64.b64decode(audio.audio_bytes)))[ + 0 + ], + "sampling_rate": 16000, + }, + "start_time": audio.start_time, + "end_time": audio.end_time, + "audio_embed": audio.audio_emb, + "diar_timestamps_start": audio.diar_timestamps_start, + "diar_timestamps_end": audio.diar_timestamps_end, + "diar_speakers": audio.diar_speakers, + "inverse_der": inverse_der, + "audio_length_score": audio_length_score, + "audio_quality_score": audio_quality_total_score, + "query_relevance_score": audio_query_score, + "total_score": total_score, + "query": query, + "submitted_at": int(curr_time.timestamp()), + } + for audio_uuid, audio_file, audio in zip( + audio_ids, audio_files, metadata + ) + ] + ) + print( + f"Added {len(metadata)} audios to batch, now have {len(self.current_batch)}" + ) if len(self.current_batch) >= self.desired_batch_size: self.submit() def submit(self) -> None: if len(self.current_batch) < self.min_batch_size: - print(f"Need at least {self.min_batch_size} audios to submit, but have {len(self.current_batch)}") + print( + f"Need at least {self.min_batch_size} audios to submit, but have {len(self.current_batch)}" + ) return - data = self.current_batch[:self.desired_batch_size] + data = self.current_batch[: self.desired_batch_size] print(f"Uploading batch of {len(self.current_batch)} audios") with BytesIO() as f: dataset = Dataset.from_list(data) @@ -173,12 +212,10 @@ def submit(self) -> None: print(f"Uploaded {num_bytes} bytes to Hugging Face") except Exception as e: print(f"Error uploading to Hugging Face: {e}") - self.current_batch = self.current_batch[self.desired_batch_size:] + self.current_batch = self.current_batch[self.desired_batch_size :] self.desired_batch_size = get_random_batch_size(config.UPLOAD_AUDIO_BATCH_SIZE) - - audio_dataset_uploader = AudioDatasetUploader() video_dataset_uploader = DatasetUploader() @@ -186,7 +223,7 @@ def submit(self) -> None: if __name__ == "__main__": audio_wav_file = "../example.wav" with open(audio_wav_file, "rb") as f: - audio_bytes = base64.b64encode(f.read()).decode('utf-8') + audio_bytes = base64.b64encode(f.read()).decode("utf-8") for _ in range(100): audio_dataset_uploader.add_audios( metadata=[ @@ -201,7 +238,8 @@ def submit(self) -> None: diar_timestamps_end=[], diar_speakers=[], ) - ] * 10, + ] + * 10, audio_ids=list(range(10)), inverse_der=0.0, audio_length_score=0.0, @@ -213,5 +251,6 @@ def submit(self) -> None: # audio_dataset_uploader.submit() import psutil import os + process = psutil.Process(os.getpid()) print(f"Current RAM usage: {process.memory_info().rss / 1024 / 1024:.2f} MB") diff --git a/validator-api/validator_api/imagebind_loader.py b/validator-api/validator_api/imagebind_loader.py index 5ce68175..5d1deded 100644 --- a/validator-api/validator_api/imagebind_loader.py +++ b/validator-api/validator_api/imagebind_loader.py @@ -26,7 +26,7 @@ async def get_imagebind(self) -> ImageBind: raise HTTPException( status_code=503, - detail="ImageBind loading has started. Please try again later." + detail="ImageBind loading has started. Please try again later.", ) def _load_imagebind_blocking(self) -> ImageBind: @@ -39,8 +39,7 @@ async def _load_imagebind_wrapper(self) -> None: # Run the blocking operation in a thread pool loop = asyncio.get_running_loop() self._imagebind = await loop.run_in_executor( - self._thread_pool, - self._load_imagebind_blocking + self._thread_pool, self._load_imagebind_blocking ) finally: self._loading_task = None diff --git a/validator-api/validator_api/limiter.py b/validator-api/validator_api/limiter.py index b36e6306..51c942fa 100644 --- a/validator-api/validator_api/limiter.py +++ b/validator-api/validator_api/limiter.py @@ -3,6 +3,7 @@ from typing import Optional from fastapi import Request + def get_rate_limit_key(request: Request) -> str: """ Extracts a rate limiting key from the request. @@ -13,42 +14,45 @@ def get_rate_limit_key(request: Request) -> str: if user_id: print(f"Rate limiting key: user:{user_id}") return f"user:{user_id}" - + ip = _get_client_ip(request) print(f"Rate limiting key: ip:{ip}") return f"ip:{ip}" + def _extract_user_id(request: Request) -> Optional[str]: """ Extracts user ID from JWT token in Authorization header. Returns None if no valid token found. """ - auth_header = request.headers.get('authorization', '') - if not auth_header.startswith('Bearer '): + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): return None - + try: - token = auth_header.split(' ')[1] + token = auth_header.split(" ")[1] payload = jwt.decode(token, options={"verify_signature": False}) - return payload.get('sub') + return payload.get("sub") except (jwt.InvalidTokenError, IndexError): return None + def _get_client_ip(request: Request) -> str: """ Gets the original client IP from Cloudflare headers, falling back to X-Forwarded-For if CF headers aren't present. """ # Try Cloudflare-specific header first - cf_connecting_ip = request.headers.get('cf-connecting-ip') + cf_connecting_ip = request.headers.get("cf-connecting-ip") if cf_connecting_ip: return cf_connecting_ip - + # Fall back to X-Forwarded-For - forwarded_for = request.headers.get('x-forwarded-for') + forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: - return forwarded_for.split(',')[0].strip() - + return forwarded_for.split(",")[0].strip() + return request.client.host -limiter = Limiter(key_func=get_rate_limit_key) \ No newline at end of file + +limiter = Limiter(key_func=get_rate_limit_key) diff --git a/validator-api/validator_api/score.py b/validator-api/validator_api/score.py index 13dffe25..c2d0dd58 100644 --- a/validator-api/validator_api/score.py +++ b/validator-api/validator_api/score.py @@ -16,7 +16,9 @@ PINECONE_INDEX = Pinecone(api_key=config.PINECONE_API_KEY).Index(config.PINECONE_INDEX) -PINECONE_AUDIO_INDEX = Pinecone(api_key=config.PINECONE_API_KEY).Index(config.PINECONE_AUDIO_INDEX) +PINECONE_AUDIO_INDEX = Pinecone(api_key=config.PINECONE_API_KEY).Index( + config.PINECONE_AUDIO_INDEX +) GPU_SEMAPHORE = asyncio.Semaphore(1) DOWNLOAD_SEMAPHORE = asyncio.Semaphore(5) VIDEO_TYPE = "video" @@ -34,55 +36,67 @@ async def query_pinecone(vector: List[float]) -> float: }, ) if len(response["matches"]) > 0: - return 1 - response["matches"][0]["score"] + return 1 - response["matches"][0]["score"] else: print("No pinecone matches, returning 0") return 0 + async def get_pinecone_novelty(metadata: List[VideoMetadata]) -> List[float]: """ Take the top match from the Pinecone index. """ - novelty_scores = await asyncio.gather(*[ - query_pinecone( - vector=mdata.video_emb - ) - for mdata in metadata - ]) + novelty_scores = await asyncio.gather( + *[query_pinecone(vector=mdata.video_emb) for mdata in metadata] + ) return novelty_scores + def compute_novelty_score_among_batch(emb: Embeddings) -> List[float]: video_tensor = emb.video num_videos = video_tensor.shape[0] novelty_scores = [] for i in range(num_videos - 1): - similarity_score = F.cosine_similarity(video_tensor[[i]], video_tensor[i + 1:]).max() + similarity_score = F.cosine_similarity( + video_tensor[[i]], video_tensor[i + 1 :] + ).max() novelty_scores.append(1 - similarity_score.item()) novelty_scores.append(1.0) # last video is 100% novel return novelty_scores + async def async_zero() -> None: return 0 + async def compute_novelty_score(embeddings: Embeddings) -> Tuple[float, List[bool]]: local_novelty_scores = compute_novelty_score_among_batch(embeddings) - global_novelty_scores = await asyncio.gather(*[ - async_zero() if local_score < DIFFERENCE_THRESHOLD else # don't even query Pinecone if it's already too similar - query_pinecone(vector=embedding.tolist()) - for embedding, local_score in zip(embeddings.video, local_novelty_scores) - ]) + global_novelty_scores = await asyncio.gather( + *[ + async_zero() + if local_score < DIFFERENCE_THRESHOLD + # don't even query Pinecone if it's already too similar + else query_pinecone(vector=embedding.tolist()) + for embedding, local_score in zip(embeddings.video, local_novelty_scores) + ] + ) true_novelty_scores = [ - min(local_score, global_score) for local_score, global_score - in zip(local_novelty_scores, global_novelty_scores) + min(local_score, global_score) + for local_score, global_score in zip( + local_novelty_scores, global_novelty_scores + ) ] is_too_similar = [score < DIFFERENCE_THRESHOLD for score in true_novelty_scores] - novelty_score = sum([ - score for score, is_too_similar - in zip(true_novelty_scores, is_too_similar) - if not is_too_similar - ]) + novelty_score = sum( + [ + score + for score, is_too_similar in zip(true_novelty_scores, is_too_similar) + if not is_too_similar + ] + ) return novelty_score, is_too_similar + def upload_to_pinecone(embeddings: Embeddings, metadata: List[VideoMetadata]) -> None: video_ids = [str(uuid.uuid4()) for _ in range(len(metadata))] try: @@ -94,17 +108,21 @@ def upload_to_pinecone(embeddings: Embeddings, metadata: List[VideoMetadata]) -> "metadata": { "youtube_id": video.video_id, "modality_type": VIDEO_TYPE, - } + }, } - for video_uuid, video, embedding_vid - in zip(video_ids, metadata, embeddings.video) + for video_uuid, video, embedding_vid in zip( + video_ids, metadata, embeddings.video + ) ], ) except Exception as e: print(f"Failed to upload to Pinecone: {e}") return video_ids -def upload_to_pinecone_audio(embeddings: Embeddings, metadata: List[AudioMetadata]) -> None: + +def upload_to_pinecone_audio( + embeddings: Embeddings, metadata: List[AudioMetadata] +) -> None: audio_ids = [str(uuid.uuid4()) for _ in range(len(metadata))] try: PINECONE_AUDIO_INDEX.upsert( @@ -114,21 +132,23 @@ def upload_to_pinecone_audio(embeddings: Embeddings, metadata: List[AudioMetadat "values": embedding_aud.tolist(), "metadata": { "youtube_id": audio.video_id, - } + }, } - for audio_uuid, audio, embedding_aud - in zip(audio_ids, metadata, embeddings.audio) + for audio_uuid, audio, embedding_aud in zip( + audio_ids, metadata, embeddings.audio + ) ], ) except Exception as e: print(f"Failed to upload to Pinecone: {e}") return audio_ids + async def upload_video_metadata( - metadata: List[VideoMetadata], - description_relevance_scores: List[float], - query_relevance_scores: List[float], - query: str, + metadata: List[VideoMetadata], + description_relevance_scores: List[float], + query_relevance_scores: List[float], + query: str, ) -> None: # generate embeddings from our metadata embeddings = Embeddings( @@ -149,6 +169,7 @@ async def upload_video_metadata( ) return video_ids + class AudioMetadataUpload(BaseModel): metadata: List[AudioMetadata] inverse_der: float @@ -159,6 +180,7 @@ class AudioMetadataUpload(BaseModel): total_score: Optional[float] = None miner_hotkey: Optional[str] = None + def _add_audios(upload_data: AudioMetadataUpload, audio_ids: List[str]): audio_dataset_uploader.add_audios( upload_data.metadata, @@ -168,15 +190,19 @@ def _add_audios(upload_data: AudioMetadataUpload, audio_ids: List[str]): upload_data.audio_quality_total_score, upload_data.audio_query_score, upload_data.topic_query, - upload_data.total_score + upload_data.total_score, ) -async def upload_audio_metadata(request: Request) -> Tuple[List[str], AudioMetadataUpload]: + +async def upload_audio_metadata( + request: Request, +) -> Tuple[List[str], AudioMetadataUpload]: if audio_dataset_uploader.currently_uploading_at: - print(f"Currently uploading since {audio_dataset_uploader.currently_uploading_at}, waiting for {(datetime.now() - audio_dataset_uploader.currently_uploading_at).total_seconds()} seconds") + print( + f"Currently uploading since {audio_dataset_uploader.currently_uploading_at}, waiting for {(datetime.now() - audio_dataset_uploader.currently_uploading_at).total_seconds()} seconds" + ) raise HTTPException( - status_code=500, - detail="Memory usage is too high. Please try again later." + status_code=500, detail="Memory usage is too high. Please try again later." ) try: @@ -189,7 +215,9 @@ async def upload_audio_metadata(request: Request) -> Tuple[List[str], AudioMetad audio=torch.stack([torch.tensor(v.audio_emb) for v in metadata]), description=None, ) - audio_ids = await asyncio.to_thread(upload_to_pinecone_audio, embeddings, metadata) + audio_ids = await asyncio.to_thread( + upload_to_pinecone_audio, embeddings, metadata + ) await asyncio.to_thread(_add_audios, upload_data, audio_ids) finally: audio_dataset_uploader.currently_uploading_at = None diff --git a/validator-api/validator_api/scoring/deepseek_chat.py b/validator-api/validator_api/scoring/deepseek_chat.py index eb022956..06c700e2 100644 --- a/validator-api/validator_api/scoring/deepseek_chat.py +++ b/validator-api/validator_api/scoring/deepseek_chat.py @@ -1,11 +1,11 @@ async def query_openai( messages: Iterable[ChatCompletionMessageParam], output_model: Optional[Type[BaseModel]] = None, - retries: int = 3 + retries: int = 3, ) -> Union[BaseModel, dict]: """ Query the OpenAI o1 model with retries. - + Args: messages: An iterable of chat completion messages following the OpenAI format. Each message should have 'role' and 'content' fields. @@ -31,19 +31,22 @@ async def query_openai( raise Exception("Empty response from API") parsed_data = json.loads(response.choices[0].message.content) - + if output_model is not None: return output_model.model_validate(parsed_data) return parsed_data - + except Exception as e: if attempt < retries - 1: - sleep_time = 2 ** attempt - print(f"OpenAI attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time} seconds...") + sleep_time = 2**attempt + print( + f"OpenAI attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time} seconds..." + ) await asyncio.sleep(sleep_time) continue raise e + async def query_llm( messages: Iterable[ChatCompletionMessageParam], output_model: Optional[Type[BaseModel]] = None, @@ -51,7 +54,7 @@ async def query_llm( ) -> Union[BaseModel, dict]: """ Query LLM models with fallback behavior. Tries DeepSeek first, falls back to OpenAI if DeepSeek fails. - + Args: messages: An iterable of chat completion messages following the OpenAI format. Each message should have 'role' and 'content' fields. @@ -70,6 +73,6 @@ async def query_llm( except Exception as e: if not openai_client: raise e - + print(f"DeepSeek failed, falling back to OpenAI: {str(e)}") - return await query_openai(messages, output_model, retries) \ No newline at end of file + return await query_openai(messages, output_model, retries) diff --git a/validator-api/validator_api/scoring/legitimacy_checks.py b/validator-api/validator_api/scoring/legitimacy_checks.py index 51296130..9ae97829 100644 --- a/validator-api/validator_api/scoring/legitimacy_checks.py +++ b/validator-api/validator_api/scoring/legitimacy_checks.py @@ -7,16 +7,21 @@ from validator_api.database.models.scoring import DetailedVideoDescription from validator_api.scoring.query_llm import query_llm + class LegitimacyCheck(ABC): @abstractmethod - async def passes_check(self, video_id: str, detailed_video_description: Optional[DetailedVideoDescription] = None) -> Tuple[bool, str]: + async def passes_check( + self, + video_id: str, + detailed_video_description: Optional[DetailedVideoDescription] = None, + ) -> Tuple[bool, str]: """ Check if the video passes this legitimacy check. - + Args: video_id: The ID of the video to check detailed_video_description: Optional pre-computed video description - + Returns: Tuple[bool, str]: (passed, failure_reason) - passed: True if check passed, False if failed @@ -24,15 +29,24 @@ async def passes_check(self, video_id: str, detailed_video_description: Optional """ pass + class ChatOnlyDetectionModel(BaseModel): rationale: str = Field(description="Detailed rationale for the score") - legitimate: bool = Field(description="False if the user is cheating by talking about completing a task, but not actually completing it, True otherwise") + legitimate: bool = Field( + description="False if the user is cheating by talking about completing a task, but not actually completing it, True otherwise" + ) + class ChatOnlyCheck(LegitimacyCheck): """ Fails if a user is talking about completing a task (e.g. in a notepad or AI chat), but not actually completing it. """ - async def passes_check(self, video_id: str, detailed_video_description: Optional[DetailedVideoDescription] = None) -> Tuple[bool, str]: + + async def passes_check( + self, + video_id: str, + detailed_video_description: Optional[DetailedVideoDescription] = None, + ) -> Tuple[bool, str]: chat_only_check_prompt = """You are an expert in analyzing task performance videos. Your current task is to determine if the user is cheating by talking about completing a task, but not actually completing it. Verify that the video shows actual evidence of task completion, not just chat interactions claiming completion. @@ -58,39 +72,52 @@ async def passes_check(self, video_id: str, detailed_video_description: Optional "legitimate": true/false; False if the user is cheating by talking about completing a task, but not actually completing it, True otherwise } """ -# important: above, we need to provide an example of the output JSON format - + # important: above, we need to provide an example of the output JSON format + # Use provided description if available, otherwise fetch from DB if detailed_video_description is None: async with get_db_context() as db: query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.deleted_at.is_(None), ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: raise ValueError(f"Video not found: {video_id}") - - if video_record.video_details and "detailed_video_description" in video_record.video_details: - detailed_video_description = DetailedVideoDescription.model_validate( - video_record.video_details["detailed_video_description"] + + if ( + video_record.video_details + and "detailed_video_description" in video_record.video_details + ): + detailed_video_description = ( + DetailedVideoDescription.model_validate( + video_record.video_details["detailed_video_description"] + ) ) else: - raise ValueError(f"Detailed video description not found for video: {video_id}") - + raise ValueError( + f"Detailed video description not found for video: {video_id}" + ) + messages = [ {"role": "system", "content": chat_only_check_prompt}, - {"role": "user", "content": f"Please analyze the following annotated transcript and determine if the user is cheating by talking about completing a task, but not actually completing it: {detailed_video_description}"} + { + "role": "user", + "content": f"Please analyze the following annotated transcript and determine if the user is cheating by talking about completing a task, but not actually completing it: {detailed_video_description}", + }, ] try: chat_only_detection_data = await query_llm(messages, ChatOnlyDetectionModel) - + print(f"[{video_id}] ChatOnlyCheck result: {chat_only_detection_data}") - - return chat_only_detection_data.legitimate, chat_only_detection_data.rationale + + return ( + chat_only_detection_data.legitimate, + chat_only_detection_data.rationale, + ) except Exception as e: print(f"[{video_id}] ❌ Error during chat-only check: {str(e)}") diff --git a/validator-api/validator_api/scoring/query_llm.py b/validator-api/validator_api/scoring/query_llm.py index 9cda6fc0..0ebc22ad 100644 --- a/validator-api/validator_api/scoring/query_llm.py +++ b/validator-api/validator_api/scoring/query_llm.py @@ -19,6 +19,7 @@ api_key=OPENAI_API_KEY, ) + async def query_llm( messages: Iterable[ChatCompletionMessageParam], output_model: Optional[Type[BaseModel]] = None, @@ -26,7 +27,7 @@ async def query_llm( ) -> Union[BaseModel, dict]: """ Query LLM models with fallback behavior. Tries DeepSeek first, falls back to OpenAI if DeepSeek fails. - + Args: messages: An iterable of chat completion messages following the OpenAI format. Each message should have 'role' and 'content' fields. @@ -46,14 +47,15 @@ async def query_llm( print(f"Chutes API DeepSeek call failed, falling back to OpenAI: {str(e)}") return await query_openai(messages, output_model, retries) + async def query_openai( messages: Iterable[ChatCompletionMessageParam], output_model: Optional[Type[BaseModel]] = None, - retries: int = 3 + retries: int = 3, ) -> Union[BaseModel, dict]: """ Query the OpenAI o1 model with retries. - + Args: messages: An iterable of chat completion messages following the OpenAI format. Each message should have 'role' and 'content' fields. @@ -79,27 +81,30 @@ async def query_openai( raise Exception("Empty response from API") parsed_data = json.loads(response.choices[0].message.content) - + if output_model is not None: return output_model.model_validate(parsed_data) return parsed_data - + except Exception as e: if attempt < retries - 1: - sleep_time = 2 ** attempt - print(f"OpenAI attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time} seconds...") + sleep_time = 2**attempt + print( + f"OpenAI attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time} seconds..." + ) await asyncio.sleep(sleep_time) continue raise e + async def query_deepseek( messages: Iterable[ChatCompletionMessageParam], output_model: Optional[Type[BaseModel]] = None, - retries: int = 3 + retries: int = 3, ) -> Union[BaseModel, dict]: """ Query the DeepSeek chat model via the Chutes API with streaming (non-streaming appears to be broken). - + This function sends a chat completion request to DeepSeek, processes the streamed response, and optionally validates it against a provided Pydantic model. Your prompt must have the "json" keyword somewhere and an example; reference: @@ -131,13 +136,13 @@ async def query_deepseek( - JSON responses are expected and enforced via the API's response_format parameter """ last_exception = None - + for attempt in range(retries): try: async with httpx.AsyncClient(timeout=120.0) as client: headers = { "Authorization": f"Bearer {CHUTES_API_TOKEN}", - "Content-Type": "application/json" + "Content-Type": "application/json", } payload = { "model": "deepseek-ai/DeepSeek-R1", @@ -145,21 +150,19 @@ async def query_deepseek( "stream": True, "max_tokens": 1000, "temperature": 0.5, - "response_format": { - "type": "json_object" - } + "response_format": {"type": "json_object"}, } - + async with client.stream( "POST", "https://chutes-deepseek-ai-deepseek-r1.chutes.ai/v1/chat/completions", headers=headers, json=payload, - timeout=120.0 + timeout=120.0, ) as response: response.raise_for_status() content = "" - + try: async for line in response.aiter_lines(): if line.strip(): @@ -168,27 +171,35 @@ async def query_deepseek( line = line[6:] if line == "[DONE]": continue - + try: chunk = json.loads(line) - if delta_content := chunk.get("choices", [{}])[0].get("delta", {}).get("content"): + if ( + delta_content := chunk.get("choices", [{}])[0] + .get("delta", {}) + .get("content") + ): content += delta_content except json.JSONDecodeError as e: print(f"Failed to parse chunk: {e}") continue except IndexError: - print("Received malformed response chunk from Chutes API call") + print( + "Received malformed response chunk from Chutes API call" + ) continue - + if not content: # Check if we got any content raise ValueError("No content received from API") - content = content.replace('```json', '').replace('```', '').strip() - - if '' in content: + content = ( + content.replace("```json", "").replace("```", "").strip() + ) + + if "" in content: # get the content after the tag - content = content.split('')[-1].strip() - + content = content.split("")[-1].strip() + # Parse JSON and optionally validate against output model try: parsed_data = json.loads(content) @@ -197,21 +208,25 @@ async def query_deepseek( return output_model.model_validate(parsed_data) return parsed_data except json.JSONDecodeError: - raise ValueError(f"Failed to parse response as JSON: {content}") - + raise ValueError( + f"Failed to parse response as JSON: {content}" + ) + except httpx.ReadTimeout: raise TimeoutError("Request timed out while reading the stream") finally: await response.aclose() - + except Exception as e: last_exception = e if attempt < retries - 1: - sleep_time = 2 ** attempt - print(f"Chutes API call attempt {attempt + 1} failed with error: {str(e)}. Retrying in {sleep_time} seconds...") + sleep_time = 2**attempt + print( + f"Chutes API call attempt {attempt + 1} failed with error: {str(e)}. Retrying in {sleep_time} seconds..." + ) await asyncio.sleep(sleep_time) continue - + # If we get here, all retries failed print(f"All {retries} attempts failed in deepseek_model") raise last_exception diff --git a/validator-api/validator_api/scoring/scoring_service.py b/validator-api/validator_api/scoring/scoring_service.py index bd58679f..8f5e4387 100644 --- a/validator-api/validator_api/scoring/scoring_service.py +++ b/validator-api/validator_api/scoring/scoring_service.py @@ -1,18 +1,18 @@ """ Description of scoring system: -- Phase 0: generate detailed annotation for video; +- Phase 0: generate detailed annotation for video; - Phase 1: spam detection + rejection (can order from least to greatest cost) - - Working: + - Working: - length of video (too long or short) - - uniqueness detection (video embedding vector similarity) - - chat-only detection (openai o1 + text description) - - Not working: - - YouTube/movie video-watching detection (gemini + first and last video chunks) - - exploit/screen recording video watching detection (gemini + first and last video chunks) - - prompt can be found in old subnet commits - - automation detection (??) (I don't think this is reliably working yet) + - uniqueness detection (video embedding vector similarity) + - chat-only detection (openai o1 + text description) + - Not working: + - YouTube/movie video-watching detection (gemini + first and last video chunks) + - exploit/screen recording video watching detection (gemini + first and last video chunks) + - prompt can be found in old subnet commits + - automation detection (??) (I don't think this is reliably working yet) - Phase 2: actual scoring - - can be gemini evaluation on the whole video, but I think it's probably more cost-efficient to use a reasoning model with the task descriptions + - can be gemini evaluation on the whole video, but I think it's probably more cost-efficient to use a reasoning model with the task descriptions """ import asyncio @@ -27,34 +27,43 @@ from pydantic import BaseModel, ValidationError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select -from validator_api.config import (GOOGLE_CLOUD_BUCKET_NAME, GOOGLE_LOCATION, - GOOGLE_PROJECT_ID, OPENAI_API_KEY, - PINECONE_API_KEY) +from validator_api.config import ( + GOOGLE_CLOUD_BUCKET_NAME, + GOOGLE_LOCATION, + GOOGLE_PROJECT_ID, + OPENAI_API_KEY, + PINECONE_API_KEY, +) from validator_api.database import get_db_context from validator_api.database.models.boosted_task import BoostedTask from validator_api.database.models.focus_video_record import ( - FocusVideoInternal, FocusVideoRecord) -from validator_api.database.models.scoring import (CompletionScore, - CompletionScoreWithoutRange, - DetailedVideoDescription, - FocusVideoEmbeddings, - LegitimacyCheckError, - VideoScore, - VideoTooLongError, - VideoTooShortError, - VideoUniquenessError) + FocusVideoInternal, + FocusVideoRecord, +) +from validator_api.database.models.scoring import ( + CompletionScore, + CompletionScoreWithoutRange, + DetailedVideoDescription, + FocusVideoEmbeddings, + LegitimacyCheckError, + VideoScore, + VideoTooLongError, + VideoTooShortError, + VideoUniquenessError, +) from validator_api.database.models.task import TaskRecordPG from validator_api.scoring import focus_scoring_prompts from validator_api.scoring.legitimacy_checks import ChatOnlyCheck from validator_api.scoring.query_llm import query_llm from validator_api.utils import run_async, run_with_retries from vertexai.generative_models import Part -from vertexai.preview.generative_models import (GenerationConfig, - GenerativeModel, - HarmBlockThreshold, - HarmCategory) -from vertexai.vision_models import (MultiModalEmbeddingModel, Video, - VideoSegmentConfig) +from vertexai.preview.generative_models import ( + GenerationConfig, + GenerativeModel, + HarmBlockThreshold, + HarmCategory, +) +from vertexai.vision_models import MultiModalEmbeddingModel, Video, VideoSegmentConfig TWO_MINUTES = 120 # in seconds NINETY_MINUTES = 5400 # in seconds @@ -62,19 +71,21 @@ FOCUS_VIDEO_MAX_SCORE = 1.0 MIN_VIDEO_UNIQUENESS_SCORE = 0.02 -async def get_video_metadata(db: AsyncSession, video_id: str) -> Optional[FocusVideoInternal]: - query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id - ) + +async def get_video_metadata( + db: AsyncSession, video_id: str +) -> Optional[FocusVideoInternal]: + query = select(FocusVideoRecord).filter(FocusVideoRecord.video_id == video_id) result = await db.execute(query) video = result.scalar_one_or_none() - + if video and video.deleted_at is not None: print(f"Video {video_id} has been deleted") return None - + return video + async def _get_details_if_boosted(video_id: str) -> Optional[BoostedTask]: """ Retrieves the details of a boosted task from the database for a given video. @@ -99,7 +110,7 @@ async def _get_details_if_boosted(video_id: str) -> Optional[BoostedTask]: ) result = await db.execute(query) task = result.scalar_one_or_none() - + if task and task.boosted_id: query = select(BoostedTask).filter( BoostedTask.id == task.boosted_id, @@ -108,6 +119,7 @@ async def _get_details_if_boosted(video_id: str) -> Optional[BoostedTask]: return result.scalar_one_or_none() return None + async def get_video_duration_seconds(video_id: str) -> int: async with get_db_context() as db: video_metadata = await get_video_metadata(db, video_id) @@ -122,13 +134,18 @@ async def get_video_duration_seconds(video_id: str) -> int: return video_duration_seconds + def get_s3_path(video_id: str) -> str: return f"clips/{video_id}.webm" + def get_gcs_uri(video_id: str) -> str: return f"gs://{GOOGLE_CLOUD_BUCKET_NAME}/{get_s3_path(video_id)}" -async def _make_gemini_request(system_prompt: str, user_prompt: str, video_id: str, OutputClassSchema: BaseModel) -> GenerativeModel: + +async def _make_gemini_request( + system_prompt: str, user_prompt: str, video_id: str, OutputClassSchema: BaseModel +) -> GenerativeModel: """ Makes a request to the Gemini model with specified prompts and video content. @@ -149,7 +166,7 @@ async def _make_gemini_request(system_prompt: str, user_prompt: str, video_id: s HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, } - + model = GenerativeModel( model_name, system_instruction=system_prompt.strip(), @@ -169,10 +186,13 @@ async def _make_gemini_request(system_prompt: str, user_prompt: str, video_id: s response = await model.generate_content_async(parts) return OutputClassSchema(**json.loads(response.text)) -async def _make_gemini_request_with_retries(system_prompt: str, user_prompt: str, video_id: str, OutputClassSchema: BaseModel) -> str: + +async def _make_gemini_request_with_retries( + system_prompt: str, user_prompt: str, video_id: str, OutputClassSchema: BaseModel +) -> str: """ Makes a request to Gemini with automatic retries on failure. - + Handles JSON parsing errors, validation errors, and general exceptions with 3 retry attempts. Delays between retries. @@ -192,47 +212,70 @@ async def _make_gemini_request_with_retries(system_prompt: str, user_prompt: str for retry_idx in range(num_retries): try: start = time.time() - output = await _make_gemini_request(system_prompt, user_prompt, video_id, OutputClassSchema) - print(f"Got gemini output in {time.time() - start} seconds for {OutputClassSchema.__name__}") + output = await _make_gemini_request( + system_prompt, user_prompt, video_id, OutputClassSchema + ) + print( + f"Got gemini output in {time.time() - start} seconds for {OutputClassSchema.__name__}" + ) return output except json.JSONDecodeError as e: - print(f"Error parsing JSON from Gemini response for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})") + print( + f"Error parsing JSON from Gemini response for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})" + ) await asyncio.sleep(1) except ValidationError as e: - print(f"Error turning parsed JSON into Pydantic object for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})") + print( + f"Error turning parsed JSON into Pydantic object for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})" + ) await asyncio.sleep(1) except Exception as e: - print(f"Error making Gemini request for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})") + print( + f"Error making Gemini request for {OutputClassSchema.__name__}, trying again: {e} ({retry_idx + 1}/{num_retries})" + ) await asyncio.sleep(6) - raise Exception(f"Failed to turn Gemini response into JSON and then into Pydantic object for {OutputClassSchema.__name__} after {num_retries} attempts") + raise Exception( + f"Failed to turn Gemini response into JSON and then into Pydantic object for {OutputClassSchema.__name__} after {num_retries} attempts" + ) -async def get_detailed_video_description(video_id: str, task_overview: str, recompute: bool = False) -> DetailedVideoDescription: + +async def get_detailed_video_description( + video_id: str, task_overview: str, recompute: bool = False +) -> DetailedVideoDescription: if not recompute: - async with get_db_context() as db: # get already computed description from db if it exists + async with ( + get_db_context() as db + ): # get already computed description from db if it exists query = select(FocusVideoRecord).filter( FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.deleted_at.is_(None), ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: raise ValueError(f"Video not found: {video_id}") - - if video_record.video_details and "detailed_video_description" in video_record.video_details: + + if ( + video_record.video_details + and "detailed_video_description" in video_record.video_details + ): return DetailedVideoDescription.model_validate( video_record.video_details["detailed_video_description"] ) - + description = await _make_gemini_request_with_retries( system_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_SYSTEM_PROMPT, - user_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_USER_PROMPT.format(task_overview=task_overview), + user_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_USER_PROMPT.format( + task_overview=task_overview + ), video_id=video_id, OutputClassSchema=DetailedVideoDescription, ) return description + # async def get_task_score_from_gemini(self, task_overview: str) -> TaskScoreBreakdown: # return await _make_gemini_request_with_retries( # system_prompt=focus_scoring_prompts.TASK_SCORE_SYSTEM_PROMPT, @@ -241,6 +284,7 @@ async def get_detailed_video_description(video_id: str, task_overview: str, reco # OutputClassSchema=TaskScoreBreakdown, # ) + async def _get_completion_score_breakdown( task_overview: str, detailed_video_description: Optional[DetailedVideoDescription] = None, @@ -262,32 +306,40 @@ async def _get_completion_score_breakdown( """ messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt.format( - task_overview=task_overview, - applications_used=detailed_video_description.applications_used, - completion_sequence_steps=detailed_video_description.completion_sequence_steps, - )} + { + "role": "user", + "content": user_prompt.format( + task_overview=task_overview, + applications_used=detailed_video_description.applications_used, + completion_sequence_steps=detailed_video_description.completion_sequence_steps, + ), + }, ] try: completion_score_without_range = await query_llm( messages=messages, # OpenAI API doesn't like it when there's a range in the Pydantic model - output_model=CompletionScoreWithoutRange + output_model=CompletionScoreWithoutRange, ) return CompletionScore( rationale=completion_score_without_range.rationale, - completion_score=max(0.0, min(1.0, completion_score_without_range.completion_score)) + completion_score=max( + 0.0, min(1.0, completion_score_without_range.completion_score) + ), ) except Exception as e: print(f"Error getting completion score: {str(e)}") raise -async def get_video_embedding(video_id: str, video_duration_seconds: int) -> List[float]: + +async def get_video_embedding( + video_id: str, video_duration_seconds: int +) -> List[float]: """ Generates an embedding vector for a video segment using Google's multimodal embedding model. - + Takes a random 120-second segment from the video if the video is longer than 2 minutes. Args: @@ -297,6 +349,7 @@ async def get_video_embedding(video_id: str, video_duration_seconds: int) -> Lis Returns: List[float]: The embedding vector for the video segment """ + async def _internal_async(): model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding") start_offset_sec = random.randint(0, max(0, video_duration_seconds - 120)) @@ -307,16 +360,18 @@ async def _internal_async(): video_segment_config=VideoSegmentConfig( start_offset_sec=start_offset_sec, end_offset_sec=end_offset_sec, - interval_sec=end_offset_sec - start_offset_sec - ) + interval_sec=end_offset_sec - start_offset_sec, + ), ) return embeddings.video_embeddings[0].embedding + return await run_with_retries(_internal_async) + async def query_pinecone(pinecone_index: Pinecone, vector: List[float]) -> float: """ Queries a Pinecone index with a vector to find the most similar existing vector. - + Returns a uniqueness score based on the inverse of the similarity score (1 - similarity). Ensures the returned score is between 0 and 1. @@ -327,6 +382,7 @@ async def query_pinecone(pinecone_index: Pinecone, vector: List[float]) -> float Returns: float: The uniqueness score (1 - similarity score) """ + async def _internal_async(): response = await run_async( pinecone_index.query, @@ -346,23 +402,37 @@ async def _internal_async(): similarity_score = 0 similarity_score = max(0.0, min(similarity_score, 1.0)) return 1.0 - similarity_score + return await run_with_retries(_internal_async) + class FocusScoringService: def __init__(self): vertexai.init(project=GOOGLE_PROJECT_ID, location=GOOGLE_LOCATION) - self.task_overview_index = Pinecone(api_key=PINECONE_API_KEY).Index("focus-task-overview-index") - self.video_description_index = Pinecone(api_key=PINECONE_API_KEY).Index("focus-video-description-index") - self.completion_video_index = Pinecone(api_key=PINECONE_API_KEY).Index("focus-completion-video-index") + self.task_overview_index = Pinecone(api_key=PINECONE_API_KEY).Index( + "focus-task-overview-index" + ) + self.video_description_index = Pinecone(api_key=PINECONE_API_KEY).Index( + "focus-video-description-index" + ) + self.completion_video_index = Pinecone(api_key=PINECONE_API_KEY).Index( + "focus-completion-video-index" + ) self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY) self.legitimacy_checks = [ChatOnlyCheck()] # Pinecone is used for similarity search and scoring (uniqueness check) - async def _get_task_uniqueness_score(self, task_overview_embedding: List[float]) -> float: + async def _get_task_uniqueness_score( + self, task_overview_embedding: List[float] + ) -> float: return await query_pinecone(self.task_overview_index, task_overview_embedding) - async def get_description_uniqueness_score(self, detailed_video_description_embedding: List[float]) -> float: - return await query_pinecone(self.video_description_index, detailed_video_description_embedding) + async def get_description_uniqueness_score( + self, detailed_video_description_embedding: List[float] + ) -> float: + return await query_pinecone( + self.video_description_index, detailed_video_description_embedding + ) async def get_video_uniqueness_score(self, video_embedding: List[float]) -> float: return await query_pinecone(self.completion_video_index, video_embedding) @@ -371,7 +441,7 @@ async def get_video_uniqueness_score(self, video_embedding: List[float]) -> floa async def get_text_embedding(self, text: str) -> Optional[List[float]]: """ Generates an embedding vector for text using OpenAI's embedding model. - + Implements timeout and retry logic for reliability. Args: @@ -380,11 +450,14 @@ async def get_text_embedding(self, text: str) -> Optional[List[float]]: Returns: Optional[List[float]]: The embedding vector, or None if the request fails """ + async def _internal_async(): - response = await asyncio.wait_for(self.openai_client.embeddings.create( - input=text, - model="text-embedding-3-large" - ), timeout=10) + response = await asyncio.wait_for( + self.openai_client.embeddings.create( + input=text, model="text-embedding-3-large" + ), + timeout=10, + ) return response.data[0].embedding try: @@ -393,13 +466,17 @@ async def _internal_async(): print(f"Error getting text embedding: {e}") return None - async def embed_and_get_task_uniqueness_score(self, task_overview: str) -> Tuple[Optional[List[float]], Optional[float]]: + async def embed_and_get_task_uniqueness_score( + self, task_overview: str + ) -> Tuple[Optional[List[float]], Optional[float]]: embedding = await self.get_text_embedding(task_overview) if embedding is None: return None, None return embedding, await self._get_task_uniqueness_score(embedding) - async def embed_and_get_video_uniqueness_score(self, video_id: str, video_duration_seconds: int): + async def embed_and_get_video_uniqueness_score( + self, video_id: str, video_duration_seconds: int + ): try: embedding = await get_video_embedding(video_id, video_duration_seconds) return embedding, await self.get_video_uniqueness_score(embedding) @@ -407,24 +484,36 @@ async def embed_and_get_video_uniqueness_score(self, video_id: str, video_durati print(f"Failed to create video embedding for {video_id}: {str(e)}") return None, 0.1 # Assumes unique if we can't check - async def get_detailed_video_description_embedding_score(self, video_id, task_overview): - detailed_video_description = await get_detailed_video_description(video_id, task_overview) - embedding = await self.get_text_embedding(detailed_video_description.model_dump_json()) + async def get_detailed_video_description_embedding_score( + self, video_id, task_overview + ): + detailed_video_description = await get_detailed_video_description( + video_id, task_overview + ) + embedding = await self.get_text_embedding( + detailed_video_description.model_dump_json() + ) if embedding is None: return detailed_video_description, None, None - return detailed_video_description, embedding, await self.get_description_uniqueness_score(embedding) + return ( + detailed_video_description, + embedding, + await self.get_description_uniqueness_score(embedding), + ) - async def score_video(self, video_id: str, focusing_task: str, focusing_description: str): + async def score_video( + self, video_id: str, focusing_task: str, focusing_description: str + ): """ Generates a comprehensive score for a video submission based on multiple factors. - + The scoring process includes: 1. Checking video duration constraints 2. Computing task, description, and video uniqueness scores 3. Running legitimacy checks 4. Generating a completion score 5. Applying any boost multipliers - + Exceptions raised here make the video rejected. Args: @@ -450,48 +539,65 @@ async def score_video(self, video_id: str, focusing_task: str, focusing_descript video_duration_seconds = await get_video_duration_seconds(video_id) if video_duration_seconds < TWO_MINUTES: - raise VideoTooShortError(f"Video duration is too short: {video_duration_seconds} seconds") + raise VideoTooShortError( + f"Video duration is too short: {video_duration_seconds} seconds" + ) if video_duration_seconds > NINETY_MINUTES: - raise VideoTooLongError(f"Video duration is too long: {video_duration_seconds} seconds") + raise VideoTooLongError( + f"Video duration is too long: {video_duration_seconds} seconds" + ) task_overview = f"# {focusing_task}\n\n{focusing_description}" ( (task_overview_embedding, task_uniqueness_score), # task_score_breakdown, - (video_description, video_description_embedding, video_description_uniqueness_score), + ( + video_description, + video_description_embedding, + video_description_uniqueness_score, + ), (video_embedding, video_uniqueness_score), ) = await asyncio.gather( - self.embed_and_get_task_uniqueness_score(task_overview), # uses openai to get embedding + self.embed_and_get_task_uniqueness_score( + task_overview + ), # uses openai to get embedding # self.get_task_score_from_gemini(task_overview), # uses gemini to score task - self.get_detailed_video_description_embedding_score(video_id, task_overview), # uses gemini to get detailed description + self.get_detailed_video_description_embedding_score( + video_id, task_overview + ), # uses gemini to get detailed description self.embed_and_get_video_uniqueness_score(video_id, video_duration_seconds), ) - + if video_uniqueness_score < MIN_VIDEO_UNIQUENESS_SCORE: raise VideoUniquenessError("Video uniqueness score is too low.") - + if self.legitimacy_checks: check_results = await asyncio.gather( - *(check.passes_check(video_id, video_description) for check in self.legitimacy_checks) + *( + check.passes_check(video_id, video_description) + for check in self.legitimacy_checks + ) ) - + for passed, failure_reason in check_results: if not passed: - raise LegitimacyCheckError(f"Video failed legitimacy check: {failure_reason}. If you think this is a mistake, please contact us through the app or the OMEGA Focus Discord server.") - + raise LegitimacyCheckError( + f"Video failed legitimacy check: {failure_reason}. If you think this is a mistake, please contact us through the app or the OMEGA Focus Discord server." + ) + completion_score_breakdown = await _get_completion_score_breakdown( task_overview, detailed_video_description=video_description, ) - + completion_gemini_score = completion_score_breakdown.completion_score final_score = completion_gemini_score * boosted_multiplier - + print(f"Final score: {final_score}") print(f"completion score breakdown: {completion_score_breakdown}") - + return VideoScore( task_uniqueness_score=task_uniqueness_score, video_completion_score=completion_gemini_score, @@ -522,7 +628,9 @@ async def main(): Read the Show-O peper to understand how they have trained a unified diffusion and autoregressive model for multimodal tokenization. """.strip() - score_details = await service.score_video(video_id, task_overview, "description") + score_details = await service.score_video( + video_id, task_overview, "description" + ) print(score_details) # task_overview_embedding = await service.get_text_embedding(task_overview) diff --git a/validator-api/validator_api/scoring/video_description.py b/validator-api/validator_api/scoring/video_description.py index d055de13..0a3c9c09 100644 --- a/validator-api/validator_api/scoring/video_description.py +++ b/validator-api/validator_api/scoring/video_description.py @@ -5,11 +5,11 @@ from validator_api.scoring import focus_scoring_prompts from validator_api.scoring.gemini_client import _make_gemini_request_with_retries + async def get_task_overview(video_id: str) -> str: async with get_db_context() as db: query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() @@ -19,49 +19,58 @@ async def get_task_overview(video_id: str) -> str: focusing_task = video_record.video_details.get("focusing_task", "") focusing_description = video_record.video_details.get( - "focusing_description", "") + "focusing_description", "" + ) - task_overview = f"# Task Title: {focusing_task}\n\n Task Description:\n{focusing_description}" - return task_overview + task_overview = ( + f"# Task Title: {focusing_task}\n\n Task Description:\n{focusing_description}" + ) + return task_overview -async def get_detailed_video_description(video_id: str, task_overview: str) -> DetailedVideoDescription: + +async def get_detailed_video_description( + video_id: str, task_overview: str +) -> DetailedVideoDescription: async with get_db_context() as db: query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record is None: raise ValueError(f"Video not found: {video_id}") - - if video_record.video_details and "detailed_video_description" in video_record.video_details: + + if ( + video_record.video_details + and "detailed_video_description" in video_record.video_details + ): return DetailedVideoDescription.model_validate( video_record.video_details["detailed_video_description"] ) - + description = await _make_gemini_request_with_retries( system_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_SYSTEM_PROMPT, - user_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_USER_PROMPT.format(task_overview=task_overview), + user_prompt=focus_scoring_prompts.DETAILED_DESCRIPTION_USER_PROMPT.format( + task_overview=task_overview + ), video_id=video_id, OutputClassSchema=DetailedVideoDescription, ) - + # Cache the description in database async with get_db_context() as db: query = select(FocusVideoRecord).filter( - FocusVideoRecord.video_id == video_id, - FocusVideoRecord.deleted_at.is_(None) + FocusVideoRecord.video_id == video_id, FocusVideoRecord.deleted_at.is_(None) ) result = await db.execute(query) video_record = result.scalar_one_or_none() - + if video_record: video_details = video_record.video_details or {} video_details["detailed_video_description"] = description.model_dump() video_record.video_details = video_details db.add(video_record) await db.commit() - - return description \ No newline at end of file + + return description diff --git a/validator-api/validator_api/utils/__init__.py b/validator-api/validator_api/utils/__init__.py index 91942ac2..f5853e53 100644 --- a/validator-api/validator_api/utils/__init__.py +++ b/validator-api/validator_api/utils/__init__.py @@ -1,14 +1,16 @@ import asyncio, functools -RETRIES=3 -DELAY_SECS=2 +RETRIES = 3 +DELAY_SECS = 2 + def run_async(func, *args, **kwargs): loop = asyncio.get_event_loop() return loop.run_in_executor(None, functools.partial(func, *args, **kwargs)) + async def run_with_retries(func, *args, **kwargs): - """ func can be sync or async, since we await the output if it's a coroutine """ + """func can be sync or async, since we await the output if it's a coroutine""" for i in range(0, RETRIES): try: output = func(*args, **kwargs) diff --git a/validator-api/validator_api/utils/marketplace.py b/validator-api/validator_api/utils/marketplace.py index 6edc69f1..5db47c16 100644 --- a/validator-api/validator_api/utils/marketplace.py +++ b/validator-api/validator_api/utils/marketplace.py @@ -3,7 +3,11 @@ import requests import bittensor as bt from validator_api.config import ( - NETWORK, BT_TESTNET, NETUID, FOCUS_REWARDS_PERCENT, FIXED_ALPHA_USD_ESTIMATE, + NETWORK, + BT_TESTNET, + NETUID, + FOCUS_REWARDS_PERCENT, + FIXED_ALPHA_USD_ESTIMATE, BOOSTED_TASKS_PERCENTAGE, ) from validator_api.utils import run_with_retries, run_async @@ -22,6 +26,7 @@ async def get_subtensor() -> bt.subtensor: def _internal() -> bt.subtensor: return bt.subtensor(network=NETWORK) + return await run_with_retries(_internal) @@ -34,14 +39,13 @@ async def get_tao_price() -> float: ) ) + # Global cache for max focus alpha -max_focus_alpha_per_day_cache = { - 'value': None, - 'timestamp': 0 -} +max_focus_alpha_per_day_cache = {"value": None, "timestamp": 0} CACHE_DURATION = 30 * 60 # 30 minutes in seconds + async def get_max_focus_alpha_per_day() -> float: """ https://docs.bittensor.com/dynamic-tao/emission @@ -49,12 +53,15 @@ async def get_max_focus_alpha_per_day() -> float: global max_focus_alpha_per_day_cache current_time = time.time() - if max_focus_alpha_per_day_cache['value'] is not None and current_time - max_focus_alpha_per_day_cache['timestamp'] < CACHE_DURATION: - return max_focus_alpha_per_day_cache['value'] + if ( + max_focus_alpha_per_day_cache["value"] is not None + and current_time - max_focus_alpha_per_day_cache["timestamp"] < CACHE_DURATION + ): + return max_focus_alpha_per_day_cache["value"] # If cache is invalid or empty, recalculate subtensor = await get_subtensor() - + def _internal_sync(): subnet = subtensor.subnet(netuid=NETUID) alpha_emission_per_block = subnet.alpha_out_emission.tao @@ -73,8 +80,8 @@ async def _internal_async() -> float: max_focus_alpha_per_day = await run_with_retries(_internal_async) # print(f"max_focus_alpha_per_day: {max_focus_alpha_per_day}") # Update cache - max_focus_alpha_per_day_cache['value'] = max_focus_alpha_per_day - max_focus_alpha_per_day_cache['timestamp'] = current_time + max_focus_alpha_per_day_cache["value"] = max_focus_alpha_per_day + max_focus_alpha_per_day_cache["timestamp"] = current_time return max_focus_alpha_per_day @@ -87,11 +94,9 @@ async def get_fixed_reward_pool_alpha() -> float: """ async with get_db_context() as db: twenty_four_hours_ago = datetime.utcnow() - timedelta(days=1) - query = select( - func.sum(FocusVideoRecord.earned_reward_alpha) - ).where( + query = select(func.sum(FocusVideoRecord.earned_reward_alpha)).where( FocusVideoRecord.task_type == TaskType.MARKETPLACE.value, - FocusVideoRecord.updated_at >= twenty_four_hours_ago + FocusVideoRecord.updated_at >= twenty_four_hours_ago, ) result = await db.execute(query) return result.scalar() or 0.0 @@ -111,11 +116,13 @@ async def get_variable_reward_pool_alpha() -> float: def get_dollars_available_today(max_focus_alpha: float) -> float: - """ Use a fixed ΩTAO - USD estimate to keep consistent for the sake of miner rewards """ + """Use a fixed ΩTAO - USD estimate to keep consistent for the sake of miner rewards""" return max_focus_alpha * FIXED_ALPHA_USD_ESTIMATE + def get_max_focus_points_available_today(max_focus_alpha: float) -> float: # 1 point = 1 dollar return int(get_dollars_available_today(max_focus_alpha)) + MAX_TASK_REWARD_TAO = 0.1 diff --git a/validator-api/validator_api/utils/wallet.py b/validator-api/validator_api/utils/wallet.py index 5514928f..8836efb5 100644 --- a/validator-api/validator_api/utils/wallet.py +++ b/validator-api/validator_api/utils/wallet.py @@ -1,17 +1,16 @@ import bittensor as bt + # import aiohttp # import time import asyncio + # from validator_api.utils import run_with_retries, run_async from typing import List from validator_api.config import NETWORK # Global cache for TAO/USD rate -tao_usd_cache = { - 'rate': None, - 'timestamp': 0 -} +tao_usd_cache = {"rate": None, "timestamp": 0} CACHE_DURATION = 30 * 60 # 30 minutes in seconds @@ -77,6 +76,7 @@ } """ + async def get_transaction_from_block_hash( wallet_address: str, block_hash: str, @@ -84,38 +84,45 @@ async def get_transaction_from_block_hash( """Get all transfers associated with the provided wallet address and block_hash.""" transactions = [] divisor = 1e9 - + async with bt.AsyncSubtensor(network=NETWORK) as subtensor: block = await subtensor.substrate.get_block(block_hash) - block_num = block['header']['number'] + block_num = block["header"]["number"] - for extrinsic in block['extrinsics']: + for extrinsic in block["extrinsics"]: extrinsic = extrinsic.value - if 'call' in extrinsic and extrinsic['call']['call_module'] == 'Balances': - if extrinsic['call']['call_function'] in ['transfer', 'transfer_allow_death']: - sender = extrinsic.get('address', 'Unknown') - recipient = extrinsic['call']['call_args'][0]['value'] - amount = int(extrinsic['call']['call_args'][1]['value']) + if "call" in extrinsic and extrinsic["call"]["call_module"] == "Balances": + if extrinsic["call"]["call_function"] in [ + "transfer", + "transfer_allow_death", + ]: + sender = extrinsic.get("address", "Unknown") + recipient = extrinsic["call"]["call_args"][0]["value"] + amount = int(extrinsic["call"]["call_args"][1]["value"]) if sender == wallet_address or recipient == wallet_address: - transactions.append({ - 'id': extrinsic['extrinsic_hash'], - 'from': sender, - 'to': recipient, - 'amount': amount / divisor, - # the Id is not actually supposed to be the hash, but we'll let it fly - # for now cause all we need is a unique identifier, which the hash is - 'extrinsicId': extrinsic['extrinsic_hash'], - 'blockNumber': block_num - }) + transactions.append( + { + "id": extrinsic["extrinsic_hash"], + "from": sender, + "to": recipient, + "amount": amount / divisor, + # the Id is not actually supposed to be the hash, but we'll let it fly + # for now cause all we need is a unique identifier, which the hash is + "extrinsicId": extrinsic["extrinsic_hash"], + "blockNumber": block_num, + } + ) return transactions[::-1] if __name__ == "__main__": # get a recent transaction from https://taostats.io/transfers - result = asyncio.run(get_transaction_from_block_hash( - wallet_address="5CAjN1UcMXKWa8YHpoddjGFbrHdu182eYB6x5i1NkDiN4kej", - block_hash="0x95a2517045778b9bda7f309d45002f1c5fe03ff400f9f73b585da1c3d1bd9cb9" - )) + result = asyncio.run( + get_transaction_from_block_hash( + wallet_address="5CAjN1UcMXKWa8YHpoddjGFbrHdu182eYB6x5i1NkDiN4kej", + block_hash="0x95a2517045778b9bda7f309d45002f1c5fe03ff400f9f73b585da1c3d1bd9cb9", + ) + ) print(result) From 4e9f983631aa790c2dd8cce4f438b672335fdd8d Mon Sep 17 00:00:00 2001 From: Eric Hasegawa Date: Tue, 4 Mar 2025 17:31:57 -0800 Subject: [PATCH 2/4] fix 1 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7e4def3d..e586e7fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,4 +31,4 @@ jobs: - name: Run Black run: | source env/bin/activate - ruff format --check \ No newline at end of file + ruff format --check From 4871c921b4e44c776b398a029bd4b57c179a90cb Mon Sep 17 00:00:00 2001 From: Eric Hasegawa Date: Tue, 4 Mar 2025 17:36:28 -0800 Subject: [PATCH 3/4] fix 2 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e8661637..27ed0e4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ librosa==0.10.2.post1 substrate-interface==1.7.11 asyncpg==0.30.0 greenlet==3.1.1 -ruff=0.9.9 +ruff==0.9.9 From ddb54eaa951cbb7b9603613baac12949d8e2f02c Mon Sep 17 00:00:00 2001 From: Eric Hasegawa Date: Tue, 4 Mar 2025 17:39:47 -0800 Subject: [PATCH 4/4] fix 2 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e586e7fa..feebf4d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: source env/bin/activate uv pip install --pre -r requirements.txt uv pip install --pre -r requirements_api.txt - - name: Run Black + - name: Run Ruff formatting run: | source env/bin/activate ruff format --check