Skip to content

Commit 5057eb0

Browse files
authored
Merge pull request #428 from smart-on-fhir/mikix/vllm
Convert our previous hugging face support from TGI to vLLM
2 parents 6ba1875 + 2cdbf2b commit 5057eb0

File tree

7 files changed

+179
-231
lines changed

7 files changed

+179
-231
lines changed

compose.yaml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ services:
4040
- AZURE_OPENAI_API_KEY
4141
- AZURE_OPENAI_ENDPOINT
4242
# Internal environment variobles
43-
- CUMULUS_HUGGING_FACE_URL=http://llama2:8086/
4443
- URL_CTAKES_REST=http://ctakes-covid:8080/ctakes-web-rest/service/analyze
4544
- URL_CNLP_NEGATION=http://cnlpt-negation:8000/negation/process
4645
- URL_CNLP_TERM_EXISTS=http://cnlpt-term-exists:8000/termexists/process
@@ -60,6 +59,7 @@ services:
6059
cumulus-etl-gpu:
6160
extends: cumulus-etl-base
6261
environment:
62+
- CUMULUS_LLAMA2_URL=http://llama2:8086/v1
6363
- URL_CNLP_NEGATION=http://cnlpt-negation-gpu:8000/negation/process
6464
- URL_CNLP_TERM_EXISTS=http://cnlpt-term-exists-gpu:8000/termexists/process
6565
profiles:
@@ -140,21 +140,23 @@ services:
140140
# This is a WIP llama2 setup, currently suitable for running in a g5.xlarge AWS instance.
141141
llama2:
142142
extends: common-base
143-
image: ghcr.io/huggingface/text-generation-inference:1.0.1
143+
image: vllm/vllm-openai:v0.10.0
144144
environment:
145+
- HF_TOKEN
146+
- HUGGING_FACE_HUB_TOKEN
147+
command:
145148
# If you update anything here that could affect NLP results, consider updating the
146149
# task_version of any tasks that use this docker.
147-
- HUGGING_FACE_HUB_TOKEN
148-
- MODEL_ID=meta-llama/Llama-2-13b-chat-hf
149-
- QUANTIZE=bitsandbytes-nf4 # 4bit
150-
- PORT=8086
151-
- REVISION=0ba94ac9b9e1d5a0037780667e8b219adde1908c
150+
- --download-dir=/data
151+
- --model=meta-llama/Llama-2-13b-chat-hf
152+
- --port=8086
153+
- --quantization=bitsandbytes # 4bit
154+
- --revision=a2cb7a712bb6e5e736ca7f8cd98167f81a0b5bd8
152155
healthcheck:
153-
# There's no curl or wget inside this container, but there is python3!
154-
test: ["CMD", "python3", "-c", "import socket; socket.create_connection(('localhost', 8086))"]
156+
test: ["CMD", "wget", "localhost:8086/health", "--output-document=/dev/null"]
155157
start_period: 20m # give plenty of time for startup, since we may be downloading a model
156158
volumes:
157-
- hf-data:/data
159+
- vllm-data:/data
158160
profiles:
159161
- hf-test
160162
networks:
@@ -257,4 +259,4 @@ networks:
257259

258260
volumes:
259261
ctakes-overrides:
260-
hf-data:
262+
vllm-data:

cumulus_etl/etl/studies/hftest/hf_tasks.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""Define tasks for the hftest study"""
22

3-
import cumulus_fhir_support as cfs
4-
import httpx
53
import pyarrow
64
import rich.progress
75

8-
from cumulus_etl import common, errors, nlp
6+
from cumulus_etl import common, nlp
97
from cumulus_etl.etl import tasks
108

119

@@ -17,48 +15,35 @@ class HuggingFaceTestTask(tasks.BaseNlpTask):
1715
task_version = 0
1816
# Task Version History:
1917
# ** 0 **
20-
# This is fluid until we actually promote this to a real task - feel free to update without bumping the version.
21-
# container: ghcr.io/huggingface/text-generation-inference
22-
# container reversion: 09eca6422788b1710c54ee0d05dd6746f16bb681
18+
# This is fluid until we actually promote this to a real task - feel free to update without
19+
# bumping the version.
20+
# container: vllm/vllm-openai
21+
# container revision: v0.10.0
2322
# container properties:
24-
# QUANTIZE=bitsandbytes-nf4
23+
# QUANTIZE=bitsandbytes
2524
# model: meta-llama/Llama-2-13b-chat-hf
26-
# model revision: 0ba94ac9b9e1d5a0037780667e8b219adde1908c
25+
# model revision: a2cb7a712bb6e5e736ca7f8cd98167f81a0b5bd8
2726
# system prompt:
28-
# "You will be given a clinical note, and you should reply with a short summary of that note."
27+
# "You will be given a clinical note, and you should reply with a short summary of that
28+
# note."
2929
# user prompt: a clinical note
3030

3131
@classmethod
3232
async def init_check(cls) -> None:
33-
try:
34-
raw_info = await nlp.hf_info()
35-
except cfs.NetworkError:
36-
errors.fatal(
37-
"Llama2 NLP server is unreachable.\n Try running 'docker compose up llama2 --wait'.",
38-
errors.SERVICE_MISSING,
39-
)
40-
41-
# Sanity check a few of the properties, to make sure we don't accidentally get pointed at an unexpected model.
42-
expected_info_present = (
43-
raw_info.get("model_id") == "meta-llama/Llama-2-13b-chat-hf"
44-
and raw_info.get("model_sha") == "0ba94ac9b9e1d5a0037780667e8b219adde1908c"
45-
and raw_info.get("sha") == "09eca6422788b1710c54ee0d05dd6746f16bb681"
46-
)
47-
if not expected_info_present:
48-
errors.fatal(
49-
"LLama2 NLP server is using an unexpected model setup.",
50-
errors.SERVICE_MISSING,
51-
)
33+
await nlp.Llama2Model().check()
5234

5335
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
5436
"""Passes clinical notes through HF and returns any symptoms found"""
55-
http_client = httpx.AsyncClient(timeout=300)
37+
client = nlp.Llama2Model()
5638

5739
async for _, docref, clinical_note in self.read_notes(progress=progress):
5840
timestamp = common.datetime_now().isoformat()
5941

6042
# If you change this prompt, consider updating task_version.
61-
system_prompt = "You will be given a clinical note, and you should reply with a short summary of that note."
43+
system_prompt = (
44+
"You will be given a clinical note, "
45+
"and you should reply with a short summary of that note."
46+
)
6247
user_prompt = clinical_note
6348

6449
summary = await nlp.cache_wrapper(
@@ -67,18 +52,17 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
6752
clinical_note,
6853
lambda x: x, # from file: just store the string
6954
lambda x: x, # to file: just read it back
70-
nlp.llama2_prompt,
55+
client.prompt,
7156
system_prompt,
7257
user_prompt,
73-
client=http_client,
7458
)
7559

7660
# Debugging
77-
# logging.warning("\n\n\n\n" "**********************************************************")
61+
# logging.warning("\n\n\n\n" "********************************************************")
7862
# logging.warning(user_prompt)
79-
# logging.warning("==========================================================")
63+
# logging.warning("========================================================")
8064
# logging.warning(summary)
81-
# logging.warning("**********************************************************")
65+
# logging.warning("********************************************************")
8266

8367
yield {
8468
"id": docref["id"], # just copy the docref

cumulus_etl/nlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Support code for NLP servers"""
22

33
from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity
4-
from .huggingface import hf_info, hf_prompt, llama2_prompt
4+
from .openai import Llama2Model
55
from .utils import cache_wrapper, get_docref_info, is_docref_valid
66
from .watcher import (
77
check_ctakes,

cumulus_etl/nlp/huggingface.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

cumulus_etl/nlp/openai.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Abstraction layer for Hugging Face's inference API"""
2+
3+
import abc
4+
import os
5+
6+
import openai
7+
8+
from cumulus_etl import errors
9+
10+
11+
class OpenAIModel(abc.ABC):
12+
COMPOSE_ID = None
13+
MODEL_NAME = None
14+
15+
@property
16+
@abc.abstractmethod
17+
def url(self) -> str:
18+
"""The OpenAI compatible URL to talk to (where's the server?)"""
19+
20+
@property
21+
@abc.abstractmethod
22+
def api_key(self) -> str:
23+
"""The API key to use (empty string for local servers)"""
24+
25+
def __init__(self):
26+
self.client = openai.AsyncClient(base_url=self.url, api_key=self.api_key)
27+
28+
async def check(self) -> None:
29+
try:
30+
models = self.client.models.list()
31+
names = {model.id async for model in models}
32+
except openai.APIError:
33+
errors.fatal(
34+
f"NLP server '{self.COMPOSE_ID}' is unreachable.\n"
35+
f"Try running 'docker compose up {self.COMPOSE_ID} --wait'.",
36+
errors.SERVICE_MISSING,
37+
)
38+
39+
if self.MODEL_NAME not in names:
40+
errors.fatal(
41+
f"NLP server '{self.COMPOSE_ID}' is using an unexpected model setup.",
42+
errors.SERVICE_MISSING,
43+
)
44+
45+
async def prompt(self, system: str, user: str) -> str:
46+
response = await self.client.responses.create(
47+
model=self.MODEL_NAME,
48+
instructions=system,
49+
input=user,
50+
temperature=0,
51+
)
52+
return response.output_text.strip()
53+
54+
55+
class LocalModel(OpenAIModel, abc.ABC):
56+
@property
57+
def api_key(self) -> str:
58+
return ""
59+
60+
61+
class Llama2Model(LocalModel):
62+
COMPOSE_ID = "llama2"
63+
MODEL_NAME = "meta-llama/Llama-2-13b-chat-hf"
64+
65+
@property
66+
def url(self) -> str:
67+
# 8000 and 8080 are both used as defaults in ctakesclient (cnlp & ctakes respectively).
68+
# 8086 is used as a joking reference to Hugging Face (HF = 86).
69+
return os.environ.get("CUMULUS_LLAMA2_URL") or "http://localhost:8086/v1"

docs/chart-review.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ docker compose run --rm \
6464

6565
The above command will take all the DiagnosticReports and DocumentReferences
6666
in Group `67890` from the EHR,
67-
mark the notes with the default NLP dictionary,
6867
anonymize the notes with `philter`,
6968
and then push the results to your Label Studio project number `3`.
7069

0 commit comments

Comments
 (0)