Skip to content

Commit

Permalink
Merge pull request #1369 from rstudio/fix/generator-deadlock
Browse files Browse the repository at this point in the history
Fix deadlock when passing generators to `fit()` with TF 2.13
  • Loading branch information
t-kalinowski authored Jul 7, 2023
2 parents f0d9835 + 5c4fcb2 commit d8c515f
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 66 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
on:
workflow_dispatch:
push:
branches:
- main
Expand Down Expand Up @@ -35,7 +36,7 @@ jobs:
- {os: 'windows-latest', tf: 'release', r: 'release'}
- {os: 'macOS-latest' , tf: 'release', r: 'release'}

- {os: 'ubuntu-latest', tf: '2.13.0rc1', r: 'release'}
- {os: 'ubuntu-latest', tf: '2.13', r: 'release'}
- {os: 'ubuntu-latest', tf: '2.12', r: 'release'}
- {os: 'ubuntu-latest', tf: '2.11', r: 'release'}
- {os: 'ubuntu-latest', tf: '2.10', r: 'release'}
Expand Down
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Depends:
R (>= 3.4)
Imports:
generics (>= 0.0.1),
reticulate (> 1.22),
reticulate (>= 1.30.9000),
tensorflow (>= 2.8.0),
tfruns (>= 1.0),
magrittr,
Expand All @@ -51,3 +51,5 @@ Suggests:
Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.2.3
VignetteBuilder: knitr
Remotes:
rstudio/reticulate
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ importFrom(reticulate,dict)
importFrom(reticulate,import)
importFrom(reticulate,import_builtins)
importFrom(reticulate,import_from_path)
importFrom(reticulate,iter_next)
importFrom(reticulate,iterate)
importFrom(reticulate,py_call)
importFrom(reticulate,py_capture_output)
Expand Down
3 changes: 2 additions & 1 deletion R/model-persistence.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ model_from_yaml <- function(yaml, custom_objects = NULL) {
#' @export
serialize_model <- function(model, include_optimizer = TRUE) {

if (!inherits(model, "keras.engine.training.Model"))
if (!inherits(model, c("keras.engine.training.Model",
"keras.src.engine.training.Model")))
stop("You must pass a Keras model object to serialize_model")

# write hdf5 file to temp file
Expand Down
47 changes: 16 additions & 31 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ resolve_input_data <- function(x, y = NULL) {
args$x <- as_generator(x)
} else if (inherits(x, "python.builtin.iterator")) {
args$x <- x
} else if (inherits(x, "keras.utils.data_utils.Sequence")) {
} else if (inherits(x, c("keras.src.utils.data_utils.Sequence",
"keras.utils.data_utils.Sequence"))) {
args$x <- x
} else {
if (!is.null(x))
Expand All @@ -576,7 +577,8 @@ resolve_validation_data <- function(validation_data) {
args$validation_data <- as_generator(validation_data)
else if (inherits(validation_data, "python.builtin.iterator"))
args$validation_data <- validation_data
else if (inherits(validation_data, "keras.utils.data_utils.Sequence"))
else if (inherits(validation_data, c("keras.src.utils.data_utils.Sequence",
"keras.utils.data_utils.Sequence")))
args$validation_data <- validation_data
else {
args$validation_data <- keras_array(validation_data)
Expand All @@ -593,32 +595,12 @@ resolve_main_thread_generators <- function(x, callback_type = "on_train_batch_be
stop("Using generators that call R functions is not supported in TensorFlow 2.1 ",
"Please upgrade your TF installation or downgrade to 2.0", call. = FALSE)

# we need a hack to make sure the generator is evaluated in the main thread.
python_path <- system.file("python", package = "keras")
tools <- reticulate::import_from_path("kerastools", path = python_path)

# as_generator will return a tuple with 2 elements.
# (1) a python generator that just consumes
# a queue.
# (2) a function that evaluates the next element of the generator
# and adds to the queue. This function should be called in the main
# thread.
# we add a `on_train_batch_begin` to call this function.
o <- tools$model$as_generator(x)

callback <- list(function(batch, logs) {
o[[2]]()
})
names(callback) <- callback_type

if (callback_type == "on_test_batch_begin") {
callback[[2]] <- callback[[1]]
names(callback)[[2]] <- "on_test_begin"
}

callback <- do.call(callback_lambda, callback)

list(generator = o[[1]], callback = callback)
# This used to house a mechanism for adding a keras callback that pumps
# the R generator from the main thread (e.g., from 'on_train_batch_begin').
# This has since been fixed upstream, by adding a `prefetch` arg to
# reticulate::py_iterator()
# TODO: remove `resolve_main_thread_generators()` from package
list(generator = x, callback = NULL)
}

#' Train a Keras model
Expand Down Expand Up @@ -1289,7 +1271,7 @@ as_generator.tensorflow.python.data.ops.dataset_ops.DatasetV2 <- function(x) {
as_generator.function <- function(x) {
python_path <- system.file("python", package = "keras")
tools <- reticulate::import_from_path("kerastools", path = python_path)
iter <- reticulate::py_iterator(function() {
reticulate::py_iterator(function() {
elem <- keras_array(x())

# deals with the case where the generator is used for prediction and only
Expand All @@ -1298,8 +1280,8 @@ as_generator.function <- function(x) {
elem[[2]] <- list()

do.call(reticulate::tuple, elem)
})
tools$generator$iter_generator(iter)
}, prefetch = 1L)

}

as_generator.keras_preprocessing.sequence.TimeseriesGenerator <- function(x) {
Expand Down Expand Up @@ -1354,6 +1336,9 @@ is_main_thread_generator.keras_preprocessing.sequence.TimeseriesGenerator <- fun
FALSE
}

is_main_thread_generator.keras.src.preprocessing.sequence.TimeseriesGenerator <-
is_main_thread_generator.keras_preprocessing.sequence.TimeseriesGenerator

is_tensorflow_dataset <- function(x) {
inherits(x, "tensorflow.python.data.ops.dataset_ops.DatasetV2") ||
inherits(x, "tensorflow.python.data.ops.dataset_ops.Dataset")
Expand Down
3 changes: 2 additions & 1 deletion R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ keras <- NULL

# let KerasTensor inherit all the S3 methods of tf.Tensor, but
# KerasTensor methods take precedence.
if("keras.engine.keras_tensor.KerasTensor" %in% classes)
if(any(c("keras.src.engine.keras_tensor.KerasTensor",
"keras.engine.keras_tensor.KerasTensor") %in% classes))
classes <- unique(c("keras.engine.keras_tensor.KerasTensor",
"tensorflow.tensor",
classes))
Expand Down
53 changes: 41 additions & 12 deletions R/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,20 @@ image_array_save <- function(img, path, data_format = NULL, file_format = NULL,



#' Generate batches of image data with real-time data augmentation. The data will be
#' looped over (in batches).
#' [Deprecated] Generate batches of image data with real-time data augmentation.
#' The data will be looped over (in batches).
#'
#' Deprecated: `image_data_generator` is not
#' recommended for new code. Prefer loading images with
#' `image_dataset_from_directory` and transforming the output
#' TF Dataset with preprocessing layers. For more information, see the
#' tutorials for loading images and augmenting images, as well as the
#' preprocessing layer guide.
#'
#' @param featurewise_center Set input mean to 0 over the dataset, feature-wise.
#' @param samplewise_center Boolean. Set each sample mean to 0.
#' @param featurewise_std_normalization Divide inputs by std of the dataset, feature-wise.
#' @param featurewise_std_normalization Divide inputs by std of the dataset,
#' feature-wise.
#' @param samplewise_std_normalization Divide each input by its std.
#' @param zca_whitening apply ZCA whitening.
#' @param zca_epsilon Epsilon for ZCA whitening. Default is 1e-6.
Expand All @@ -630,12 +638,11 @@ image_array_save <- function(img, path, data_format = NULL, file_format = NULL,
#' @param brightness_range the range of brightness to apply
#' @param shear_range shear intensity (shear angle in radians).
#' @param zoom_range amount of zoom. if scalar z, zoom will be randomly picked
#' in the range `[1-z, 1+z]`. A sequence of two can be passed instead to select
#' this range.
#' in the range `[1-z, 1+z]`. A sequence of two can be passed instead to
#' select this range.
#' @param channel_shift_range shift range for each channels.
#' @param fill_mode One of "constant", "nearest", "reflect" or "wrap".
#' Points outside the boundaries of the input are filled according to
#' the given mode:
#' @param fill_mode One of "constant", "nearest", "reflect" or "wrap". Points
#' outside the boundaries of the input are filled according to the given mode:
#' - "constant": `kkkkkkkk|abcd|kkkkkkkk` (`cval=k`)
#' - "nearest": `aaaaaaaa|abcd|dddddddd`
#' - "reflect": `abcddcba|abcd|dcbaabcd`
Expand All @@ -649,14 +656,15 @@ image_array_save <- function(img, path, data_format = NULL, file_format = NULL,
#' other transformation).
#' @param preprocessing_function function that will be implied on each input.
#' The function will run before any other modification on it. The function
#' should take one argument: one image (tensor with rank 3), and should
#' output a tensor with the same shape.
#' should take one argument: one image (tensor with rank 3), and should output
#' a tensor with the same shape.
#' @param data_format 'channels_first' or 'channels_last'. In 'channels_first'
#' mode, the channels dimension (the depth) is at index 1, in 'channels_last'
#' mode it is at index 3. It defaults to the `image_data_format` value found
#' in your Keras config file at `~/.keras/keras.json`. If you never set it,
#' then it will be "channels_last".
#' @param validation_split fraction of images reserved for validation (strictly between 0 and 1).
#' @param validation_split fraction of images reserved for validation (strictly
#' between 0 and 1).
#'
#' @export
image_data_generator <- function(featurewise_center = FALSE, samplewise_center = FALSE,
Expand Down Expand Up @@ -685,13 +693,19 @@ image_data_generator <- function(featurewise_center = FALSE, samplewise_center =
preprocessing_function = preprocessing_function,
data_format = data_format
)

if (keras_version() >= "2.0.4")
args$zca_epsilon <- zca_epsilon
if (keras_version() >= "2.1.5") {
args$brightness_range <- brightness_range
args$validation_split <- validation_split
}

if(is.function(preprocessing_function) &&
!inherits(preprocessing_function, "python.builtin.object"))
args$preprocessing_function <-
reticulate::py_main_thread_func(preprocessing_function)

do.call(keras$preprocessing$image$ImageDataGenerator, args)

}
Expand Down Expand Up @@ -766,6 +780,7 @@ fit_image_data_generator <- function(object, x, augment = FALSE, rounds = 1, see
#'
#' @family image preprocessing
#'
#' @importFrom reticulate iter_next
#' @export
flow_images_from_data <- function(
x, y = NULL, generator = image_data_generator(), batch_size = 32,
Expand All @@ -790,7 +805,21 @@ flow_images_from_data <- function(
if (keras_version() >= "2.2.0")
args$sample_weight <- sample_weight

do.call(generator$flow, args)
iterator <- do.call(generator$flow, args)

if(!is.null(generator$preprocessing_function)) {
# user supplied a custom preprocessing function, which likely is an R
# function that must be called from the main thread. Wrap this in
# py_iterator(prefetch=1) to ensure we don't end in a deadlock.
iter_env <- new.env(parent = parent.env(environment())) # pkg namespace
iter_env$.iterator <- iterator
expr <- substitute(py_iterator(function() iter_next(iterator), prefetch=1L),
list(iterator = quote(.iterator)))
iterator <- eval(expr, iter_env)
}

iterator

}

#' Generates batches of data from images in a directory (with optional
Expand Down
2 changes: 1 addition & 1 deletion R/py-classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ print.py_R6ClassGenerator <- function(x, ...) {
#' @export
`$.py_R6ClassGenerator` <- function(x, name) {
if (identical(name, "new"))
return(self)
return(x)
NextMethod()
}

Expand Down
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ py_to_r_wrapper.keras.src.engine.training.Model <- py_to_r_wrapper.keras.engine.

#' @export
summary.keras.src.engine.training.Model <- summary.keras.engine.training.Model

as_generator.keras.src.utils.data_utils.Sequence <- as_generator.keras_preprocessing.sequence.TimeseriesGenerator
31 changes: 18 additions & 13 deletions man/image_data_generator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions tests/testthat/helper-utils.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 1)
# Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 1)
# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printed

if(reticulate::virtualenv_exists("r-tensorflow"))
if(!reticulate::py_available() && reticulate::virtualenv_exists("r-tensorflow"))
reticulate::use_virtualenv("r-tensorflow")

if(reticulate::py_available()) {
print(reticulate::py_config())
} else {
setHook("reticulate.onPyInit", function() print(reticulate::py_config()))
}

# Sys.setenv(RETICULATE_PYTHON = "~/.local/share/r-miniconda/envs/tf-2.7-cpu/bin/python")
# Sys.setenv(RETICULATE_PYTHON = "~/.local/share/r-miniconda/envs/tf-nightly-cpu/bin/python")
# reticulate::use_condaenv("tf-2.5-cpu", required = TRUE)
Expand Down Expand Up @@ -147,3 +154,4 @@ local_tf_device <- function(device_name = "CPU") {
withr::defer_parent(device$`__exit__`())
invisible(device)
}

6 changes: 4 additions & 2 deletions tests/testthat/test-callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ if (tensorflow::tf_version() <= "2.1")
test_callback("progbar_logger", callback_progbar_logger())


test_callback("model_checkpoint", callback_model_checkpoint(tempfile(fileext = ".h5")), h5py = TRUE)
test_callback("model_checkpoint",
callback_model_checkpoint(tempfile(fileext = ".keras")),
h5py = TRUE)

if(tf_version() >= "2.8")
test_callback("backup_and_restore", callback_backup_and_restore(tempfile()))
Expand Down Expand Up @@ -252,7 +254,7 @@ test_succeeds("on predict/evaluation callbacks", {

warns <- capture_warnings(
out <- capture_output(
pred <- predict(model, gen, callbacks = cc, steps = 1)
pred <- predict(model, gen, callbacks = cc, steps = 5)
)
)
expect_warns_and_out(warns, out)
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ test_metric <- function(metric, ...) {
m <- metric(...)

expect_s3_class(m, c("keras.metrics.Metric",
'keras.metrics.base_metric.Metric'))
'keras.metrics.base_metric.Metric',
'keras.src.metrics.base_metric.Metric'))

define_model() %>%
compile(loss = loss,
Expand Down

0 comments on commit d8c515f

Please sign in to comment.