diff --git a/python/config.cfg b/python/config.cfg index efc184e7..4bc44384 100644 --- a/python/config.cfg +++ b/python/config.cfg @@ -3,7 +3,7 @@ # specification file of network architecture fnet_spec = ../networks/N4.znn # file of data spec -fdata_spec = ../dataset/Piriform/dataset.spec +fdata_spec = ../dataset/ISBI2012/dataset.spec # number of threads. if <=0, the thread number will be equal to # the number of concurrent threads supported by the implementation. num_threads = 0 @@ -18,10 +18,10 @@ logging = no # saved network file name. will automatically add iteration number # saved file name example: net_21000.h5, net_current.h5 # the net_current.h5 will always be the latest network -train_net = ../experiments/piriform/N4/net.h5 +train_net = ../experiments/ISBI/N4/net.h5 # sample ID range for train # example: 2-3,7 -train_range = 2 +train_range = 1 # sample ID range for validate/test during training # example: 1,4-6,8 test_range = 1 diff --git a/python/front_end/zsample.py b/python/front_end/zsample.py index 156cdd99..54f517c2 100644 --- a/python/front_end/zsample.py +++ b/python/front_end/zsample.py @@ -65,9 +65,6 @@ def __init__(self, config, pars, sample_id, net, \ if not is_forward: print "\ncreate label image class..." for name,setsz_out in self.setsz_outs.iteritems(): - #Allowing for users to abstain from specifying labels - if not config.has_option(self.sec_name, name): - continue #Finding the section of the config file imid = config.getint(self.sec_name, name) imsec_name = "label%d" % (imid,) diff --git a/python/test.py b/python/test.py index bf132951..aa32b244 100644 --- a/python/test.py +++ b/python/test.py @@ -16,9 +16,9 @@ def _single_test(net, pars, sample): props = net.forward( vol_ins ) # cost function and accumulate errors - props, err, grdts = pars['cost_fn']( props, lbl_outs ) + props, err, grdts = pars['cost_fn']( props, lbl_outs, msks ) # pixel classification error - cls = cost_fn.get_cls(props, lbl_outs) + cls = cost_fn.get_cls(props, lbl_outs, msks) # rand error re = pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0]) diff --git a/python/train.py b/python/train.py index e848786c..1717c410 100644 --- a/python/train.py +++ b/python/train.py @@ -132,12 +132,12 @@ def main( args ): # cost function and accumulate errors props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks ) err += cerr - cls += cost_fn.get_cls(props, lbl_outs) + cls += cost_fn.get_cls(props, lbl_outs, msks) # compute rand error if pars['is_debug']: assert not np.all(lbl_outs.values()[0]==0) re += pyznn.get_rand_error( props.values()[0], lbl_outs.values()[0] ) - num_mask_voxels += utils.sum_over_dict(msks) + num_mask_voxels += utils.get_total_num_mask(msks) # check whether there is a NaN here! if pars['is_debug']: @@ -182,8 +182,8 @@ def main( args ): err = err / vn / pars['Num_iter_per_show'] cls = cls / vn / pars['Num_iter_per_show'] else: - err = err / num_mask_voxels / pars['Num_iter_per_show'] - cls = cls / num_mask_voxels / pars['Num_iter_per_show'] + err = err / num_mask_voxels + cls = cls / num_mask_voxels re = re / pars['Num_iter_per_show'] lc.append_train(i, err, cls, re) diff --git a/python/utils.py b/python/utils.py index a6436f68..bf7ed5c4 100644 --- a/python/utils.py +++ b/python/utils.py @@ -211,6 +211,19 @@ def get_total_num(outputs): n = n + np.prod(sz) return n + +def get_total_num_mask(masks, props=None): + '''Returns the total number of active voxels in a forward pass''' + s = 0 + for name, mask in masks.iteritems(): + #full mask can correspond to empty array + if mask.size == 0 and props is not None: + s += props[name].size + else: + s += np.count_nonzero(mask) + return s + + def sum_over_dict(dict_vol): s = 0 for name, vol in dict_vol.iteritems():