-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
Hi, @chrischoy
I found there is a mismatched data type issue when I provided downsample_max_num_points
for the point_pool(...)
method. At first, I created an instance of the Point
class. By default, the data type of the batched_coordinates is float
, and the method will check the batched_coordinates must be an int type:
File "/home/chenyu/Projects/DOGE/conerf/model/generation/utils.py", line 270, in colored_pc_to_warp_conv_points
sparse_tensor, _ = point_pool(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 258, in point_pool
return _pool_by_max_num_points(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 194, in _pool_by_max_num_points
out_pc = RETURN_CLS(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/types/voxels.py", line 55, in __init__
batched_coordinates = IntCoords(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/integer.py", line 63, in __init__
self.check()
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/integer.py", line 67, in check
assert self.batched_tensor.dtype in [
AssertionError: Discrete coordinates must be integers
Then I cast the batched_coordinates to be a int type, another issue occurs:
File "/home/chenyu/Projects/DOGE/conerf/model/generation/utils.py", line 269, in colored_pc_to_warp_conv_points
sparse_tensor, _ = point_pool(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 258, in point_pool
return _pool_by_max_num_points(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 135, in _pool_by_max_num_points
knn_down_indices = batched_knn_search(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 134, in batched_knn_search
neighbor_index = knn_search(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 73, in knn_search
neighbor_indices = _chunked_knn_search(
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 43, in _chunked_knn_search
chunk_neighbor_indices = _knn_search(ref_positions, chunk_out_positions, k)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 24, in _knn_search
dists = torch.cdist(query_positions, ref_positions)
File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/functional.py", line 1483, in cdist
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
RuntimeError: cdist only supports floating-point dtypes, X1 got: Int
The issue occurs since both query_positions
and ref_positions
are int type in dists = torch.cdist(query_positions, ref_positions)
, while torch.cdist
requires floating-point data type.
I'm using torch 2.6.0. Could you reproduce this issue and try to fix it?
Metadata
Metadata
Assignees
Labels
No labels