Skip to content

Commit

Permalink
Adding NCCL clique to the RAFT handle
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Sep 11, 2024
1 parent 8a71b98 commit ea6b160
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,13 @@ class std_comms : public comms_iface {
// Wait for a UCXX progress thread roundtrip, prevent waiting for longer
// than 10ms for each operation, will retry in next iteration.
ucxx::utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([&callbackNotifierPre]() { callbackNotifierPre.set(); },
10000000 /* 10ms */);
(void)worker->registerGenericPre(
[&callbackNotifierPre]() { callbackNotifierPre.set(); }, 10000000 /* 10ms */);
callbackNotifierPre.wait();

ucxx::utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); },
10000000 /* 10ms */);
(void)worker->registerGenericPost(
[&callbackNotifierPost]() { callbackNotifierPost.set(); }, 10000000 /* 10ms */);
callbackNotifierPost.wait();
} else {
// Causes UCXX to progress through the send/recv message queue
Expand Down
118 changes: 118 additions & 0 deletions cpp/include/raft/comms/nccl_clique.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/comms/nccl_clique.hpp>
#include <raft/core/device_resources.hpp>

/**
* @brief Error checking macro for NCCL runtime API functions.
*
* Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an
* exception detailing the NCCL error that occurred
*/
#define RAFT_NCCL_TRY(call) \
do { \
ncclResult_t const status = (call); \
if (ncclSuccess != status) { \
std::string msg{}; \
SET_ERROR_MSG(msg, \
"NCCL error encountered at: ", \
"call='%s', Reason=%d:%s", \
#call, \
status, \
ncclGetErrorString(status)); \
throw raft::logic_error(msg); \
} \
} while (0);

namespace raft::comms {
void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank);
}

namespace raft::comms {

nccl_clique::nccl_clique(int percent_of_free_memory)
: root_rank_(0),
percent_of_free_memory_(percent_of_free_memory),
per_device_pools_(0),
device_resources_(0)
{
cudaGetDeviceCount(&num_ranks_);
device_ids_.resize(num_ranks_);
std::iota(device_ids_.begin(), device_ids_.end(), 0);
nccl_comms_.resize(num_ranks_);
nccl_clique_init();
}

nccl_clique::nccl_clique(const std::vector<int>& device_ids, int percent_of_free_memory)
: root_rank_(0),
num_ranks_(device_ids.size()),
percent_of_free_memory_(percent_of_free_memory),
device_ids_(device_ids),
nccl_comms_(device_ids.size()),
per_device_pools_(0),
device_resources_(0)
{
nccl_clique_init();
}

void nccl_clique::nccl_clique_init()
{
RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data()));

for (int rank = 0; rank < num_ranks_; rank++) {
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank]));

// create a pool memory resource for each device
auto old_mr = rmm::mr::get_current_device_resource();
per_device_pools_.push_back(std::make_unique<pool_mr>(
old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory_)));
rmm::cuda_device_id id(device_ids_[rank]);
rmm::mr::set_per_device_resource(id, per_device_pools_.back().get());

// create a device resource handle for each device
device_resources_.emplace_back();

// add NCCL communications to the device resource handle
raft::comms::build_comms_nccl_only(
&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank);
}

for (int rank = 0; rank < num_ranks_; rank++) {
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank]));
raft::resource::sync_stream(device_resources_[rank]);
}
}

const raft::device_resources& nccl_clique::set_current_device_to_root_rank() const
{
int root_device_id = device_ids_[root_rank_];
RAFT_CUDA_TRY(cudaSetDevice(root_device_id));
return device_resources_[root_rank_];
}

nccl_clique::~nccl_clique()
{
#pragma omp parallel for // necessary to avoid hangs
for (int rank = 0; rank < num_ranks_; rank++) {
cudaSetDevice(device_ids_[rank]);
ncclCommDestroy(nccl_comms_[rank]);
rmm::cuda_device_id id(device_ids_[rank]);
rmm::mr::set_per_device_resource(id, nullptr);
}
}

} // namespace raft::comms
71 changes: 71 additions & 0 deletions cpp/include/raft/comms/nccl_clique.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <nccl.h>

namespace raft {
struct device_resources;
}

namespace raft::comms {

struct nccl_clique {
using pool_mr = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;

/**
* Instantiates a NCCL clique with all available GPUs
*
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool
*
*/
nccl_clique(int percent_of_free_memory = 80);

/**
* Instantiates a NCCL clique
*
* Usage example:
* @code{.cpp}
* int n_devices;
* cudaGetDeviceCount(&n_devices);
* std::vector<int> device_ids(n_devices);
* std::iota(device_ids.begin(), device_ids.end(), 0);
* cuvs::neighbors::mg::nccl_clique& clique(device_ids); // first device is the root rank
* @endcode
*
* @param[in] device_ids list of device IDs to be used to initiate the clique
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool
*
*/
nccl_clique(const std::vector<int>& device_ids, int percent_of_free_memory = 80);

void nccl_clique_init();
const raft::device_resources& set_current_device_to_root_rank() const;
~nccl_clique();

int root_rank_;
int num_ranks_;
int percent_of_free_memory_;
std::vector<int> device_ids_;
std::vector<ncclComm_t> nccl_comms_;
std::vector<std::shared_ptr<pool_mr>> per_device_pools_;
std::vector<raft::device_resources> device_resources_;
};

} // namespace raft::comms
6 changes: 6 additions & 0 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <raft/core/resource/device_id.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/core/resource/nccl_clique_handle.hpp>
#include <raft/core/resource/sub_comms.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
Expand Down Expand Up @@ -123,6 +124,11 @@ class device_resources : public resources {

rmm::exec_policy_nosync& get_thrust_policy() const { return resource::get_thrust_policy(*this); }

raft::comms::nccl_clique get_nccl_clique_handle() const
{
return resource::get_nccl_clique_handle(*this);
}

/**
* @brief synchronize a stream on the current container
*/
Expand Down
67 changes: 67 additions & 0 deletions cpp/include/raft/core/resource/nccl_clique_handle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/comms/nccl_clique.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <memory>

namespace raft::resource {

class nccl_clique_resource : public resource {
public:
nccl_clique_resource() : clique_() {}
~nccl_clique_resource() noexcept override {}
void* get_resource() override { return &clique_; }

private:
raft::comms::nccl_clique clique_;
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class nccl_clique_resource_factory : public resource_factory {
public:
resource_type get_resource_type() override { return resource_type::NCCL_CLIQUE; }
resource* make_resource() override { return new nccl_clique_resource(); }
};

/**
* @defgroup nccl_clique_resource resource functions
* @{
*/

/**
* Retrieves a NCCL clique from raft res if it exists, otherwise initializes it and return it.
*
* @param[in] res the raft resources object
* @return NCCL clique
*/
inline raft::comms::nccl_clique get_nccl_clique_handle(resources const& res)
{
if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) {
res.add_resource_factory(std::make_shared<nccl_clique_resource_factory>());
}
auto ret = *res.get_resource<raft::comms::nccl_clique>(resource_type::NCCL_CLIQUE);
return ret;
};

/**
* @}
*/

} // namespace raft::resource
1 change: 1 addition & 0 deletions cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum resource_type {
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations
NCCL_CLIQUE, // nccl clique

LAST_KEY // reserved for the last key
};
Expand Down

0 comments on commit ea6b160

Please sign in to comment.