Skip to content

Commit

Permalink
feat(openai): add openai flag dynamic js snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
j-mendez committed Mar 20, 2024
1 parent 473ed26 commit 83c7e07
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 21 deletions.
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
[package]
edition = "2021"
name = "spider_rs"
version = "0.0.27"
description = "The fastest web crawler written in Rust ported to nodejs."
repository = "https://github.com/spider-rs/spider-nodejs"
version = "0.0.30"
repository = "https://github.com/spider-rs/spider-py"
license = "MIT"

[lib]
crate-type = ["cdylib"]

[dependencies]
indexmap = "2.1.0"
num_cpus = "1.16.0"
spider = { version = "1.85.4", features = ["budget", "cron", "regex", "cookies", "socks", "chrome", "control", "smart", "chrome_intercept", "cache" ] }
pyo3 = { version = "0.20.3", features = ["extension-module"] }
spider = { version = "1.86.11", features = ["budget", "cron", "regex", "cookies", "socks", "chrome", "control", "smart", "chrome_intercept", "cache", "serde", "openai" ] }
pyo3 = { version = "0.20.3", features = ["extension-module", "serde"] }
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
serde_json = "1.0.114"

[target.x86_64-unknown-linux-gnu.dependencies]
openssl-sys = { version = "0.9.96", features = ["vendored"] }
Expand Down
41 changes: 41 additions & 0 deletions book/src/website.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,47 @@ async def main():
asyncio.run(main())
```
### OpenAI
Use OpenAI to generate dynamic scripts to use with headless. Make sure to set the `OPENAI_API_KEY` env variable.
```py
import asyncio
from spider_rs import Website

async def main():
website = Website("https://choosealicense.com").with_openai({ model: "gpt-3.5-turbo", prompt: "Search for movies", maxTokens: 300 })

asyncio.run(main())
```
### Screenshots
Take a screenshot of the pages on crawl when using headless chrome.
```py
import asyncio
from spider_rs import Website

async def main():
website = (
Website("https://choosealicense.com", False)
.with_screenshot({
"params": {
"cdp_params": None,
"full_page": True,
"omit_background": False
},
"bytes": False,
"save": True,
"output_dir": None
})
)

