diff --git a/README.md b/README.md index d244e61..7b01cef 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,15 @@ dependencies: Look at the spec for most functions ```crystal -require "crystal-fann" ann = Fann::Network::Standard.new(2, [2], 1) -500.times do - ann.train_single([1.0, 0.1], [0.5]) +ann.randomize_weights(0.0, 1.0) +3000.times do + ann.train_single([1.0, 0.0], [0.5]) end -result = ann.run([1.0, 0.1]) +result = ann.run([1.0, 0.0]) # Remember to close the network when done to free allocated C memory ann.close +(result < [0.55] && result > [0.45]).should be_true ``` ```crystal @@ -38,15 +39,13 @@ input = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]] output = [[0.0], [1.0], [1.0], [0.0]] train_data = Fann::TrainData.new(input, output) data = train_data.train_data -ann.train_algorithm(LibFANN::TrainEnum::TrainRprop) -ann.set_hidden_layer_activation_func(LibFANN::ActivationfuncEnum::Linear) -ann.set_output_layer_activation_func(LibFANN::ActivationfuncEnum::Linear) +ann.randomize_weights(0.0, 1.0) if data - ann.train_batch(data, {:max_runs => 8000, :desired_mse => 0.001_f64, :log_each => 1000}) + ann.train_batch(data, {:max_runs => 8000, :desired_mse => 0.001, :log_each => 1000}) end result = ann.run([1.0, 1.0]) ann.close -(result < [0.1]).should be_true +(result < [0.15]).should be_true ``` ```crystal @@ -57,10 +56,9 @@ output = [[0.0], [1.0], [1.0], [0.0]] train_data = Fann::TrainData.new(input, output) data = train_data.train_data ann.train_algorithm(LibFANN::TrainEnum::TrainRprop) -ann.set_hidden_layer_activation_func(LibFANN::ActivationfuncEnum::Linear) -ann.set_output_layer_activation_func(LibFANN::ActivationfuncEnum::Linear) +ann.randomize_weights(0.0, 1.0) if data - ann.train_batch(data, {:max_neurons => 500, :desired_mse => 0.1_f64, :log_each => 10}) + ann.train_batch(data, {:max_neurons => 500, :desired_mse => 0.001, :log_each => 10}) end result = ann.run([1.0, 1.0]) ann.close diff --git a/spec/network_spec.cr b/spec/network_spec.cr index 0b011b3..ec25c5a 100644 --- a/spec/network_spec.cr +++ b/spec/network_spec.cr @@ -49,10 +49,7 @@ describe Fann::Network do output = [[0.0], [1.0], [1.0], [0.0]] train_data = Fann::TrainData.new(input, output) data = train_data.train_data - # ann.train_algorithem(LibFANN::TrainEnum::TrainSarprop) ann.randomize_weights(0.0, 1.0) - # ann.set_hidden_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu) - # ann.set_output_layer_activation_func(LibFANN::ActivationfuncEnum::LeakyRelu) if data ann.train_batch(data, {:max_runs => 8000, :desired_mse => 0.001, :log_each => 1000}) end