Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy"
# model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
if args.export_dir:
print("exporting model to: {}".format(args.export_dir))
compat.export_saved_model(model, args.export_dir, ctx.job_name == 'chief')
tf_feed.terminate()
# create a Spark DataFrame of training examples (features, labels)
rdd = self.sc.parallelize(self.train_examples, 2)
trainDF = rdd.toDF(['col1', 'col2'])
# train and export model
args = {}
estimator = TFEstimator(_spark_train, args) \
.setInputMapping({'col1': 'x', 'col2': 'y_'}) \
.setModelDir(self.model_dir) \
.setExportDir(self.export_dir) \
.setClusterSize(self.num_workers) \
.setNumPS(0) \
.setBatchSize(1) \
.setEpochs(1)
model = estimator.fit(trainDF)
self.assertTrue(os.path.isdir(self.export_dir))
# create a Spark DataFrame of test examples (features, labels)
testDF = self.spark.createDataFrame(self.test_examples, ['c1', 'c2'])
# test saved_model using exported signature
model.setTagSet('serve') \
.setSignatureDefKey('serving_default') \
def test_basic_tf(self):
"""Single-node TF graph (w/ args) running independently on multiple executors."""
def _map_fun(args, ctx):
import tensorflow as tf
x = tf.constant(args['x'])
y = tf.constant(args['y'])
sum = tf.math.add(x, y)
assert sum.numpy() == 3
args = {'x': 1, 'y': 2}
cluster = TFCluster.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, num_ps=0)
cluster.shutdown()
"""Normalization of absolution & relative string paths depending on filesystem"""
cwd = os.getcwd()
user = getpass.getuser()
fs = ["file://", "hdfs://", "viewfs://"]
paths = {
"hdfs://foo/bar": ["hdfs://foo/bar", "hdfs://foo/bar", "hdfs://foo/bar"],
"viewfs://foo/bar": ["viewfs://foo/bar", "viewfs://foo/bar", "viewfs://foo/bar"],
"file://foo/bar": ["file://foo/bar", "file://foo/bar", "file://foo/bar"],
"/foo/bar": ["file:///foo/bar", "hdfs:///foo/bar", "viewfs:///foo/bar"],
"foo/bar": ["file://{}/foo/bar".format(cwd), "hdfs:///user/{}/foo/bar".format(user), "viewfs:///user/{}/foo/bar".format(user)],
}
for i in range(len(fs)):
ctx = type('MockContext', (), {'defaultFS': fs[i], 'working_dir': cwd})
for path, expected in paths.items():
final_path = TFNode.hdfs_path(ctx, path)
self.assertEqual(final_path, expected[i], "fs({}) + path({}) => {}, expected {}".format(fs[i], path, final_path, expected[i]))
def _spark_train(args, ctx):
"""Basic linear regression in a distributed TF cluster using InputMode.SPARK"""
import tensorflow as tf
from tensorflowonspark import TFNode
tf.compat.v1.reset_default_graph()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = Sequential()
model.add(Dense(1, activation='linear', input_shape=[2]))
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.2), loss='mse', metrics=['mse'])
model.summary()
tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping)
def rdd_generator():
while not tf_feed.should_stop():
batch = tf_feed.next_batch(1)
if len(batch['x']) > 0:
features = batch['x'][0]
label = batch['y_'][0]
yield (features, label)
else:
return
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([2]), tf.TensorShape([1])))
# disable auto-sharding since we're feeding from an RDD generator
options = tf.data.Options()
compat.disable_auto_shard(options)
ds = ds.with_options(options)
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(10)
if len(batch) > 0:
squares = tf.math.square(batch)
tf_feed.batch_results(squares.numpy())
# simulate post-feed actions that raise an exception
time.sleep(2)
raise Exception("FAKE exception after feeding")
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(10)
if len(batch) > 0:
squares = tf.math.square(batch)
tf_feed.batch_results(squares.numpy())
raise Exception("FAKE exception during feeding")
def test_dfutils(self):
# create a DataFrame of a single row consisting of standard types (str, int, int_array, float, float_array, binary)
row1 = ('text string', 1, [2, 3, 4, 5], -1.1, [-2.2, -3.3, -4.4, -5.5], bytearray(b'\xff\xfe\xfd\xfc'))
rdd = self.sc.parallelize([row1])
df1 = self.spark.createDataFrame(rdd, ['a', 'b', 'c', 'd', 'e', 'f'])
print("schema: {}".format(df1.schema))
# save the DataFrame as TFRecords
dfutil.saveAsTFRecords(df1, self.tfrecord_dir)
self.assertTrue(os.path.isdir(self.tfrecord_dir))
# reload the DataFrame from exported TFRecords
df2 = dfutil.loadTFRecords(self.sc, self.tfrecord_dir, binary_features=['f'])
row2 = df2.take(1)[0]
print("row_saved: {}".format(row1))
print("row_loaded: {}".format(row2))
# confirm loaded values match original/saved values
self.assertEqual(row1[0], row2['a'])
self.assertEqual(row1[1], row2['b'])
self.assertEqual(row1[2], row2['c'])
self.assertAlmostEqual(row1[3], row2['d'], 6)
for i in range(len(row1[4])):
self.assertAlmostEqual(row1[4][i], row2['e'][i], 6)
print("type(f): {}".format(type(row2['f'])))
for i in range(len(row1[5])):
self.assertEqual(row1[5][i], row2['f'][i])
def test_inputmode_spark(self):
"""Distributed TF cluster w/ InputMode.SPARK"""
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(batch_size=10)
print("batch: {}".format(batch))
squares = tf.math.square(batch)
print("squares: {}".format(squares))
tf_feed.batch_results(squares.numpy())
input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder
rdd = self.sc.parallelize(input, 10)
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK)
rdd_out = cluster.inference(rdd)
rdd_sum = rdd_out.sum()
self.assertEqual(rdd_sum, sum([x * x for x in range(1000)]))
cluster.shutdown()
def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
tfso_server = Server(5)
assert tfso_server.get_server_ip() == util.get_ip_address()
def test_dfutils(self):
# create a DataFrame of a single row consisting of standard types (str, int, int_array, float, float_array, binary)
row1 = ('text string', 1, [2, 3, 4, 5], -1.1, [-2.2, -3.3, -4.4, -5.5], bytearray(b'\xff\xfe\xfd\xfc'))
rdd = self.sc.parallelize([row1])
df1 = self.spark.createDataFrame(rdd, ['a', 'b', 'c', 'd', 'e', 'f'])
print("schema: {}".format(df1.schema))
# save the DataFrame as TFRecords
dfutil.saveAsTFRecords(df1, self.tfrecord_dir)
self.assertTrue(os.path.isdir(self.tfrecord_dir))
# reload the DataFrame from exported TFRecords
df2 = dfutil.loadTFRecords(self.sc, self.tfrecord_dir, binary_features=['f'])
row2 = df2.take(1)[0]
print("row_saved: {}".format(row1))
print("row_loaded: {}".format(row2))
# confirm loaded values match original/saved values
self.assertEqual(row1[0], row2['a'])
self.assertEqual(row1[1], row2['b'])
self.assertEqual(row1[2], row2['c'])
self.assertAlmostEqual(row1[3], row2['d'], 6)
for i in range(len(row1[4])):
self.assertAlmostEqual(row1[4][i], row2['e'][i], 6)