asyncio.run(main())
```
### Http2 Prior Knowledge
Use http2 to connect if you know the website servers supports this.
Expand Down
9 changes: 7 additions & 2 deletions examples/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from spider_rs import Website

async def main():
website = Website("https://choosealicense.com", False).with_agent("BotBot").with_headers({ "authorization": "Something "})
website = (
Website("https://choosealicense.com", False)
.with_user_agent("BotBot")
.with_headers({"authorization": "Something "})
)
website.crawl()
print(website.get_links())

asyncio.run(main())

asyncio.run(main())
29 changes: 29 additions & 0 deletions examples/screenshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import asyncio

from spider_rs import Website

async def main():
website = (
Website("https://choosealicense.com", False)
.with_screenshot({
"params": {
"cdp_params": {
"format": None,
"quality": None,
"clip": None,
"from_surface": None,
"capture_beyond_viewport": None
},
"full_page": True,
"omit_background": False
},
"bytes": False,
"save": True,
"output_dir": None
})
)
website.crawl(None, None, True)
print(website.get_links())


asyncio.run(main())
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
[build-system]
requires = ["maturin>=1,<2"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"

[tool.maturin]
features = ["pyo3/extension-module"]

[project]
name = "spider_rs"
requires-python = ">=3.7"
summary = "The fastest web crawler written in Rust"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ pub mod npage;
pub mod nwebsite;
pub mod page;
pub mod shortcut;
pub mod utils;
pub mod website;

pub use npage::{new_page, page_title, NPage};
pub use nwebsite::NWebsite;
pub use page::Page;
pub use utils::pydict_to_json_value;
pub use website::Website;

#[pyfunction]
Expand Down
70 changes: 70 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use pyo3::types::{PyAny, PyDict, PyList};
use pyo3::PyResult;
use serde_json::Value as JsonValue;

/// convert pyobject to json value
pub fn pyobj_to_json_value(obj: &PyAny) -> PyResult<JsonValue> {
// Handle None
if obj.is_none() {
return Ok(JsonValue::Null);
}
// Handle boolean
else if let Ok(val) = obj.extract::<bool>() {
return Ok(JsonValue::Bool(val));
}
// Handle integers
else if let Ok(val) = obj.extract::<i64>() {
return Ok(JsonValue::Number(val.into()));
}
// Handle floats
else if let Ok(val) = obj.extract::<f64>() {
if let Some(num) = serde_json::Number::from_f64(val) {
return Ok(JsonValue::Number(num));
} else {
return Err(pyo3::exceptions::PyValueError::new_err(
"Float value out of range",
));
}
}
// Handle strings
else if let Ok(val) = obj.extract::<&str>() {
return Ok(JsonValue::String(val.to_string()));
}
// Handle lists
else if let Ok(list) = obj.downcast::<PyList>() {
let mut vec = Vec::new();
for item in list.iter() {
vec.push(pyobj_to_json_value(item)?);
}
return Ok(JsonValue::Array(vec));
}
// Handle dictionaries
else if let Ok(dict) = obj.downcast::<PyDict>() {
let mut map = serde_json::Map::new();
for (k, v) in dict.iter() {
let key: &str = k.extract()?;
let value = pyobj_to_json_value(v)?;
map.insert(key.to_string(), value);
}
return Ok(JsonValue::Object(map));
}
// Catch-all for unsupported types
else {
Err(pyo3::exceptions::PyTypeError::new_err(
"Unsupported Python type",
))
}
}

/// convert pydict to json value
pub fn pydict_to_json_value(py_dict: &pyo3::types::PyDict) -> PyResult<JsonValue> {
let mut map = serde_json::Map::new();

for (k, v) in py_dict.iter() {
let key: &str = k.extract()?;
let value: JsonValue = pyobj_to_json_value(v)?;
map.insert(key.to_string(), value);
}

Ok(serde_json::Value::Object(map))
}
70 changes: 57 additions & 13 deletions src/website.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{new_page, NPage, BUFFER};
use crate::{new_page, pydict_to_json_value, NPage, BUFFER};
use indexmap::IndexMap;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use spider::compact_str::CompactString;
use spider::configuration::WaitForIdleNetwork;
use spider::tokio::select;
Expand Down Expand Up @@ -741,24 +742,67 @@ impl Website {
slf
}

/// Take a screenshot of the page when using chrome.
pub fn with_screenshot<'a>(
mut slf: PyRefMut<'a, Self>,
screenshot_configs: Option<&'a PyDict>,
) -> PyRefMut<'a, Self> {
if let Some(py_obj) = screenshot_configs {
if let Ok(config_json) = pydict_to_json_value(py_obj) {
match serde_json::from_value::<spider::configuration::ScreenShotConfig>(config_json) {
Ok(configs) => {
slf.inner.with_screenshot(Some(configs));
}
Err(e) => {
spider::utils::log("", e.to_string());
}
}
} else {
spider::utils::log("Error extracting String from PyAny", "");
}
}

slf
}

/// Use OpenAI to generate dynamic javascript snippets. Make sure to set the `OPENAI_API_KEY` env variable.
pub fn with_openai<'a>(
mut slf: PyRefMut<'a, Self>,
openai_configs: Option<&'a PyDict>,
) -> PyRefMut<'a, Self> {
if let Some(py_obj) = openai_configs {
if let Ok(config_json) = pydict_to_json_value(py_obj) {
match serde_json::from_value::<spider::configuration::GPTConfigs>(config_json) {
Ok(configs) => {
slf.inner.with_openai(Some(configs));
}
Err(e) => {
spider::utils::log("", e.to_string());
}
}
} else {
spider::utils::log("Error extracting String from PyAny for OpenAI config", "");
}
}

slf
}

/// Regex black list urls from the crawl
pub fn with_blacklist_url(
mut slf: PyRefMut<'_, Self>,
blacklist_url: Option<Vec<String>>,
) -> PyRefMut<'_, Self> {
slf
.inner
.configuration
.with_blacklist_url(match blacklist_url {
Some(v) => {
let mut blacklist: Vec<CompactString> = Vec::new();
for item in v {
blacklist.push(CompactString::new(item));
}
Some(blacklist)
slf.inner.with_blacklist_url(match blacklist_url {
Some(v) => {
let mut blacklist: Vec<CompactString> = Vec::new();
for item in v {
blacklist.push(CompactString::new(item));
}
_ => None,
});
Some(blacklist)
}
_ => None,
});

slf
}
Expand Down

0 comments on commit 83c7e07

Please sign in to comment.