|
| 1 | +""" |
| 2 | +Requirements: |
| 3 | + pip install onnx onnxruntime |
| 4 | + pip install onnx-simplifier |
| 5 | +
|
| 6 | +Notes: |
| 7 | + # Move the YOLO models to the model distribution directory |
| 8 | +
|
| 9 | + /data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt |
| 10 | +
|
| 11 | + mkdir /home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9 |
| 12 | +
|
| 13 | + cp /data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt \ |
| 14 | + /home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt |
| 15 | +
|
| 16 | + cp /data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/train_config.yaml \ |
| 17 | + /home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-train_config.yaml |
| 18 | +
|
| 19 | + cp /home/joncrall/code/YOLO-v9/yolov9-simplified.onnx \ |
| 20 | + /home/joncrall/code/shitspotter/shitspotter_dvc/models/yolo-v9/shitspotter-simple-v3-run-v06-epoch=0032-step=000132-trainlosstrain_loss=7.603.onnx |
| 21 | +
|
| 22 | +""" |
| 23 | +import onnxruntime as ort |
| 24 | +import numpy as np |
| 25 | +import torch |
| 26 | +import kwutil |
| 27 | +import ubelt as ub |
| 28 | +from yolo.utils.kwcoco_utils import tensor_to_kwimage |
| 29 | +from yolo.utils.bounding_box_utils import create_converter |
| 30 | +from yolo.utils.model_utils import PostProcess |
| 31 | +from omegaconf.dictconfig import DictConfig |
| 32 | +from yolo.tools.solver import InferenceModel |
| 33 | + |
| 34 | +# See ~/code/shitspotter/dev/poc/train_yolo_shitspotter.sh |
| 35 | +checkpoint_path = ub.Path('/data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/lightning_logs/version_1/checkpoints/epoch=0032-step=000132-trainlosstrain_loss=7.603.ckpt.ckpt') |
| 36 | +train_config = ub.Path('/data/joncrall/dvc-repos/shitspotter_expt_dvc/training/toothbrush/joncrall/ShitSpotter/runs/shitspotter-simple-v3-run-v06/train/shitspotter-simple-v3-run-v06/train_config.yaml') |
| 37 | + |
| 38 | +config = kwutil.Yaml.coerce(train_config, backend='pyyaml') |
| 39 | +cfg = DictConfig(config) |
| 40 | +cfg.weight = checkpoint_path |
| 41 | +model = InferenceModel(cfg) |
| 42 | +model.eval() |
| 43 | +model.post_process = PostProcess(model.vec2box, model.validation_cfg.nms) |
| 44 | +vec2box = create_converter( |
| 45 | + model.cfg.model.name, model.model, model.cfg.model.anchor, model.cfg.image_size, model.device |
| 46 | +) |
| 47 | + |
| 48 | +input_tensor = np.random.randn(1, 3, 640, 640).astype(np.float32) |
| 49 | + |
| 50 | +# Test a regular forward pass. |
| 51 | +model.cfg.task.nms = DictConfig(kwutil.Yaml.coerce( |
| 52 | + ''' |
| 53 | + min_confidence: 0.01 |
| 54 | + min_iou: 0.5 |
| 55 | + max_bbox: 300 |
| 56 | + ''', backend='pyyaml')) |
| 57 | +post_process = PostProcess(vec2box, model.cfg.task.nms) |
| 58 | +torch_outputs = model.forward(torch.Tensor(input_tensor)) |
| 59 | + |
| 60 | +predicts = post_process(torch_outputs) |
| 61 | +classes = cfg.dataset.class_list |
| 62 | +detections = [ |
| 63 | + tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy() |
| 64 | + for yolo_annot_tensor in predicts] |
| 65 | + |
| 66 | + |
| 67 | +# Convert to onnx |
| 68 | +device = torch.device('cpu') |
| 69 | +dummy_input = torch.randn(1, 3, 640, 640).to(device) # Adjust image size as needed |
| 70 | +torch.onnx.export( |
| 71 | + model, # The loaded YOLO model |
| 72 | + dummy_input, # Example input tensor |
| 73 | + "yolov9.onnx", # Output ONNX file |
| 74 | + export_params=True, # Store trained weights |
| 75 | + opset_version=12, # ONNX opset version |
| 76 | + do_constant_folding=True, # Optimize the graph |
| 77 | + input_names=['input'], # Input name |
| 78 | + output_names=['output'], # Output name |
| 79 | + dynamic_axes={ |
| 80 | + 'input': {0: 'batch_size'}, # Enable dynamic batch size |
| 81 | + 'output': {0: 'batch_size'} |
| 82 | + } |
| 83 | +) |
| 84 | + |
| 85 | + |
| 86 | +# ub.cmd('python -m onnxsim yolov9.onnx yolov9-simplified.onnx') |
| 87 | +# ort_session = ort.InferenceSession("yolov9-simplified.onnx") |
| 88 | + |
| 89 | +ort_session = ort.InferenceSession("yolov9.onnx") |
| 90 | + |
| 91 | +# Simulate an image tensor |
| 92 | +onnx_outputs = ort_session.run(None, {'input': input_tensor}) |
| 93 | + |
| 94 | +torch_walker = ub.IndexableWalker(torch_outputs) |
| 95 | +onnx_walker = ub.IndexableWalker(onnx_outputs) |
| 96 | + |
| 97 | + |
| 98 | +def walker_to_nx(walker): |
| 99 | + import networkx as nx |
| 100 | + graph = nx.DiGraph() |
| 101 | + |
| 102 | + # root |
| 103 | + node = tuple() |
| 104 | + v = walker.data |
| 105 | + graph.add_node(node, label=f'.: {type(v).__name__}') |
| 106 | + |
| 107 | + for p, v in walker: |
| 108 | + node = tuple(p) |
| 109 | + parent = node[0:-1] |
| 110 | + graph.add_node(node) |
| 111 | + graph.add_edge(parent, node) |
| 112 | + if not isinstance(v, (list, dict)): |
| 113 | + if hasattr(v, 'shape'): |
| 114 | + graph.nodes[node]['label'] = f'{node}: {type(v)}[{v.shape}]' |
| 115 | + else: |
| 116 | + graph.nodes[node]['label'] = f'{node}: {type(v)}' |
| 117 | + else: |
| 118 | + graph.nodes[node]['label'] = f'{node}: {type(v).__name__}' |
| 119 | + nx.write_network_text(graph) |
| 120 | + |
| 121 | +print('Torch Output:') |
| 122 | +walker_to_nx(torch_walker) |
| 123 | +print('ONNX Output:') |
| 124 | +walker_to_nx(onnx_walker) |
| 125 | + |
| 126 | + |
| 127 | +onnx_outputs[0].shape |
| 128 | + |
| 129 | +torch_outputs['Main'][0][0].shape |
| 130 | + |
| 131 | +# Split the ONNX output back into its tuple-like structure |
| 132 | +recon_outputs = {} |
| 133 | +onnx_outputs_ = [torch.Tensor(a) for a in onnx_outputs] |
| 134 | +recon_outputs['Main'] = list(ub.chunks(onnx_outputs_[0:9], chunksize=3)) |
| 135 | +recon_outputs['Aux'] = list(ub.chunks(onnx_outputs_[9:], chunksize=3)) |
| 136 | + |
| 137 | +recon_walker = ub.IndexableWalker(recon_outputs) |
| 138 | +print('ONNX Recon Output:') |
| 139 | +walker_to_nx(recon_walker) |
| 140 | + |
| 141 | +predicts = post_process(recon_outputs) |
| 142 | +classes = cfg.dataset.class_list |
| 143 | +detections = [ |
| 144 | + tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy() |
| 145 | + for yolo_annot_tensor in predicts] |
| 146 | + |
| 147 | +# Now see: ~/code/shitspotter/tpl/scatspotter_app/explore.rst |
0 commit comments