Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sklearn_to_flow_list_of_lists(self):
from sklearn.preprocessing import OrdinalEncoder
ordinal_encoder = OrdinalEncoder(categories=[[0, 1], [0, 1]])
extension = openml.extensions.sklearn.SklearnExtension()
# Test serialization works
flow = extension.model_to_flow(ordinal_encoder)
# Test flow is accepted by server
self._add_sentinel_to_flow_name(flow)
flow.publish()
TestBase._mark_entity_for_removal('flow', (flow.flow_id, flow.name))
TestBase.logger.info("collected from {}: {}".format(__file__.split('/')[-1], flow.flow_id))
# Test deserialization works
server_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True)
self.assertEqual(server_flow.parameters['categories'], '[[0, 1], [0, 1]]')
self.assertEqual(server_flow.model.categories, flow.model.categories)
def test_run_flow_on_task_downloaded_flow(self):
model = sklearn.ensemble.RandomForestClassifier(n_estimators=33)
flow = self.extension.model_to_flow(model)
flow.publish(raise_error_if_exists=False)
TestBase._mark_entity_for_removal('flow', (flow.flow_id, flow.name))
TestBase.logger.info("collected from test_run_functions: {}".format(flow.flow_id))
downloaded_flow = openml.flows.get_flow(flow.flow_id)
task = openml.tasks.get_task(119) # diabetes
run = openml.runs.run_flow_on_task(
flow=downloaded_flow,
task=task,
avoid_duplicate_runs=False,
upload_flow=False,
)
run.publish()
TestBase._mark_entity_for_removal('run', run.run_id)
TestBase.logger.info("collected from {}: {}".format(__file__.split('/')[-1], run.run_id))
sklearn.feature_selection.VarianceThreshold(),
),
('classifier', sklearn.tree.DecisionTreeClassifier())
]
complicated = sklearn.pipeline.Pipeline(steps=steps)
for classifier in [nb, complicated]:
flow = self.extension.model_to_flow(classifier)
flow, _ = self._add_sentinel_to_flow_name(flow, None)
# publish the flow
flow = flow.publish()
TestBase._mark_entity_for_removal('flow', (flow.flow_id, flow.name))
TestBase.logger.info("collected from {}: {}".format(__file__.split('/')[-1],
flow.flow_id))
# redownload the flow
flow = openml.flows.get_flow(flow.flow_id)
# check if flow exists can find it
flow = openml.flows.get_flow(flow.flow_id)
downloaded_flow_id = openml.flows.flow_exists(
flow.name,
flow.external_version,
)
self.assertEqual(downloaded_flow_id, flow.flow_id)
def test_get_flow(self):
# We need to use the production server here because 4024 is not the
# test server
openml.config.server = self.production_server
flow = openml.flows.get_flow(4024)
self.assertIsInstance(flow, openml.OpenMLFlow)
self.assertEqual(flow.flow_id, 4024)
self.assertEqual(len(flow.parameters), 24)
self.assertEqual(len(flow.components), 1)
subflow_1 = list(flow.components.values())[0]
self.assertIsInstance(subflow_1, openml.OpenMLFlow)
self.assertEqual(subflow_1.flow_id, 4025)
self.assertEqual(len(subflow_1.parameters), 14)
self.assertEqual(subflow_1.parameters['E'], 'CC')
self.assertEqual(len(subflow_1.components), 1)
subflow_2 = list(subflow_1.components.values())[0]
self.assertIsInstance(subflow_2, openml.OpenMLFlow)
self.assertEqual(subflow_2.flow_id, 4026)
self.assertEqual(len(subflow_2.parameters), 13)
def test_get_flow_reinstantiate_model_no_extension(self):
# Flow 10 is a WEKA flow
self.assertRaisesRegex(RuntimeError,
"No extension could be found for flow 10: weka.SMO",
openml.flows.get_flow,
flow_id=10,
reinstantiate=True)
def test_get_structure(self):
# also responsible for testing: flow.get_subflow
# We need to use the production server here because 4024 is not the
# test server
openml.config.server = self.production_server
flow = openml.flows.get_flow(4024)
flow_structure_name = flow.get_structure('name')
flow_structure_id = flow.get_structure('flow_id')
# components: root (filteredclassifier), multisearch, loginboost,
# reptree
self.assertEqual(len(flow_structure_name), 4)
self.assertEqual(len(flow_structure_id), 4)
for sub_flow_name, structure in flow_structure_name.items():
if len(structure) > 0: # skip root element
subflow = flow.get_subflow(structure)
self.assertEqual(subflow.name, sub_flow_name)
for sub_flow_id, structure in flow_structure_id.items():
if len(structure) > 0: # skip root element
subflow = flow.get_subflow(structure)
self.assertEqual(subflow.flow_id, sub_flow_id)
def test_can_handle_flow(self):
openml.config.server = self.production_server
R_flow = openml.flows.get_flow(6794)
assert not self.extension.can_handle_flow(R_flow)
old_3rd_party_flow = openml.flows.get_flow(7660)
assert self.extension.can_handle_flow(old_3rd_party_flow)
openml.config.server = self.test_server