Skip to content
This repository was archived by the owner on Feb 6, 2020. It is now read-only.

Cls error masking #67

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# specification file of network architecture
fnet_spec = ../networks/N4.znn
# file of data spec
fdata_spec = ../dataset/Piriform/dataset.spec
fdata_spec = ../dataset/ISBI2012/dataset.spec
# number of threads. if <=0, the thread number will be equal to
# the number of concurrent threads supported by the implementation.
num_threads = 0
Expand All @@ -18,10 +18,10 @@ logging = no
# saved network file name. will automatically add iteration number
# saved file name example: net_21000.h5, net_current.h5
# the net_current.h5 will always be the latest network
train_net = ../experiments/piriform/N4/net.h5
train_net = ../experiments/ISBI/N4/net.h5
# sample ID range for train
# example: 2-3,7
train_range = 2
train_range = 1
# sample ID range for validate/test during training
# example: 1,4-6,8
test_range = 1
Expand Down
3 changes: 0 additions & 3 deletions python/front_end/zsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def __init__(self, config, pars, sample_id, net, \
if not is_forward:
print "\ncreate label image class..."
for name,setsz_out in self.setsz_outs.iteritems():
#Allowing for users to abstain from specifying labels
if not config.has_option(self.sec_name, name):
continue
#Finding the section of the config file
imid = config.getint(self.sec_name, name)
imsec_name = "label%d" % (imid,)
Expand Down
4 changes: 2 additions & 2 deletions python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def _single_test(net, pars, sample):
props = net.forward( vol_ins )

# cost function and accumulate errors
props, err, grdts = pars['cost_fn']( props, lbl_outs )
props, err, grdts = pars['cost_fn']( props, lbl_outs, msks )
# pixel classification error
cls = cost_fn.get_cls(props, lbl_outs)
cls = cost_fn.get_cls(props, lbl_outs, msks)
# rand error
re = pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0])

Expand Down
8 changes: 4 additions & 4 deletions python/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ def main( args ):
# cost function and accumulate errors
props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks )
err += cerr
cls += cost_fn.get_cls(props, lbl_outs)
cls += cost_fn.get_cls(props, lbl_outs, msks)
# compute rand error
if pars['is_debug']:
assert not np.all(lbl_outs.values()[0]==0)
re += pyznn.get_rand_error( props.values()[0], lbl_outs.values()[0] )
num_mask_voxels += utils.sum_over_dict(msks)
num_mask_voxels += utils.get_total_num_mask(msks)

# check whether there is a NaN here!
if pars['is_debug']:
Expand Down Expand Up @@ -182,8 +182,8 @@ def main( args ):
err = err / vn / pars['Num_iter_per_show']
cls = cls / vn / pars['Num_iter_per_show']
else:
err = err / num_mask_voxels / pars['Num_iter_per_show']
cls = cls / num_mask_voxels / pars['Num_iter_per_show']
err = err / num_mask_voxels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nicholasturner1 could you explain this change a little bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The num_mask_voxels variable accumulates over the number of iterations anyway, so we don't need to further normalize it by the number of iterations.

For example, if there were 10 rounds of 10 mask voxels, the num_mask_voxels value would be 100, and diving by 10 at the end obscures the resulting error value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense, tks.

cls = cls / num_mask_voxels
re = re / pars['Num_iter_per_show']
lc.append_train(i, err, cls, re)

Expand Down
13 changes: 13 additions & 0 deletions python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,19 @@ def get_total_num(outputs):
n = n + np.prod(sz)
return n


def get_total_num_mask(masks, props=None):
'''Returns the total number of active voxels in a forward pass'''
s = 0
for name, mask in masks.iteritems():
#full mask can correspond to empty array
if mask.size == 0 and props is not None:
s += props[name].size
else:
s += np.count_nonzero(mask)
return s


def sum_over_dict(dict_vol):
s = 0
for name, vol in dict_vol.iteritems():
Expand Down