Skip to content

Commit b99e210

Browse files
authored
Akash/resolution in yaml (#22)
1 parent e61894c commit b99e210

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,17 @@ python3 -m cube3d.generate \
6565
--prompt "sleek vintage green couch with clean lines and velvet material"
6666
```
6767

68+
> **Note**: `--fast-inference` is optional and may not be available for all GPU that have limited VRAM. This flag will also not work on MacOS.
69+
6870
The output will be an `.obj` file saved in the specified `output` directory.
6971

7072
If you want to render a turntable gif of the mesh, you can use the `--render-gif` flag, which will render a turntable gif of the mesh
7173
and save it as `turntable.gif` in the specified `output` directory.
7274

7375
> **Note**: You must have Blender installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
7476
77+
> **Note**: If shape decoding is slow, you can try try to specify a lower resolution using the `--resolution-base` flag. A lower resolution will create a coarser and lower quality output mesh but faster decoding. Values between 4.0 and 9.0 are recommended.
78+
7579
#### 2. Shape Tokenization and De-tokenization
7680

7781
To tokenize a 3D shape into token indices and reconstruct it back, you can use the following command:

cube3d/generate.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@
1717
logging.basicConfig(level=logging.INFO)
1818

1919

20-
def generate_mesh(engine, prompt, output_dir, output_name, disable_postprocess=False):
21-
mesh_v_f = engine.t2s([prompt], use_kv_cache=True)
20+
def generate_mesh(
21+
engine,
22+
prompt,
23+
output_dir,
24+
output_name,
25+
resolution_base=8.0,
26+
disable_postprocess=False,
27+
):
28+
mesh_v_f = engine.t2s([prompt], use_kv_cache=True, resolution_base=resolution_base)
2229
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
2330
obj_path = os.path.join(output_dir, f"{output_name}.obj")
2431
if PYMESHLAB_AVAILABLE:
@@ -89,6 +96,12 @@ def generate_mesh(engine, prompt, output_dir, output_name, disable_postprocess=F
8996
default=False,
9097
action="store_true",
9198
)
99+
parser.add_argument(
100+
"--resolution-base",
101+
type=float,
102+
default=8.0,
103+
help="Resolution base for the shape decoder.",
104+
)
92105
args = parser.parse_args()
93106
os.makedirs(args.output_dir, exist_ok=True)
94107
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -110,7 +123,12 @@ def generate_mesh(engine, prompt, output_dir, output_name, disable_postprocess=F
110123

111124
# Generate meshes based on input source
112125
obj_path = generate_mesh(
113-
engine, args.prompt, args.output_dir, "output", args.disable_postprocessing
126+
engine,
127+
args.prompt,
128+
args.output_dir,
129+
"output",
130+
args.resolution_base,
131+
args.disable_postprocessing,
114132
)
115133
if args.render_gif:
116134
gif_path = renderer.render_turntable(obj_path, args.output_dir)

0 commit comments

Comments
 (0)