Skip to content

Commit 6638b73

Browse files
committed
fix(server): correctly retrieve image
1 parent 7509a1f commit 6638b73

File tree

4 files changed

+532
-469
lines changed

4 files changed

+532
-469
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ dev = [
8383
"typeguard>=4.4.1",
8484
"pyfakefs>=5.7.3",
8585
"mongomock>=4.3.0",
86+
"qwen-vl-utils>=0.0.10",
8687
]
8788
triton = [
8889
"triton>=3.1.0",

server/app.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def list_dictionaries():
9191
return client.list_saes(sae_series=sae_series, has_analyses=True)
9292

9393

94-
@app.get("/images/{dataset_name}/{context_idx}/{image_idx}")
95-
def get_image(dataset_name: str, context_idx: int, image_idx: int):
96-
dataset = get_dataset(dataset_name)
94+
@app.get("/images/{dataset_name}")
95+
def get_image(dataset_name: str, context_idx: int, image_idx: int, shard_idx: int = 0, n_shards: int = 1):
96+
dataset = get_dataset(dataset_name, shard_idx, n_shards)
9797
data = dataset[int(context_idx)]
9898

9999
image_key = "image" if "image" in data else "images" if "images" in data else None
@@ -144,17 +144,19 @@ def get_feature(name: str, feature_index: str | int):
144144
dataset_name = sampling.dataset_name[i]
145145
model_name = sampling.model_name[i]
146146
model = get_model(model_name)
147-
data = get_dataset(
148-
dataset_name,
149-
sampling.shard_idx[i] if sampling.shard_idx is not None else 0,
150-
sampling.n_shards[i] if sampling.n_shards is not None else 1,
151-
)[context_idx]
147+
shard_idx = sampling.shard_idx[i] if sampling.shard_idx is not None else 0
148+
n_shards = sampling.n_shards[i] if sampling.n_shards is not None else 1
149+
data = get_dataset(dataset_name, shard_idx, n_shards)[context_idx]
150+
152151
_, token_origins = model.to_tokens_with_origins(data)
153152

154153
# Replace image_key with image_url
155154
image_key = "image" if "image" in data else "images" if "images" in data else None
156155
if image_key is not None:
157-
image_urls = [f"/images/{dataset_name}/{context_idx}/{i}" for i in range(len(data[image_key]))]
156+
image_urls = [
157+
f"/images/{dataset_name}?context_idx={context_idx}&shard_idx={shard_idx}&n_shards={n_shards}&image_idx={image_idx}"
158+
for image_idx in range(len(data[image_key]))
159+
]
158160
del data[image_key]
159161
data["images"] = image_urls
160162

src/lm_saes/resource_loaders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def load_dataset_shard(
2626
dataset = datasets.load_from_disk(cfg.dataset_name_or_path)
2727
dataset = cast(datasets.Dataset, dataset)
2828
dataset = dataset.shard(num_shards=n_shards, index=shard_idx, contiguous=True)
29+
dataset = dataset.with_format("torch")
2930
return dataset
3031

3132

0 commit comments

Comments
 (0)