from typing import Dict, List
from copy import deepcopy
from dflow import (
InputParameter,
Inputs,
InputArtifact,
Outputs,
OutputArtifact,
Step,
Steps,
argo_range,
argo_len,
)
from dflow.python import(
PythonOPTemplate,
OP,
Slices,
)
from rid.utils import init_executor
[docs]class Label(Steps):
r"""" Label SuperOP.
This SuperOP combines CheckLabelInputs OP, PrepLabel OP and RunLabel OP.
"""
def __init__(
self,
name: str,
check_input_op: OP,
prep_op: OP,
run_op: OP,
stats_op: OP,
prep_config: Dict,
run_config: Dict,
upload_python_package = None,
retry_times = None
):
self._input_parameters = {
"label_config": InputParameter(type=Dict),
"cv_config": InputParameter(type=Dict),
"tail": InputParameter(type=float, value=0.9),
"std_threshold": InputParameter(type=float, value=5.0),
"block_tag" : InputParameter(type=str, value="")
}
self._input_artifacts = {
"topology" : InputArtifact(optional=True),
"models" : InputArtifact(optional=True),
"forcefield" : InputArtifact(optional=True),
"inputfile": InputArtifact(optional=True),
"confs": InputArtifact(),
"at": InputArtifact(optional=True),
"index_file": InputArtifact(optional=True),
"dp_files": InputArtifact(optional=True),
"cv_file": InputArtifact(optional=True),
"conf_tags": InputArtifact(optional=True)
}
self._output_parameters = {
}
self._output_artifacts = {
"md_log": OutputArtifact(),
"cv_forces": OutputArtifact()
}
super().__init__(
name=name,
inputs=Inputs(
parameters=self._input_parameters,
artifacts=self._input_artifacts
),
outputs=Outputs(
parameters=self._output_parameters,
artifacts=self._output_artifacts
),
)
step_keys = {
"check_label_inputs": "{}-check-label-inputs".format(self.inputs.parameters["block_tag"]),
"prep_label": "{}-prep-label".format(self.inputs.parameters["block_tag"]),
"run_label": "{}-run-label".format(self.inputs.parameters["block_tag"]),
"label_stats": "{}-label-stats".format(self.inputs.parameters["block_tag"])
}
self = _label(
self,
step_keys,
check_input_op,
prep_op,
run_op,
stats_op,
prep_config = prep_config,
run_config = run_config,
upload_python_package = upload_python_package,
retry_times = retry_times
)
@property
def input_parameters(self):
return self._input_parameters
@property
def input_artifacts(self):
return self._input_artifacts
@property
def output_parameters(self):
return self._output_parameters
@property
def output_artifacts(self):
return self._output_artifacts
@property
def keys(self):
return self._keys
def _label(
label_steps,
step_keys,
check_label_input_op : OP,
prep_label_op : OP,
run_label_op : OP,
label_stats_op: OP,
prep_config : Dict,
run_config : Dict,
upload_python_package : str = None,
retry_times: int = None
):
prep_config = deepcopy(prep_config)
run_config = deepcopy(run_config)
prep_template_config = prep_config.pop('template_config')
run_template_config = run_config.pop('template_config')
prep_executor = init_executor(prep_config.pop('executor'))
run_executor = init_executor(run_config.pop('executor'))
check_label_inputs = Step(
'check-label-inputs',
template=PythonOPTemplate(
check_label_input_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
**prep_template_config,
),
parameters={},
artifacts={
"conf_tags": label_steps.inputs.artifacts['conf_tags'],
"confs": label_steps.inputs.artifacts['confs'],
},
key = step_keys['check_label_inputs'],
executor = prep_executor,
**prep_config,
)
label_steps.add(check_label_inputs)
prep_merge = False
if prep_executor is not None:
prep_merge = prep_executor.merge_sliced_step
if prep_merge:
prep_label = Step(
'prep-label',
template=PythonOPTemplate(
prep_label_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
slices=Slices("{{item}}",
input_parameter=["task_name"],
input_artifact=["conf", "at"],
output_artifact=["task_path"]),
**prep_template_config,
),
parameters={
"label_config": label_steps.inputs.parameters['label_config'],
"cv_config": label_steps.inputs.parameters['cv_config'],
"task_name": check_label_inputs.outputs.parameters['conf_tags']
},
artifacts={
"topology": label_steps.inputs.artifacts['topology'],
"conf": label_steps.inputs.artifacts['confs'],
"at": label_steps.inputs.artifacts['at'],
"cv_file": label_steps.inputs.artifacts['cv_file']
},
key = step_keys['prep_label']+"-{{item}}",
executor = prep_executor,
with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])),
when = "%s > 0" % (check_label_inputs.outputs.parameters["if_continue"]),
**prep_config
)
else:
prep_label = Step(
'prep-label',
template=PythonOPTemplate(
prep_label_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
slices=Slices("{{item}}",
group_size=10,
pool_size=1,
input_parameter=["task_name"],
input_artifact=["conf", "at"],
output_artifact=["task_path"]),
**prep_template_config,
),
parameters={
"label_config": label_steps.inputs.parameters['label_config'],
"cv_config": label_steps.inputs.parameters['cv_config'],
"task_name": check_label_inputs.outputs.parameters['conf_tags']
},
artifacts={
"topology": label_steps.inputs.artifacts['topology'],
"conf": label_steps.inputs.artifacts['confs'],
"at": label_steps.inputs.artifacts['at'],
"cv_file": label_steps.inputs.artifacts['cv_file']
},
key = step_keys['prep_label']+"-{{item}}",
executor = prep_executor,
with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])),
when = "%s > 0" % (check_label_inputs.outputs.parameters["if_continue"]),
**prep_config
)
label_steps.add(prep_label)
run_merge = False
if run_executor is not None:
run_merge = run_executor.merge_sliced_step
if run_merge:
run_label = Step(
'run-label',
template=PythonOPTemplate(
run_label_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
slices=Slices("{{item}}",
input_parameter=["task_name"],
input_artifact=["task_path","at"],
output_artifact=["plm_out","cv_forces","mf_info","mf_fig","md_log","trajectory"]),
**run_template_config,
),
parameters={
"label_config": label_steps.inputs.parameters["label_config"],
"cv_config": label_steps.inputs.parameters['cv_config'],
"task_name": check_label_inputs.outputs.parameters['conf_tags'],
"tail": label_steps.inputs.parameters['tail']
},
artifacts={
"forcefield": label_steps.inputs.artifacts['forcefield'],
"task_path": prep_label.outputs.artifacts["task_path"],
"index_file": label_steps.inputs.artifacts['index_file'],
"dp_files": label_steps.inputs.artifacts['dp_files'],
"cv_file": label_steps.inputs.artifacts['cv_file'],
"inputfile": label_steps.inputs.artifacts['inputfile'],
"at": label_steps.inputs.artifacts['at']
},
key = step_keys['run_label']+"-{{item}}",
executor = run_executor,
with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])),
continue_on_success_ratio = 0.75,
**run_config
)
else:
run_label = Step(
'run-label',
template=PythonOPTemplate(
run_label_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
slices=Slices("{{item}}",
group_size=10,
pool_size=1,
input_parameter=["task_name"],
input_artifact=["task_path","at"],
output_artifact=["plm_out","cv_forces","mf_info","mf_fig","md_log", "trajectory"]),
**run_template_config,
),
parameters={
"label_config": label_steps.inputs.parameters["label_config"],
"cv_config": label_steps.inputs.parameters['cv_config'],
"task_name": check_label_inputs.outputs.parameters['conf_tags'],
"tail": label_steps.inputs.parameters['tail']
},
artifacts={
"forcefield": label_steps.inputs.artifacts['forcefield'],
"task_path": prep_label.outputs.artifacts["task_path"],
"index_file": label_steps.inputs.artifacts['index_file'],
"dp_files": label_steps.inputs.artifacts['dp_files'],
"cv_file": label_steps.inputs.artifacts['cv_file'],
"inputfile": label_steps.inputs.artifacts['inputfile'],
"at": label_steps.inputs.artifacts['at']
},
key = step_keys['run_label']+"-{{item}}",
executor = run_executor,
with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])),
continue_on_success_ratio = 0.75,
**run_config
)
label_steps.add(run_label)
label_outputs_stats = Step(
'label-outputs-stats',
template=PythonOPTemplate(
label_stats_op,
python_packages = upload_python_package,
retry_on_transient_error = retry_times,
**prep_template_config,
),
parameters={
"std_threshold": label_steps.inputs.parameters["std_threshold"]
},
artifacts={
"cv_forces": run_label.outputs.artifacts["cv_forces"],
"mf_info": run_label.outputs.artifacts["mf_info"]
},
key = step_keys["label_stats"],
executor = prep_executor,
**prep_config,
)
label_steps.add(label_outputs_stats)
label_steps.outputs.artifacts["cv_forces"]._from = label_outputs_stats.outputs.artifacts["cv_forces"]
label_steps.outputs.artifacts["md_log"]._from = run_label.outputs.artifacts["md_log"]
return label_steps