diff --git a/.travis.yml b/.travis.yml index 9ccc4436..60d771eb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,19 +53,19 @@ script: # check patch using single precision, this check will create the net_current.h5 file for testing loading - cd python # test affinity training patch, network initialization, network with even field of view - - python train.py -c ../testsuit/affinity/config.cfg -d single -k yes + - python ztrain.py -c ../testsuit/affinity/config.cfg -d single -k yes # test boundary map training, patch matching, network initialization - - python train.py -c ../testsuit/boundary/config.cfg -d single -k yes + - python ztrain.py -c ../testsuit/boundary/config.cfg -d single -k yes # second check to test the network loading - - python train.py -c ../testsuit/boundary/config.cfg -d single -k yes + - python ztrain.py -c ../testsuit/boundary/config.cfg -d single -k yes # check the double precision # compile the core with double precision - cd core; make double -j 4 # return to `python` - cd .. - - python train.py -c ../testsuit/boundary/config.cfg -d double -k yes + - python ztrain.py -c ../testsuit/boundary/config.cfg -d double -k yes # test forward pass - - python forward.py -c ../testsuit/forward/config.cfg + - python zforward.py -c ../testsuit/forward/config.cfg # return to root directory - cd .. diff --git a/python/tifffile.py b/python/emirt/tifffile.py similarity index 100% rename from python/tifffile.py rename to python/emirt/tifffile.py diff --git a/python/front_end/__init__.py b/python/front_end/__init__.py index 98d85cd3..65575ce6 100644 --- a/python/front_end/__init__.py +++ b/python/front_end/__init__.py @@ -3,4 +3,9 @@ import zlog import znetio import zsample -import zshow \ No newline at end of file +import zshow +import zutils +import zaws +import zcost +import ztest +import zcheck \ No newline at end of file diff --git a/python/front_end/zaws.py b/python/front_end/zaws.py new file mode 100644 index 00000000..754cbe0b --- /dev/null +++ b/python/front_end/zaws.py @@ -0,0 +1,30 @@ +import os + +def s3download(s3fname, tmpdir="/tmp/"): + """ + download aws s3 file + + params: + - fname: string, file name in s3 + - tmpdir: string, temporary directory + + return: + - lcfname: string, local file name + """ + if s3fname and "s3://" in s3fname: + # base name + bn = os.path.basename(s3fname) + # local directory + lcdir = os.path.dirname( s3fname ) + lcdir = lcdir.replace("s3://", "") + lcdir = os.path.join(tmpdir, lcdir) + # make directory + os.makedirs(lcdir) + # local file name + lcfname = os.path.join( tmpdir, bn ) + # copy file from s3 + os.system("aws s3 cp {} {}".format(s3fname, lcfname)) + return lcfname + else: + # this is not a s3 file, just a local file + return s3fname \ No newline at end of file diff --git a/python/zcheck.py b/python/front_end/zcheck.py similarity index 98% rename from python/zcheck.py rename to python/front_end/zcheck.py index 90550027..57e8c3ca 100644 --- a/python/zcheck.py +++ b/python/front_end/zcheck.py @@ -5,7 +5,7 @@ """ import os import numpy as np -import utils +import zutils import emirt def check_gradient(pars, net, smp, h=0.00001): @@ -21,7 +21,7 @@ def check_gradient(pars, net, smp, h=0.00001): # numerical gradient # apply the transformations in memory rather than array view - vol_ins = utils.make_continuous(vol_ins) + vol_ins = zutils.make_continuous(vol_ins) # shift the input to compute the analytical gradient vol_ins1 = dict() vol_ins2 = dict() diff --git a/python/front_end/zconfig.py b/python/front_end/zconfig.py index aec16105..bdda2ba2 100644 --- a/python/front_end/zconfig.py +++ b/python/front_end/zconfig.py @@ -10,9 +10,10 @@ import ConfigParser import numpy as np import os -import cost_fn -import utils +import zcost +import zutils from emirt import volume_util +import zaws def parser(conf_fname): # parse config file to get parameters @@ -78,11 +79,11 @@ def parse_cfg( conf_fname ): #TRAINING OPTIONS #Samples to use for training if config.has_option('parameters', 'train_range'): - pars['train_range'] = utils.parseIntSet( config.get('parameters', 'train_range') ) + pars['train_range'] = zutils.parseIntSet( config.get('parameters', 'train_range') ) #Samples to use for cross-validation if config.has_option('parameters', 'test_range'): - pars['test_range'] = utils.parseIntSet( config.get('parameters', 'test_range') ) + pars['test_range'] = zutils.parseIntSet( config.get('parameters', 'test_range') ) #Learning Rate if config.has_option('parameters', 'eta'): pars['eta'] = config.getfloat('parameters', 'eta') @@ -241,7 +242,7 @@ def parse_cfg( conf_fname ): #FULL FORWARD PASS PARAMETERS #Which samples to use - pars['forward_range'] = utils.parseIntSet( config.get('parameters', 'forward_range') ) + pars['forward_range'] = zutils.parseIntSet( config.get('parameters', 'forward_range') ) #Which network file to load pars['forward_net'] = config.get('parameters', 'forward_net') #Output Patch Size @@ -258,26 +259,36 @@ def autoset_pars(pars): # automatic choosing of cost function if 'boundary' in pars['out_type']: pars['cost_fn_str'] = 'softmax_loss' - pars['cost_fn'] = cost_fn.softmax_loss + pars['cost_fn'] = zcost.softmax_loss elif 'affin' in pars['out_type']: pars['cost_fn_str'] = 'binomial_cross_entropy' - pars['cost_fn'] = cost_fn.binomial_cross_entropy + pars['cost_fn'] = zcost.binomial_cross_entropy elif 'semantic' in pars['out_type']: pars['cost_fn_str'] = 'softmax_loss' - pars['cost_fn'] = cost_fn.softmax_loss + pars['cost_fn'] = zcost.softmax_loss else: raise NameError("no matching cost function for out_type!") elif "square-square" in pars['cost_fn_str']: - pars['cost_fn'] = cost_fn.square_square_loss + pars['cost_fn'] = zcost.square_square_loss elif "square" in pars['cost_fn_str']: - pars['cost_fn'] = cost_fn.square_loss + pars['cost_fn'] = zcost.square_loss elif "binomial" in pars['cost_fn_str']: - pars['cost_fn'] = cost_fn.binomial_cross_entropy + pars['cost_fn'] = zcost.binomial_cross_entropy elif "softmax" in pars['cost_fn_str']: - pars['cost_fn'] = cost_fn.softmax_loss + pars['cost_fn'] = zcost.softmax_loss else: raise NameError('unknown type of cost function') + # aws s3 filehandling + pars['fnet_spec'] = zaws.s3download( pars['fnet_spec'] ) + pars['fdata_spec'] = zaws.s3download( pars['fdata_spec'] ) + # local file name + if "s3://" in pars['train_net_prefix']: + # copy the path as a backup + pars['s3_train_net_prefix'] = pars['train_net_prefix'] + bn = os.path.basename( pars['train_net_prefix'] ) + # replace with local path + pars['train_net_prefix'] = "/tmp/{}".format(bn) return pars def check_pars(pars): @@ -353,6 +364,9 @@ def autoset_dspec(pars, dspec): # parse args def parse_args(args): + # s3 to local + args['config'] = zaws.s3download( args['config'] ) + args['seed'] = zaws.s3download( args['seed'] ) #%% parameters if not os.path.exists( args['config'] ): raise NameError("config file not exist!") diff --git a/python/cost_fn.py b/python/front_end/zcost.py similarity index 92% rename from python/cost_fn.py rename to python/front_end/zcost.py index 98e82060..29ad340f 100644 --- a/python/cost_fn.py +++ b/python/front_end/zcost.py @@ -5,7 +5,7 @@ """ import numpy as np import emirt -import utils +import zutils from core import pyznn def get_cls(props, lbls, mask=None): @@ -25,8 +25,8 @@ def get_cls(props, lbls, mask=None): c = 0.0 #Applying mask if it exists - props = utils.mask_dict_vol(props, mask) - lbls = utils.mask_dict_vol(lbls, mask) + props = zutils.mask_dict_vol(props, mask) + lbls = zutils.mask_dict_vol(lbls, mask) for name, prop in props.iteritems(): lbl = lbls[name] @@ -53,8 +53,8 @@ def square_loss(props, lbls, mask=None): err = 0 #Applying mask if it exists - props = utils.mask_dict_vol(props, mask) - lbls = utils.mask_dict_vol(lbls, mask) + props = zutils.mask_dict_vol(props, mask) + lbls = zutils.mask_dict_vol(lbls, mask) for name, prop in props.iteritems(): lbl = lbls[name] @@ -74,8 +74,8 @@ def square_square_loss(props, lbls, mask=None, margin=0.2): error = 0 #Applying mask if it exists - props = utils.mask_dict_vol(props, mask) - lbls = utils.mask_dict_vol(lbls, mask) + props = zutils.mask_dict_vol(props, mask) + lbls = zutils.mask_dict_vol(lbls, mask) for name, propagation in props.iteritems(): lbl = lbls[name] @@ -120,8 +120,8 @@ def binomial_cross_entropy(props, lbls, mask=None): entropy[name] = -lbl*np.log(prop) - (1-lbl)*np.log(1-prop) #Applying mask if it exists - grdts = utils.mask_dict_vol(grdts, mask) - entropy = utils.mask_dict_vol(entropy, mask) + grdts = zutils.mask_dict_vol(grdts, mask) + entropy = zutils.mask_dict_vol(entropy, mask) for name, vol in entropy.iteritems(): err += np.sum( vol ) @@ -186,7 +186,7 @@ def multinomial_cross_entropy(props, lbls, mask=None): entropy[name] = -lbl * np.log(prop) #Applying mask if it exists - entropy = utils.mask_dict_vol(entropy, mask) + entropy = zutils.mask_dict_vol(entropy, mask) for name, vol in entropy.iteritems(): cost += np.sum( vol ) @@ -399,7 +399,7 @@ def get_grdt(pars, history, props, lbl_outs, msks, wmsks, vn): # history['re'] += pyznn.get_rand_error( props.values(), lbl_outs.values() ) # print 're: {}'.format( history['re'] ) - num_mask_voxels = utils.sum_over_dict(msks) + num_mask_voxels = zutils.sum_over_dict(msks) if num_mask_voxels > 0: history['err'] += cerr / num_mask_voxels history['cls'] += get_cls(props, lbl_outs) / num_mask_voxels @@ -408,24 +408,24 @@ def get_grdt(pars, history, props, lbl_outs, msks, wmsks, vn): history['cls'] += get_cls(props, lbl_outs) / vn if pars['is_debug']: - c2 = utils.check_dict_nan(lbl_outs) - c3 = utils.check_dict_nan(msks) - c4 = utils.check_dict_nan(wmsks) - c5 = utils.check_dict_nan(props) - c6 = utils.check_dict_nan(grdts) + c2 = zutils.check_dict_nan(lbl_outs) + c3 = zutils.check_dict_nan(msks) + c4 = zutils.check_dict_nan(wmsks) + c5 = zutils.check_dict_nan(props) + c6 = zutils.check_dict_nan(grdts) if not ( c2 and c3 and c4 and c5 and c6): # stop training raise NameError('nan encountered!') # gradient reweighting - grdts = utils.dict_mul( grdts, msks ) + grdts = zutils.dict_mul( grdts, msks ) if pars['rebalance_mode']: - grdts = utils.dict_mul( grdts, wmsks ) + grdts = zutils.dict_mul( grdts, wmsks ) if pars['is_malis'] : malis_weights, rand_errors, num_non_bdr = cost_fn.malis_weight(pars, props, lbl_outs) - grdts = utils.dict_mul(grdts, malis_weights) - dmc, dme = utils.get_malis_cost( props, lbl_outs, malis_weights ) + grdts = zutils.dict_mul(grdts, malis_weights) + dmc, dme = zutils.get_malis_cost( props, lbl_outs, malis_weights ) if num_mask_voxels > 0: history['mc'] += dmc.values()[0] / num_mask_voxels history['me'] += dme.values()[0] / num_mask_voxels @@ -433,5 +433,5 @@ def get_grdt(pars, history, props, lbl_outs, msks, wmsks, vn): history['mc'] += dmc.values()[0] / vn history['me'] += dme.values()[0] / vn - grdts = utils.make_continuous(grdts) + grdts = zutils.make_continuous(grdts) return props, grdts, history diff --git a/python/front_end/zdataset.py b/python/front_end/zdataset.py index 7165dae0..155f4bff 100644 --- a/python/front_end/zdataset.py +++ b/python/front_end/zdataset.py @@ -10,7 +10,7 @@ import sys import numpy as np import emirt -import utils +import zutils class CDataset(object): @@ -383,7 +383,7 @@ def __init__(self, dspec, pars, sec_name, \ if pars['is_bd_mirror']: if self.pars['is_debug']: print "data shape before mirror: ", self.data.shape - self.data = utils.boundary_mirror(self.data, self.mapsz) + self.data = zutils.boundary_mirror(self.data, self.mapsz) #Modifying the deviation boundaries for the modified dataset self.calculate_sizes( ) if self.pars['is_debug']: @@ -450,9 +450,6 @@ def __init__(self, dspec, pars, sec_name, outsz, setsz, mapsz ): self.sublbl = None self.submsk = None - # rename data as lbl - self.lbl = self.data - # deal with mask self.msk = np.array([]) if dspec[sec_name].has_key('fmasks'): @@ -498,7 +495,10 @@ def get_dataset(self): return the whole label for examination """ return self.data - + def get_lbl(self): + return self.data + def get_msk(self): + return self.msk def get_candidate_loc( self, low, high ): """ find the candidate location of subvolume diff --git a/python/front_end/zlog.py b/python/front_end/zlog.py index 28b62fc1..20e59ca4 100644 --- a/python/front_end/zlog.py +++ b/python/front_end/zlog.py @@ -10,8 +10,7 @@ import ConfigParser import numpy as np import os -import cost_fn -import utils +import zutils from emirt import volume_util @@ -28,7 +27,7 @@ def record_config_file(params=None, config_filename=None, net_save_filename=None #Need to specify either a params object, or all of the other optional args #"ALL" optional args excludes train - utils.assert_arglist(params, + zutils.assert_arglist(params, [config_filename, net_save_filename] ) @@ -57,7 +56,7 @@ def record_config_file(params=None, config_filename=None, net_save_filename=None #Deriving destination filename information if timestamp is None: - timestamp = utils.timestamp() + timestamp = zutils.timestamp() mode = "train" if train else "forward" #Actually saving @@ -76,7 +75,7 @@ def make_logfile_name(params=None, net_save_filename=None, timestamp = None, tra ''' #Need to specify either a params object, or the net save prefix - utils.assert_arglist(params, + zutils.assert_arglist(params, [net_save_filename]) if params is not None: @@ -94,7 +93,7 @@ def make_logfile_name(params=None, net_save_filename=None, timestamp = None, tra assert(save_prefix_valid) if timestamp is None: - timestamp = utils.timestamp() + timestamp = zutils.timestamp() mode = "train" if train else "forward" directory_name = os.path.dirname( save_prefix ) diff --git a/python/front_end/zsample.py b/python/front_end/zsample.py index e163fa71..419d2625 100644 --- a/python/front_end/zsample.py +++ b/python/front_end/zsample.py @@ -13,7 +13,7 @@ import sys import numpy as np import emirt -import utils +import zutils from zdataset import * import os @@ -111,11 +111,11 @@ def _data_aug(self): if self.pars['is_data_aug']: rft = (np.random.rand(4)>0.5) for key, subinput in self.subimgs.iteritems(): - self.subimgs[key] = utils.data_aug_transform(subinput, rft ) + self.subimgs[key] = zutils.data_aug_transform(subinput, rft ) for key, sublbl in self.sublbls.iteritems(): submsk = self.submsks[key] - self.sublbls[key] = utils.data_aug_transform(sublbl, rft ) - self.submsks[key] = utils.data_aug_transform(submsk, rft ) + self.sublbls[key] = zutils.data_aug_transform(sublbl, rft ) + self.submsks[key] = zutils.data_aug_transform(submsk, rft ) def get_random_sample(self): '''Fetches a matching random sample from all input and output volumes''' @@ -139,7 +139,7 @@ def get_random_sample(self): self._data_aug() # make sure that the input image is continuous in memory # the C++ core can not deal with numpy view - self.subimgs = utils.make_continuous(self.subimgs) + self.subimgs = zutils.make_continuous(self.subimgs) return ( self.subimgs, self.sublbls, self.submsks ) def _get_balance_weight(self, arr, msk=None): @@ -225,8 +225,8 @@ def write_request_to_log(self, dev): if self.log is not None: log_line1 = self.name log_line2 = "subvolume: [{},{},{}] requested".format(dev[0],dev[1],dev[2]) - utils.write_to_log(self.log, log_line1) - utils.write_to_log(self.log, log_line2) + zutils.write_to_log(self.log, log_line1) + zutils.write_to_log(self.log, log_line2) class CAffinitySample(CSample): """ @@ -255,8 +255,8 @@ def __init__(self, dspec, pars, sample_id, net, outsz, log=None, is_forward=Fals self.taffs = dict() self.tmsks = dict() for k, out in self.outs.iteritems(): - self.taffs[k] = self._seg2aff( out.lbl ) - self.tmsks[k] = self._msk2affmsk( out.msk ) + self.taffs[k] = self._seg2aff( out.get_lbl() ) + self.tmsks[k] = self._msk2affmsk( out.get_msk() ) self._prepare_rebalance_weights( self.taffs, self.tmsks ) return @@ -415,7 +415,7 @@ def _prepare_rebalance_weights(self): self.wps = dict() self.wzs = dict() for key, out in self.outs.iteritems(): - self.wps[key], self.wzs[key] = self._get_balance_weight( out.lbl,out.msk ) + self.wps[key], self.wzs[key] = self._get_balance_weight( out.get_lbl(),out.get_msk() ) def _binary_class(self, lbl): """ diff --git a/python/test.py b/python/front_end/ztest.py similarity index 88% rename from python/test.py rename to python/front_end/ztest.py index 031eb8f5..3f92f635 100644 --- a/python/test.py +++ b/python/front_end/ztest.py @@ -3,11 +3,11 @@ Jingpeng Wu , 2015 """ -import utils -import cost_fn +import zutils import numpy as np from core import pyznn import time +import zcost def _single_test(net, pars, sample, vn): # return errors as a dictionary @@ -15,20 +15,20 @@ def _single_test(net, pars, sample, vn): vol_ins, lbl_outs, msks, wmsks = sample.get_random_sample() # forward pass - vol_ins = utils.make_continuous(vol_ins) + vol_ins = zutils.make_continuous(vol_ins) props = net.forward( vol_ins ) # cost function and accumulate errors props, derr['err'], grdts = pars['cost_fn']( props, lbl_outs ) # pixel classification error - derr['cls'] = cost_fn.get_cls(props, lbl_outs) + derr['cls'] = zcost.get_cls(props, lbl_outs) # rand error #derr['re'] = pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0]) if pars['is_malis']: - malis_weights, rand_errors, num_non_bdr = cost_fn.malis_weight( pars, props, lbl_outs ) + malis_weights, rand_errors, num_non_bdr = zcost.malis_weight( pars, props, lbl_outs ) # dictionary of malis classification error - dmc, dme = utils.get_malis_cost( props, lbl_outs, malis_weights ) + dmc, dme = zutils.get_malis_cost( props, lbl_outs, malis_weights ) derr['mc'] = dmc.values()[0] derr['me'] = dme.values()[0] # normalization diff --git a/python/utils.py b/python/front_end/zutils.py similarity index 94% rename from python/utils.py rename to python/front_end/zutils.py index 97225d24..583bedd2 100644 --- a/python/utils.py +++ b/python/front_end/zutils.py @@ -325,6 +325,10 @@ def init_save(pars, lc, net, iter_last): os.remove( fname ) lc.save(pars, fname) znetio.save_network(net, fname, pars['is_stdio']) + if pars.has_key('s3_train_net_prefix'): + # should transfer local network to s3 + s3fname = pars['s3_train_net_prefix'] + "_init_{}.h5".format(iter_last) + os.system("aws cp {} {}".format(fname, s3fname)) # save the intermediate networks while training def inter_save(pars, lc, net, vol_ins, props, lbl_outs, grdts, wmsks, it): @@ -355,3 +359,10 @@ def inter_save(pars, lc, net, vol_ins, props, lbl_outs, grdts, wmsks, it): # Overwriting most current file with completely saved version shutil.copyfile(filename, filename_current) + + if pars.has_key('s3_train_net_prefix'): + # should transfer local network to s3 + s3fname = pars['s3_train_net_prefix'] + "_{}.h5".format(iter_last) + os.system("aws cp {} {}".format(filename, s3fname)) + s3fname = pars['s3_train_net_prefix'] + "_current.h5" + os.system("aws cp {} {}".format(filename, s3fname)) diff --git a/python/tests/test_seg.py b/python/tests/test_seg.py deleted file mode 100644 index d2fc9ffa..00000000 --- a/python/tests/test_seg.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -__doc__ = """ - -Jingpeng Wu , 2015 -""" - -# read affinity -fname = "/usr/people/jingpeng/seungmount/research/Jingpeng/09_pypipeline/znn_merged.h5" -import emirt -affs = emirt.emio.imread( fname ) - -seg = emirt.volume_util.seg_affs(affs, threshold=0.8) - -#%% -com = emirt.show.CompareVol((affs[0, :,:,:], seg)) -com.vol_compare_slice() - -#%% -emirt.show.random_color_show( seg[4,:,:] ) \ No newline at end of file diff --git a/python/forward.py b/python/zforward.py similarity index 97% rename from python/forward.py rename to python/zforward.py index 60b3ad80..0001a800 100644 --- a/python/forward.py +++ b/python/zforward.py @@ -32,7 +32,6 @@ import numpy as np import os from front_end import * -import utils from emirt import emio def parse_args( args ): @@ -45,7 +44,7 @@ def parse_args( args ): assert( os.path.exists( params['forward_net'] ) ) if args['range']: - params['forward_range'] = utils.parseIntSet( args['range'] ) + params['forward_range'] = zutils.parseIntSet( args['range'] ) return dspec, params @@ -90,7 +89,7 @@ def forward_pass( params, Dataset, network, verbose=True ): if verbose: print "Output patch #{} of {}".format(i+1, num_patches) # i is just an index input_patches, junk = Dataset.get_next_patch() - vol_ins = utils.make_continuous(input_patches) + vol_ins = zutils.make_continuous(input_patches) output = network.forward( vol_ins ) Output.set_next_patch( output ) if params['is_check']: @@ -107,12 +106,10 @@ def run_softmax( sample_output ): Performs a softmax calculation over the output volumes for a given sample output ''' - from cost_fn import softmax - for dname, dataset in sample_output.output_volumes.iteritems(): props = {'dataset':dataset.data} - props = softmax(props) + props = zcost.softmax(props) dataset.data = props.values()[0] sample_output.output_volumes[dname] = dataset diff --git a/python/zstatistics.py b/python/zstatistics.py index 30abba82..7b0718b7 100644 --- a/python/zstatistics.py +++ b/python/zstatistics.py @@ -6,7 +6,7 @@ import numpy as np import time import os -import utils +from front_end import * class CLearnCurve: def __init__(self, fname=None): @@ -194,7 +194,7 @@ def plot(self, w=3, plotmode='matplotlib'): fig = plt.figure() # number of subplots - nsp = len(self.train)-1 + nsp = 2 print "number of subplots: {}".format(nsp) # print the maximum iteration @@ -214,7 +214,7 @@ def plot(self, w=3, plotmode='matplotlib'): # plot data idx = 0 for key in self.train.keys(): - if key == 'it': + if key in ['it', 'elapse', 'eta', 're', 'num_mask_voxels']: continue idx += 1 ax = fig.add_subplot(1,nsp,idx) diff --git a/python/train.py b/python/ztrain.py similarity index 88% rename from python/train.py rename to python/ztrain.py index 93c3521d..7c8bc29e 100644 --- a/python/train.py +++ b/python/ztrain.py @@ -5,12 +5,9 @@ """ import time from front_end import * -import cost_fn -import utils import zstatistics import os import numpy as np -import test def main( args ): @@ -19,7 +16,7 @@ def main( args ): net, lc = znetio.create_net(pars) # total voxel number of output volumes - vn = utils.get_total_num(net.get_outputs_setsz()) + vn = zutils.get_total_num(net.get_outputs_setsz()) # initialize samples print "\n\ncreate train samples..." @@ -28,7 +25,6 @@ def main( args ): smp_tst = zsample.CSamples(dspec, pars, pars['test_range'], net, pars['train_outsz'], logfile) if pars['is_check']: - import zcheck zcheck.check_patch(pars, smp_trn) # initialize history recording @@ -40,7 +36,7 @@ def main( args ): #Saving initial/seeded network # get file name fname, fname_current = znetio.get_net_fname( pars['train_net_prefix'], iter_last, suffix="init" ) - utils.init_save(pars, lc, net, iter_last) + zutils.init_save(pars, lc, net, iter_last) # start time cumulation print "start training..." @@ -59,7 +55,7 @@ def main( args ): #print props # get gradient and record history - props, grdts, history = cost_fn.get_grdt(pars, history, props, lbl_outs, msks, wmsks, vn) + props, grdts, history = zcost.get_grdt(pars, history, props, lbl_outs, msks, wmsks, vn) #print props #print lbl_outs @@ -69,10 +65,10 @@ def main( args ): # post backward pass processing history, net, lc, start, total_time = zstatistics.process_history(pars, history, \ lc, net, it, start, total_time) - utils.inter_save(pars, lc, net, vol_ins, props, \ + zutils.inter_save(pars, lc, net, vol_ins, props, \ lbl_outs, grdts, wmsks, it) - lc, start, total_time = test.run_test(net, pars, smp_tst, \ + lc, start, total_time = ztest.run_test(net, pars, smp_tst, \ vn, it, lc, start, total_time) # stop the iteration at checking mode