Skip to content

Mismatched data type in point_pool #11

@AIBluefisher

Description

@AIBluefisher

Hi, @chrischoy

I found there is a mismatched data type issue when I provided downsample_max_num_points for the point_pool(...) method. At first, I created an instance of the Point class. By default, the data type of the batched_coordinates is float, and the method will check the batched_coordinates must be an int type:

File "/home/chenyu/Projects/DOGE/conerf/model/generation/utils.py", line 270, in colored_pc_to_warp_conv_points
    sparse_tensor, _ = point_pool(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 258, in point_pool
    return _pool_by_max_num_points(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 194, in _pool_by_max_num_points
    out_pc = RETURN_CLS(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/types/voxels.py", line 55, in __init__
    batched_coordinates = IntCoords(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/integer.py", line 63, in __init__
    self.check()
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/integer.py", line 67, in check
    assert self.batched_tensor.dtype in [
AssertionError: Discrete coordinates must be integers

Then I cast the batched_coordinates to be a int type, another issue occurs:

File "/home/chenyu/Projects/DOGE/conerf/model/generation/utils.py", line 269, in colored_pc_to_warp_conv_points
    sparse_tensor, _ = point_pool(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 258, in point_pool
    return _pool_by_max_num_points(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/nn/functional/point_pool.py", line 135, in _pool_by_max_num_points
    knn_down_indices = batched_knn_search(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 134, in batched_knn_search
    neighbor_index = knn_search(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 73, in knn_search
    neighbor_indices = _chunked_knn_search(
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 43, in _chunked_knn_search
    chunk_neighbor_indices = _knn_search(ref_positions, chunk_out_positions, k)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/warpconvnet-0.3.6-py3.10-linux-x86_64.egg/warpconvnet/geometry/coords/search/knn.py", line 24, in _knn_search
    dists = torch.cdist(query_positions, ref_positions)
  File "/home/chenyu/anaconda3/envs/dogs/lib/python3.10/site-packages/torch/functional.py", line 1483, in cdist
    return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
RuntimeError: cdist only supports floating-point dtypes, X1 got: Int

The issue occurs since both query_positions and ref_positions are int type in dists = torch.cdist(query_positions, ref_positions), while torch.cdist requires floating-point data type.

I'm using torch 2.6.0. Could you reproduce this issue and try to fix it?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions