-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
76 lines (53 loc) · 2.83 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import models
def parse_arguments():
"""
Parses the command-line arguments provided by the user.
## To run this script from the command line, you can provide arguments like so:
python inference.py --gt_folder <path_to_gt>
--hazy_folder <path_to_hazy> --output_dir <output_dir>
--classifier <path_to_classifier> --cloudSD <path_to_cloudSD>
--ehSD <path_to_ehSD> --fogSD <path_to_fogSD>
Returns:
argparse.Namespace: An object containing the parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description="Run the HazeSpace2M inference")
parser.add_argument('--gt_folder', type=str,
help='Path to the Ground Truth (GT) folder or image.',
default=r'F:\Research\HazeSpace2M\data\Fog\Haze\OOTSEHL1_1.jpg',
required=False)
parser.add_argument('--hazy_folder', type=str,
help='Path to the hazy folder or image. It can handle both single image and folder of images.',
default=r'F:\Research\HazeSpace2M\data\Fog\Haze\OOTSEHL1_3.jpg',
required=False)
parser.add_argument('--output_dir', type=str,
help='Directory to save dehazed images.',
default=r"./storage/Test1/",
required=False)
parser.add_argument('--classifier', type=str,
help='Path to the classifier model. The classifier you want to use for predicting the haze type and conditional dehazing.',
default="./pretrained_weights/classifiers/ResNet152.pth",
required=False)
parser.add_argument('--cloudSD', type=str,
help='Path to the cloud specialized dehazer model.',
default=r"./pretrained_weights/dehazers/LD_Net_Cloud.pth",
required=False)
parser.add_argument('--ehSD', type=str,
help='PPath to the EH specialized dehazer model.',
default=r"./pretrained_weights/dehazers/LD_Net_EH.pth",
required=False)
parser.add_argument('--fogSD', type=str,
help='Path to the Fog specialized dehazer model.',
default=r"./pretrained_weights/dehazers/LD_Net_Fog.pth",
required=False)
return parser.parse_args()
def main():
"""
Main function to run the HazeSpace2M inference based on provided arguments.
"""
args = parse_arguments()
dehazers = [args.cloudSD, args.ehSD, args.fogSD]
models.conditionalDehazing(gt_image=args.gt_folder, hazy_image=args.hazy_folder, dehazers=dehazers,
classifier=args.classifier, output_dir=args.output_dir)
if __name__ == "__main__":
main()