Skip to content

Commit

Permalink
added metric evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
YannisZa committed Apr 22, 2024
1 parent b135d4f commit a79221a
Showing 1 changed file with 161 additions and 112 deletions.
273 changes: 161 additions & 112 deletions notebooks/Reading outputs (work in progress).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "c76b8620",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_282017/2303974116.py:12: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
" from IPython.core.display import display, HTML\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import glob\n",
Expand All @@ -41,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "b718b815",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -218,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "e21669cd",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -246,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "41e02907",
"metadata": {},
"outputs": [],
Expand All @@ -273,59 +264,17 @@
" \"group_by\":[],\n",
" \"filename_ending\":\"test\",\n",
" \"sample\":[\"intensity\",\"table\"],\n",
" \"validation_data\":{\"test_cells\":\"../data/inputs/DC/train_cells.txt\"},\n",
" \"validation_data\":{\"test_cells\":\"../data/inputs/DC/test_cells.txt\"},\n",
" \"force_reload\":False\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "833a9fad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"05:32.509 config INFO ----------------------------------------------------------------------------------\n",
"05:32.517 config INFO Parameter space size: \n",
" --- sigma: ['sigma', 'to_learn'] (3)\n",
"05:32.526 config INFO Total = 3.\n",
"05:32.534 config INFO ----------------------------------------------------------------------------------\n",
"05:32.553 outputs INFO //////////////////////////////////////////////////////////////////////////////////\n",
"05:32.561 outputs INFO Slicing coordinates:\n",
"05:32.570 outputs INFO loss_name == str(['dest_attraction_ts_likelihood_loss'])\n",
"05:32.578 outputs INFO //////////////////////////////////////////////////////////////////////////////////\n",
"05:32.587 outputs INFO Reading samples alpha, beta, log_destination_attraction, table.\n",
"05:54.974 outputs INFO Creating Data Collection for each group. \n",
"Grouping/Initialising Data Collection samples sequentially: 100%|██████████| 12/12 [00:00<00:00, 69615.00it/s]\n",
"Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 82782.32it/s]\n",
"Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73584.28it/s]\n",
"Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73584.28it/s]\n",
"Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73156.47it/s]\n",
"Slicing coordinates sequentially: 25%|██▌ | 3/12 [00:00<00:00, 10.59it/s]05:55.284 outputs INFO table: 12 collection ids kept out of 12.\n",
"05:55.292 outputs INFO log_destination_attraction: 12 collection ids kept out of 12.\n",
"05:55.300 outputs INFO beta: 12 collection ids kept out of 12.\n",
"05:55.308 outputs INFO alpha: 12 collection ids kept out of 12.\n",
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 experiments matched\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r"
]
}
],
"outputs": [],
"source": [
"# Initialise outputs\n",
"current_sweep_outputs = Outputs(\n",
Expand Down Expand Up @@ -453,24 +402,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "d5eb796d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Sweeps: 3\n",
"ItemsView(Coordinates:\n",
" * id (id) object MultiIndex\n",
" * iter (id) int32 1 2 3 4 5 6 7 ... 99995 99996 99997 99998 99999 100000\n",
" * sweep (sweep) object MultiIndex\n",
" * sigma (sweep) object 'none'\n",
" * to_learn (sweep) object \"['alpha', 'beta', 'sigma']\")\n"
]
}
],
"outputs": [],
"source": [
"index = 0\n",
"current_data = current_sweep_outputs.get(index)\n",
Expand All @@ -480,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "fa8fb841",
"metadata": {},
"outputs": [],
Expand All @@ -489,60 +424,174 @@
" config = current_data.config\n",
")\n",
"ins.cast_to_xarray()\n",
"test_cells = current_data.get_sample('test_cells')"
"test_cells = current_data.get_sample('test_cells')\n",
"train_cells = current_data.get_sample('train_cells')"
]
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": null,
"id": "1a1e911f",
"metadata": {},
"outputs": [],
"source": [
"test_error = srmse(\n",
" prediction = current_data.data.table.mean('id'),\n",
"all_table_error = srmse(\n",
" prediction = current_data.data.table.mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table\n",
")\n",
"train_table_error = srmse(\n",
" prediction = current_data.data.table.mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" test_cells = test_cells\n",
" cells = train_cells\n",
")\n",
"all_error = srmse(\n",
" prediction = current_data.data.table.mean('id'),\n",
"test_table_error = srmse(\n",
" prediction = current_data.data.table.mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" cells = test_cells\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc905be7",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" all_table_error.values.squeeze().item(),\n",
" train_table_error.values.squeeze().item(),\n",
" test_table_error.values.squeeze().item()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7090ffd",
"metadata": {},
"outputs": [],
"source": [
"all_intensity_error = srmse(\n",
" prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table\n",
")\n",
"train_intensity_error = srmse(\n",
" prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" cells = train_cells\n",
")\n",
"test_intensity_error = srmse(\n",
" prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" cells = test_cells\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": null,
"id": "dcf96b7f",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" all_intensity_error.values.squeeze().item(),\n",
" train_intensity_error.values.squeeze().item(),\n",
" test_intensity_error.values.squeeze().item()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca0eecab",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [90]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_cp \u001b[38;5;241m=\u001b[39m \u001b[43mcoverage_probability\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mprediction\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcurrent_data\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mground_truth\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mins\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mground_truth_table\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mregion_mass\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.95\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_cells\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtest_cells\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m all_cp \u001b[38;5;241m=\u001b[39m coverage_probability(\n\u001b[1;32m 8\u001b[0m prediction \u001b[38;5;241m=\u001b[39m current_data\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mtable,\n\u001b[1;32m 9\u001b[0m ground_truth \u001b[38;5;241m=\u001b[39m ins\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mground_truth_table,\n\u001b[1;32m 10\u001b[0m region_mass \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.95\u001b[39m\n\u001b[1;32m 11\u001b[0m )\n",
"File \u001b[0;32m~/GeNSIT/gensit/utils/math_utils.py:323\u001b[0m, in \u001b[0;36mcoverage_probability\u001b[0;34m(prediction, ground_truth, **kwargs)\u001b[0m\n\u001b[1;32m 320\u001b[0m stacked_dims \u001b[38;5;241m=\u001b[39m deepcopy(prediction\u001b[38;5;241m.\u001b[39mdims)\n\u001b[1;32m 322\u001b[0m \u001b[38;5;66;03m# Sort all samples by iteration-seed\u001b[39;00m\n\u001b[0;32m--> 323\u001b[0m prediction[:] \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msort\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprediction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mstacked_dims\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mid\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;66;03m# Get lower and upper bound high posterior density regions\u001b[39;00m\n\u001b[1;32m 326\u001b[0m lower_bound_hpdr,upper_bound_hpdr \u001b[38;5;241m=\u001b[39m calculate_min_interval(\n\u001b[1;32m 327\u001b[0m prediction,\n\u001b[1;32m 328\u001b[0m alpha\n\u001b[1;32m 329\u001b[0m )\n",
"File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36msort\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/miniconda3/envs/gensit/lib/python3.10/site-packages/numpy/core/fromnumeric.py:1003\u001b[0m, in \u001b[0;36msort\u001b[0;34m(a, axis, kind, order)\u001b[0m\n\u001b[1;32m 1001\u001b[0m axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1002\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1003\u001b[0m a \u001b[38;5;241m=\u001b[39m \u001b[43masanyarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mK\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1004\u001b[0m a\u001b[38;5;241m.\u001b[39msort(axis\u001b[38;5;241m=\u001b[39maxis, kind\u001b[38;5;241m=\u001b[39mkind, order\u001b[38;5;241m=\u001b[39morder)\n\u001b[1;32m 1005\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m a\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# test_cp = coverage_probability(\n",
"# prediction = current_data.data.table,\n",
"# ground_truth = ins.data.ground_truth_table,\n",
"# region_mass = 0.95,\n",
"# test_cells = test_cells\n",
"# )\n",
"# all_cp = coverage_probability(\n",
"# prediction = current_data.data.table,\n",
"# ground_truth = ins.data.ground_truth_table,\n",
"# region_mass = 0.95\n",
"# )"
"outputs": [],
"source": [
"all_table_cp = coverage_probability(\n",
" prediction = current_data.data.table,\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95\n",
")\n",
"train_table_cp = coverage_probability(\n",
" prediction = current_data.data.table,\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95,\n",
" cells = train_cells\n",
")\n",
"test_table_cp = coverage_probability(\n",
" prediction = current_data.data.table,\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95,\n",
" cells = test_cells\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed196756",
"metadata": {},
"outputs": [],
"source": [
"all_cp = all_table_cp\n",
"test_cp = train_table_cp\n",
"test_cp = test_table_cp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6404451",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" all_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
" train_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
" test_table_cp.mean(['origin','destination'],skipna=True).values.item()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c60450c",
"metadata": {},
"outputs": [],
"source": [
"all_intensity_cp = coverage_probability(\n",
" prediction = current_data.get_sample('intensity'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95\n",
")\n",
"train_intensity_cp = coverage_probability(\n",
" prediction = current_data.get_sample('intensity'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95,\n",
" cells = train_cells\n",
")\n",
"test_intensity_cp = coverage_probability(\n",
" prediction = current_data.get_sample('intensity'),\n",
" ground_truth = ins.data.ground_truth_table,\n",
" region_mass = 0.95,\n",
" cells = test_cells\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30e55924",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" all_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
" train_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
" test_intensity_cp.mean(['origin','destination'],skipna=True).values.item()\n",
")"
]
},
{
Expand Down

0 comments on commit a79221a

Please sign in to comment.