diff --git a/src/models/inference_relaxation.py b/src/models/inference_relaxation.py index 94793c1..8ac3c42 100644 --- a/src/models/inference_relaxation.py +++ b/src/models/inference_relaxation.py @@ -106,18 +106,29 @@ def relax_inference_results( ligand_filepaths = sorted(ligand_filepaths) if len(ligand_filepaths) < len(protein_filepaths): # NOTE: the performance of these loops could likely be improved - protein_filepaths = [ - protein_filepath - for protein_filepath in protein_filepaths - if any( - "_".join(protein_filepath.stem.split("_")[:2]) in ligand_filepath.stem - for ligand_filepath in ligand_filepaths - ) - or any( - "_".join(protein_filepath.stem.split("_")[:2]) in ligand_filepath.parent.stem - for ligand_filepath in ligand_filepaths - ) - ] + if cfg.method == "dynamicbind": + protein_filepaths = [ + protein_filepath + for protein_filepath in protein_filepaths + if any( + "_".join(protein_filepath.parent.parent.stem.split("_")[-3:]) + in ligand_filepath.parent.parent.stem + for ligand_filepath in ligand_filepaths + ) + ] + else: + protein_filepaths = [ + protein_filepath + for protein_filepath in protein_filepaths + if any( + "_".join(protein_filepath.stem.split("_")[:2]) in ligand_filepath.stem + for ligand_filepath in ligand_filepaths + ) + or any( + "_".join(protein_filepath.stem.split("_")[:2]) in ligand_filepath.parent.stem + for ligand_filepath in ligand_filepaths + ) + ] if cfg.method in ["diffdock", "rfaa"]: ligand_filepaths = [ ligand_filepath