|
| 1 | +#!/usr/bin/python3 |
| 2 | +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted provided that the following conditions |
| 6 | +# are met: |
| 7 | +# * Redistributions of source code must retain the above copyright |
| 8 | +# notice, this list of conditions and the following disclaimer. |
| 9 | +# * Redistributions in binary form must reproduce the above copyright |
| 10 | +# notice, this list of conditions and the following disclaimer in the |
| 11 | +# documentation and/or other materials provided with the distribution. |
| 12 | +# * Neither the name of NVIDIA CORPORATION nor the names of its |
| 13 | +# contributors may be used to endorse or promote products derived |
| 14 | +# from this software without specific prior written permission. |
| 15 | +# |
| 16 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY |
| 17 | +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 18 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
| 19 | +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR |
| 20 | +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| 21 | +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| 22 | +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
| 23 | +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY |
| 24 | +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 25 | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 26 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 27 | + |
| 28 | +import sys |
| 29 | + |
| 30 | +sys.path.append("../common") |
| 31 | + |
| 32 | +import argparse |
| 33 | +import json |
| 34 | + |
| 35 | +import requests |
| 36 | + |
| 37 | + |
| 38 | +# To run the test, have tritonserver running and run this script with the endpoint as a flag. |
| 39 | +# |
| 40 | +# Example: |
| 41 | +# ``` |
| 42 | +# python3 orca_header_test.py http://localhost:8000/v2/models/ensemble/generate |
| 43 | +# ``` |
| 44 | +def get_endpoint_header(url, data, request_header=None): |
| 45 | + """ |
| 46 | + Sends a POST request to the given URL with the provided data and returns the value of the "endpoint-load-metrics" header, |
| 47 | + or None if the request fails. |
| 48 | + """ |
| 49 | + HEADER_KEY = "endpoint-load-metrics" |
| 50 | + try: |
| 51 | + response = None |
| 52 | + if request_header: |
| 53 | + response = requests.post(url, json=data, headers=request_header) |
| 54 | + else: |
| 55 | + response = requests.post(url, json=data) |
| 56 | + response.raise_for_status() |
| 57 | + return response.headers.get(HEADER_KEY, "") |
| 58 | + except requests.exceptions.RequestException as e: |
| 59 | + print(f"Error making request: {e}") |
| 60 | + return None |
| 61 | + |
| 62 | + |
| 63 | +def parse_header_data(header, orca_format): |
| 64 | + """ |
| 65 | + Parses the header data into a dictionary based on the given format. |
| 66 | + """ |
| 67 | + METRIC_KEY = "named_metrics" |
| 68 | + try: |
| 69 | + if orca_format == "json": |
| 70 | + # Parse the header in JSON format |
| 71 | + data = json.loads(header.replace("JSON ", "")) |
| 72 | + if METRIC_KEY in data: |
| 73 | + return data[METRIC_KEY] |
| 74 | + else: |
| 75 | + print(f"No key '{METRIC_KEY}' in header data: {data}") |
| 76 | + return None |
| 77 | + elif orca_format == "text": |
| 78 | + # Parse the header in TEXT format |
| 79 | + data = {} |
| 80 | + for key_value_pair in header.replace("TEXT ", "").split(", "): |
| 81 | + key, value = key_value_pair.split("=") |
| 82 | + if "." in key: |
| 83 | + prefix, nested_key = key.split(".", 1) |
| 84 | + if prefix == METRIC_KEY: |
| 85 | + data[nested_key] = float(value) |
| 86 | + if not data: |
| 87 | + print(f"Could not parse any keys from header: {header}") |
| 88 | + return None |
| 89 | + return data |
| 90 | + else: |
| 91 | + print(f"Invalid ORCA format: {orca_format}") |
| 92 | + return None |
| 93 | + except (json.JSONDecodeError, ValueError, KeyError): |
| 94 | + print("Error: Invalid data in the header.") |
| 95 | + return None |
| 96 | + |
| 97 | + |
| 98 | +def check_for_keys(data, desired_keys, orca_format): |
| 99 | + """ |
| 100 | + Checks if all desired keys are present in the given data dictionary. |
| 101 | + """ |
| 102 | + if all(key in data for key in desired_keys): |
| 103 | + print( |
| 104 | + "ORCA header present in ", |
| 105 | + orca_format, |
| 106 | + "format with" "kv_cache_utilization:", |
| 107 | + [k + ": " + str(data[k]) for k in desired_keys], |
| 108 | + ) |
| 109 | + return True |
| 110 | + else: |
| 111 | + print(f"Missing keys in header: {', '.join(set(desired_keys) - set(data))}") |
| 112 | + return False |
| 113 | + |
| 114 | + |
| 115 | +def request_header(orca_format): |
| 116 | + return {"endpoint-load-metrics-format": orca_format} if orca_format else None |
| 117 | + |
| 118 | + |
| 119 | +def test_header_type(url, data, orca_format): |
| 120 | + req_header = request_header(orca_format) |
| 121 | + response_header = get_endpoint_header(args.url, TEST_DATA, req_header) |
| 122 | + |
| 123 | + desired_keys = { |
| 124 | + "kv_cache_utilization", |
| 125 | + "max_token_capacity", |
| 126 | + } # Just the keys, no need to initialize with None |
| 127 | + |
| 128 | + if response_header is None: |
| 129 | + print(f"Request to endpoint: '{args.url}' failed.") |
| 130 | + return False |
| 131 | + elif response_header == "": |
| 132 | + if orca_format: |
| 133 | + print( |
| 134 | + f"response header empty, endpoint-load-metrics-format={orca_format} is not a valid ORCA metric format" |
| 135 | + ) |
| 136 | + return False |
| 137 | + else: |
| 138 | + # No request header set <=> no response header. Intended behavior. |
| 139 | + print(f"response header empty, endpoint-load-metrics-format is not set") |
| 140 | + return True |
| 141 | + |
| 142 | + data = parse_header_data(response_header, orca_format) |
| 143 | + if data: |
| 144 | + return check_for_keys(data, desired_keys, orca_format) |
| 145 | + else: |
| 146 | + print(f"Unexpected response header value: {response_header}") |
| 147 | + return False |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + parser = argparse.ArgumentParser( |
| 152 | + description="Make a POST request to generate endpoint to test the ORCA metrics header." |
| 153 | + ) |
| 154 | + parser.add_argument("url", help="The model URL to send the request to.") |
| 155 | + args = parser.parse_args() |
| 156 | + TEST_DATA = json.loads( |
| 157 | + '{"text_input": "hello world", "max_tokens": 20, "bad_words": "", "stop_words": ""}' |
| 158 | + ) |
| 159 | + passed = True |
| 160 | + |
| 161 | + for format in ["json", "text", None]: |
| 162 | + print("Checking response header for ORCA format:", format) |
| 163 | + if not test_header_type(args.url, TEST_DATA, format): |
| 164 | + print("FAIL on format:", format) |
| 165 | + passed = False |
| 166 | + |
| 167 | + sys.exit(0 if passed else 1) |
0 commit comments