Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def next_pred_label(model, data_generator, verbose=False):
"""
predict the next sample batch from the generator, and compute max labels
return sample, prediction, max_labels
"""
sample = next(data_generator)
with timer.Timer('prediction', verbose):
pred = model.predict(sample[0])
sample_input = sample[0] if not isinstance(sample[0], (list, tuple)) else sample[0][0]
max_labels = pred_to_label(sample_input, pred)
return (sample, pred) + max_labels
def _load_medical_volume(filename, ext, verbose=False):
"""
load a medical volume from one of a number of file types
"""
with timer.Timer('load_vol', verbose >= 2):
if ext == '.npz':
vol_file = np.load(filename)
vol_data = vol_file['vol_data']
elif ext == 'npy':
vol_data = np.load(filename)
elif ext == '.mgz' or ext == '.nii' or ext == '.nii.gz':
vol_med = nib.load(filename)
vol_data = vol_med.get_data()
else:
raise ValueError("Unexpected extension %s" % ext)
return vol_data
# with the number of patches in batch matching output of gen
"""
# get prior
if prior_type == 'location':
prior_vol = nd.volsize2ndgrid(vol_size)
prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
prior_vol = np.expand_dims(prior_vol, axis=0) # reshape for model
elif prior_type == 'file': # assumes a npz filename passed in prior_file
with timer.Timer('loading prior', True):
data = np.load(prior_file)
prior_vol = data['prior'].astype('float16')
else: # assumes a volume
with timer.Timer('loading prior', True):
prior_vol = prior_file.astype('float16')
if force_binary:
nb_labels = prior_vol.shape[-1]
prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)
nb_channels = prior_vol.shape[-1]
if extract_slice is not None:
if isinstance(extract_slice, int):
prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :]
else: # assume slices
prior_vol = prior_vol[:, :, extract_slice, :]
# get the prior to have the right volume [x, y, z, nb_channels]
def next_vol_pred(model, data_generator, verbose=False):
"""
get the next batch, predict model output
returns (input_vol, y_true, y_pred, )
"""
# batch to input, output and prediction
sample = next(data_generator)
with timer.Timer('prediction', verbose):
pred = model.predict(sample[0])
data = (sample[0], sample[1], pred)
if isinstance(sample[0], (list, tuple)): # if given prior, might be a list
data = (sample[0][0], sample[1], pred, sample[0][1])
return data
def on_model_save(self, epoch, iter, logs=None):
""" save the model to hdf5. Code mostly from keras core """
with timer.Timer('model save callback', self.verbose):
logs = logs or {}
self.steps_since_last_save += 1
if self.steps_since_last_save >= self.period:
self.steps_since_last_save = 0
filepath = self.filepath.format(epoch=epoch, iter=iter, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('Epoch %05d Iter%05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, iter, self.monitor, self.best,
patch_rand=False,
patch_rand_seed=None):
"""
#
# add a prior generator to a given generator
# with the number of patches in batch matching output of gen
"""
# get prior
if prior_type == 'location':
prior_vol = nd.volsize2ndgrid(vol_size)
prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
prior_vol = np.expand_dims(prior_vol, axis=0) # reshape for model
elif prior_type == 'file': # assumes a npz filename passed in prior_file
with timer.Timer('loading prior', True):
data = np.load(prior_file)
prior_vol = data['prior'].astype('float16')
else: # assumes a volume
with timer.Timer('loading prior', True):
prior_vol = prior_file.astype('float16')
if force_binary:
nb_labels = prior_vol.shape[-1]
prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)
nb_channels = prior_vol.shape[-1]
if extract_slice is not None:
if isinstance(extract_slice, int):
def on_metric_call(self, epoch, iter, logs={}):
""" compute metrics on several predictions """
with timer.Timer('predict metrics callback', self.verbose):
# prepare metric
met = np.zeros((self.nb_samples, self.nb_labels, len(self.metrics)))
# generate predictions
# the idea is to predict either a full volume or just a slice,
# depending on what we need
gen = _generate_predictions(self.model,
self.data_generator,
self.batch_size,
self.nb_samples,
self.vol_params)
batch_idx = 0
for (vol_true, vol_pred) in gen:
for idx, metric in enumerate(self.metrics):
met[batch_idx, :, idx] = metric(vol_true, vol_pred)
force_binary=force_binary,
verbose=verbose,
patch_size=patch_size,
patch_stride=patch_stride,
batch_size=batch_size,
vol_rand_seed=vol_rand_seed,
nb_input_feats=nb_input_feats)
# get prior
if prior_type == 'location':
prior_vol = nd.volsize2ndgrid(vol_size)
prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
prior_vol = np.expand_dims(prior_vol, axis=0) # reshape for model
elif prior_type == 'file': # assumes a npz filename passed in prior_file
with timer.Timer('loading prior', True):
data = np.load(prior_file)
prior_vol = data['prior'].astype('float16')
else : # assumes a volume
with timer.Timer('astyping prior', verbose):
prior_vol = prior_file
if not (prior_vol.dtype == 'float16'):
prior_vol = prior_vol.astype('float16')
if force_binary:
nb_labels = prior_vol.shape[-1]
prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)
nb_channels = prior_vol.shape[-1]
if extract_slice is not None:
def _load_medical_volume(filename, ext, verbose=False):
"""
load a medical volume from one of a number of file types
"""
with timer.Timer('load_vol', verbose >= 2):
if ext == '.npz':
vol_file = np.load(filename)
vol_data = vol_file['vol_data']
elif ext == 'npy':
vol_data = np.load(filename)
elif ext == '.mgz' or ext == '.nii' or ext == '.nii.gz':
vol_med = nib.load(filename)
vol_data = vol_med.get_data()
else:
raise ValueError("Unexpected extension %s" % ext)
return vol_data