Skip to content

Commit

Permalink
network: implemented set_optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
MSallermann committed Nov 16, 2023
1 parent ae0b329 commit e6b6fa5
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
5 changes: 4 additions & 1 deletion examples/mnist/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ int main()
// No. of trainable params
network.summary();
network.loss_tol = 2e-2;
network.fit( x_train, y_train, 1000, 0.01, true );

auto opt = Robbie::Optimizers::StochasticGradientDescent( 0.01 );
network.set_optimizer( &opt );
network.fit( x_train, y_train, 1000, true );

fmt::print( "Loss on test set = {:.3e}\n", network.loss( x_test, y_test ) );

Expand Down
19 changes: 11 additions & 8 deletions include/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstddef>
#include <memory>
#include <optional>
#include <stdexcept>
#include <type_traits>
#include <vector>

Expand All @@ -22,9 +23,6 @@ class Network
public:
std::optional<scalar> loss_tol = std::nullopt;

std::unique_ptr<Optimizers::Optimizer<scalar>> opt
= std::make_unique<Optimizers::StochasticGradientDescent<scalar>>( 0.00001 );

Network() = default;

template<typename LayerT, typename... T>
Expand Down Expand Up @@ -67,10 +65,10 @@ class Network
return loss;
}

// void set_optimizer( const Optimizers::Optimizer<scalar> & opt )
// {
// this->opt = s
// }
void set_optimizer( Optimizers::Optimizer<scalar> * opt )
{
this->opt = opt;
}

void register_optimizer_variables()
{
Expand All @@ -93,8 +91,12 @@ class Network

void
fit( const std::vector<Matrix<scalar>> & x_train, const std::vector<Matrix<scalar>> & y_train, size_t epochs,
scalar learning_rate, bool print_progress = false )
bool print_progress = false )
{

if( this->opt == nullptr )
throw std::runtime_error( "Optimizer has not been set!" );

register_optimizer_variables();

auto n_samples = x_train.size();
Expand Down Expand Up @@ -194,6 +196,7 @@ class Network

private:
std::vector<std::unique_ptr<Layer<scalar>>> layers;
Optimizers::Optimizer<scalar> * opt = nullptr;
};

} // namespace Robbie
3 changes: 2 additions & 1 deletion include/robbie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
#include "fc_layer.hpp"
#include "layer.hpp"
#include "loss_functions.hpp"
#include "network.hpp"
#include "network.hpp"
#include "optimizers.hpp"
5 changes: 4 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ int main()
fmt::print( "x_train[10] = {}\n", fmt::streamed( x_train[10] ) );
fmt::print( "y_train[10] = {}\n", fmt::streamed( y_train[10] ) );

auto opt = Optimizers::StochasticGradientDescent<scalar>( 0.00001 );

auto network = Network<scalar, LossFunctions::MeanSquareError>();
network.set_optimizer( &opt );
network.add<FCLayer<scalar>>( input_size, 100 );
network.add<ActivationLayer<scalar, ActivationFunctions::Tanh>>();
network.add<DropoutLayer<scalar>>( 0.5 );
Expand All @@ -57,7 +60,7 @@ int main()
network.add<FCLayer<scalar>>( 30, 10 );
network.summary();

network.fit( x_train, y_train, 300, 0.00001, true );
network.fit( x_train, y_train, 300, true );

fmt::print( "Loss on test set = {:.3e}\n", network.loss( x_test, y_test ) );
}
7 changes: 5 additions & 2 deletions test/test_xor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ TEST_CASE( "Test_XOR" )
y_train[2] << 1;
y_train[3] << 0;

auto network = Network<double, LossFunctions::MeanSquareError>();
auto opt = Optimizers::StochasticGradientDescent<double>( 0.01 );
auto network = Network<double, LossFunctions::MeanSquareError>();
network.set_optimizer( &opt );

network.loss_tol = 5e-4;

network.add<FCLayer<double>>( 2, 10 );
network.add<ActivationLayer<double, ActivationFunctions::Tanh>>();
network.add<FCLayer<double>>( 10, 1 );
network.add<ActivationLayer<double, ActivationFunctions::Tanh>>();

network.fit( x_train, y_train, 50000, 0.01 );
network.fit( x_train, y_train, 50000 );

auto out = network.predict( x_train );

Expand Down

0 comments on commit e6b6fa5

Please sign in to comment.