Replies: 1 comment
-
Hi @Requiem8 , See here |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Should it be input_.dtype instead of input_.dtype()
From the 9th code cell of the extending_functionality tutorial, using google collab T4GPU:
trainer, model = build_trainer(rpu_config, log="two_pass")
print(model)
fit_model(trainer, model)
plot_loss(trainer, "FP update with non-ideal two-pass forward.")
Output:
LitAnalogModel(
(analog_model): AnalogSequential(
(0): AnalogLinearMapped(
in_features=784, out_features=256, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TileModuleArray(
(array): ModuleList(
(0-1): 2 x ModuleList(
(0): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(256, 392, cpu)
)
)
)
)
)
(1): Sigmoid()
(2): AnalogLinearMapped(
in_features=256, out_features=128, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(128, 256, cpu)
)
)
(3): Sigmoid()
(4): AnalogLinearMapped(
in_features=128, out_features=10, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(10, 128, cpu)
)
)
(5): LogSoftmax(dim=1)
)
)
Sanity Checking DataLoader 0: 0%
0/2 [00:00<?, ?it/s]
TypeError Traceback (most recent call last)
in <cell line: 3>()
1 trainer, model = build_trainer(rpu_config, log="two_pass")
2 print(model)
----> 3 fit_model(trainer, model)
4 plot_loss(trainer, "FP update with non-ideal two-pass forward.")
27 frames
/usr/local/lib/python3.10/dist-packages/aihwkit/simulator/tiles/analog_mvm.py in matmul(cls, weight, input_, io_pars, trans, is_test, **fwd_pars)
92 ):
93 # - Shortcut, output would be all zeros
---> 94 return zeros(size=out_size, device=input_.device, dtype=input_.dtype())
95
96 if isinstance(nm_scale_values, Tensor):
TypeError: 'torch.dtype' object is not callable
Beta Was this translation helpful? Give feedback.
All reactions