Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Inference worker."""
logger = medaka.common.get_named_logger('PWorker')
remainder_regions = list()
loader = DataLoader(
4 * batch_size, bam, regions, feature_encoder,
chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
enable_chunking=enable_chunking)
batches = medaka.common.grouper(loader, batch_size)
total_region_mbases = sum(r.size for r in regions) / 1e6
logger.info(
"Running inference for {:.1f}M draft bases.".format(
total_region_mbases))
with medaka.datastore.DataStore(output, 'a') as ds:
mbases_done = 0
cache_size_log_interval = 5
t0 = now()
tlast = t0
tcache = t0
for data in batches:
if now() - tcache > cache_size_log_interval:
logger.info("Samples in cache: {}.".format(
loader.results.qsize()))
tcache = now()
x_data = np.stack([x.features for x in data])
class_probs = model.predict_on_batch(x_data)
# calculate bases done taking into account overlap
new_bases = 0
for x in data:
logger = medaka.common.get_named_logger('Prepare')
if args.chunk_ovlp >= args.chunk_len:
raise ValueError(
'chunk_ovlp {} is not smaller than chunk_len {}'.format(
args.chunk_ovlp, args.chunk_len))
regions = medaka.common.get_regions(args.bam, args.regions)
reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions])
logger.info('Got regions:\n{}'.format(reg_str))
if args.truth is None:
logger.warn(
'Running medaka features without a truth bam, '
'unlabelled data will be produced. Is this intended?')
time.sleep(3)
no_data = False
with medaka.datastore.DataStore(args.output, 'w') as ds:
# write feature options to file
logger.info("Writing meta data to file.")
num_qstrat = args.feature_encoder_args.get('num_qstrat')
max_run = args.label_scheme_args.get('max_run')
# If one of them is set, set the other to agree.
# If both are set, force them to agree.
# If none is set or they are the same continue merrily.
if max_run is None and num_qstrat is not None:
args.label_scheme_args['max_run'] = num_qstrat
elif max_run is not None and num_qstrat is None:
args.feature_encoder_args['num_qstrat'] = max_run
elif max_run != num_qstrat:
raise ValueError(
'num_qstrat in feature_encoder_args must agree '
'with max_run in feature_encoder_args')
def sample_to_x_y_bq_worker(sample, label_scheme):
"""Convert a `common.Sample` object into a training x, y tuple.
:param sample: (sample key, filename).
:param label_scheme: `LabelScheme` obj.
:returns: (np.ndarray of inputs, np.ndarray of labels)
"""
sample_key, sample_file = sample
with medaka.datastore.DataStore(sample_file) as ds:
s = ds.load_sample(sample_key)
if s.labels is None:
raise ValueError("Sample {} in {} has no labels.".format(
sample_key, sample_file))
x = s.features
y = label_scheme.encoded_labels_to_training_vectors(s.labels)
return x, y
def is_rle_encoder(model_name):
""" Return encoder used by model"""
rle_encoders = [medaka.features.HardRLEFeatureEncoder]
model = medaka.datastore.DataStore(model_name)
encoder = model.get_meta('feature_encoder')
is_rle = issubclass(type(encoder), medaka.features.HardRLEFeatureEncoder)
return is_rle
def on_epoch_end(self, epoch, logs=None):
"""Perform actions at the end of an epoch."""
super(ModelMetaCheckpoint, self).on_epoch_end(epoch, logs)
filepath = self.filepath.format(epoch=epoch + 1, **logs)
with medaka.datastore.DataStore(filepath, 'a') as ds:
for k, v in self.medaka_meta.items():
ds.set_meta(v, k)
args.chunk_len, args.chunk_ovlp, # these won't be used
batch_size=args.batch_size, save_features=args.save_features,
tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing,
enable_chunking=False
)
if len(new_remainders) > 0:
# shouldn't get here
ignored = [x[0] for x in new_remainders]
n_ignored = len(ignored)
logger.warning("{} regions were not processed: {}.".format(n_ignored, ignored))
logger.info("Finished processing all regions.")
if args.check_output:
logger.info("Validating and finalising output data.")
with medaka.datastore.DataStore(args.output, 'a') as ds:
pass
"""Inference program."""
logger_level = logging.getLogger(__package__).level
if logger_level > logging.DEBUG:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow.keras import backend as K
args.regions = medaka.common.get_regions(
args.bam, region_strs=args.regions)
logger = medaka.common.get_named_logger('Predict')
logger.info('Processing region(s): {}'.format(
' '.join(str(r) for r in args.regions)))
# create output and copy meta
with medaka.datastore.DataStore(args.model) as ds:
ds.copy_meta(args.output)
feature_encoder = ds.get_meta('feature_encoder')
feature_encoder.tag_name = args.tag_name
feature_encoder.tag_value = args.tag_value
feature_encoder.tag_keep_missing = args.tag_keep_missing
feature_encoder.read_group = args.RG
logger.info("Setting tensorflow threads to {}.".format(args.threads))
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
K.set_session(tf.Session(
config=tf.ConfigProto(
intra_op_parallelism_threads=args.threads,
inter_op_parallelism_threads=args.threads)
))
if tf.test.is_gpu_available(cuda_only=True):
data_gen = medaka.features.SampleGenerator(
bam, region, model_file, rle_ref, read_fraction,
chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
tag_name=tag_name, tag_value=tag_value,
tag_keep_missing=tag_keep_missing,
enable_chunking=enable_chunking)
yield from data_gen.samples
remainder_regions.extend(data_gen._quarantined)
batches = medaka.common.background_generator(
medaka.common.grouper(sample_gen(), batch_size), 10
)
total_region_mbases = sum(r.size for r in regions) / 1e6
logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases))
with medaka.datastore.DataStore(output, 'a', verify_on_close=False) as ds:
mbases_done = 0
t0 = now()
tlast = t0
for data in batches:
x_data = np.stack([x.features for x in data])
class_probs = model.predict_on_batch(x_data)
# calculate bases done taking into account overlap
new_bases = 0
for x in data:
if chunk_ovlp < x.size:
new_bases += x.last_pos[0] - x._get_pos(chunk_ovlp)[0]
else:
new_bases += x.span
mbases_done += new_bases / 1e6
mbases_done = min(mbases_done, total_region_mbases) # just to avoid funny log msg
h['medaka_features_kwargs'][()])
if 'normalise' in medaka_features_kwargs:
normalise = medaka_features_kwargs['normalise']
# delete existing metadata
for i in ['medaka_feature_decoding',
'medaka_features_kwargs',
'medaka_label_counts',
'medaka_label_decoding',
'medaka_model_kwargs',
'medaka_model_name']:
if h.get(i):
del h[i]
# write new-style metadata
with medaka.datastore.DataStore(args.output, mode='a') as ds:
ds.set_meta(medaka.labels.HaploidLabelScheme(), 'label_scheme')
ds.set_meta(
medaka.features.CountsFeatureEncoder(normalise=normalise),
'feature_encoder')
ds.set_meta(
functools.partial(
build_model, feat_len, num_classes,
gru_size=gru_size, classify_activation=classify_activation),
'model_function')