From 93512f2feef6d0049153885165304748f2ea61fc Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 3 Mar 2024 08:02:30 -0500 Subject: [PATCH] add dfdx-mamba dep --- .gitignore | 1 + Cargo.toml | 22 +++++++++++----------- src/common/mamba.rs | 16 ++++++++++------ src/common/mod.rs | 2 +- src/wasm/non_ui.rs | 12 +++--------- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index ea8c4bf..96ef6c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 5da5b91..8f00971 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,19 +20,19 @@ required-features = ["native"] # dfdx version containing necessary PRs [dependencies.dfdx] -# path = "../dfdx/dfdx" git = 'https://github.com/swfsql/dfdx.git' -rev = "bff1b658aa3b91b0af57f037c7cc704d216d1f03" +branch = "this-main" +# rev = "c4a2995" default-features = false -features = [ - "safetensors", - "nightly", - # "std", - # "cpu", - # "fast-alloc", - # "cuda", - # "cudnn", -] +features = ["nightly", "safetensors"] + +[dependencies.dfdx-mamba] +git = 'https://github.com/swfsql/dfdx-mamba.git' +branch = "main" +# path = "../dfdx-mamba" +# rev = "3b7549845baff53c8fab51068a04d87e21ba0470" +features = ["nightly", "safetensors"] + [dependencies] anyhow = "1.0.0" diff --git a/src/common/mamba.rs b/src/common/mamba.rs index 8665d61..a67558a 100644 --- a/src/common/mamba.rs +++ b/src/common/mamba.rs @@ -54,9 +54,11 @@ pub mod types { pub type DInner = usize; /// A [MambaBlockConfig] set to runtime values. - pub type MambaBlockDynConfig = MambaBlockConfig; + pub type MambaBlockDynConfig = + dfdx_mamba::MambaBlockConfig; /// A [MambaBlock] set to runtime values. - pub type MambaBlockDyn = MambaBlock; + pub type MambaBlockDyn = + dfdx_mamba::MambaBlock; } #[derive(Default, Debug, Clone, CustomModule)] @@ -110,8 +112,9 @@ impl MambaConfig { layers: { let mut layers = Vec::with_capacity(n_layer); for _ in 0..n_layer { - let mamba_block = - MambaBlockConfig::new(d_model, d_state, dt_rank, d_conv, d_inner); + let mamba_block = dfdx_mamba::MambaBlockConfig::new( + d_model, d_state, dt_rank, d_conv, d_inner, + ); let norm = LayerRMSNorm1DConfig(d_model); let residual = ResidualAdd((norm, mamba_block)); let layer = ResidualMambaBlockConfig { res: residual }; @@ -193,7 +196,8 @@ pub mod stateful { pub type SingleInput = Tensor<(Batch, DModel), E, D, T>; /// A [MambaStateCache] set to runtime values. - pub type StateCache = MambaStateCache; + pub type StateCache = + dfdx_mamba::MambaStateCache; /// A list containing a [MambaStateCache] per [MambaBlock] (stateful). pub type MambaStatesDyn = Vec>; @@ -210,7 +214,7 @@ pub mod stateful { fn try_forward(&self, x: VocabInputWithStates) -> Result { let (x, states): ( VocabInput, - Vec>, + Vec>, ) = x; let mut x: SingleInput = self.embedding.try_forward(x)?; diff --git a/src/common/mod.rs b/src/common/mod.rs index 4839883..a690e47 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -77,7 +77,7 @@ impl MambaWrapper { let cpu = self.mamba.embedding.weight.device(); let mut states = vec![]; for _ in 0..self.mamba.layers.len() { - let state = cpu.try_build_module::(dfdx::nn::MambaStateCacheConfig::new( + let state = cpu.try_build_module::(dfdx_mamba::MambaStateCacheConfig::new( 1, 16, 4, diff --git a/src/wasm/non_ui.rs b/src/wasm/non_ui.rs index 8c36c9b..7208c9a 100644 --- a/src/wasm/non_ui.rs +++ b/src/wasm/non_ui.rs @@ -30,8 +30,7 @@ pub async fn run() -> anyhow::Result<()> { .get(&FilePath( hf::mamba_130m::FILE_PATH_MODEL_SAFETENSORS.into(), )) - .await - .unwrap(); + .await?; log::info!( "finished downloading/checking the mamba model in {}ms", timing.elapsed().as_millis() @@ -39,7 +38,7 @@ pub async fn run() -> anyhow::Result<()> { timing = web_time::Instant::now(); log::info!("loading tokenizer data"); - let tokenizer = api.load_bytes(&tokenizer_filename).await; + let tokenizer = api.load_bytes(&tokenizer_filename).await.unwrap(); log::info!( "tokenizer data loaded in {}ms", timing.elapsed().as_millis() @@ -53,7 +52,7 @@ pub async fn run() -> anyhow::Result<()> { timing = web_time::Instant::now(); log::info!("loading mamba data"); - let mamba_bytes = api.load_bytes(&mamba_filename).await; + let mamba_bytes = api.load_bytes(&mamba_filename).await.unwrap(); log::info!("mamba data loaded in {}ms", timing.elapsed().as_millis()); // ~2-3s let mamba = { let n_layer = 24; @@ -138,10 +137,5 @@ pub async fn run() -> anyhow::Result<()> { ); log::info!("{output}"); - // models.run_stateless("Mamba is the", 14, &mut processor)?; - // println!(); - // models.run_stateful("Mamba is the", 5000, &mut processor)?; - // println!(); - Ok(()) }