Skip to content

Commit 5dda56b

Browse files
committed
initial commit
0 parents  commit 5dda56b

24 files changed

+3099
-0
lines changed

.gitignore

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
correlation.egg-info
2+
checkpoints/*
3+
data/
4+
images/*
5+
runs/*
6+
.vscode/*
7+
models/*
8+
*.png
9+
*.csv
10+
*.pyc
11+
12+
*.sh
13+
*.zip
14+
job_*.txt
15+
eval_all.log
16+
alt_cuda_corr/build/
17+
alt_cuda_corr/dist/
18+
*.flo

LICENSE

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
Copyright (c) 2023 Azin Jahedi. All Rights Reserved.
2+
3+
Redistribution and use in source and binary forms, with or without modification,
4+
are permitted provided that the following conditions are met:
5+
6+
1. Redistribution of source code must retain the above copyright notice, this
7+
list of conditions and the following disclaimer.
8+
9+
2. Redistribution in binary form must reproduce the above copyright notice,
10+
this list of conditions and the following disclaimer in the documentation
11+
and/or other materials provided with the distribution.
12+
13+
3. Neither the name of the copyright holder nor the names of its contributors
14+
may be used to endorse or promote products derived from this software without
15+
specific prior written permission.
16+
17+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20+
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
21+
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
22+
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24+
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25+
OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
26+
OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
You acknowledge that this software is not designed, licensed or intended for use
29+
in the design, construction, operation or maintenance of any military facility.

README.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# CCMR
2+
This is the inference code of our `CCMR` optical flow estimation method.
3+
4+
**[CCMR: High Resolution Optical Flow Estimation via Coarse-to-Fine Context-Guided Motion Reasoning](https://openaccess.thecvf.com/content/WACV2024/papers/Jahedi_CCMR_High_Resolution_Optical_Flow_Estimation_via_Coarse-To-Fine_Context-Guided_Motion_WACV_2024_paper.pdf)**<br/>
5+
> _WACV 2024_ <br/>
6+
> Azin Jahedi, Maximilian Luz, Marc Rivinius and Andrés Bruhn
7+
8+
## Requirements
9+
10+
The code has been tested with PyTorch 1.10.2+cu113.
11+
Install the required dependencies via
12+
```
13+
pip install -r requirements.txt
14+
```
15+
16+
Alternatively you can also manually install the following packages in your virtual environment:
17+
- `torch`, `torchvision`, and `torchaudio` (e.g., with `--extra-index-url https://download.pytorch.org/whl/cu113` for CUDA 11.3)
18+
- `matplotlib`
19+
- `scipy`
20+
- `tensorboard`
21+
- `opencv-python`
22+
- `tqdm`
23+
- `parse`
24+
- `timm`
25+
- `flowpy`
26+
27+
28+
## Pre-Trained Checkpoints
29+
30+
You can find our trained models in the release assets.
31+
32+
33+
## Datasets
34+
35+
Datasets are expected to be located under `./data` in the following layout:
36+
```
37+
./data
38+
├── kitti15 # KITTI 2015
39+
│ └── dataset
40+
│ ├── testing/...
41+
│ └── training/...
42+
└── sintel # Sintel
43+
├── test/...
44+
└── training/...
45+
46+
```
47+
48+
## Running CCMR(+)
49+
50+
For running `CCMR+` on MPI Sintel images you need about 4.5 GB of GPU VRAM. `CCMR` (the 3-scale version) needs about 3 GBs of VRAM, using the following Cuda module.
51+
52+
To compile the CUDA correlation module run the following once:
53+
```Shell
54+
cd alt_cuda_corr && python setup.py install && cd ..
55+
```
56+
57+
And to reproduce our benchmark results after finetuning run:
58+
```Shell
59+
python evaluate.py --model_type "CCMR+" --model models/CCMR+_sintel.pth --dataset sintel_test
60+
python evaluate.py --model_type "CCMR+" --model models/CCMR+_kitti.pth --dataset kitti_test
61+
62+
python evaluate.py --model_type "CCMR" --model models/CCMR_sintel.pth --dataset sintel_test
63+
python evaluate.py --model_type "CCMR" --model models/CCMR_kitti.pth --dataset kitti_test
64+
```
65+
66+
## License
67+
- Our code is licensed under the BSD 3-Clause **No Military** License. See [LICENSE](LICENSE).
68+
- The provided checkpoints are under the [CC BY-NC-SA 3.0](https://creativecommons.org/licenses/by-nc-sa/3.0/) license.
69+
70+
## Acknowledgement
71+
Parts of this repository are adapted from [RAFT](https://github.com/princeton-vl/RAFT) ([license](licenses/RAFT/LICENSE)), [MS-RAFT+](https://github.com/cv-stuttgart/MS_RAFT_plus) ([license](https://github.com/cv-stuttgart/MS_RAFT_plus/blob/main/LICENSE)), and [XCiT](https://github.com/facebookresearch/xcit/tree/main) ([license](https://github.com/facebookresearch/xcit/blob/main/LICENSE)).
72+
We thank the authors.

alt_cuda_corr/correlation.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <torch/extension.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
#include <vector>
4+
5+
// CUDA forward declarations
6+
std::vector<torch::Tensor> corr_cuda_forward(
7+
torch::Tensor fmap1,
8+
torch::Tensor fmap2,
9+
torch::Tensor coords,
10+
int radius);
11+
12+
std::vector<torch::Tensor> corr_cuda_backward(
13+
torch::Tensor fmap1,
14+
torch::Tensor fmap2,
15+
torch::Tensor coords,
16+
torch::Tensor corr_grad,
17+
int radius);
18+
19+
// C++ interface
20+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
21+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
22+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
23+
24+
std::vector<torch::Tensor> corr_forward(
25+
torch::Tensor fmap1,
26+
torch::Tensor fmap2,
27+
torch::Tensor coords,
28+
int radius) {
29+
CHECK_INPUT(fmap1);
30+
CHECK_INPUT(fmap2);
31+
CHECK_INPUT(coords);
32+
33+
const at::cuda::OptionalCUDAGuard device_guard(device_of(coords));
34+
35+
return corr_cuda_forward(fmap1, fmap2, coords, radius);
36+
}
37+
38+
39+
std::vector<torch::Tensor> corr_backward(
40+
torch::Tensor fmap1,
41+
torch::Tensor fmap2,
42+
torch::Tensor coords,
43+
torch::Tensor corr_grad,
44+
int radius) {
45+
CHECK_INPUT(fmap1);
46+
CHECK_INPUT(fmap2);
47+
CHECK_INPUT(coords);
48+
CHECK_INPUT(corr_grad);
49+
50+
const at::cuda::OptionalCUDAGuard device_guard(device_of(coords));
51+
52+
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
53+
}
54+
55+
56+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
57+
m.def("forward", &corr_forward, "CORR forward");
58+
m.def("backward", &corr_backward, "CORR backward");
59+
}

0 commit comments

Comments
 (0)