-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hi!
Is there any limitation to running your code on a mask with an even number of channels?
I have a 3-channel mask and when I run fista.run(img) I get:
/tmp/ipykernel_28645/373462679.py in
----> 1 out_img = fista.run(img)
~/SpectralDiffuserCam/Python/fista_spectral_cupy.py in run(self, inputs)
162 # Start FISTA loop
163 for i in range(0,self.iters):
--> 164 vk, tk, xk, l = self.fista_update(vk, tk, xk, inputs)
165
166 llist.append(l)
~/SpectralDiffuserCam/Python/fista_spectral_cupy.py in fista_update(self, vk, tk, xk, inputs)
144 grads = self.Hadj(error)
145
--> 146 xup = self.prox(vk - 1/self.L * grads)
147 tup = 1 + np.sqrt(1 + 4*tk**2)/2
148 vup = xup + (tk-1)/tup * (xup-xk)
~/SpectralDiffuserCam/Python/fista_spectral_cupy.py in prox(self, x)
113 def prox(self,x):
114 if self.prox_method == 'tv':
--> 115 x = 0.5*(np.maximum(x,0) + tv.tv3dApproxHaar(x, self.tv_lambda/self.L, self.tv_lambdaw))
116 if self.prox_method == 'native':
117 x = np.maximum(x,0) + self.soft_thresh(x, self.tau)
...
cupy/core/_kernel.pyx in cupy.core._kernel.ufunc.call()
cupy/core/_kernel.pyx in cupy.core._kernel._get_out_args()
ValueError: Out shape is mismatched
However when I run fista = fista_spectral_numpy(psf, mask[:,:,0:-1]) it runs without a problem
Thanks