Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DSAD for newer flux.jl explicit API #9

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

rolling-robot
Copy link

@rolling-robot rolling-robot commented Apr 16, 2023

I was tinkering around with Flux and anomaly detection and found that Flux changes its API starting from 0.13 to what they call "explicit style". So these are changes to support Flux 0.13 and above for DSAD algorithm.

More:
https://fluxml.ai/Flux.jl/stable/training/training/#Model-Gradients and https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1

My testing code:

using MLJ
using PyCall
using Flux
using OutlierDetection
using OutlierDetectionInterface: Labels, Data
using Plots
using CategoricalArrays

n_train = 200
n_test = 200

ocnn = @load DSADDetector pkg=OutlierDetectionNetworks verbosity=0

skl_ds = pyimport("sklearn.datasets")
data, labels = skl_ds.make_moons(n_train, noise=0.1)
fig = scatter(data[:,1], data[:,2], marker=:+)

anomaly_labels = map((x ->  "normal"),labels)
push!(anomaly_labels, "outlier")
data = vcat(data, [1.5;0.5] |> transpose)
push!(anomaly_labels, "outlier")
data = vcat(data, [-0.5;0.] |> permutedims)

encoder = Chain(
    Dense(2 => 4, relu, bias=false),
    Dense(4 => 8, relu, bias=false),
    Dense(8 => 15, relu, bias=false)
)
decoder = Chain(
    Dense(15 => 8, relu, bias=false),
    Dense(8 => 4, relu, bias=false),
    Dense(4 => 2, relu, bias=false))

loss_log = Vector()
detector = ocnn(
    encoder=encoder,
    decoder=decoder,
    epochs=120,
    callback = (
        ((m, x) -> ()),
        ((m, x, y) -> push!(loss_log, mean(m(x)))))
)

model, score = OutlierDetection.fit(detector,
    data |> permutedims,
    CategoricalArray(anomaly_labels, levels=["normal", "outlier"]),
    verbosity=2)

test = rand(Float64, (n_test,2)) * 11 - ones(n_test,2)*5
ŷ = OutlierDetection.transform(detector, model, test |> permutedims)
contour!(fig,
    range(-2, stop = 3, length = 200),
    range(-2, stop = 3, length = 200),
    (x, y) -> first(OutlierDetection.transform(detector, model, [x y] |> permutedims)),
         levels=30)

Noticeable changes:

  1. DSADModel now contains only encoder and radius for hidden variables (decoder seemed to be redundant there)
  2. DSADModel can be executed and returns svddScore for better integration with Flux.gradient
  3. callback signature is changed. It is now tuple with two functions with different signatures.

More:
https://fluxml.ai/Flux.jl/stable/training/training/#Model-Gradients
and https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1

strat refactoring to explicit Flux api

rename to ae_model for clarity

working prototype on new api

add loops to handle high-level callbacks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant