Skip to content

Commit 2c19720

Browse files
add fastapi for ml
1 parent 7b9948d commit 2c19720

File tree

6 files changed

+630
-2
lines changed

6 files changed

+630
-2
lines changed

Chapter5/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
.env
55
delta_lake*
66
employees
7-
mlruns
7+
mlruns
8+
*.joblib

Chapter5/machine_learning.ipynb

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,232 @@
28262826
"source": [
28272827
"[Learn more about MLFlow Models](https://bit.ly/46y6gpF)."
28282828
]
2829+
},
2830+
{
2831+
"cell_type": "code",
2832+
"execution_count": null,
2833+
"id": "8acadfca",
2834+
"metadata": {
2835+
"tags": [
2836+
"hide-cell"
2837+
]
2838+
},
2839+
"outputs": [],
2840+
"source": [
2841+
"!pip3 install joblib \"fastapi[standard]\""
2842+
]
2843+
},
2844+
{
2845+
"cell_type": "markdown",
2846+
"id": "e815f01b",
2847+
"metadata": {},
2848+
"source": [
2849+
"Imagine this scenario: You have just built a machine learning (ML) model with great performance, and you want to share this model with your team members so that they can develop a web application on top of your model.\n",
2850+
"\n",
2851+
"One way to share the model with your team members is to save the model to a file (e.g., using pickle, joblib, or framework-specific methods) and share the file directly\n",
2852+
"\n",
2853+
"\n",
2854+
"```python\n",
2855+
"import joblib\n",
2856+
"\n",
2857+
"model = ...\n",
2858+
"\n",
2859+
"# Save model\n",
2860+
"joblib.dump(model, \"model.joblib\")\n",
2861+
"\n",
2862+
"# Load model\n",
2863+
"model = joblib.load(model)\n",
2864+
"```\n",
2865+
"\n",
2866+
"However, this approach requires the same environment and dependencies, and it can pose potential security risks.\n"
2867+
]
2868+
},
2869+
{
2870+
"cell_type": "markdown",
2871+
"id": "b364f9fc",
2872+
"metadata": {},
2873+
"source": [
2874+
"An alternative is creating an API for your ML model. APIs define how software components interact, allowing:\n",
2875+
"\n",
2876+
"1. Access from various programming languages and platforms\n",
2877+
"2. Easier integration for developers unfamiliar with ML or Python\n",
2878+
"3. Versatile use across different applications (web, mobile, etc.)\n",
2879+
"\n",
2880+
"This approach simplifies model sharing and usage, making it more accessible for diverse development needs.\n",
2881+
"\n",
2882+
"Let's learn how to create an ML API with FastAPI, a modern and fast web framework for building APIs with Python. \n",
2883+
"\n",
2884+
"Before we begin constructing an API for a machine learning model, let's first develop a basic model that our API will use. In this example, we'll create a model that predicts the median house price in California."
2885+
]
2886+
},
2887+
{
2888+
"cell_type": "code",
2889+
"execution_count": 12,
2890+
"id": "d7ea435d",
2891+
"metadata": {},
2892+
"outputs": [
2893+
{
2894+
"name": "stdout",
2895+
"output_type": "stream",
2896+
"text": [
2897+
"Mean squared error: 0.56\n"
2898+
]
2899+
},
2900+
{
2901+
"data": {
2902+
"text/plain": [
2903+
"['lr.joblib']"
2904+
]
2905+
},
2906+
"execution_count": 12,
2907+
"metadata": {},
2908+
"output_type": "execute_result"
2909+
}
2910+
],
2911+
"source": [
2912+
"from sklearn.datasets import fetch_california_housing\n",
2913+
"from sklearn.model_selection import train_test_split\n",
2914+
"from sklearn.linear_model import LinearRegression\n",
2915+
"from sklearn.metrics import mean_squared_error\n",
2916+
"import joblib\n",
2917+
"\n",
2918+
"# Load dataset\n",
2919+
"X, y = fetch_california_housing(as_frame=True, return_X_y=True)\n",
2920+
"\n",
2921+
"# Split dataset into training and test sets\n",
2922+
"X_train, X_test, y_train, y_test = train_test_split(\n",
2923+
" X, y, test_size=0.2, random_state=42\n",
2924+
")\n",
2925+
"\n",
2926+
"# Initialize and train the logistic regression model\n",
2927+
"model = LinearRegression()\n",
2928+
"model.fit(X_train, y_train)\n",
2929+
"\n",
2930+
"# Predict and evaluate the model\n",
2931+
"y_pred = model.predict(X_test)\n",
2932+
"mse = mean_squared_error(y_test, y_pred)\n",
2933+
"print(f\"Mean squared error: {mse:.2f}\")\n",
2934+
"\n",
2935+
"# Save model\n",
2936+
"joblib.dump(model, \"lr.joblib\")"
2937+
]
2938+
},
2939+
{
2940+
"cell_type": "markdown",
2941+
"id": "a5aaad8e",
2942+
"metadata": {},
2943+
"source": [
2944+
"Once we have our model, we can create an API for it using FastAPI. We'll define a POST endpoint for making predictions and use the model to make predictions.\n",
2945+
"\n",
2946+
"Here's an example of how to create an API for a machine learning model using FastAPI:"
2947+
]
2948+
},
2949+
{
2950+
"cell_type": "code",
2951+
"execution_count": 8,
2952+
"id": "581f789f",
2953+
"metadata": {},
2954+
"outputs": [
2955+
{
2956+
"name": "stdout",
2957+
"output_type": "stream",
2958+
"text": [
2959+
"Overwriting ml_app.py\n"
2960+
]
2961+
}
2962+
],
2963+
"source": [
2964+
"%%writefile ml_app.py\n",
2965+
"from fastapi import FastAPI\n",
2966+
"import joblib\n",
2967+
"import pandas as pd \n",
2968+
"\n",
2969+
"# Create a FastAPI application instance\n",
2970+
"app = FastAPI()\n",
2971+
"\n",
2972+
"# Load the pre-trained machine learning model\n",
2973+
"model = joblib.load(\"lr.joblib\")\n",
2974+
"\n",
2975+
"# Define a POST endpoint for making predictions\n",
2976+
"@app.post(\"/predict/\")\n",
2977+
"def predict(data: list[float]):\n",
2978+
" # Define the column names for the input features\n",
2979+
" columns = [\n",
2980+
" \"MedInc\",\n",
2981+
" \"HouseAge\",\n",
2982+
" \"AveRooms\",\n",
2983+
" \"AveBedrms\",\n",
2984+
" \"Population\",\n",
2985+
" \"AveOccup\",\n",
2986+
" \"Latitude\",\n",
2987+
" \"Longitude\",\n",
2988+
" ]\n",
2989+
" \n",
2990+
" # Create a pandas DataFrame from the input data\n",
2991+
" features = pd.DataFrame([data], columns=columns)\n",
2992+
" \n",
2993+
" # Use the model to make a prediction\n",
2994+
" prediction = model.predict(features)[0]\n",
2995+
" \n",
2996+
" # Return the prediction as a JSON object, rounding to 2 decimal places\n",
2997+
" return {\"price\": round(prediction, 2)}"
2998+
]
2999+
},
3000+
{
3001+
"cell_type": "markdown",
3002+
"id": "34aba2f2",
3003+
"metadata": {},
3004+
"source": [
3005+
"To run your FastAPI app for development, use the `fastapi dev` command:\n",
3006+
"```bash\n",
3007+
"$ fastapi dev ml_app.py\n",
3008+
"``` "
3009+
]
3010+
},
3011+
{
3012+
"cell_type": "code",
3013+
"execution_count": null,
3014+
"id": "375b4bce",
3015+
"metadata": {
3016+
"tags": [
3017+
"remove-cell"
3018+
]
3019+
},
3020+
"outputs": [],
3021+
"source": [
3022+
"!fastapi dev ml_app.py"
3023+
]
3024+
},
3025+
{
3026+
"cell_type": "markdown",
3027+
"id": "3fff7352",
3028+
"metadata": {},
3029+
"source": [
3030+
"This will start the development server and open the API documentation in your default browser.\n",
3031+
"\n",
3032+
"You can now use the API to make predictions by sending a POST request to the `/predict/` endpoint with the input data. For example:"
3033+
]
3034+
},
3035+
{
3036+
"cell_type": "markdown",
3037+
"id": "8fb49c7c",
3038+
"metadata": {},
3039+
"source": [
3040+
"Running this cURL command on your terminal:\n",
3041+
"```bash\n",
3042+
"curl -X 'POST' \\\n",
3043+
" 'http://127.0.0.1:8000/predict/' \\\n",
3044+
" -H 'accept: application/json' \\\n",
3045+
" -H 'Content-Type: application/json' \\\n",
3046+
" -d '[\n",
3047+
" 1.68, 25, 4, 2, 1400, 3, 36.06, -119.01\n",
3048+
"]'\n",
3049+
"```\n",
3050+
"This will return the predicted price as a JSON object, rounded to 2 decimal places:\n",
3051+
"```python\n",
3052+
"{\"price\":1.51}\n",
3053+
"```"
3054+
]
28293055
}
28303056
],
28313057
"metadata": {

Chapter5/ml_app.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from fastapi import FastAPI
2+
import joblib
3+
import pandas as pd
4+
5+
# Create a FastAPI application instance
6+
app = FastAPI()
7+
8+
# Load the pre-trained machine learning model
9+
model = joblib.load("lr.joblib")
10+
11+
# Define a POST endpoint for making predictions
12+
@app.post("/predict/")
13+
def predict(data: list[float]):
14+
# Define the column names for the input features
15+
columns = [
16+
"MedInc",
17+
"HouseAge",
18+
"AveRooms",
19+
"AveBedrms",
20+
"Population",
21+
"AveOccup",
22+
"Latitude",
23+
"Longitude",
24+
]
25+
26+
# Create a pandas DataFrame from the input data
27+
features = pd.DataFrame([data], columns=columns)
28+
29+
# Use the model to make a prediction
30+
prediction = model.predict(features)[0]
31+
32+
# Return the prediction as a JSON object, rounding to 2 decimal places
33+
return {"price": round(prediction, 2)}

0 commit comments

Comments
 (0)