diff --git a/requirements.txt b/requirements.txt index 7fa0bf7..dc40a91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ sentencepiece==0.2.0 protobuf==3.20.3 wandb==0.19.1 numpy==2.0.2 +PyJWT==2.9.0 python-dotenv==1.0.1 google-generativeai boto3 diff --git a/validator-api/app.py b/validator-api/app.py index f5c78c6..331153d 100644 --- a/validator-api/app.py +++ b/validator-api/app.py @@ -1,57 +1,63 @@ -import mysql.connector -from validator_api.limiter import limiter -from validator_api.check_blocking import detect_blocking -from validator_api.dataset_upload import video_dataset_uploader, audio_dataset_uploader import asyncio -import os import json -from datetime import datetime -import time -from typing import Annotated, List, Optional, Dict, Any +import os import random -import json -from pydantic import BaseModel +import time import traceback +from datetime import datetime from tempfile import TemporaryDirectory -import huggingface_hub -from datasets import load_dataset -import ulid from traceback import print_exception +from typing import Annotated, Any, Dict, List, Optional + +import aiohttp import bittensor +import huggingface_hub +import mysql.connector +import sentry_sdk +import ulid import uvicorn -from fastapi import FastAPI, HTTPException, Depends, Body, Path, Security, BackgroundTasks, Request -from fastapi.security import HTTPBasicCredentials, HTTPBasic +from datasets import load_dataset +from fastapi import (BackgroundTasks, Body, Depends, FastAPI, HTTPException, + Path, Request, Security) +from fastapi.responses import FileResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security.api_key import APIKeyHeader from fastapi.staticfiles import StaticFiles -from fastapi.responses import FileResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session from starlette import status from substrateinterface import Keypair -import sentry_sdk -from sqlalchemy.orm import Session +from validator_api.check_blocking import detect_blocking +from validator_api.communex._common import get_node_url +from validator_api.communex.client import CommuneClient +from validator_api.config import (API_KEY_NAME, API_KEYS, COMMUNE_NETUID, + COMMUNE_NETWORK, DB_CONFIG, ENABLE_COMMUNE, + FIXED_ALPHA_TAO_ESTIMATE, FOCUS_API_KEYS, + FOCUS_API_URL, FOCUS_REWARDS_PERCENT, + IMPORT_SCORE, IS_PROD, NETUID, NETWORK, PORT, + PROXY_LIST, SENTRY_DSN, TOPICS_LIST) +# from validator_api.utils.marketplace import get_max_focus_tao, get_purchase_max_focus_tao +from validator_api.cron.confirm_purchase import (confirm_transfer, + confirm_video_purchased) from validator_api.database import get_db, get_db_context from validator_api.database.crud.focusvideo import ( - get_all_available_focus, check_availability, get_video_owner_coldkey, - already_purchased_max_focus_tao, get_miner_purchase_stats, MinerPurchaseStats, - set_focus_video_score, mark_video_rejected, mark_video_submitted, TaskType, - alpha_to_tao_rate, -) -from validator_api.utils.marketplace import TASK_TYPE_MAP, get_max_focus_alpha_per_day, get_purchase_max_focus_alpha -# from validator_api.utils.marketplace import get_max_focus_tao, get_purchase_max_focus_tao -from validator_api.cron.confirm_purchase import confirm_transfer, confirm_video_purchased -from validator_api.scoring.scoring_service import FocusScoringService, VideoUniquenessError, LegitimacyCheckError -from validator_api.communex.client import CommuneClient -from validator_api.communex._common import get_node_url -from omega.protocol import Videos, VideoMetadata, AudioMetadata -import aiohttp -from validator_api.config import ( - NETWORK, NETUID, PORT, - ENABLE_COMMUNE, COMMUNE_NETWORK, COMMUNE_NETUID, - API_KEY_NAME, API_KEYS, DB_CONFIG, - TOPICS_LIST, PROXY_LIST, IS_PROD, - FOCUS_REWARDS_PERCENT, FOCUS_API_KEYS, FOCUS_API_URL, - SENTRY_DSN, IMPORT_SCORE, - FIXED_ALPHA_TAO_ESTIMATE, -) + MinerPurchaseStats, TaskType, alpha_to_tao_rate, + already_purchased_max_focus_tao, check_availability, + get_all_available_focus, get_miner_purchase_stats, get_video_owner_coldkey, + mark_video_rejected, mark_video_submitted, set_focus_video_score) +from validator_api.dataset_upload import (audio_dataset_uploader, + video_dataset_uploader) +from validator_api.limiter import limiter +from validator_api.scoring.scoring_service import (FocusScoringService, + LegitimacyCheckError, + VideoTooLongError, + VideoTooShortError, + VideoUniquenessError) +from validator_api.utils.marketplace import (TASK_TYPE_MAP, + get_max_focus_alpha_per_day, + get_purchase_max_focus_alpha) + +from omega.protocol import AudioMetadata, VideoMetadata print("IMPORT_SCORE:", IMPORT_SCORE) @@ -282,7 +288,11 @@ async def run_focus_scoring( print(f"Error scoring focus video <{video_id}>: {error_string}") # Determine appropriate rejection reason based on error type - if isinstance(e, VideoUniquenessError): + if isinstance(e, VideoTooShortError): + rejection_reason = "Video is too short. Please ensure the video is at least 10 seconds long." + elif isinstance(e, VideoTooLongError): + rejection_reason = "Video is too long. Please ensure the video is less than 10 minutes long." + elif isinstance(e, VideoUniquenessError): rejection_reason = "Task recording is not unique. If you believe this is an error, please contact a team member." elif isinstance(e, LegitimacyCheckError): rejection_reason = "An anomaly was detected in the video. If you believe this is an error, please contact a team member via the OMEGA Focus Discord channel." @@ -825,7 +835,10 @@ async def cache_max_focus_alpha(): ################ END OMEGA FOCUS ENDPOINTS ################ @app.get("/") - async def healthcheck(): + @limiter.limit("10/minute") + async def healthcheck( + request: Request, + ): return datetime.utcnow() ################ START MULTI-MODAL API / OPENTENSOR CONNECTOR ################ diff --git a/validator-api/validator_api/database/models/scoring.py b/validator-api/validator_api/database/models/scoring.py index 8111664..195fc98 100644 --- a/validator-api/validator_api/database/models/scoring.py +++ b/validator-api/validator_api/database/models/scoring.py @@ -1,6 +1,12 @@ from typing import List, Optional from pydantic import BaseModel, Field +class VideoTooShortError(Exception): + pass + +class VideoTooLongError(Exception): + pass + class VideoUniquenessError(Exception): pass @@ -46,7 +52,7 @@ class FocusVideoEmbeddings(BaseModel): # embeddings task_overview_embedding: Optional[List[float]] detailed_video_description_embedding: Optional[List[float]] - video_embedding: List[float] + video_embedding: Optional[List[float]] class BoostedTaskIndex(BaseModel): index: int diff --git a/validator-api/validator_api/limiter.py b/validator-api/validator_api/limiter.py index 38404a8..b36e630 100644 --- a/validator-api/validator_api/limiter.py +++ b/validator-api/validator_api/limiter.py @@ -1,4 +1,54 @@ from slowapi import Limiter -from slowapi.util import get_remote_address +import jwt +from typing import Optional +from fastapi import Request -limiter = Limiter(key_func=get_remote_address) +def get_rate_limit_key(request: Request) -> str: + """ + Extracts a rate limiting key from the request. + For authenticated users, uses their user ID. + For unauthenticated requests, falls back to their IP address. + """ + user_id = _extract_user_id(request) + if user_id: + print(f"Rate limiting key: user:{user_id}") + return f"user:{user_id}" + + ip = _get_client_ip(request) + print(f"Rate limiting key: ip:{ip}") + return f"ip:{ip}" + +def _extract_user_id(request: Request) -> Optional[str]: + """ + Extracts user ID from JWT token in Authorization header. + Returns None if no valid token found. + """ + auth_header = request.headers.get('authorization', '') + if not auth_header.startswith('Bearer '): + return None + + try: + token = auth_header.split(' ')[1] + payload = jwt.decode(token, options={"verify_signature": False}) + return payload.get('sub') + except (jwt.InvalidTokenError, IndexError): + return None + +def _get_client_ip(request: Request) -> str: + """ + Gets the original client IP from Cloudflare headers, + falling back to X-Forwarded-For if CF headers aren't present. + """ + # Try Cloudflare-specific header first + cf_connecting_ip = request.headers.get('cf-connecting-ip') + if cf_connecting_ip: + return cf_connecting_ip + + # Fall back to X-Forwarded-For + forwarded_for = request.headers.get('x-forwarded-for') + if forwarded_for: + return forwarded_for.split(',')[0].strip() + + return request.client.host + +limiter = Limiter(key_func=get_rate_limit_key) \ No newline at end of file diff --git a/validator-api/validator_api/scoring/scoring_service.py b/validator-api/validator_api/scoring/scoring_service.py index e9f1820..d0ed51a 100644 --- a/validator-api/validator_api/scoring/scoring_service.py +++ b/validator-api/validator_api/scoring/scoring_service.py @@ -33,8 +33,19 @@ from validator_api.database.models.boosted_task import BoostedTask from validator_api.database.models.focus_video_record import ( FocusVideoInternal, FocusVideoRecord) +from validator_api.database.models.scoring import (CompletionScore, + CompletionScoreWithoutRange, + DetailedVideoDescription, + FocusVideoEmbeddings, + LegitimacyCheckError, + VideoScore, + VideoTooLongError, + VideoTooShortError, + VideoUniquenessError) from validator_api.database.models.task import TaskRecordPG from validator_api.scoring import focus_scoring_prompts +from validator_api.scoring.legitimacy_checks import ChatOnlyCheck +from validator_api.scoring.query_llm import query_llm from validator_api.utils import run_async, run_with_retries from vertexai.generative_models import Part from vertexai.preview.generative_models import (GenerationConfig, @@ -43,10 +54,6 @@ HarmCategory) from vertexai.vision_models import (MultiModalEmbeddingModel, Video, VideoSegmentConfig) -from validator_api.database.models.scoring import DetailedVideoDescription, CompletionScore, CompletionScoreWithoutRange, VideoScore, FocusVideoEmbeddings, VideoUniquenessError, LegitimacyCheckError -from validator_api.scoring.legitimacy_checks import ChatOnlyCheck -from validator_api.scoring.query_llm import query_llm - TWO_MINUTES = 120 # in seconds NINETY_MINUTES = 5400 # in seconds @@ -383,8 +390,12 @@ async def embed_and_get_task_uniqueness_score(self, task_overview: str) -> Tuple return embedding, await self._get_task_uniqueness_score(embedding) async def embed_and_get_video_uniqueness_score(self, video_id: str, video_duration_seconds: int): - embedding = await get_video_embedding(video_id, video_duration_seconds) - return embedding, await self.get_video_uniqueness_score(embedding) + try: + embedding = await get_video_embedding(video_id, video_duration_seconds) + return embedding, await self.get_video_uniqueness_score(embedding) + except Exception as e: + print(f"Failed to create video embedding for {video_id}: {str(e)}") + return None, 0.1 # Assumes unique if we can't check async def get_detailed_video_description_embedding_score(self, video_id, task_overview): detailed_video_description = await get_detailed_video_description(video_id, task_overview) @@ -429,10 +440,10 @@ async def score_video(self, video_id: str, focusing_task: str, focusing_descript video_duration_seconds = get_video_duration_seconds(video_id) if video_duration_seconds < TWO_MINUTES: - raise ValueError(f"Video duration is too short: {video_duration_seconds} seconds") + raise VideoTooShortError(f"Video duration is too short: {video_duration_seconds} seconds") if video_duration_seconds > NINETY_MINUTES: - raise ValueError(f"Video duration is too long: {video_duration_seconds} seconds") + raise VideoTooLongError(f"Video duration is too long: {video_duration_seconds} seconds") task_overview = f"# {focusing_task}\n\n{focusing_description}" @@ -448,8 +459,8 @@ async def score_video(self, video_id: str, focusing_task: str, focusing_descript self.embed_and_get_video_uniqueness_score(video_id, video_duration_seconds), ) - # if video_uniqueness_score < MIN_VIDEO_UNIQUENESS_SCORE: - # raise VideoUniquenessError("Video uniqueness score is too low.") + if video_uniqueness_score < MIN_VIDEO_UNIQUENESS_SCORE: + raise VideoUniquenessError("Video uniqueness score is too low.") if self.legitimacy_checks: check_results = await asyncio.gather(