Skip to content

Commit c7be438

Browse files
author
Matt Sokoloff
committed
str or path and format
1 parent 9b1f15f commit c7be438

File tree

5 files changed

+58
-40
lines changed

5 files changed

+58
-40
lines changed

labelbox/data/serialization/coco/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88

9-
def rle_decoding(rle_arr : List[int], w : int, h: int) -> np.ndarray:
9+
def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray:
1010
indices = []
1111
for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]):
1212
indices.extend(list(range(idx - 1,

labelbox/data/serialization/coco/converter.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Any
1+
from typing import Dict, Any, Union
22
from pathlib import Path
33
import os
44

@@ -7,16 +7,20 @@
77
from labelbox.data.serialization.coco.panoptic_dataset import CocoPanopticDataset
88

99

10-
def create_path_if_not_exists(path: Path, ignore_existing_data=False):
10+
def create_path_if_not_exists(path: Union[Path, str],
11+
ignore_existing_data=False):
12+
path = Path(path)
1113
if not path.exists():
1214
path.mkdir(parents=True, exist_ok=True)
1315
elif not ignore_existing_data and os.listdir(path):
1416
raise ValueError(
1517
f"Directory `{path}`` must be empty. Or set `ignore_existing_data=True`"
1618
)
19+
return path
1720

1821

19-
def validate_path(path: Path, name: str):
22+
def validate_path(path: Union[Path, str], name: str):
23+
path = Path(path)
2024
if not path.exists():
2125
raise ValueError(f"{name} `{path}` must exist")
2226

@@ -29,10 +33,12 @@ class COCOConverter:
2933
Subclasses are currently ignored.
3034
To use subclasses, manually flatten them before using the converter.
3135
"""
36+
3237
@staticmethod
3338
def serialize_instances(labels: LabelCollection,
34-
image_root: Path,
35-
ignore_existing_data=False, max_workers = 8) -> Dict[str, Any]:
39+
image_root: Union[Path, str],
40+
ignore_existing_data=False,
41+
max_workers=8) -> Dict[str, Any]:
3642
"""
3743
Convert a Labelbox LabelCollection into an mscoco dataset.
3844
This function will only convert masks, polygons, and rectangles.
@@ -48,16 +54,18 @@ def serialize_instances(labels: LabelCollection,
4854
Returns:
4955
A dictionary containing labels in the coco object format.
5056
"""
51-
create_path_if_not_exists(image_root, ignore_existing_data)
57+
image_root = create_path_if_not_exists(image_root, ignore_existing_data)
5258
return CocoInstanceDataset.from_common(labels=labels,
53-
image_root=image_root, max_workers = max_workers).dict()
59+
image_root=image_root,
60+
max_workers=max_workers).dict()
5461

5562
@staticmethod
5663
def serialize_panoptic(labels: LabelCollection,
57-
image_root: Path,
58-
mask_root: Path,
64+
image_root: Union[Path, str],
65+
mask_root: Union[Path, str],
5966
all_stuff: bool = False,
60-
ignore_existing_data=False, max_workers = 8) -> Dict[str, Any]:
67+
ignore_existing_data=False,
68+
max_workers=8) -> Dict[str, Any]:
6169
"""
6270
Convert a Labelbox LabelCollection into an mscoco dataset.
6371
This function will only convert masks, polygons, and rectangles.
@@ -76,16 +84,18 @@ def serialize_panoptic(labels: LabelCollection,
7684
Returns:
7785
A dictionary containing labels in the coco panoptic format.
7886
"""
79-
create_path_if_not_exists(image_root, ignore_existing_data)
80-
create_path_if_not_exists(mask_root, ignore_existing_data)
87+
image_root = create_path_if_not_exists(image_root, ignore_existing_data)
88+
mask_root = create_path_if_not_exists(mask_root, ignore_existing_data)
8189
return CocoPanopticDataset.from_common(labels=labels,
8290
image_root=image_root,
8391
mask_root=mask_root,
84-
all_stuff=all_stuff , max_workers = max_workers).dict()
92+
all_stuff=all_stuff,
93+
max_workers=max_workers).dict()
8594

8695
@staticmethod
87-
def deserialize_panoptic(json_data: Dict[str, Any], image_root: Path,
88-
mask_root: Path) -> LabelGenerator:
96+
def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path,
97+
str],
98+
mask_root: Union[Path, str]) -> LabelGenerator:
8999
"""
90100
Convert coco panoptic data into the labelbox format (as a LabelGenerator).
91101
@@ -96,8 +106,8 @@ def deserialize_panoptic(json_data: Dict[str, Any], image_root: Path,
96106
Returns:
97107
LabelGenerator
98108
"""
99-
validate_path(image_root, 'image_root')
100-
validate_path(mask_root, 'mask_root')
109+
image_root = validate_path(image_root, 'image_root')
110+
mask_root = validate_path(mask_root, 'mask_root')
101111
objs = CocoPanopticDataset(**json_data)
102112
gen = objs.to_common(image_root, mask_root)
103113
return LabelGenerator(data=gen)
@@ -114,7 +124,7 @@ def deserialize_instances(json_data: Dict[str, Any],
114124
Returns:
115125
LabelGenerator
116126
"""
117-
validate_path(image_root, 'image_root')
127+
image_root = validate_path(image_root, 'image_root')
118128
objs = CocoInstanceDataset(**json_data)
119129
gen = objs.to_common(image_root)
120130
return LabelGenerator(data=gen)

labelbox/data/serialization/coco/image.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,15 @@ def get_image_id(label: Label, idx: int) -> int:
2626
return idx
2727

2828

29-
def get_image(label: Label, image_root : Path, image_id: str) -> CocoImage:
29+
def get_image(label: Label, image_root: Path, image_id: str) -> CocoImage:
3030
path = Path(image_root, f"{image_id}.jpg")
3131
if not path.exists():
3232
im = Image.fromarray(label.data.value)
3333
im.save(path)
3434
w, h = im.size
3535
else:
3636
w, h = imagesize.get(str(path))
37-
return CocoImage(id=image_id,
38-
width=w,
39-
height=h,
40-
file_name=Path(path.name))
37+
return CocoImage(id=image_id, width=w, height=h, file_name=Path(path.name))
4138

4239

4340
def id_to_rgb(id: int) -> Tuple[int, int, int]:

labelbox/data/serialization/coco/instance_dataset.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from .image import CocoImage, get_image, get_image_id
1616

1717

18-
def mask_to_coco_object_annotation(annotation: ObjectAnnotation, annot_idx : int,
19-
image_id : int, category_id : int) -> COCOObjectAnnotation:
18+
def mask_to_coco_object_annotation(annotation: ObjectAnnotation, annot_idx: int,
19+
image_id: int,
20+
category_id: int) -> COCOObjectAnnotation:
2021
# This is going to fill any holes into the multipolygon
2122
# If you need to support holes use the panoptic data format
2223
shapely = annotation.value.shapely.simplify(1).buffer(0)
@@ -38,8 +39,9 @@ def mask_to_coco_object_annotation(annotation: ObjectAnnotation, annot_idx : int
3839
iscrowd=0)
3940

4041

41-
def vector_to_coco_object_annotation(annotation: ObjectAnnotation, annot_idx : int,
42-
image_id: int, category_id: int) -> COCOObjectAnnotation:
42+
def vector_to_coco_object_annotation(annotation: ObjectAnnotation,
43+
annot_idx: int, image_id: int,
44+
category_id: int) -> COCOObjectAnnotation:
4345
shapely = annotation.value.shapely
4446
xmin, ymin, xmax, ymax = shapely.bounds
4547
segmentation = []
@@ -62,15 +64,17 @@ def vector_to_coco_object_annotation(annotation: ObjectAnnotation, annot_idx : i
6264
iscrowd=0)
6365

6466

65-
def rle_to_common(class_annotations : COCOObjectAnnotation, class_name : str) -> ObjectAnnotation:
67+
def rle_to_common(class_annotations: COCOObjectAnnotation,
68+
class_name: str) -> ObjectAnnotation:
6669
mask = rle_decoding(class_annotations.segmentation.counts,
6770
*class_annotations.segmentation.size[::-1])
6871
return ObjectAnnotation(name=class_name,
6972
value=Mask(mask=MaskData.from_2D_arr(mask),
7073
color=[1, 1, 1]))
7174

7275

73-
def segmentations_to_common(class_annotations : COCOObjectAnnotation, class_name: str) -> List[ObjectAnnotation]:
76+
def segmentations_to_common(class_annotations: COCOObjectAnnotation,
77+
class_name: str) -> List[ObjectAnnotation]:
7478
# Technically it is polygons. But the key in coco is called segmentations..
7579
annotations = []
7680
for points in class_annotations.segmentation:
@@ -83,10 +87,12 @@ def segmentations_to_common(class_annotations : COCOObjectAnnotation, class_name
8387
return annotations
8488

8589

86-
def process_label(label: Label,
87-
idx : int,
88-
image_root :str,
89-
max_annotations_per_image=10000) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]:
90+
def process_label(
91+
label: Label,
92+
idx: int,
93+
image_root: str,
94+
max_annotations_per_image=10000
95+
) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]:
9096
annot_idx = idx * max_annotations_per_image
9197
image_id = get_image_id(label, idx)
9298
image = get_image(label, image_root, image_id)
@@ -119,7 +125,10 @@ class CocoInstanceDataset(BaseModel):
119125
categories: List[Categories]
120126

121127
@classmethod
122-
def from_common(cls, labels: LabelCollection, image_root : Path, max_workers = 8):
128+
def from_common(cls,
129+
labels: LabelCollection,
130+
image_root: Path,
131+
max_workers=8):
123132
all_coco_annotations = []
124133
categories = {}
125134
images = []

labelbox/data/serialization/coco/panoptic_dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from concurrent.futures import ProcessPoolExecutor, as_completed
32
from typing import Dict, Any, List, Union
43
from pathlib import Path
@@ -104,8 +103,12 @@ class CocoPanopticDataset(BaseModel):
104103
categories: List[Categories]
105104

106105
@classmethod
107-
def from_common(cls, labels: LabelCollection, image_root, mask_root,
108-
all_stuff, max_workers = 8):
106+
def from_common(cls,
107+
labels: LabelCollection,
108+
image_root,
109+
mask_root,
110+
all_stuff,
111+
max_workers=8):
109112
all_coco_annotations = []
110113
coco_categories = {}
111114
coco_things = {}
@@ -167,8 +170,7 @@ def to_common(self, image_root: Path, mask_root: Path):
167170
raise ValueError(
168171
f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}"
169172
)
170-
mask = MaskData(
171-
file_path=Path(mask_root, annotation.file_name))
173+
mask = MaskData(file_path=Path(mask_root, annotation.file_name))
172174

173175
for segmentation in annotation.segments_info:
174176
category = category_lookup[segmentation.category_id]

0 commit comments

Comments
 (0)