Skip to content

Commit

Permalink
Add more relational constraints, tools, and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alperaltuntas committed Mar 28, 2024
1 parent 5ddcaa0 commit 8fa0c82
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'temp'))

def test_relational_constraints():
"""Check several relational constraints for the custom compset configuration."""
def test_constraint_violation_detection():
"""Confirm relational constraint violations are caught for the custom compset configuration."""

ConfigVar.reboot()
Stage.reboot()
Expand All @@ -46,24 +46,35 @@ def test_relational_constraints():
cvars['COMP_ATM'].value = "cam"

with pytest.raises(ConstraintViolation):
# CAM cannot be coupled with Data ICE
cvars['COMP_ICE'].value = "dice"

cvars['COMP_LND'].value = "clm"
cvars['COMP_ICE'].value = "cice"

cvars['COMP_ICE'].value = "cice"
with pytest.raises(ConstraintViolation):
# to enable CICE, must pick an active/data ocn
cvars['COMP_OCN'].value = "socn"

cvars['COMP_ICE'].value = "sice"
cvars['COMP_OCN'].value = "socn"

with pytest.raises(ConstraintViolation):
# cannot couple stub ocn with active wave
cvars['COMP_WAV'].value = "ww3"
assert cvars['COMP_WAV'].value == None

cvars['COMP_OCN'].value = "mom"
cvars['COMP_ICE'].value = "cice"

with pytest.raises(ConstraintViolation):
# MOM6 cannot be coupled with data wave component
cvars['COMP_WAV'].value = "dwav"
assert cvars['COMP_WAV'].value == None

cvars['COMP_ROF'].value = "mosart"
with pytest.raises(ConstraintViolation):
# MOSART cannot be run with slim
cvars['COMP_LND'].value = "slim"
assert cvars['COMP_LND'].value == "clm"

Expand All @@ -81,6 +92,7 @@ def test_relational_constraints():
cvars['COMP_ATM_OPTION'].value = "(none)"

with pytest.raises(ConstraintViolation):
# must pick a valid CLM option
cvars['COMP_LND_OPTION'].value = "(none)"
cvars['COMP_LND_OPTION'].value = "SP"

Expand Down Expand Up @@ -146,6 +158,6 @@ def test_multiple_reasons():
cvars['COMP_WAV'].value = "swav"

if __name__ == "__main__":
test_relational_constraints()
test_constraint_violation_detection()
test_multiple_reasons()

97 changes: 97 additions & 0 deletions tests/4_static/test_relational_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest

from z3 import And, Not, Implies, Or, Solver, sat, unsat

from ProConPy.config_var import ConfigVar, cvars
from ProConPy.config_var import cvars
from ProConPy.csp_solver import csp
from ProConPy.stage import Stage

from visualCaseGen.cime_interface import CIME_interface
from visualCaseGen.initialize_configvars import initialize_configvars
from visualCaseGen.initialize_widgets import initialize_widgets
from visualCaseGen.initialize_stages import initialize_stages
from visualCaseGen.specs.options import set_options
from visualCaseGen.specs.relational_constraints import get_relational_constraints


def test_initial_satisfiability():
"""Check that the relational constraints are satisfiable"""
ConfigVar.reboot()
Stage.reboot()
cime = CIME_interface()
initialize_configvars(cime)
initialize_widgets(cime)
initialize_stages(cime)
set_options(cime)
relational_constraints_dict = get_relational_constraints(cvars)
csp.initialize(cvars, relational_constraints_dict, Stage.first())

# check that relational constraints are satisfiable
s = Solver()
s.add([k for k in relational_constraints_dict.keys()])
assert s.check() != unsat, "Relational constraints are not satisfiable."

# check that initial options are all satisfiable
for varname, var in cvars.items():
if var.has_options():
s.add(Or([var == opt for opt in var.options]))
assert s.check() != unsat, f"Initial options for {varname} are not satisfiable."
elif var._options_spec:
opts = var._options_spec()
if opts[0] is not None:
s.add(Or([var == opt for opt in opts]))
assert s.check() != unsat, f"Initial options_spec for {varname} are not satisfiable."

# check that all initial options are satisfiable in some combination
for varname, var in cvars.items():
opts = []
if var.has_options():
opts = var.options
elif var._options_spec:
opts = var._options_spec()[0] or []
for opt in opts:
assert s.check(var == opt) == sat, f"Initial option {opt} for {varname} is not satisfiable."


def test_constraint_redundancy():
"""Check to see if any of the relational constraints is redundant
i.e., already implied by the preceding constraints."""

ConfigVar.reboot()
Stage.reboot()
cime = CIME_interface()
initialize_configvars(cime)
initialize_widgets(cime)
initialize_stages(cime)
set_options(cime)
relational_constraints_dict = get_relational_constraints(cvars)
csp.initialize(cvars, relational_constraints_dict, Stage.first())

constraints = [constr for constr, _ in relational_constraints_dict.items()]

for i in range(1,len(constraints)):
constraint = constraints[i]
s = Solver()
if s.check(Not(Implies(And(constraints[:i]), constraint))) == unsat:
raise AssertionError(f'Constraint "{constraint}" is redundant.')

def test_err_msg_repetition():
"""Check if any error messages are repeated in the relational constraints."""

relational_constraints = get_relational_constraints(cvars)

err_msg_list = [err_msg for _, err_msg in relational_constraints.items()]
err_msg_set = set(err_msg_list)

# If any error message is repeated, find out which ones are repeated and raise an AssertionError
if len(err_msg_list) != len(err_msg_set):
count = {err_msg: 0 for err_msg in err_msg_set}
for err_msg in err_msg_list:
count[err_msg] += 1
repeated_err_msgs = {err_msg: count[err_msg] for err_msg in err_msg_set if count[err_msg] > 1}
raise AssertionError(f"Error messages are repeated: {repeated_err_msgs}")


if __name__ == "__main__":
test_initial_satisfiability()
57 changes: 0 additions & 57 deletions tests/4_static/test_satisfiability.py

This file was deleted.

File renamed without changes.
File renamed without changes.
67 changes: 67 additions & 0 deletions tools/pattern_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python

from visualCaseGen.cime_interface import CIME_interface

cime = CIME_interface()

def rn(s):
"""Remove numeric characters from a string."""
return ''.join([i for i in s if not i.isdigit()])


def get_models(compset_lname):
"""Given a compset lname, return the list of models that are coupled."""
compset_components = cime.get_components_from_compset_lname(compset_lname)
model_list = []
for comp_class, compset_component in compset_components.items():
if not compset_component.startswith('X'):
phys = compset_component.split('%')[0]
model = next((model for model in cime.models[comp_class] if phys in cime.comp_phys[model]), None)
model_list.append(model)
return model_list

def compset_pattern_finder():
"""Find the models that are never coupled and always coupled with each other.
This is useful for identifying patterns and adding them as relational constraints."""

all_models = {model for model_list in cime.models.values() for model in model_list if model[0] != 'x'}

never = {model: all_models - {model} for model in all_models}
always = {model: all_models - {model} for model in all_models}

for compset in cime.compsets.values():

compset_models = get_models(compset.lname)

for model in compset_models:
if model:
never[model].difference_update(compset_models)
always[model].intersection_update(compset_models)


# From never, remove the models that are in the same component class as the model:
for model, never_coupled in never.items():
comp_class = None
for cc, models in cime.models.items():
if model in models:
comp_class = cc
break
if comp_class:
never[model] = never_coupled - set(cime.models[comp_class])

# Print the results
print('Never coupled:')
print('----------------------------------------')
for model, never_coupled in never.items():
if never_coupled:
print(f'{model}: {never_coupled}')

print('\nAlways coupled:')
print('----------------------------------------')
for model, always_coupled in always.items():
if always_coupled:
print(f'{model}: {always_coupled}')


if __name__ == '__main__':
compset_pattern_finder()
27 changes: 21 additions & 6 deletions visualCaseGen/specs/relational_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,32 @@ def get_relational_constraints(cvars):
Implies(COMP_WAV=="ww3", In(COMP_OCN, ["mom", "pop"])) :
"WW3 can only be selected if either POP2 or MOM6 is the ocean component.",

Implies(Or(COMP_ROF=="rtm", COMP_ROF=="mosart"), COMP_LND=='clm') :
"RTM or MOSART can only be selected if CLM is the land component.",
Implies(Or(COMP_ROF=="rtm", COMP_ROF=="mosart", COMP_ROF=="mizuroute"), COMP_LND=='clm') :
"Active runoff models can only be selected if CLM is the land component.",

Implies(And(In(COMP_OCN, ["pop", "mom"]), COMP_ATM=="datm"), COMP_LND=="slnd") :
"When MOM|POP is coupled with data atmosphere (datm), LND component must be stub (slnd).",

Implies(And(COMP_ATM=="datm", COMP_LND=="clm"), And(COMP_ICE=="sice", COMP_OCN=="socn")) :
"If CLM is coupled with DATM, then both ICE and OCN must be stub.",

Implies(In(COMP_OCN, ["mom", "pop"]), COMP_ATM!="satm") :
"An active or data atm model is needed to force active ocean models.",
Implies(COMP_ATM=="satm", And(COMP_ICE=="sice", COMP_ROF=="srof", COMP_OCN=="socn")) :
"An active or data atmosphere model is needed to force ocean, ice, and/or runoff models.",

Implies(COMP_LND=="slnd", COMP_GLC=="sglc") :
"GLC cannot be coupled with a stub land model.",

Implies(COMP_LND=="slim", And(COMP_GLC=="sglc", COMP_ROF=="srof", COMP_WAV=="swav")) :
"GLC, ROF, and WAV cannot be coupled with SLIM.",

Implies(COMP_OCN=="socn", COMP_ICE=="sice") :
"When ocean is made stub, ice must also be stub.",

Implies(COMP_LND=="clm", COMP_ROF!="drof") :
"CLM cannot be coupled with a data runoff model.",

Implies(COMP_LND=="dlnd", COMP_ATM!="cam") : # TODO: check this constraint.
"Data land model cannot be coupled with CAM.",

Implies(COMP_OCN=="docn", COMP_OCN_OPTION != "(none)"):
"Must pick a valid DOCN option.",
Expand Down Expand Up @@ -132,10 +147,10 @@ def get_relational_constraints(cvars):
"Global ocean domains must be reentrant in the x-direction.",

Implies(OCN_GRID_EXTENT=="Global", OCN_LENX==360.0):
"Global ocean model domains musth have a length of 360 degrees in the x-direction.",
"Global ocean model domains must have a length of 360 degrees in the x-direction.",

Implies(OCN_GRID_EXTENT=="Global", And(OCN_LENY>0.0, OCN_LENY<=180.0) ):
"OCN grid length in Y direction must be less than or equal to 180.0 when OCN grid extent is global.",
"OCN grid length in Y direction must be <= 180.0 when OCN grid extent is global.",

# Custom lnd grid constraints ------------------
Implies(COMP_LND!="clm", LND_GRID_MODE=="Standard"):
Expand Down

0 comments on commit 8fa0c82

Please sign in to comment.