Skip to content

Commit

Permalink
Respects torch.device(0) new behavior without breaking backward
Browse files Browse the repository at this point in the history
compatibilty.
  • Loading branch information
Narsil committed Jul 31, 2024
1 parent c00471e commit f665edb
Showing 1 changed file with 22 additions and 45 deletions.
67 changes: 22 additions & 45 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ enum Device {
Npu(usize),
Xpu(usize),
Xla(usize),
/// User didn't specify acceletor, torch
/// is responsible for choosing.
Anonymous(usize),
}

/// Parsing the device index.
fn parse_device(name: &str) -> PyResult<usize> {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(device)
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}

impl<'source> FromPyObject<'source> for Device {
Expand All @@ -279,56 +295,16 @@ impl<'source> FromPyObject<'source> for Device {
"npu" => Ok(Device::Npu(0)),
"xpu" => Ok(Device::Xpu(0)),
"xla" => Ok(Device::Xla(0)),
name if name.starts_with("cuda:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Cuda(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("npu:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Npu(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("xpu:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Xpu(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("xla:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Xla(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
name if name.starts_with("xla:") => parse_device(name).map(Device::Xla),
name => Err(SafetensorError::new_err(format!(
"device {name} is invalid"
))),
}
} else if let Ok(number) = ob.extract::<usize>() {
Ok(Device::Cuda(number))
Ok(Device::Anonymous(number))
} else {
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
}
Expand All @@ -344,6 +320,7 @@ impl IntoPy<PyObject> for Device {
Device::Npu(n) => format!("npu:{n}").into_py(py),
Device::Xpu(n) => format!("xpu:{n}").into_py(py),
Device::Xla(n) => format!("xla:{n}").into_py(py),
Device::Anonymous(n) => format!("{n}").into_py(py),
}
}
}
Expand Down

0 comments on commit f665edb

Please sign in to comment.