diff --git a/configs/analysis/inference_analysis_casp.yaml b/configs/analysis/inference_analysis_casp.yaml index 9120dad..5d78ca6 100644 --- a/configs/analysis/inference_analysis_casp.yaml +++ b/configs/analysis/inference_analysis_casp.yaml @@ -12,3 +12,4 @@ fault_tolerant: true # whether to continue processing targets if an error occurs skip_existing: true # whether to skip processing targets for which output already exists score_relaxed_structures: true # whether to score relaxed structures in addition to the original (unrelaxed) structures repeat_index: 1 # the run index to use for scoring predictions +no_pretraining: false # whether to score a model without pretraining diff --git a/configs/model/ensemble_generation.yaml b/configs/model/ensemble_generation.yaml index 9574df2..e7d752b 100644 --- a/configs/model/ensemble_generation.yaml +++ b/configs/model/ensemble_generation.yaml @@ -96,6 +96,7 @@ neuralplexer_use_template: true # whether to use the input template protein stru neuralplexer_separate_pdb: true # whether to separate the predicted protein structures into dedicated PDB files neuralplexer_rank_outputs_by_confidence: true # whether to rank the output conformations, by default, by ligand confidence (if available) and by protein confidence otherwise neuralplexer_plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`) +neuralplexer_no_pretraining: false # whether to avoid loading pretrained weights # RoseTTAFold-All-Atom inference arguments: rfaa_python_exec_path: ${oc.env:PROJECT_ROOT}/forks/RoseTTAFold-All-Atom/RFAA/bin/python3 # the Python executable to use rfaa_exec_dir: ${oc.env:PROJECT_ROOT}/forks/RoseTTAFold-All-Atom # the RoseTTAFold-All-Atom directory in which to execute the inference scripts diff --git a/src/analysis/inference_analysis_casp.py b/src/analysis/inference_analysis_casp.py index fd79e75..365bc06 100644 --- a/src/analysis/inference_analysis_casp.py +++ b/src/analysis/inference_analysis_casp.py @@ -183,6 +183,11 @@ def main(cfg: DictConfig): :param cfg: Configuration dictionary from the hydra YAML file. """ + if cfg.no_pretraining: + with open_dict(cfg): + cfg.predictions_dir = cfg.predictions_dir.replace( + "_ensemble_predictions", "_npt_ensemble_predictions" + ) if cfg.method == "vina": with open_dict(cfg): cfg.predictions_dir = cfg.predictions_dir.replace( diff --git a/src/models/ensemble_generation.py b/src/models/ensemble_generation.py index a588ed8..5284d20 100644 --- a/src/models/ensemble_generation.py +++ b/src/models/ensemble_generation.py @@ -815,7 +815,7 @@ def get_method_predictions( elif method == "neuralplexer": ensemble_benchmarking_output_dir = ( Path(cfg.input_dir if cfg.input_dir else cfg.neuralplexer_out_path).parent - / f"neuralplexer_{cfg.ensemble_benchmarking_dataset}_outputs_{cfg.ensemble_benchmarking_repeat_index}" + / f"neuralplexer{'_npt' if cfg.neuralplexer_no_pretraining else ''}_{cfg.ensemble_benchmarking_dataset}_outputs_{cfg.ensemble_benchmarking_repeat_index}" if cfg.ensemble_benchmarking else (cfg.input_dir if cfg.input_dir else cfg.neuralplexer_out_path) )