Skip to content

Commit

Permalink
Add float16 reduction to MPI
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Aug 29, 2024
1 parent d1b4019 commit b32ce2c
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,30 @@ array ensure_row_contiguous(const array& arr) {
}
}

// TODO: Change to a vectorized sum
template <typename T>
void simple_sum(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;

while (N-- > 0) {
*acc += *in;
acc++;
in++;
}
}
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);

struct MPIWrapper {
MPIWrapper() {
initialized_ = false;

libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
Expand All @@ -50,6 +72,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);

// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
Expand Down Expand Up @@ -79,7 +104,24 @@ struct MPIWrapper {
if (!is_available()) {
return false;
}
return init(nullptr, nullptr) == MPI_SUCCESS;
bool success = init(nullptr, nullptr) == MPI_SUCCESS;

// Initialize custom types and ops
if (success && !initialized_) {
// Custom float16 dtypes
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
mpi_type_commit(&mpi_float16_);
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);

// Custom sum ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);

initialized_ = true;
}

return success;
}

void finalize_safe() {
Expand Down Expand Up @@ -117,13 +159,21 @@ struct MPIWrapper {
case complex64:
return mpi_complex_;
case float16:
return mpi_float16_;
case bfloat16:
throw std::runtime_error("MPI doesn't support 16-bit floats");
return mpi_bfloat16_;
}
}

MPI_Op op_sum() {
return op_sum_;
MPI_Op op_sum(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_sum_f16_;
case bfloat16:
return op_sum_bf16_;
default:
return op_sum_;
}
}

void* libmpi_handle_;
Expand Down Expand Up @@ -152,6 +202,8 @@ struct MPIWrapper {

// Ops
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;

// Datatypes
MPI_Datatype mpi_bool_;
Expand All @@ -165,6 +217,16 @@ struct MPIWrapper {
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_complex_;
MPI_Datatype mpi_float16_;
MPI_Datatype mpi_bfloat16_;

private:
bool initialized_;

// Private API
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
int (*mpi_type_commit)(MPI_Datatype*);
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
};

MPIWrapper& mpi() {
Expand Down Expand Up @@ -273,7 +335,7 @@ void all_sum(Group group, const array& input_, array& output) {
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
mpi().op_sum(input),
to_comm(group));
}

Expand Down

0 comments on commit b32ce2c

Please sign in to comment.