Skip to content

Commit

Permalink
add dfdx-mamba dep
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Mar 3, 2024
1 parent eb9396c commit 93512f2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/target
Cargo.lock
22 changes: 11 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ required-features = ["native"]

# dfdx version containing necessary PRs
[dependencies.dfdx]
# path = "../dfdx/dfdx"
git = 'https://github.com/swfsql/dfdx.git'
rev = "bff1b658aa3b91b0af57f037c7cc704d216d1f03"
branch = "this-main"
# rev = "c4a2995"
default-features = false
features = [
"safetensors",
"nightly",
# "std",
# "cpu",
# "fast-alloc",
# "cuda",
# "cudnn",
]
features = ["nightly", "safetensors"]

[dependencies.dfdx-mamba]
git = 'https://github.com/swfsql/dfdx-mamba.git'
branch = "main"
# path = "../dfdx-mamba"
# rev = "3b7549845baff53c8fab51068a04d87e21ba0470"
features = ["nightly", "safetensors"]


[dependencies]
anyhow = "1.0.0"
Expand Down
16 changes: 10 additions & 6 deletions src/common/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ pub mod types {
pub type DInner = usize;

/// A [MambaBlockConfig] set to runtime values.
pub type MambaBlockDynConfig = MambaBlockConfig<DModel, DState, DtRank, DConv, DInner>;
pub type MambaBlockDynConfig =
dfdx_mamba::MambaBlockConfig<DModel, DState, DtRank, DConv, DInner>;
/// A [MambaBlock] set to runtime values.
pub type MambaBlockDyn<E, D> = MambaBlock<DModel, DState, DtRank, DConv, DInner, E, D>;
pub type MambaBlockDyn<E, D> =
dfdx_mamba::MambaBlock<DModel, DState, DtRank, DConv, DInner, E, D>;
}

#[derive(Default, Debug, Clone, CustomModule)]
Expand Down Expand Up @@ -110,8 +112,9 @@ impl MambaConfig {
layers: {
let mut layers = Vec::with_capacity(n_layer);
for _ in 0..n_layer {
let mamba_block =
MambaBlockConfig::new(d_model, d_state, dt_rank, d_conv, d_inner);
let mamba_block = dfdx_mamba::MambaBlockConfig::new(
d_model, d_state, dt_rank, d_conv, d_inner,
);
let norm = LayerRMSNorm1DConfig(d_model);
let residual = ResidualAdd((norm, mamba_block));
let layer = ResidualMambaBlockConfig { res: residual };
Expand Down Expand Up @@ -193,7 +196,8 @@ pub mod stateful {
pub type SingleInput<E, D, T> = Tensor<(Batch, DModel), E, D, T>;

/// A [MambaStateCache] set to runtime values.
pub type StateCache<E, D, T> = MambaStateCache<Batch, DState, DConv, DInner, E, D, T>;
pub type StateCache<E, D, T> =
dfdx_mamba::MambaStateCache<Batch, DState, DConv, DInner, E, D, T>;

/// A list containing a [MambaStateCache] per [MambaBlock] (stateful).
pub type MambaStatesDyn<E, D, T> = Vec<StateCache<E, D, T>>;
Expand All @@ -210,7 +214,7 @@ pub mod stateful {
fn try_forward(&self, x: VocabInputWithStates<E, D, T>) -> Result<Self::Output, Error> {
let (x, states): (
VocabInput<D, T>,
Vec<MambaStateCache<Batch, DState, DConv, DInner, E, D, T>>,
Vec<dfdx_mamba::MambaStateCache<Batch, DState, DConv, DInner, E, D, T>>,
) = x;

let mut x: SingleInput<E, D, T> = self.embedding.try_forward(x)?;
Expand Down
2 changes: 1 addition & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl MambaWrapper {
let cpu = self.mamba.embedding.weight.device();
let mut states = vec![];
for _ in 0..self.mamba.layers.len() {
let state = cpu.try_build_module::<f32>(dfdx::nn::MambaStateCacheConfig::new(
let state = cpu.try_build_module::<f32>(dfdx_mamba::MambaStateCacheConfig::new(
1,
16,
4,
Expand Down
12 changes: 3 additions & 9 deletions src/wasm/non_ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@ pub async fn run() -> anyhow::Result<()> {
.get(&FilePath(
hf::mamba_130m::FILE_PATH_MODEL_SAFETENSORS.into(),
))
.await
.unwrap();
.await?;
log::info!(
"finished downloading/checking the mamba model in {}ms",
timing.elapsed().as_millis()
); // ~180s/2s

timing = web_time::Instant::now();
log::info!("loading tokenizer data");
let tokenizer = api.load_bytes(&tokenizer_filename).await;
let tokenizer = api.load_bytes(&tokenizer_filename).await.unwrap();
log::info!(
"tokenizer data loaded in {}ms",
timing.elapsed().as_millis()
Expand All @@ -53,7 +52,7 @@ pub async fn run() -> anyhow::Result<()> {

timing = web_time::Instant::now();
log::info!("loading mamba data");
let mamba_bytes = api.load_bytes(&mamba_filename).await;
let mamba_bytes = api.load_bytes(&mamba_filename).await.unwrap();
log::info!("mamba data loaded in {}ms", timing.elapsed().as_millis()); // ~2-3s
let mamba = {
let n_layer = 24;
Expand Down Expand Up @@ -138,10 +137,5 @@ pub async fn run() -> anyhow::Result<()> {
);
log::info!("{output}");

// models.run_stateless("Mamba is the", 14, &mut processor)?;
// println!();
// models.run_stateful("Mamba is the", 5000, &mut processor)?;
// println!();

Ok(())
}

0 comments on commit 93512f2

Please sign in to comment.