|
2826 | 2826 | "source": [
|
2827 | 2827 | "[Learn more about MLFlow Models](https://bit.ly/46y6gpF)."
|
2828 | 2828 | ]
|
| 2829 | + }, |
| 2830 | + { |
| 2831 | + "cell_type": "code", |
| 2832 | + "execution_count": null, |
| 2833 | + "id": "8acadfca", |
| 2834 | + "metadata": { |
| 2835 | + "tags": [ |
| 2836 | + "hide-cell" |
| 2837 | + ] |
| 2838 | + }, |
| 2839 | + "outputs": [], |
| 2840 | + "source": [ |
| 2841 | + "!pip3 install joblib \"fastapi[standard]\"" |
| 2842 | + ] |
| 2843 | + }, |
| 2844 | + { |
| 2845 | + "cell_type": "markdown", |
| 2846 | + "id": "e815f01b", |
| 2847 | + "metadata": {}, |
| 2848 | + "source": [ |
| 2849 | + "Imagine this scenario: You have just built a machine learning (ML) model with great performance, and you want to share this model with your team members so that they can develop a web application on top of your model.\n", |
| 2850 | + "\n", |
| 2851 | + "One way to share the model with your team members is to save the model to a file (e.g., using pickle, joblib, or framework-specific methods) and share the file directly\n", |
| 2852 | + "\n", |
| 2853 | + "\n", |
| 2854 | + "```python\n", |
| 2855 | + "import joblib\n", |
| 2856 | + "\n", |
| 2857 | + "model = ...\n", |
| 2858 | + "\n", |
| 2859 | + "# Save model\n", |
| 2860 | + "joblib.dump(model, \"model.joblib\")\n", |
| 2861 | + "\n", |
| 2862 | + "# Load model\n", |
| 2863 | + "model = joblib.load(model)\n", |
| 2864 | + "```\n", |
| 2865 | + "\n", |
| 2866 | + "However, this approach requires the same environment and dependencies, and it can pose potential security risks.\n" |
| 2867 | + ] |
| 2868 | + }, |
| 2869 | + { |
| 2870 | + "cell_type": "markdown", |
| 2871 | + "id": "b364f9fc", |
| 2872 | + "metadata": {}, |
| 2873 | + "source": [ |
| 2874 | + "An alternative is creating an API for your ML model. APIs define how software components interact, allowing:\n", |
| 2875 | + "\n", |
| 2876 | + "1. Access from various programming languages and platforms\n", |
| 2877 | + "2. Easier integration for developers unfamiliar with ML or Python\n", |
| 2878 | + "3. Versatile use across different applications (web, mobile, etc.)\n", |
| 2879 | + "\n", |
| 2880 | + "This approach simplifies model sharing and usage, making it more accessible for diverse development needs.\n", |
| 2881 | + "\n", |
| 2882 | + "Let's learn how to create an ML API with FastAPI, a modern and fast web framework for building APIs with Python. \n", |
| 2883 | + "\n", |
| 2884 | + "Before we begin constructing an API for a machine learning model, let's first develop a basic model that our API will use. In this example, we'll create a model that predicts the median house price in California." |
| 2885 | + ] |
| 2886 | + }, |
| 2887 | + { |
| 2888 | + "cell_type": "code", |
| 2889 | + "execution_count": 12, |
| 2890 | + "id": "d7ea435d", |
| 2891 | + "metadata": {}, |
| 2892 | + "outputs": [ |
| 2893 | + { |
| 2894 | + "name": "stdout", |
| 2895 | + "output_type": "stream", |
| 2896 | + "text": [ |
| 2897 | + "Mean squared error: 0.56\n" |
| 2898 | + ] |
| 2899 | + }, |
| 2900 | + { |
| 2901 | + "data": { |
| 2902 | + "text/plain": [ |
| 2903 | + "['lr.joblib']" |
| 2904 | + ] |
| 2905 | + }, |
| 2906 | + "execution_count": 12, |
| 2907 | + "metadata": {}, |
| 2908 | + "output_type": "execute_result" |
| 2909 | + } |
| 2910 | + ], |
| 2911 | + "source": [ |
| 2912 | + "from sklearn.datasets import fetch_california_housing\n", |
| 2913 | + "from sklearn.model_selection import train_test_split\n", |
| 2914 | + "from sklearn.linear_model import LinearRegression\n", |
| 2915 | + "from sklearn.metrics import mean_squared_error\n", |
| 2916 | + "import joblib\n", |
| 2917 | + "\n", |
| 2918 | + "# Load dataset\n", |
| 2919 | + "X, y = fetch_california_housing(as_frame=True, return_X_y=True)\n", |
| 2920 | + "\n", |
| 2921 | + "# Split dataset into training and test sets\n", |
| 2922 | + "X_train, X_test, y_train, y_test = train_test_split(\n", |
| 2923 | + " X, y, test_size=0.2, random_state=42\n", |
| 2924 | + ")\n", |
| 2925 | + "\n", |
| 2926 | + "# Initialize and train the logistic regression model\n", |
| 2927 | + "model = LinearRegression()\n", |
| 2928 | + "model.fit(X_train, y_train)\n", |
| 2929 | + "\n", |
| 2930 | + "# Predict and evaluate the model\n", |
| 2931 | + "y_pred = model.predict(X_test)\n", |
| 2932 | + "mse = mean_squared_error(y_test, y_pred)\n", |
| 2933 | + "print(f\"Mean squared error: {mse:.2f}\")\n", |
| 2934 | + "\n", |
| 2935 | + "# Save model\n", |
| 2936 | + "joblib.dump(model, \"lr.joblib\")" |
| 2937 | + ] |
| 2938 | + }, |
| 2939 | + { |
| 2940 | + "cell_type": "markdown", |
| 2941 | + "id": "a5aaad8e", |
| 2942 | + "metadata": {}, |
| 2943 | + "source": [ |
| 2944 | + "Once we have our model, we can create an API for it using FastAPI. We'll define a POST endpoint for making predictions and use the model to make predictions.\n", |
| 2945 | + "\n", |
| 2946 | + "Here's an example of how to create an API for a machine learning model using FastAPI:" |
| 2947 | + ] |
| 2948 | + }, |
| 2949 | + { |
| 2950 | + "cell_type": "code", |
| 2951 | + "execution_count": 8, |
| 2952 | + "id": "581f789f", |
| 2953 | + "metadata": {}, |
| 2954 | + "outputs": [ |
| 2955 | + { |
| 2956 | + "name": "stdout", |
| 2957 | + "output_type": "stream", |
| 2958 | + "text": [ |
| 2959 | + "Overwriting ml_app.py\n" |
| 2960 | + ] |
| 2961 | + } |
| 2962 | + ], |
| 2963 | + "source": [ |
| 2964 | + "%%writefile ml_app.py\n", |
| 2965 | + "from fastapi import FastAPI\n", |
| 2966 | + "import joblib\n", |
| 2967 | + "import pandas as pd \n", |
| 2968 | + "\n", |
| 2969 | + "# Create a FastAPI application instance\n", |
| 2970 | + "app = FastAPI()\n", |
| 2971 | + "\n", |
| 2972 | + "# Load the pre-trained machine learning model\n", |
| 2973 | + "model = joblib.load(\"lr.joblib\")\n", |
| 2974 | + "\n", |
| 2975 | + "# Define a POST endpoint for making predictions\n", |
| 2976 | + "@app.post(\"/predict/\")\n", |
| 2977 | + "def predict(data: list[float]):\n", |
| 2978 | + " # Define the column names for the input features\n", |
| 2979 | + " columns = [\n", |
| 2980 | + " \"MedInc\",\n", |
| 2981 | + " \"HouseAge\",\n", |
| 2982 | + " \"AveRooms\",\n", |
| 2983 | + " \"AveBedrms\",\n", |
| 2984 | + " \"Population\",\n", |
| 2985 | + " \"AveOccup\",\n", |
| 2986 | + " \"Latitude\",\n", |
| 2987 | + " \"Longitude\",\n", |
| 2988 | + " ]\n", |
| 2989 | + " \n", |
| 2990 | + " # Create a pandas DataFrame from the input data\n", |
| 2991 | + " features = pd.DataFrame([data], columns=columns)\n", |
| 2992 | + " \n", |
| 2993 | + " # Use the model to make a prediction\n", |
| 2994 | + " prediction = model.predict(features)[0]\n", |
| 2995 | + " \n", |
| 2996 | + " # Return the prediction as a JSON object, rounding to 2 decimal places\n", |
| 2997 | + " return {\"price\": round(prediction, 2)}" |
| 2998 | + ] |
| 2999 | + }, |
| 3000 | + { |
| 3001 | + "cell_type": "markdown", |
| 3002 | + "id": "34aba2f2", |
| 3003 | + "metadata": {}, |
| 3004 | + "source": [ |
| 3005 | + "To run your FastAPI app for development, use the `fastapi dev` command:\n", |
| 3006 | + "```bash\n", |
| 3007 | + "$ fastapi dev ml_app.py\n", |
| 3008 | + "``` " |
| 3009 | + ] |
| 3010 | + }, |
| 3011 | + { |
| 3012 | + "cell_type": "code", |
| 3013 | + "execution_count": null, |
| 3014 | + "id": "375b4bce", |
| 3015 | + "metadata": { |
| 3016 | + "tags": [ |
| 3017 | + "remove-cell" |
| 3018 | + ] |
| 3019 | + }, |
| 3020 | + "outputs": [], |
| 3021 | + "source": [ |
| 3022 | + "!fastapi dev ml_app.py" |
| 3023 | + ] |
| 3024 | + }, |
| 3025 | + { |
| 3026 | + "cell_type": "markdown", |
| 3027 | + "id": "3fff7352", |
| 3028 | + "metadata": {}, |
| 3029 | + "source": [ |
| 3030 | + "This will start the development server and open the API documentation in your default browser.\n", |
| 3031 | + "\n", |
| 3032 | + "You can now use the API to make predictions by sending a POST request to the `/predict/` endpoint with the input data. For example:" |
| 3033 | + ] |
| 3034 | + }, |
| 3035 | + { |
| 3036 | + "cell_type": "markdown", |
| 3037 | + "id": "8fb49c7c", |
| 3038 | + "metadata": {}, |
| 3039 | + "source": [ |
| 3040 | + "Running this cURL command on your terminal:\n", |
| 3041 | + "```bash\n", |
| 3042 | + "curl -X 'POST' \\\n", |
| 3043 | + " 'http://127.0.0.1:8000/predict/' \\\n", |
| 3044 | + " -H 'accept: application/json' \\\n", |
| 3045 | + " -H 'Content-Type: application/json' \\\n", |
| 3046 | + " -d '[\n", |
| 3047 | + " 1.68, 25, 4, 2, 1400, 3, 36.06, -119.01\n", |
| 3048 | + "]'\n", |
| 3049 | + "```\n", |
| 3050 | + "This will return the predicted price as a JSON object, rounded to 2 decimal places:\n", |
| 3051 | + "```python\n", |
| 3052 | + "{\"price\":1.51}\n", |
| 3053 | + "```" |
| 3054 | + ] |
2829 | 3055 | }
|
2830 | 3056 | ],
|
2831 | 3057 | "metadata": {
|
|
0 commit comments