Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_sample_task_metadata():
"""
:rtype: flytekit.models.task.TaskMetadata
"""
return _task_model.TaskMetadata(
True,
_task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
_literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!"
)
b0 = _literals.Binding('a', _literals.BindingData(
scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))))
b1 = _literals.Binding('b', _literals.BindingData(
promise=_types.OutputReference('my_node', 'b')))
b2 = _literals.Binding('c', _literals.BindingData(
promise=_types.OutputReference('my_node', 'c')))
node_metadata = _workflow.NodeMetadata(
name='node1',
timeout=timedelta(seconds=10),
retries=_literals.RetryStrategy(0)
)
task_metadata = _task.TaskMetadata(
True,
_task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
_literals.RetryStrategy(3),
"0.1.1b0",
"This is deprecated!"
)
cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")
resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])
task = _task.TaskTemplate(
_identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"),
"python",
task_metadata,
typed_interface,
{'a': 1, 'b': {'c': 2, 'd': 3}},
container=_task.Container(
LIST_OF_RESOURCE_ENTRY_LISTS = [
LIST_OF_RESOURCE_ENTRIES
]
LIST_OF_RESOURCES = [
task.Resources(request, limit)
for request, limit in product(LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS)
]
LIST_OF_RUNTIME_METADATA = [
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python"),
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang")
]
LIST_OF_RETRY_POLICIES = [
literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100]
]
LIST_OF_TASK_METADATA = [
task.TaskMetadata(
discoverable,
runtime_metadata,
timeout,
retry_strategy,
discovery_version,
deprecated
obj = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!"
)
assert obj.discoverable is True
assert obj.retries.retries == 3
assert obj.interruptible is True
assert obj.timeout == timedelta(days=1)
assert obj.runtime.flavor == "python"
assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK
assert obj.runtime.version == "1.0.0"
assert obj.deprecated_error_message == "This is deprecated!"
assert obj.discovery_version == "0.1.1b0"
assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
"""
# Use the training job model as a measure of type checking
self._training_job_model = _training_job_models.TrainingJob(
algorithm_specification=algorithm_specification,
training_job_config=training_job_config,
)
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
# job gracefully
timeout = _datetime.timedelta(seconds=0)
super(SdkSimpleTrainingJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=interruptible,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"static_hyperparameters": _interface_model.Variable(
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
description="",
),
# Use the training job model as a measure of type checking
hpo_job = _hpo_job_model.HPOJob(
max_number_of_training_jobs=max_number_of_training_jobs,
max_parallel_training_jobs=max_parallel_training_jobs,
training_job=training_job.training_job_model,
).to_flyte_idl()
# Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
# the underlying training job
# TODO: Discuss whether this is a viable interface or contract
timeout = _datetime.timedelta(seconds=0)
super(SdkSimpleHPOJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_HPO_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=interruptible,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"hpo_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(_hpo_job_pb2.HPOJobConfig).to_flyte_literal_type(), ""
),
:param Text discovery_version: String describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param datetime.timedelta timeout:
"""
# Set as class fields which are used down below to configure implicit
# parameters
self._routing_group = routing_group or ""
self._catalog = catalog or ""
self._schema = schema or ""
metadata = _task_model.TaskMetadata(
discoverable,
# This needs to have the proper version reflected in it
_task_model.RuntimeMetadata(
_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__,
"python"),
timeout or _datetime.timedelta(seconds=0),
_literals.RetryStrategy(retries),
interruptible,
discovery_version,
"This is deprecated!"
)
presto_query = _presto_models.PrestoQuery(
routing_group=routing_group or "",
catalog=catalog or "",
schema=schema or "",
statement=statement
)
# Here we set the routing_group, catalog, and schema as implicit
task_obj['task_type'] = _common_constants.SdkTaskType.PYTHON_TASK,
task_obj['retries'] = retries,
task_obj['storage_request'] = storage_request,
task_obj['cpu_request'] = cpu_request,
task_obj['gpu_request'] = gpu_request,
task_obj['memory_request'] = memory_request,
task_obj['storage_limit'] = storage_limit,
task_obj['cpu_limit'] = cpu_limit,
task_obj['gpu_limit'] = gpu_limit,
task_obj['memory_limit'] = memory_limit,
task_obj['environment'] = environment,
task_obj['custom'] = {}
metadata = _task_model.TaskMetadata(
cache,
_task_model.RuntimeMetadata(
_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK,
'1.2.3',
'python'
),
timeout or _datetime.timedelta(seconds=0),
_literal_models.RetryStrategy(retries),
interruptible,
cache_version,
deprecated
)
interface = get_interface_from_task_info(fn.__annotations__, outputs or [])
task_instance = PythonTask(fn, interface, metadata, outputs, task_obj)
# TODO: One of the things I want to make sure to do is better naming support. At this point, we should already
# be able to determine the name of the task right? Can anyone think of situations where we can't?