Skip to content

Commit 698242e

Browse files
committed
refactor: verbose error msgs for abstract modules
1 parent 60fb351 commit 698242e

File tree

4 files changed

+64
-21
lines changed

4 files changed

+64
-21
lines changed

cellseg_models_pytorch/modules/attention_modules.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,13 @@ def __init__(self, name: str, **kwargs) -> None:
325325
)
326326

327327
if name is not None:
328-
self.att = ATT_LOOKUP[name](**kwargs)
328+
try:
329+
self.att = ATT_LOOKUP[name](**kwargs)
330+
except Exception as e:
331+
raise Exception(
332+
"Encountered an error when trying to init chl attention function: "
333+
f"Attention(name='{name}'): {e.__class__.__name__}: {e}"
334+
)
329335
else:
330336
self.att = Identity()
331337

cellseg_models_pytorch/modules/base_modules.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ def __init__(self, name: str, **kwargs) -> None:
5050
try:
5151
self.act = ACT_LOOKUP[name](**kwargs, inplace=True)
5252
except Exception:
53-
self.act = ACT_LOOKUP[name](**kwargs)
53+
try:
54+
self.act = ACT_LOOKUP[name](**kwargs)
55+
except Exception as e:
56+
raise Exception(
57+
"Encountered an error when trying to init activation function: "
58+
f"Activation(name='{name}'): {e.__class__.__name__}: {e}"
59+
)
5460
else:
5561
self.act = Identity()
5662

@@ -81,7 +87,13 @@ def __init__(self, name: str, **kwargs) -> None:
8187
)
8288

8389
if name is not None:
84-
self.norm = NORM_LOOKUP[name](**kwargs)
90+
try:
91+
self.norm = NORM_LOOKUP[name](**kwargs)
92+
except Exception as e:
93+
raise Exception(
94+
"Encountered an error when trying to init normalization function: "
95+
f"Norm(name='{name}'): {e.__class__.__name__}: {e}"
96+
)
8597
else:
8698
self.norm = Identity()
8799

@@ -118,7 +130,14 @@ def __init__(self, name: str, scale_factor: int = 2, **kwargs) -> None:
118130
kwargs["align_corners"] = True
119131

120132
kwargs["scale_factor"] = scale_factor
121-
self.up = UP_LOOKUP[name](**kwargs)
133+
134+
try:
135+
self.up = UP_LOOKUP[name](**kwargs)
136+
except Exception as e:
137+
raise Exception(
138+
"Encountered an error when trying to init upsampling function: "
139+
f"Up(name='{name}'): {e.__class__.__name__}: {e}"
140+
)
122141

123142
def forward(self, x: torch.Tensor) -> torch.Tensor:
124143
"""Forward pass for the upsampling function."""
@@ -146,7 +165,13 @@ def __init__(self, name: str, **kwargs) -> None:
146165
f"Illegal convolution method given. Allowed: {allowed}. Got: '{name}'"
147166
)
148167

149-
self.conv = CONV_LOOKUP[name](**kwargs)
168+
try:
169+
self.conv = CONV_LOOKUP[name](**kwargs)
170+
except Exception as e:
171+
raise Exception(
172+
"Encountered an error when trying to init convolution function: "
173+
f"Conv(name='{name}'): {e.__class__.__name__}: {e}"
174+
)
150175

151176
def forward(self, x: torch.Tensor) -> torch.Tensor:
152177
"""Forward pass for the convolution function."""

cellseg_models_pytorch/modules/conv_block.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,21 +161,27 @@ def __init__(
161161
f"Illegal `short_skip` given. Got: '{short_skip}'. Allowed: {allowed}."
162162
)
163163

164-
self.block = CONVBLOCK_LOOKUP[name](
165-
in_channels=in_channels,
166-
out_channels=out_channels,
167-
same_padding=same_padding,
168-
normalization=normalization,
169-
activation=activation,
170-
convolution=convolution,
171-
preactivate=preactivate,
172-
kernel_size=kernel_size,
173-
groups=groups,
174-
bias=bias,
175-
attention=attention,
176-
preattend=preattend,
177-
**kwargs,
178-
)
164+
try:
165+
self.block = CONVBLOCK_LOOKUP[name](
166+
in_channels=in_channels,
167+
out_channels=out_channels,
168+
same_padding=same_padding,
169+
normalization=normalization,
170+
activation=activation,
171+
convolution=convolution,
172+
preactivate=preactivate,
173+
kernel_size=kernel_size,
174+
groups=groups,
175+
bias=bias,
176+
attention=attention,
177+
preattend=preattend,
178+
**kwargs,
179+
)
180+
except Exception as e:
181+
raise Exception(
182+
"Encountered an error when trying to init ConvBlock module: "
183+
f"ConvBlock(name='{name}'): {e.__class__.__name__}: {e}"
184+
)
179185

180186
self.downsample = None
181187
if short_skip == "residual" and in_channels != self.out_channels:

cellseg_models_pytorch/modules/self_attention/exact_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,13 @@ def forward(
193193
The self-attention matrix. Same shape as inputs.
194194
"""
195195
if self.self_attention == "memeff":
196-
attn = memory_efficient_attention(query, key, value)
196+
if all([query.is_cuda, key.is_cuda, value.is_cuda]):
197+
attn = memory_efficient_attention(query, key, value)
198+
else:
199+
raise RuntimeError(
200+
"`xformers.ops.memory_efficient_attention` is only implemented "
201+
"for cuda. Make sure your inputs & model devices are set to cuda."
202+
)
197203
elif self.self_attention == "flash":
198204
raise NotImplementedError
199205
elif self.self_attention == "slice":

0 commit comments

Comments
 (0)