diff --git a/README.md b/README.md index 24e1c6f..4878fbf 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ # simd_neuralnet Feed-forward neural network implementation in C with SIMD instructions. -BTW: simd_neuralnet is a bit of a misnomer. It is not SIMD instructions apart from AVX and AVX2 instructions in some of the most heavy functions. Can anyone come up with a better name? It is a neural network library written in C. I'm not sure what else makes this library different from the other libraries out there. +BTW: simd_neuralnet is a bit of a misnomer. It is not SIMD instructions apart from AVX, AVX2 +and AVX512 instructions in some of the most heavy functions. Can anyone come up with a better name? +It is a neural network library written in C. I'm not sure what else makes this library different +from the other libraries out there. + +It is very fast compared to other libraries like Keras and PyTorch - just give it a try and you +will see. ## The idea. -At current this is just a study project for myself to improve my abilities to implement a +This project started out as a study project for myself to improve my abilities to implement a feed forward neural network in C. A lot of this code will be based on code from some of my other projects. Hopefully this will be generally usable. @@ -26,7 +32,8 @@ struggle with python bindings or slow memory transfers to GPU memory. ## Limitations To be able to achieve the above, we need to set some limitations. - * **`float32` precision only!** The code will use SIMD instructions, so double precision will slow things down, and float16 is not precise enough and has limited support. + * **`float32` precision only!** The code will use SIMD instructions, so double precision will + slow things down, and `float16` is not precise enough and has limited support. * **Fully connect feed forward neural networks only!** No support for LSTM, convolutional layers, RNN or whatever. ### Loss functions implemented diff --git a/examples/example_02.c b/examples/example_02.c index 8ba339e..6d1b28e 100644 --- a/examples/example_02.c +++ b/examples/example_02.c @@ -66,6 +66,10 @@ int main( int argc, char *argv[] ) printf("Test loss : %5.5f\n", results[2] ); printf("Test accuracy : %5.5f\n", results[3] ); + /* Let's save the neural net and see if we can recreate the result form a saved nn + * That test will be done in a separate souce file (example_02b.c) */ + neuralnet_save( nn, "mushroom-neuralnet.npz"); + /* Clean up the resources */ neuralnet_free( nn ); npy_array_list_free( filelist ); diff --git a/examples/example_02b.c b/examples/example_02b.c new file mode 100644 index 0000000..4f76dfa --- /dev/null +++ b/examples/example_02b.c @@ -0,0 +1,85 @@ +#include "npy_array.h" +#include "npy_array_list.h" +#include "neuralnet.h" +#include "neuralnet_predict_batch.h" +#include "simd.h" + +#include "evaluate.h" + +#include +#include +#include +#include +void local_evaluate( neuralnet_t *nn, const int n_valid_samples, const float *valid_X, const float *valid_Y, + metric_func metrics[], float *results ) +{ + //const int n_input = nn->layer[0].n_input; + const int n_output = nn->layer[nn->n_layers-1].n_output; + + metric_func *mf_ptr = metrics; + + int n_metrics = 0; + while ( *mf_ptr++ ) + n_metrics++; + + float predictions[ n_output * n_valid_samples ]; + neuralnet_predict_batch( nn, n_valid_samples, valid_X, predictions); + + float local_results[n_metrics]; + memset( local_results, 0, n_metrics * sizeof(float)); + for ( int i = 0; i < n_valid_samples; i++ ){ + float *res = local_results; + for ( int j = 0; j < n_metrics; j++ ){ + float _error = metrics[j]( n_output, predictions + (i*n_output), valid_Y + (i*n_output)); + *res++ += _error; + } + } + + float *res = results; + for ( int i = 0; i < n_metrics; i++ ) + *res++ = local_results[i] / (float) n_valid_samples; +} + +int main( int argc, char *argv[] ) +{ + /* Read the datafile created in python+numpy */ + npy_array_list_t *filelist = npy_array_list_load( "mushroom_train.npz" ); + assert( filelist ); + + npy_array_list_t *iter = filelist; + npy_array_t *train_X = iter->array; iter = iter->next; + npy_array_t *train_Y = iter->array; iter = iter->next; + npy_array_t *test_X = iter->array; iter = iter->next; + npy_array_t *test_Y = iter->array; + + /* if any of these asserts fails, try to open the weight in python and save with + * np.ascontiguousarray( matrix ) */ + assert( train_X->fortran_order == false ); + assert( train_Y->fortran_order == false ); + assert( test_X->fortran_order == false ); + assert( test_Y->fortran_order == false ); + + /* Set up a new Neural Network */ + neuralnet_t *nn = neuralnet_load( "mushroom-neuralnet.npz"); + assert( nn ); + + const int n_train_samples = train_X->shape[0]; + const int n_test_samples = test_X->shape[0]; + + metric_func metrics[] = { get_metric_func("binary_crossentropy"), get_metric_func( "binary_accuracy"), NULL }; + + float results[ 2 * 2 ]; + local_evaluate( nn, n_train_samples, (float*) train_X->data, (float*) train_Y->data, metrics, results); + local_evaluate( nn, n_test_samples, (float*) test_X->data, (float*) test_Y->data, metrics, results + 2); + + printf("Train loss : %5.5f\n", results[0] ); + printf("Train accuracy: %5.5f\n", results[1] ); + printf("Test loss : %5.5f\n", results[2] ); + printf("Test accuracy : %5.5f\n", results[3] ); + + /* Clean up the resources */ + neuralnet_free( nn ); + npy_array_list_free( filelist ); + return 0; +} + diff --git a/src/activation.c b/src/activation.c index 1183c0c..90effd3 100644 --- a/src/activation.c +++ b/src/activation.c @@ -308,7 +308,7 @@ static void relu( const int n, float *y ) #ifdef __AVX__ const __m256 zero = _mm256_set1_ps(0.0f); __m256 YMM0, YMM1; - for ( ; !is_aligned( y + i ); i++) + for ( ; !is_aligned( y + i ) && i < n; i++) y[i] = fmaxf(0.0f, y[i]); for ( ; i <= ((n)-16); i += 16) { @@ -456,6 +456,13 @@ static void softmax( const int n, float *ar ) classification problems. If using two registers, I will lose all vectorization for classification problems with less than 16 classes. This is of course a trade off, and if you ever do a classification problem with 16 or more classes, you could consider re-writing. */ + + /* Skip the unaligned - FIXME: This code is specific for AVX2, + * yet it checks align for whatever the computer supports */ + for (; !is_aligned( ar+j ) && j @@ -11,7 +12,9 @@ void evaluate( neuralnet_t *nn, const int n_valid_samples, const float *valid_X, const float *valid_Y, metric_func metrics[], float *results ) { +#ifndef USE_CBLAS const int n_input = nn->layer[0].n_input; +#endif const int n_output = nn->layer[nn->n_layers-1].n_output; metric_func *mf_ptr = metrics; @@ -22,11 +25,19 @@ void evaluate( neuralnet_t *nn, const int n_valid_samples, const float *valid_X, float local_results[n_metrics]; memset( local_results, 0, n_metrics * sizeof(float)); +#ifdef USE_CBLAS + float predictions[ n_output * n_valid_samples ]; + memset( predictions, 0, n_output * n_valid_samples * sizeof(float)); + neuralnet_predict_batch( nn, n_valid_samples, valid_X, predictions); +#endif #pragma omp parallel for reduction(+:local_results[:]) for ( int i = 0; i < n_valid_samples; i++ ){ +#ifdef USE_CBLAS + float *y_pred = predictions + (i*n_output); +#else SIMD_ALIGN(float y_pred[n_output]); neuralnet_predict( nn, valid_X + (i*n_input), y_pred ); - +#endif float *res = local_results; for ( int j = 0; j < n_metrics; j++ ){ float _error = metrics[j]( n_output, y_pred, valid_Y + (i*n_output)); @@ -37,5 +48,4 @@ void evaluate( neuralnet_t *nn, const int n_valid_samples, const float *valid_X, float *res = results; for ( int i = 0; i < n_metrics; i++ ) *res++ = local_results[i] / (float) n_valid_samples; - } diff --git a/src/neuralnet.c b/src/neuralnet.c index 7ef3f6e..94195ad 100644 --- a/src/neuralnet.c +++ b/src/neuralnet.c @@ -243,7 +243,7 @@ void neuralnet_predict( const neuralnet_t *nn, const float *input, float *out ) { /* These asserts are important - end user may forget to SIMD_ALIGN memory and then there is a extremly hard bug to find - Think before you remove these assert() */ - assert( is_aligned( out )); + // assert( is_aligned( out )); /* Stack allocating memory */ /* FIXME: Do this once and once only! */ diff --git a/src/neuralnet_predict_batch.c b/src/neuralnet_predict_batch.c new file mode 100644 index 0000000..6b52cc7 --- /dev/null +++ b/src/neuralnet_predict_batch.c @@ -0,0 +1,138 @@ +/* neuralnet_predict_batch.c - Øystein Schønning-Johansen 2023 */ +/* + vim: ts=4 sw=4 softtabstop=4 expandtab +*/ +#include "neuralnet_predict_batch.h" +#include "activation.h" +#include +#include +#include +#include + +#ifdef USE_CBLAS +#include +#endif + +#ifndef USE_CBLAS +/* This is the primitive implemetation using OpenMP to thread th foward calculation of several samples. + * The recommendation is to us the BLAS implementation, and then add the threading at a higher level in + * you application. */ +void neuralnet_predict_batch( const neuralnet_t *nn, const int n_samples, const float *inputs, float *output ) +{ + const int n_inputs = nn->layer[0].n_input; + const int n_output = nn->layer[nn->n_layers-1].n_output; +#pragma omp parallel for + for ( int i = 0; i < n_samples; i++ ) + neuralnet_predict( nn, inputs + i*n_inputs, output + i*n_output); +} +#else +/* This number depends on your system - how much memory do you want to stack allocate? + * On a desktop or laptop you probably have plenty. If you ever run into a stack overflow, + * you can recomile with -DN_STACK_ALLOC_FLOATS=1024 (or even a lower number) + * + * So this function, `neuralnet_predict_batch()`, uses some work memory. I don't want to + * allocate this dynamically as heap allocation will be performance killer, and stack + * allocation can blow up the stack. The idea is therefore to allocate (on stack) some + * fixed size memory, N_STACK_ALLOC_FLOATS, and then check if we need mor or less than + * this. If the size need is more than the allocated, we simply split the set in two + * and recurse like a divide-and-conquere scheme. + */ +#ifndef N_STACK_ALLOC_FLOATS +#define N_STACK_ALLOC_FLOATS 64 * 1024 +#endif + +void neuralnet_predict_batch( const neuralnet_t *nn, const int n_samples, const float *inputs, float *output ) +{ + /* Make some work memory on stack. First calculate how much we need. */ + int workmem_sz = 0; + for( int i = 0; i < nn->n_layers; i++) + workmem_sz += nn->layer[i].n_output; + workmem_sz *= n_samples; /* This size is also in floats */ + +#if 0 + /* Let's see how often this this fails. */ + assert( N_STACK_ALLOC_FLOATS >= workmem_sz && "Stack size limit reached - " + "either recompile with a higher limit find another way to handle work memory" ); +#endif + + if( N_STACK_ALLOC_FLOATS < workmem_sz ){ + int half = n_samples >> 1; + const int n_inputs = nn->layer[0].n_input; + const int n_output = nn->layer[nn->n_layers-1].n_output; + // fprintf(stderr, "Warning: Stack limit reached with %d samples - recursing.\n", n_samples); + neuralnet_predict_batch( nn, half, inputs, output ); + neuralnet_predict_batch( nn, n_samples - half, inputs+(half*n_inputs), output+(half*n_output) ); + return; + } + + /* So the above line makes sure we don't blow off the stack - but what to do when we hit the limit? + * + * BTW: What is the limit here? The stack is usually about one MB, so I have initially + * set the limit to 512 kb. If we have a neural net with say 1000 n_ouptus (cumulative over + * all layers) and then n_samples is 600... That actually fucks it up. + * + * Options: + * 1. Divide and Conquere: Divide the output into two halfs and recurse. Cool! + * 2. Have a private (static) function that controls memory. Say: + * + * `float *workmem = _get_workmemory( nn, n_samples );` + * + * and this function returns a pointer to a preallocated area of memory. + * 3. Have a pointer input for work memory and let the caller take care. + * 4. Allocate on heap .... ? + * + * ... and then: even if I have the above limit, I can still get stack overflow. + * + * (so far we stay with option 1, but we allocate a fixed size on the stack suck that it + * cannot overflow.) + * */ + + float workmem[ N_STACK_ALLOC_FLOATS ]; /* can we blow the stack here? */ + float *activations[nn->n_layers+1]; + activations[0] = (float*) inputs; + activations[1] = workmem; + + for( int i = 1; i < nn->n_layers-1; i++) + activations[i+1] = activations[i] + nn->layer[i-1].n_output * n_samples; + + activations[nn->n_layers] = output; + + /* Oh, I have to fill the activations with the biases */ + for( int i = 0; i < nn->n_layers; i++){ + const layer_t *layer_ptr = nn->layer + i; + const size_t size = layer_ptr->n_output * sizeof(float); + for( int j = 0; j < n_samples; j++) + memcpy( activations[i+1] + j*layer_ptr->n_output, layer_ptr->bias, size ); + } + + static activation_func softmax = NULL; /* Keep it static such that get_() is called only once! */ + if( !softmax ) + softmax = get_activation_func( "softmax" ); /* Slow? */ + + /* Then we do the forward calculation */ + for( int i = 0; i < nn->n_layers; i++){ + const layer_t *layer_ptr = nn->layer + i; + /* Matrix multiplication */ + cblas_sgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, + n_samples, layer_ptr->n_output, layer_ptr->n_input, + 1.0f, /* alpha (7) */ + activations[i], /* A (8) */ + layer_ptr->n_input, /* lda (9) */ + layer_ptr->weight, /* B (8) */ + layer_ptr->n_output, /* ldb (11) */ + 1.0f, /* beta (12) */ + activations[i+1], /* C (13) */ + layer_ptr->n_output /* ldc (14) */ + ); + /* Activation */ + /* ( I really hope the silly if-condition doesn't kill performance. */ + if ( layer_ptr->activation_func == softmax ){ + float *out = activations[i+1]; + for ( int j = 0; j < n_samples; j++, out += layer_ptr->n_output) + layer_ptr->activation_func ( layer_ptr->n_output, out ); + } else { + layer_ptr->activation_func ( layer_ptr->n_output * n_samples, activations[i+1] ); + } + } +} +#endif /* USE_CBLAS */ diff --git a/src/neuralnet_predict_batch.h b/src/neuralnet_predict_batch.h new file mode 100644 index 0000000..b4d9ed2 --- /dev/null +++ b/src/neuralnet_predict_batch.h @@ -0,0 +1,6 @@ +/* neuralnet_predict_batch.h - Øystein Schønning-Johansen 2023 */ +/* + vim: ts=4 sw=4 softtabstop=4 expandtab +*/ +#include "neuralnet.h" +void neuralnet_predict_batch( const neuralnet_t *nn, const int n_samples, const float *inputs, float *output );