Skip to content

Commit 5ab1588

Browse files
committed
wip
1 parent 91bb61b commit 5ab1588

File tree

2 files changed

+165
-2
lines changed

2 files changed

+165
-2
lines changed

dev/poc/export_yolo_onnx.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Requirements:
3+
pip install onnx onnxruntime
4+
pip install onnx-simplifier
5+
6+
Notes:
7+
# Move the YOLO models to the model distribution directory
8+
9+
/data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt
10+
11+
mkdir /home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9
12+
13+
cp /data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt \
14+
/home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt
15+
16+
cp /data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/train_config.yaml \
17+
/home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-train_config.yaml
18+
19+
cp /home/joncrall/code/YOLO-v9/yolov9-simplified.onnx \
20+
/home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-epoch=0032-step=000132-trainlosstrain_loss=7.603.onnx
21+
22+
"""
23+
import onnxruntime as ort
24+
import numpy as np
25+
import torch
26+
import kwutil
27+
import ubelt as ub
28+
from yolo.utils.kwcoco_utils import tensor_to_kwimage
29+
from yolo.utils.bounding_box_utils import create_converter
30+
from yolo.utils.model_utils import PostProcess
31+
from omegaconf.dictconfig import DictConfig
32+
from yolo.tools.solver import InferenceModel
33+
34+
# See ~/code/shitspotter/dev/poc/train_yolo_shitspotter.sh
35+
checkpoint_path = ub.Path('/data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt')
36+
train_config = ub.Path('/data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/train_config.yaml')
37+
38+
config = kwutil.Yaml.coerce(train_config, backend='pyyaml')
39+
cfg = DictConfig(config)
40+
cfg.weight = checkpoint_path
41+
model = InferenceModel(cfg)
42+
model.eval()
43+
model.post_process = PostProcess(model.vec2box, model.validation_cfg.nms)
44+
vec2box = create_converter(
45+
model.cfg.model.name, model.model, model.cfg.model.anchor, model.cfg.image_size, model.device
46+
)
47+
48+
input_tensor = np.random.randn(1, 3, 640, 640).astype(np.float32)
49+
50+
# Test a regular forward pass.
51+
model.cfg.task.nms = DictConfig(kwutil.Yaml.coerce(
52+
'''
53+
min_confidence: 0.01
54+
min_iou: 0.5
55+
max_bbox: 300
56+
''', backend='pyyaml'))
57+
post_process = PostProcess(vec2box, model.cfg.task.nms)
58+
torch_outputs = model.forward(torch.Tensor(input_tensor))
59+
60+
predicts = post_process(torch_outputs)
61+
classes = cfg.dataset.class_list
62+
detections = [
63+
tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy()
64+
for yolo_annot_tensor in predicts]
65+
66+
67+
# Convert to onnx
68+
device = torch.device('cpu')
69+
dummy_input = torch.randn(1, 3, 640, 640).to(device) # Adjust image size as needed
70+
torch.onnx.export(
71+
model, # The loaded YOLO model
72+
dummy_input, # Example input tensor
73+
"yolov9.onnx", # Output ONNX file
74+
export_params=True, # Store trained weights
75+
opset_version=12, # ONNX opset version
76+
do_constant_folding=True, # Optimize the graph
77+
input_names=['input'], # Input name
78+
output_names=['output'], # Output name
79+
dynamic_axes={
80+
'input': {0: 'batch_size'}, # Enable dynamic batch size
81+
'output': {0: 'batch_size'}
82+
}
83+
)
84+
85+
86+
# ub.cmd('python -m onnxsim yolov9.onnx yolov9-simplified.onnx')
87+
# ort_session = ort.InferenceSession("yolov9-simplified.onnx")
88+
89+
ort_session = ort.InferenceSession("yolov9.onnx")
90+
91+
# Simulate an image tensor
92+
onnx_outputs = ort_session.run(None, {'input': input_tensor})
93+
94+
torch_walker = ub.IndexableWalker(torch_outputs)
95+
onnx_walker = ub.IndexableWalker(onnx_outputs)
96+
97+
98+
def walker_to_nx(walker):
99+
import networkx as nx
100+
graph = nx.DiGraph()
101+
102+
# root
103+
node = tuple()
104+
v = walker.data
105+
graph.add_node(node, label=f'.: {type(v).__name__}')
106+
107+
for p, v in walker:
108+
node = tuple(p)
109+
parent = node[0:-1]
110+
graph.add_node(node)
111+
graph.add_edge(parent, node)
112+
if not isinstance(v, (list, dict)):
113+
if hasattr(v, 'shape'):
114+
graph.nodes[node]['label'] = f'{node}: {type(v)}[{v.shape}]'
115+
else:
116+
graph.nodes[node]['label'] = f'{node}: {type(v)}'
117+
else:
118+
graph.nodes[node]['label'] = f'{node}: {type(v).__name__}'
119+
nx.write_network_text(graph)
120+
121+
print('Torch Output:')
122+
walker_to_nx(torch_walker)
123+
print('ONNX Output:')
124+
walker_to_nx(onnx_walker)
125+
126+
127+
onnx_outputs[0].shape
128+
129+
torch_outputs['Main'][0][0].shape
130+
131+
# Split the ONNX output back into its tuple-like structure
132+
recon_outputs = {}
133+
onnx_outputs_ = [torch.Tensor(a) for a in onnx_outputs]
134+
recon_outputs['Main'] = list(ub.chunks(onnx_outputs_[0:9], chunksize=3))
135+
recon_outputs['Aux'] = list(ub.chunks(onnx_outputs_[9:], chunksize=3))
136+
137+
recon_walker = ub.IndexableWalker(recon_outputs)
138+
print('ONNX Recon Output:')
139+
walker_to_nx(recon_walker)
140+
141+
predicts = post_process(recon_outputs)
142+
classes = cfg.dataset.class_list
143+
detections = [
144+
tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy()
145+
for yolo_annot_tensor in predicts]
146+
147+
# Now see: ~/code/shitspotter/tpl/scatspotter_app/explore.rst

papers/neurips-2025/scripts/compress_pdf.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ class CompressPdfCLI(scfg.DataConfig):
1515
pdf_fpath = scfg.Value(None, help='pdf_fpath', position=1)
1616
start_page = scfg.Value(0, help='The index of the first page to start from (zero indexed inclusive)')
1717
stop_page = scfg.Value(None, help='The index of the last page to start from (zero indexed exclusive)')
18+
quality = scfg.Value('default', help=(
19+
'Compression quality level. '
20+
'Choices: "screen", "ebook", "printer", "prepress", "default". '
21+
'"screen" is highest compression, "prepress" is least.'),
22+
choices=['screen', 'ebook', 'printer', 'prepress', 'default'])
1823

1924
@classmethod
2025
def main(cls, cmdline=1, **kwargs):
@@ -49,6 +54,17 @@ def compress_pdf(config):
4954
'-dQUIET',
5055
'-dBATCH',
5156
]
57+
58+
# Apply compression quality
59+
quality_map = {
60+
'screen': '/screen',
61+
'ebook': '/ebook',
62+
'printer': '/printer',
63+
'prepress': '/prepress',
64+
'default': '/default',
65+
}
66+
gs_options += [f'-dPDFSETTINGS={quality_map[config.quality]}']
67+
5268
if config.start_page != 0:
5369
gs_options += [f'-dFirstPage={config.start_page + 1}']
5470

@@ -65,11 +81,11 @@ def compress_pdf(config):
6581
ub.cmd(cmd_list)
6682
return output_pdf_fpath
6783

84+
6885
if __name__ == '__main__':
6986
"""
70-
7187
CommandLine:
72-
python ~/code/shitspotter/papers/neurips-2025/scripts/compress_pdf.py ~/code/shitspotter/papers/neurips-2025/main.pdf
88+
python ~/code/shitspotter/papers/neurips-2025/scripts/compress_pdf.py main.pdf --quality=screen
7389
python -m compress_pdf
7490
"""
7591
__cli__.main()

0 commit comments

Comments
 (0)