Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for setting JobConf Environment Variables on a per-step basis
"""
from mrjob.compat import jobconf_from_env
from mrjob.job import MRJob
from mrjob.step import MRStep
JOBCONF_LIST = [
'mapred.map.tasks',
'mapreduce.job.local.dir',
'user.defined',
]
class MRTestPerStepJobConf(MRJob):
def mapper_init(self):
self.increment_counter('count', 'mapper_init', 1)
for jobconf in JOBCONF_LIST:
yield ((self.options.step_num, jobconf),
jobconf_from_env(jobconf, None))
def mapper(self, key, value):
yield key, value
def steps(self):
return([
MRStep(mapper_init=self.mapper_init),
MRStep(mapper_init=self.mapper_init,
mapper=self.mapper,
jobconf={'user.defined': 'nothing',
def _test_explicit(self, m=False, c=False, r=False, **kwargs):
s = MRStep(**kwargs)
self.assertEqual(s.has_explicit_mapper, m)
self.assertEqual(s.has_explicit_combiner, c)
self.assertEqual(s.has_explicit_reducer, r)
input2_file.write(input_file_bytes)
with open(manifest_filename, 'w') as manifest_file:
manifest_file.writelines([
'%s\n' % input1_filename, '%s\n' % input2_filename])
job = self._harness_job(
MRNickNackWithHadoopInputFormat, runner_alias='local',
spark_conf=self.SPARK_CONF, input_paths=([manifest_filename]))
with job.make_runner() as runner:
runner.run()
output_counts = dict(
line.strip().split(b'\t')
for line in to_lines(runner.cat_output()))
expected_output_counts = {b'"tomato"': b'2', b'"potato"': b'2'}
self.assertEqual(expected_output_counts, output_counts)
job = self._harness_job(
MRNickNack, input_bytes=input_bytes, runner_alias='local',
spark_conf=self.SPARK_CONF, compression_codec=compression_codec)
with job.make_runner() as runner:
runner.run()
self.assertTrue(runner.fs.exists(
join(runner.get_output_dir(), 'o', 'part*.gz')))
self.assertTrue(runner.fs.exists(
join(runner.get_output_dir(), 't', 'part*.gz')))
output_counts = dict(
line.strip().split(b'\t')
for line in to_lines(runner.cat_output()))
expected_output_counts = {b'"one"': b'2', b'"two"': b'2'}
self.assertEqual(expected_output_counts, output_counts)
def test_cant_override_libjars_on_command_line(self):
with patch.object(MRJob, 'libjars', return_value=['honey.jar']):
job = MRJob(['--libjars', 'cookie.jar'])
# ignore switch, don't resolve relative path
self.assertEqual(job._runner_kwargs()['libjars'],
['honey.jar', 'cookie.jar'])
def test_is_task(self):
self.assertEqual(MRJob([]).is_task(), False)
self.assertEqual(MRJob(['--mapper']).is_task(), True)
self.assertEqual(MRJob(['--reducer']).is_task(), True)
self.assertEqual(MRJob(['--combiner']).is_task(), True)
self.assertEqual(MRJob(['--spark']).is_task(), True)
# it as a script anyway.
class MRBoringJob2(MRBoringJob):
INPUT_PROTOCOL = StandardJSONProtocol
INTERNAL_PROTOCOL = PickleProtocol
OUTPUT_PROTOCOL = ReprProtocol
class MRBoringJob3(MRBoringJob):
def internal_protocol(self):
return ReprProtocol()
class MRBoringJob4(MRBoringJob):
INTERNAL_PROTOCOL = ReprProtocol
class MRTrivialJob(MRJob):
OUTPUT_PROTOCOL = RawValueProtocol
def mapper(self, key, value):
yield key, value
def assertMethodsEqual(self, fs, gs):
# we're going to use this to match bound against unbound methods
self.assertEqual([_im_func(f) for f in fs],
[_im_func(g) for g in gs])
def test_default_protocols(self):
mr_job = MRBoringJob([])
self.assertMethodsEqual(
mr_job.pick_protocols(0, 'mapper'),
(RawValueProtocol.read, JSONProtocol.write))
def test_init_does_not_require_tzset(self):
MRJob()
def test_region_nobucket_nomatchexists(self):
# aws_region specified, no bucket specified, no buckets have matching
# region
self.bucket1.set_location('PUPPYLAND')
j = EMRJobRunner(aws_region='KITTYLAND',
s3_endpoint='KITTYLAND',
conf_path=False)
self.assertNotEqual(j._opts['s3_scratch_uri'], self.bucket1_uri)
def make_runner(self):
self.runner = EMRJobRunner(conf_paths=[])
self.add_mock_s3_data({'walrus': {}})
self.runner = EMRJobRunner(cloud_fs_sync_secs=0,
cloud_tmp_dir='s3://walrus/tmp',
conf_paths=[])
self.runner._s3_log_dir_uri = BUCKET_URI + LOG_DIR
self.prepare_runner_for_ssh(self.runner)
self.output_dir = tempfile.mkdtemp(prefix='mrboss_wd')