Skip to content

[RFC]: Support Prithvi geospatial model in serving mode #20234

Open
@mgazz

Description

@mgazz

Motivation.

The landscape of foundation models is extending to consider use-cases that do not involve text but images as input/output data. A notable example is Prithvi, a geospatial foundation model developed through a collaboration between IBM and NASA. Prithvi has recently been added to vLLM as a pooling model (#12830), but the current integration does not address the complexities associated with deploying the model in an online setting.

The goal of this RFC is to highlight the blockers in running Prithvi via the vLLM server and propose potential solutions.

Investigating the use of the vLLM Pooling API I identified two main blockers:

  1. The vLLM server does not handle Tensors as a potential input modality. This is the type of data provided/generated while using Prithvi. The model:
    • takes two tensors as input data, one representing the geospatial image, one containing the location coordinates.
    • produces an image mask, always represented via a tensor.
  2. The vLLM server assumes that the model uses a tokenizer. This assumption is not true anymore when it comes to Prithvi.

In the current design, the vLLM server accepts multi-modal data like images, videos, audios or image embeddings, but the concept of input tensors is not new. In the case of Image Embedding Inputs, users can specify multiple base64 encoded tensors as part of the image_embeds modality. Here an example from the documentation where, to query Qwen/Qwen2-VL-2B-Instruct, the user must add image_grid_thw in the form of an encoded tensorlink. This capability would solve the first challenge but, in my opinion, image_embeds does not correctly represent the type of input data that we pass to Prithvi during an inference.

Proposed Change.

This RFC aims at extending the current vLLM server to:

  • handle models that do not initialise a tokenizer.
  • support a new modality named tensors for the submission of generic input Tensors

A reference implementation of the proposed changes can be found in this branch: tensor_inputs. This is conditional to the acceptance of the pull request #20072

Add support for Tensor Inputs

We extend the chat utils used by the vLLM server to consider a new type of modality named tensors that allows users to pass base64 encoded tensors as potential input.

Moreover, the processing of the encoded tensors is consolidate between tensor and image_embeds modalities to avoid code duplication. ref

Additional features explored to enable tensors:

Add Support for models with uninitialised tokenizers.

The request preprocessing pipeline must account for scenarios where the tokenizer is not initialised in order to support Prithvi in serving mode. This can be easily done by handling the scenario where skip-tokenizer-init is set to True. ref1 and ref2

The lack of a tokenizer means that we must provide or generate prompt_token_ids. To overcome this problem we can request the user to specify additional_data["prompt_token_ids"]. Such approach allows us to maintain the EmbeddingChatRequest unchanged. We handle this in the _preprocess_chat method by checking for prompt_token_ids when the tokenizer is Noneref

Note: Without a tokenizer, I had to use a string placeholder for the request_prompt ref. There might be a better way to handle this.

In conclusion, with these proposed changes a user would be able to issue a request targeting the Prithvi model using the following prompt:

    prompt={
        "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
        "encoding_format":"tensor",
        "additional_data":{
            "prompt_token_ids": [1]
        },
        "messages": [
            {
                "role" : "user",
                "content": [
                        { "type": "tensors",
                        "tensors": {
                            "pixel_values": <base64_tensor_embedding>,
                            "location_coords": <base64_coord_embedding>
                            },
                        }
                        ],
            }
            
        ]
    }

Feedback Period.

2 weeks

CC List.

@DarkLight1337 @maxdebayser @njhill @christian-pinto

Any Other Things.

Try it out

From the branch tensor_inputs, start the server as following:

python -m vllm.entrypoints.openai.api_server \
    --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' \
    --task embed --trust-remote-code --dtype float16 \
    --skip-tokenizer-init --enforce-eager  --disable-log-stats

Following a simple script to request an inference:

import json
import base64

import requests
import torch
import io
import numpy as np

torch.set_default_dtype(torch.float16)


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client","Content-Type": "application/json"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


def decompress(output):
    np_result = np.frombuffer(
        base64.b64decode(output), dtype=np.float16)
    return np_result.reshape(1, 2, 512, 512)
    # Try following return statement with encode_tensor instead of encode_base64
    #return torch.load(io.BytesIO(base64.b64decode(output)))


def main():
    api_url = f"http://localhost:8000/pooling"
    model_name = 'christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'

    pixel_values = torch.full((6, 512, 512), 1.0,dtype=torch.float16)
    location_coords = torch.full((1, 2), 1.0,dtype=torch.float16)

    buffer_tiff = io.BytesIO()
    torch.save(pixel_values, buffer_tiff)
    buffer_tiff.seek(0)
    binary_data = buffer_tiff.read()
    base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8')

    buffer_coord = io.BytesIO()
    torch.save(location_coords, buffer_coord)
    buffer_coord.seek(0)
    binary_data = buffer_coord.read()
    base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8')

    prompt={
        "model":model_name,
        "encoding_format":"tensor",
        "additional_data":{
            "prompt_token_ids": [1]
        },
        "messages": [
            {
                "role" : "user",
                "content": [
                        { "type": "tensors",
                        "tensors": {
                            "pixel_values": base64_tensor_embedding,
                            "location_coords": base64_coord_embedding
                            },
                        }
                        ],
            }
            
        ]
    }

    pooling_response = post_http_request(prompt=prompt, api_url=api_url)
    numpy_data = decompress(pooling_response.json()["data"][0]["data"])
    print(f"Returned result: {numpy_data}")


if __name__ == "__main__":
    main()

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions