Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
silky1708 committed Nov 10, 2023
1 parent 3158a4d commit b92c97c
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# LOCATE
[BMVC 2023] Official repository for "LOCATE: Self-supervised Object Discovery via Flow-guided Graph-cut and Bootstrapped Self-training"
*Silky Singh, Shripad Deshmukh, Mausoom Sarkar, Balaji Krishnamurthy.*
*Silky Singh, Shripad Deshmukh, Mausoom Sarkar, Balaji Krishnamurthy.*

[project page](https://silky1708.github.io/LOCATE/) | [arXiv](https://arxiv.org/abs/2308.11239) | [bibtex](https://github.com/silky1708/LOCATE/tree/main#citation)


![qual results](assets/locate_VOS_qual.png)

Our self-supervised method LOCATE trained on video datasets can perform object segmentation on standalone images.
Our self-supervised framework LOCATE trained on video datasets can perform object segmentation on standalone images.

<!-- ![model pipeline](assets/model_pipeline.png) -->

Expand All @@ -19,9 +22,7 @@ conda activate locate

The code has been tested with `python=3.8`, `pytorch=1.12.1`, `torchvision=0.13.1` with `cudatoolkit=11.3` on Nvidia A100 machine.

Use the official Pytorch installation instructions provided [here](https://pytorch.org/get-started/previous-versions/).

Other dependencies can be installed following the [guess-what-moves](https://github.com/karazijal/guess-what-moves) repository. It is mentioned below for completeness.
Use the official Pytorch installation instructions provided [here](https://pytorch.org/get-started/previous-versions/). Other dependencies can be installed following the [guess-what-moves](https://github.com/karazijal/guess-what-moves) repository. It is mentioned below for completeness.

```
conda install -y pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
Expand All @@ -40,37 +41,31 @@ We have tested our method on video object segmentation datasets (DAVIS 2016, FBM

### Step 1. Graph Cut

We utilise the MaskCut algorithm from the CutLER's repository [[link](https://github.com/facebookresearch/CutLER)] with `N=1` to get the segmentation mask for the salient object in all the video frames independently. We modify the pipeline to take in optical flow features of the video frame, and combine both image and flow features in a linear combination to produce edge weights. The modified code can be found at: `/path/to/new/graphcut`.
We utilise the MaskCut algorithm from the CutLER's repository [[link](https://github.com/facebookresearch/CutLER)] with `N=1` to get the segmentation mask for the salient object in all the video frames independently. We modify the pipeline to take in optical flow features of the video frame, and combine both image and flow feature similarities in a linear combination to produce edge weights. The modified code can be found in the `CutLER` directory.

We perform a single round of post-processing using Conditional Random Fields (CRF) to get pixel-level segmentation masks. The initial segmentation masks will be released for all the datasets.
We perform a single round of post-processing using Conditional Random Fields (CRF) to get pixel-level segmentation masks. The graphcut masks for all the datasets are released [here](https://www.dropbox.com/scl/fo/wdr6jxutv9x4zte1n8jyz/h?rlkey=ayfmd4dp03tjdg6a2m0xg4iac&dl=0). We use [ARFlow](https://github.com/lliuz/ARFlow) trained on the synthetic Sintel dataset to compute the optical flow between video frames.


### Step 2. Bootstrapped Self-training

Using segmentation masks from previous step as pseudo-ground-truth, we train a [MaskFormer](https://github.com/facebookresearch/MaskFormer) network.

In the `src` directory, run the following command for training:
```
python main.py GWM.DATASET <dataset> LOG_ID <log_id>
```
where `dataset` (e.g., `DAVIS`), `log_id` (e.g., `davis`) need to be set.

Using segmentation masks from previous step as pseudo-ground-truth, we train a [segmentation](https://github.com/facebookresearch/MaskFormer) network. In the root directory, run `train.sh`.

## Testing/Inference
## Inference

Use the test script for running inference: `/path/to/test.py`
Use the test script for running inference: `python test.py`


## Model Checkpoints

| Dataset | Checkpoint |
| Dataset | Checkpoint path |
| ------- | ---------- |
| DAVIS16 | `/path/to/davis/ckpt` |
| SegTrackv2 | `/path/to/segtrack/ckpt` |
| FBMS59 | `/path/to/fbms/ckpt` |
| Combined | `/path/to/combined/ckpt` |
| DAVIS16 | `locate_checkpoints/davis2016.pth` |
| SegTrackv2 | `locate_checkpoints/segtrackv2.pth` |
| FBMS59 (graph-cut masks) | `locate_checkpoints/fbms59_graphcut.pth` |
| FBMS59 (zero-shot) | `locate_checkpoints/fbms59_zero_shot.pth` |
| DAVIS16+STv2+FBMS | `locate_checkpoints/combined.pth` |

The Combined checkpoint refers to the model trained on all the video datasets (DAVIS16, SegTrackv2, FBMS59) combined.
The checkpoints are released [here](https://www.dropbox.com/scl/fo/v2akgrbzyyvkgtr98x2ok/h?rlkey=wfhmcm26fb3ivirdpx6pdkdxb&dl=0). The `combined.pth` checkpoint refers to the model trained on all the video datasets (DAVIS16, SegTrackv2, FBMS59) combined.

## Acknowledgments

Expand Down

0 comments on commit b92c97c

Please sign in to comment.