A robust system for large-scale text inference using Vertex AI (Gemini).
This repository hosts a modular framework to orchestrate large-scale batch and live inference requests to Gemini models.
- π¦ Installation
- π Quick Start
- βοΈ Core Functions
- π‘ Example Usage
- π Package Overview
- ποΈ System Architecture
- π Key Concepts and Features
- π§© Core Components
- π Table Schema
- π License
You can install the package directly from GitHub:
pip install git+https://github.com/ericzhao28/easyinference.git
- Set up your credentials for GCP and Vertex AI:
gcloud auth application-default login
- Configure the necessary environment variables:
# Google Cloud Platform Configuration
export GCP_PROJECT_ID="your-project-id"
export GCP_PROJECT_NUM="123456789012"
export GCP_REGION="us-central1"
export VERTEX_BUCKET="your-gcs-bucket"
export GEMINI_API_KEY=""
# SQL Configuration
export TABLE_NAME="your-table"
export SQL_DATABASE_NAME="your-database"
export SQL_USER="db-user"
export SQL_PASSWORD="your-password"
export SQL_INSTANCE_CONNECTION_NAME="project-id:region:instance-name"
export POOL_SIZE="50"
# Local Postgres Configuration (Optional)
export DB_TYPE="local"
export LOCAL_POSTGRES_HOST="localhost"
export LOCAL_POSTGRES_PORT="5432"
# Additional Configuration
export COOLDOWN_SECONDS="1.0"
export MAX_RETRIES="8"
export BATCH_TIMEOUT_HOURS="3"
export ROUND_ROBIN_ENABLED="false"
Alternatively, you can use the provided example.env
file:
- Copy
example.env
to.env
- Update the values in
.env
with your configuration - Use
python-dotenv
to load these variables in your code - Make sure you set your environment variables before importing
easyinference
. Otherwise, you should runeasyinference.reload_config()
after setting your environment variables.
- Initialize the database connection:
from easyinference import initialize_query_connection
# Initialize the database connection before using any inference functions
initialize_query_connection()
- Import and use the package:
from dotenv import load_dotenv # pip install python-dotenv
# Load environment variables from .env file (if using this approach)
load_dotenv()
from easyinference import inference, individual_inference, run_clearing_inference, reload_config, initialize_query_connection
# Initialize the database connection
initialize_query_connection()
Main async function for batch processing multiple datapoints
async def inference(
prompt_functions: List[Callable[[Any], str]], # Functions that convert datapoints to prompt text
datapoints: List[Any], # List of data items to process
tags: Optional[List[str]] = None, # Identifier tags for tracking
duplication_indices: Optional[List[int]] = None, # Indices for running datapoints multiple times
run_fast: bool = True, # If True, makes direct API calls; if False, queues for batch
allow_failure: bool = False, # If True, continues after max retries with error messages
attempts_cap: int = 8, # Maximum number of retry attempts
temperature: float = 0, # Temperature parameter for generation
max_output_tokens: int = 65535, # Maximum tokens to generate in response
thinking_budget_tokens: int = 32768, # Maximum tokens to generate in response
system_prompt: str = "", # System prompt to guide model behavior
model: str = "gemini-2.5-pro-preview-06-05", # Generative model to use
batch_size: int = 1000, # Max concurrent requests or batch job size
run_fast_timeout: float = 200, # Timeout in seconds for fast mode calls
cooldown_seconds: float = 1.0, # Base wait time between retries
batch_timeout_hours: int = 3, # Max runtime before restarting
round_robin_enabled: bool = False, # Whether to cycle through regions
round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
initial_histories: Optional[List[dict]] = None, # Starting conversation histories for the inference sessions
) -> tuple[List[tuple], str] # Returns ([[[response 1, response 2, ...], [query 1, query 2, ...]], ... for each datapoint], launch_timestamp_tag)
For processing a single datapoint through multiple prompt functions
async def individual_inference(
prompt_functions: List[Callable[[Any], str]], # Functions that convert datapoint to prompt text
datapoint: Any, # Data to process
tags: Optional[List[str]] = None, # Identifier tags for tracking
optional_tags: Optional[List[str]] = None, # Additional tags not used for lookup
duplication_index: int = 0, # Index to distinguish duplicate runs
run_fast: bool = True, # If True, makes direct API calls; if False, queues for batch
allow_failure: bool = False, # If True, continues after max retries with error messages
attempts_cap: int = 8, # Maximum number of retry attempts
temperature: float = 0, # Temperature parameter for generation
max_output_tokens: int = 65535, # Maximum tokens to generate in response
thinking_budget_tokens: int = 32768, # Maximum tokens to generate in response
system_prompt: str = "", # System prompt to guide model behavior
model: str = "gemini-2.5-pro-preview-06-05", # Generative model to use
run_fast_timeout: float = 200, # Timeout in seconds for fast mode calls
cooldown_seconds: float = 1.0, # Base wait time between retries
round_robin_enabled: bool = False, # Whether to cycle through regions
round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
initial_history_json: Optional[dict] = None, # Starting conversation history for the inference session
) -> tuple[List[str], List[str]] # Returns [[response 1, response 2, ...], [query 1, query 2, ...]]
For managing batch inference jobs
async def run_clearing_inference(
tag: str, # Unique identifier tag for the batch
batch_size: int, # Maximum number of requests per batch job
run_batch_jobs: bool, # Whether to launch new batch jobs
batch_timeout_hours: int = 3 # Maximum runtime hours before restarting
) -> None
For reloading the config after setting environment variables
def reload_config() -> None
import asyncio
from dotenv import load_dotenv
from easyinference import inference, reload_config, initialize_query_connection
load_dotenv()
reload_config()
# Initialize the database connection before using any inference functions
initialize_query_connection()
async def process_data():
# Define data and prompt function
datapoints = [
{"text": "What is machine learning?"},
{"text": "Explain neural networks"}
]
def create_prompt(dp):
return f"Please explain: {dp['text']}"
# Run inference
results, timestamp = await inference(
prompt_functions=[create_prompt],
datapoints=datapoints,
tags=["explanation", "v1"],
run_fast=True
)
# Process results
first_datapoint_result, second_datapoint_result = results
for i, (response, query) in enumerate(first_datapoint_result):
print(f"Query: {query}")
print(f"Response: {response}")
return results
# Run the async function
results = asyncio.run(process_data())
Goal: We provide a scalable and robust pipeline to handle:
- β¨ Conversation-based inference requests to Gemini models
- β¨ Failure tracking and retry logic to ensure stable operation
- β¨ Asynchronous or synchronous methods for generating text from the model
We accomplish this by:
- Storing every inference "step" in a PostgreSQL table, which captures the query text, model parameters, conversation history, and final responses (or errors).
- Separating "fast" live calls vs. "slow" batch-based calls.
- Monitoring the status of batch inference jobs, so you can schedule or restart them if they take too long.
- Allowing different usage patterns: single datapoint or bulk processing, with multi-prompt sequences, concurrency caps, and re-tries.
βββββββββββββββββββββ
β Your Application β
βββββββββββ¬ββββββββββ
β
βββββββββββββββββββ
β β
βΌ βΌ
βββββββββββββββββββββββ βββββββββββββββββββββββ
βIndividual Inference β β Inference β
β (Fast) ββ---β β
ββββββββββββ¬βββββββββββ ββββββββββββ¬βββββββββββ
β β
β βΌ
β βββββββββββββββββββββββββββ
β β Batch Clearing β
β β (monitoring) β
β β β
β ββββββββββββ¬βββββββββββββββ
β β
βΌ βΌ
ββββββββββββββββββββββββββ ββββββββββββββββββββββββββ
β Vertex AI (Gemini API) β β Vertex AI (Gemini API) β
β (Live Calls) β β (Batch Job) β
ββββββββββββββββββββββββββ ββββββββββββββββββββββββββ
β β
ββββββββββββ βββββββββββββ
βΌ βΌ
ββββββββββββββββββββββ
β PostgreSQL β
β Master Table β
ββββββββββββββββββββββ
Individual Inference
manages a single datapoint and a sequence of prompts.Inference
is a bulk orchestrator that calls individual inference on multiple datapoints.Clearing Inference
takes unprocessed/failed rows and triggers additional attempts (live or batch). It also monitors batch jobs and handles timeouts.
Stored in PostgreSQL under history_json
as a JSON object:
{
"history": [
{"role": "user", "parts": {"text": "Hello, how are you?"}},
{"role": "model", "parts": {"text": "I am fine. How can I help?"}}
]
}
This helps Vertex continue the same conversation context across multiple queries without duplication.
Stored under generation_params_json
(JSON):
{
"temperature": 0.7,
"max_output_tokens": 65535,
"system_prompt": "You are a helpful assistant..."
}
An integer marking whether a row is an exact duplicate of an earlier row (e.g., a re-run). Defaults to 0.
A list of strings (alphabetically sorted) representing categories or labels applied to a request (e.g. ["admin", "api-v1"]
).
This can help in filtering or grouping by usage scenario.
Either "intentional"
(explicit user request) or "backup"
(an automatic fallback).
Last Status
can be"PENDING"
,"RUNNING"
,"FAILED"
,"SUCCEEDED"
,"WAITING"
.Failure Count
tracks how many attempts have failed so far, andAttempts Cap
sets the max allowed.
A hash of (Model, History, Query, GenerationParams, DuplicationIndex)
for deduplicating or resuming.
- Run Fast: calls the Vertex API directly, returning the result in real-time.
- Run Slow: queues up the request for a batch job. The
run_clearing_inference
function handles job submission and monitoring.
Before using any inference functions, you must initialize the database connection by calling:
from easyinference import initialize_query_connection
initialize_query_connection()
This sets up the necessary connections to the PostgreSQL database for tracking inference requests.
- Defines a
ConvoRow
data class that mirrors each column in the table. - Enumerations for
RequestStatus
andRequestCause
.
EasyInference supports three database configuration options:
- Google Cloud SQL (default)
- Local Postgres for development or when using your own database infrastructure
- No Database for simple inference without tracking or batch processing
- Choose your database type by setting the
DB_TYPE
environment variable:
# Use Google Cloud SQL (default)
export DB_TYPE="gcp"
# Use local Postgres
export DB_TYPE="local"
export LOCAL_POSTGRES_HOST="localhost" # Or your Postgres server address
export LOCAL_POSTGRES_PORT="5432" # Or your Postgres server port
# Use no database
export DB_TYPE="none"
- For Google Cloud SQL or local Postgres, set the required database parameters:
# Required for both GCP and local Postgres options
export SQL_DATABASE_NAME="your-database"
export SQL_USER="db-user"
export SQL_PASSWORD="your-password"
export TABLE_NAME="your-table"
# Only required for GCP option
export SQL_INSTANCE_CONNECTION_NAME="project-id:region:instance-name"
- Initialize the database connection as usual:
from easyinference import initialize_query_connection
initialize_query_connection()
You can easily switch between database types in your Python code:
import os
from easyinference import reload_config, initialize_query_connection
# Switch to local Postgres
os.environ["DB_TYPE"] = "local"
reload_config()
initialize_query_connection()
# Later, switch to Google Cloud SQL
os.environ["DB_TYPE"] = "gcp"
reload_config()
initialize_query_connection()
# Or disable database operations entirely
os.environ["DB_TYPE"] = "none"
reload_config()
initialize_query_connection()
When DB_TYPE="none"
, EasyInference operates without any database tracking. In this mode:
- No database connection is established
- Batch inference is not available (will raise an error)
- Tagged inference is not available (will raise an error)
- Only direct, synchronous inference calls without tags are supported
- Helper functions to insert, update, or read rows from PostgreSQL.
- Includes concurrency checks so you don't overwrite a "SUCCEEDED" row with "FAILED."
- Functions for connecting to PostgreSQL, creating tables, and querying data.
- Implements both
individual_inference
andinference
functions - Contains
run_chat_inference_async
for "fast" calls with built-in retry/backoff - Implements
run_clearing_inference
that handles both batch submission and monitoring - Manages the logic for deduplicating (by content hash), incrementing failure counts, and handling partial successes
- Configuration settings for database connections, retry logic, and batch operations.
- Contains defaults for constants like
MAX_RETRIES
,BATCH_TIMEOUT_HOURS
, and various connection parameters.
Your master PostgreSQL table has the following columns:
Column Name | Type | Description |
---|---|---|
row_id | INTEGER | Auto-incrementing primary key |
content_hash | STRING | SHA-256 hash of key fields for deduplication |
history_json | JSON | JSON storing prior conversation messages in a format with the key "history" |
query | STRING | User's latest query that needs a response |
model | STRING | Full path of the model (e.g. "gemini-2.5-pro-preview-06-05" ) |
generation_params_json | JSON | JSON storing generation settings, e.g. {"temperature":0.7,"max_output_tokens":8192,"system_prompt":"..."} |
duplication_index | INTEGER | Used to mark re-runs or explicit duplicates. Defaults to 0 |
tags | ARRAY(STRING) | A sorted list of tags (e.g. ["api-v1","testing"] ) |
request_cause | STRING | "intentional" or "backup" . Uses the RequestCause enum |
request_timestamp | STRING | ISO 8601 timestamp ("2025-02-25T12:34:56Z" ) |
access_timestamps | ARRAY(STRING) | List of ISO 8601 timestamps of each read/update |
attempts_metadata_json | ARRAY(JSON) | JSON array of prior attempts, storing batch info and error messages |
response_json | JSON | JSON containing the final successful response if available. Example: {"text":"...response..."} |
current_batch | STRING | The ID of any currently running batch job. Can be NULL |
last_status | STRING | "PENDING" , "RUNNING" , "FAILED" , "SUCCEEDED" , or "WAITING" |
failure_count | INTEGER | How many times this row has failed so far |
attempts_cap | INTEGER | The maximum number of times we will re-try |
notes | STRING | Optional free-text notes |
insertion_timestamp | TIMESTAMP | When the row was inserted into the database |
- SHA-256 over the combination of
(Model, History, Query, GenerationParams, DuplicationIndex)
. - Ensures we don't re-run the same content multiple times unless we want to.
- A query can have tags like
["api-v1","admin-request"]
. The system enforces that the tag list is alphabetically sorted. - For batch mode, a timestamp tag is automatically added for tracking.
This project is provided under the MIT License.
Feel free to modify or extend the code to suit your deployment and usage requirements.