Skip to content

Commit b2e42ff

Browse files
authored
SAM2 AMG cli and other QoL improvements (#1336)
1 parent 51c87b6 commit b2e42ff

File tree

8 files changed

+180
-69
lines changed

8 files changed

+180
-69
lines changed

examples/sam2_amg_server/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output
88
Start the server
99

1010
```
11-
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
11+
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
1212
```
1313

1414
Collect the rles
@@ -58,7 +58,7 @@ Make sure you've installed https://github.com/facebookresearch/sam2
5858

5959
Start server
6060
```
61-
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --baseline
61+
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --baseline
6262
```
6363

6464
Generate and save rles (one line per json via `-w "\n"`)
@@ -73,7 +73,7 @@ sys 0m4.137s
7373
### 3. Start server with torchao variant of SAM2
7474
Start server
7575
```
76-
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname>
76+
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname>
7777
```
7878

7979
Generate and save rles (one line per json via `-w "\n"`)
@@ -88,7 +88,7 @@ sys 0m4.350s
8888
### 4. Start server with torchao variant of SAM2 and `--fast` optimizations
8989
Start server
9090
```
91-
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
91+
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
9292
```
9393

9494
Generate and save rles (one line per json via `-w "\n"`)
@@ -103,7 +103,7 @@ sys 0m4.138s
103103
### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations
104104
Start server
105105
```
106-
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast --furious
106+
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast --furious
107107
```
108108

109109
Generate and save rles (one line per json via `-w "\n"`)

examples/sam2_amg_server/cli.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import fire
2+
import logging
3+
import matplotlib.pyplot as plt
4+
from server import file_bytes_to_image_tensor
5+
from server import show_anns
6+
from server import model_type_to_paths
7+
from server import MODEL_TYPES_TO_MODEL
8+
from torchao._models.sam2.build_sam import build_sam2
9+
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
10+
from torchao._models.sam2.utils.amg import rle_to_mask
11+
from io import BytesIO
12+
13+
def main_docstring():
14+
return f"""
15+
Args:
16+
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
17+
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
18+
input_path (str): Path to input image
19+
output_path (str): Path to output image
20+
"""
21+
22+
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
23+
device = "cuda"
24+
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
25+
if verbose:
26+
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
27+
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
28+
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
29+
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
30+
if verbose:
31+
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
32+
masks = mask_generator.generate(image_tensor)
33+
34+
# Save an example
35+
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
36+
plt.imshow(image_tensor)
37+
show_anns(masks, rle_to_mask)
38+
plt.axis('off')
39+
plt.tight_layout()
40+
buf = BytesIO()
41+
plt.savefig(buf, format=output_format)
42+
buf.seek(0)
43+
with open(output_path, "wb") as file:
44+
file.write(buf.getvalue())
45+
46+
main.__doc__ = main_docstring()
47+
if __name__ == "__main__":
48+
fire.Fire(main)

examples/sam2_amg_server/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ hydra-core
77
tqdm
88
iopath
99
python-multipart
10+
requests

examples/sam2_amg_server/server.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import requests
23
import uvicorn
34
import fire
45
import tempfile
@@ -37,6 +38,23 @@
3738
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
3839
torch._dynamo.config.capture_dynamic_output_shape_ops = True
3940

41+
def download_file(url, download_dir):
42+
# Create the directory if it doesn't exist
43+
download_dir = Path(download_dir)
44+
download_dir.mkdir(parents=True, exist_ok=True)
45+
# Extract the file name from the URL
46+
file_name = url.split('/')[-1]
47+
# Define the full path for the downloaded file
48+
file_path = download_dir / file_name
49+
# Download the file
50+
response = requests.get(url, stream=True)
51+
response.raise_for_status() # Raise an error for bad responses
52+
# Write the file to the specified directory
53+
print(f"Downloading '{file_name}' to '{download_dir}'")
54+
with open(file_path, 'wb') as file:
55+
for chunk in response.iter_content(chunk_size=8192):
56+
file.write(chunk)
57+
print(f"Downloaded '{file_name}' to '{download_dir}'")
4058

4159
def example_shapes():
4260
return [(848, 480, 3),
@@ -272,7 +290,51 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
272290
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")
273291

274292

293+
MODEL_TYPES_TO_CONFIG = {
294+
"tiny": "sam2.1_hiera_t.yaml",
295+
"small": "sam2.1_hiera_s.yaml",
296+
"plus": "sam2.1_hiera_b+.yaml",
297+
"large": "sam2.1_hiera_l.yaml",
298+
}
299+
300+
MODEL_TYPES_TO_MODEL = {
301+
"tiny": "sam2.1_hiera_tiny.pt",
302+
"small": "sam2.1_hiera_small.pt",
303+
"plus": "sam2.1_hiera_base_plus.pt",
304+
"large": "sam2.1_hiera_large.pt",
305+
}
306+
307+
308+
MODEL_TYPES_TO_URL = {
309+
"tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
310+
"small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
311+
"plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
312+
"large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
313+
}
314+
315+
316+
def main_docstring():
317+
return f"""
318+
Args:
319+
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
320+
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
321+
"""
322+
323+
324+
def model_type_to_paths(checkpoint_path, model_type):
325+
if model_type not in MODEL_TYPES_TO_CONFIG.keys():
326+
raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}")
327+
sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type])
328+
if not sam2_checkpoint.exists():
329+
print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.")
330+
download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path)
331+
assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue."
332+
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
333+
return sam2_checkpoint, model_cfg
334+
335+
275336
def main(checkpoint_path,
337+
model_type,
276338
baseline=False,
277339
fast=False,
278340
furious=False,
@@ -306,9 +368,7 @@ def main(checkpoint_path,
306368
from torchao._models.sam2.utils.amg import rle_to_mask
307369

308370
device = "cuda"
309-
from pathlib import Path
310-
sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt")
311-
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
371+
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
312372

313373
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
314374
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
@@ -450,5 +510,6 @@ async def upload_image(image: UploadFile = File(...)):
450510
# uvicorn.run(app, host=host, port=port, log_level="info")
451511
uvicorn.run(app, host=host, port=port)
452512

513+
main.__doc__ = main_docstring()
453514
if __name__ == "__main__":
454515
fire.Fire(main)

torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22

33
# Model
44
model:
5-
_target_: sam2.modeling.sam2_base.SAM2Base
5+
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
66
image_encoder:
7-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
7+
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
88
scalp: 1
99
trunk:
10-
_target_: sam2.modeling.backbones.hieradet.Hiera
10+
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
1111
embed_dim: 112
1212
num_heads: 2
1313
neck:
14-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
14+
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
1515
position_encoding:
16-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
16+
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
1717
num_pos_feats: 256
1818
normalize: true
1919
scale: null
@@ -24,17 +24,17 @@ model:
2424
fpn_interp_model: nearest
2525

2626
memory_attention:
27-
_target_: sam2.modeling.memory_attention.MemoryAttention
27+
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
2828
d_model: 256
2929
pos_enc_at_input: true
3030
layer:
31-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
31+
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
3232
activation: relu
3333
dim_feedforward: 2048
3434
dropout: 0.1
3535
pos_enc_at_attn: false
3636
self_attention:
37-
_target_: sam2.modeling.sam.transformer.RoPEAttention
37+
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
3838
rope_theta: 10000.0
3939
feat_sizes: [32, 32]
4040
embedding_dim: 256
@@ -45,7 +45,7 @@ model:
4545
pos_enc_at_cross_attn_keys: true
4646
pos_enc_at_cross_attn_queries: false
4747
cross_attention:
48-
_target_: sam2.modeling.sam.transformer.RoPEAttention
48+
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
4949
rope_theta: 10000.0
5050
feat_sizes: [32, 32]
5151
rope_k_repeat: True
@@ -57,23 +57,23 @@ model:
5757
num_layers: 4
5858

5959
memory_encoder:
60-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
60+
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
6161
out_dim: 64
6262
position_encoding:
63-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
63+
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
6464
num_pos_feats: 64
6565
normalize: true
6666
scale: null
6767
temperature: 10000
6868
mask_downsampler:
69-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
69+
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
7070
kernel_size: 3
7171
stride: 2
7272
padding: 1
7373
fuser:
74-
_target_: sam2.modeling.memory_encoder.Fuser
74+
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
7575
layer:
76-
_target_: sam2.modeling.memory_encoder.CXBlock
76+
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
7777
dim: 256
7878
kernel_size: 7
7979
padding: 3

torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22

33
# Model
44
model:
5-
_target_: sam2.modeling.sam2_base.SAM2Base
5+
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
66
image_encoder:
7-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
7+
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
88
scalp: 1
99
trunk:
10-
_target_: sam2.modeling.backbones.hieradet.Hiera
10+
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
1111
embed_dim: 96
1212
num_heads: 1
1313
stages: [1, 2, 11, 2]
1414
global_att_blocks: [7, 10, 13]
1515
window_pos_embed_bkg_spatial_size: [7, 7]
1616
neck:
17-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
17+
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
1818
position_encoding:
19-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
19+
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
2020
num_pos_feats: 256
2121
normalize: true
2222
scale: null
@@ -27,17 +27,17 @@ model:
2727
fpn_interp_model: nearest
2828

2929
memory_attention:
30-
_target_: sam2.modeling.memory_attention.MemoryAttention
30+
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
3131
d_model: 256
3232
pos_enc_at_input: true
3333
layer:
34-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
34+
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
3535
activation: relu
3636
dim_feedforward: 2048
3737
dropout: 0.1
3838
pos_enc_at_attn: false
3939
self_attention:
40-
_target_: sam2.modeling.sam.transformer.RoPEAttention
40+
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
4141
rope_theta: 10000.0
4242
feat_sizes: [32, 32]
4343
embedding_dim: 256
@@ -48,7 +48,7 @@ model:
4848
pos_enc_at_cross_attn_keys: true
4949
pos_enc_at_cross_attn_queries: false
5050
cross_attention:
51-
_target_: sam2.modeling.sam.transformer.RoPEAttention
51+
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
5252
rope_theta: 10000.0
5353
feat_sizes: [32, 32]
5454
rope_k_repeat: True
@@ -60,23 +60,23 @@ model:
6060
num_layers: 4
6161

6262
memory_encoder:
63-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
63+
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
6464
out_dim: 64
6565
position_encoding:
66-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
66+
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
6767
num_pos_feats: 64
6868
normalize: true
6969
scale: null
7070
temperature: 10000
7171
mask_downsampler:
72-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
72+
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
7373
kernel_size: 3
7474
stride: 2
7575
padding: 1
7676
fuser:
77-
_target_: sam2.modeling.memory_encoder.Fuser
77+
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
7878
layer:
79-
_target_: sam2.modeling.memory_encoder.CXBlock
79+
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
8080
dim: 256
8181
kernel_size: 7
8282
padding: 3

0 commit comments

Comments
 (0)