Skip to content

Commit b6d870c

Browse files
committed
add optional adaptive rmsnorm on text embed conditioning
1 parent ac67dd5 commit b6d870c

File tree

6 files changed

+100
-4
lines changed

6 files changed

+100
-4
lines changed

.github/workflows/test.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
name: Pytest
2+
on: [push, pull_request]
3+
4+
jobs:
5+
build:
6+
7+
runs-on: ubuntu-latest
8+
9+
steps:
10+
- uses: actions/checkout@v4
11+
- name: Set up Python 3.10
12+
uses: actions/setup-python@v5
13+
with:
14+
python-version: "3.10"
15+
- name: Install dependencies
16+
run: |
17+
python -m pip install --upgrade pip
18+
python -m pip install -e .[test]
19+
- name: Test with pytest
20+
run: |
21+
python -m pytest tests/

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def forward(
10751075
return recon_faces, total_loss, loss_breakdown
10761076

10771077
@save_load(version = __version__)
1078-
class MeshTransformer(Module,PyTorchModelHubMixin):
1078+
class MeshTransformer(Module, PyTorchModelHubMixin):
10791079
@typecheck
10801080
def __init__(
10811081
self,
@@ -1094,12 +1094,13 @@ def __init__(
10941094
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition
10951095
dropout = 0.,
10961096
coarse_pre_gateloop_depth = 2,
1097+
coarse_adaptive_rmsnorm = False,
10971098
fine_pre_gateloop_depth = 2,
10981099
gateloop_use_heinsen = False,
10991100
fine_attn_depth = 2,
11001101
fine_attn_dim_head = 32,
11011102
fine_attn_heads = 8,
1102-
fine_cross_attend_text = False,
1103+
fine_cross_attend_text = False, # additional conditioning - fine transformer cross attention to text tokens
11031104
pad_id = -1,
11041105
num_sos_tokens = None,
11051106
condition_on_text = False,
@@ -1177,6 +1178,8 @@ def __init__(
11771178
# main autoregressive attention network
11781179
# attending to a face token
11791180

1181+
self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
1182+
11801183
self.decoder = Decoder(
11811184
dim = dim,
11821185
depth = attn_depth,
@@ -1185,6 +1188,8 @@ def __init__(
11851188
attn_flash = flash_attn,
11861189
attn_dropout = dropout,
11871190
ff_dropout = dropout,
1191+
use_adaptive_rmsnorm = coarse_adaptive_rmsnorm,
1192+
dim_condition = dim_text,
11881193
cross_attend = condition_on_text,
11891194
cross_attn_dim_context = cross_attn_dim_context,
11901195
cross_attn_num_mem_kv = cross_attn_num_mem_kv,
@@ -1458,6 +1463,11 @@ def forward_on_codes(
14581463
context_mask = text_mask
14591464
)
14601465

1466+
if self.coarse_adaptive_rmsnorm:
1467+
attn_context_kwargs.update(
1468+
condition = pooled_text_embed
1469+
)
1470+
14611471
# take care of codes that may be flattened
14621472

14631473
if codes.ndim > 2:

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.22'
1+
__version__ = '1.4.0'

setup.cfg

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[aliases]
2+
test=pytest
3+
4+
[tool:pytest]
5+
addopts = --verbose -s
6+
python_files = tests/*.py
7+
python_paths = "."

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@
3838
'torch_geometric',
3939
'tqdm',
4040
'vector-quantize-pytorch>=1.14.22',
41-
'x-transformers>=1.30.6',
41+
'x-transformers>=1.30.19',
42+
],
43+
setup_requires=[
44+
'pytest-runner',
45+
],
46+
tests_require=[
47+
'pytest'
4248
],
4349
classifiers=[
4450
'Development Status :: 4 - Beta',

tests/test_meshgpt.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
import torch
3+
4+
from meshgpt_pytorch import (
5+
MeshAutoencoder,
6+
MeshTransformer
7+
)
8+
9+
@pytest.mark.parametrize('adaptive_rmsnorm', (True, False))
10+
def test_readme(adaptive_rmsnorm):
11+
12+
autoencoder = MeshAutoencoder(
13+
num_discrete_coors = 128
14+
)
15+
16+
# mock inputs
17+
18+
vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3))
19+
faces = torch.randint(0, 121, (2, 64, 3)) # (batch, num faces, vertices (3))
20+
21+
# forward in the faces
22+
23+
loss = autoencoder(
24+
vertices = vertices,
25+
faces = faces
26+
)
27+
28+
loss.backward()
29+
30+
# after much training...
31+
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices
32+
33+
transformer = MeshTransformer(
34+
autoencoder,
35+
dim = 512,
36+
max_seq_len = 768,
37+
num_sos_tokens = 1,
38+
fine_cross_attend_text = True,
39+
text_cond_with_film = False,
40+
condition_on_text = True,
41+
coarse_adaptive_rmsnorm = adaptive_rmsnorm
42+
)
43+
44+
loss = transformer(
45+
vertices = vertices,
46+
faces = faces,
47+
texts = ['a high chair', 'a small teapot']
48+
)
49+
50+
loss.backward()
51+
52+
faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.)

0 commit comments

Comments
 (0)