Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task_cache example #1712

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
# %% [markdown]
# For any {py:func}`flytekit.task` in Flyte, there is always one required import, which is:
# %%
from flytekit import HashMethod, task, workflow
from flytekit import HashMethod, ImageSpec, task, workflow
from flytekit.core.node_creation import create_node
from typing_extensions import Annotated

image_spec = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["pandas"],
)


# Task caching is disabled by default to avoid unintended consequences of
# caching tasks with side effects. To enable caching and control its behavior,
Expand All @@ -19,7 +24,7 @@
# Bumping the `cache_version` is akin to invalidating the cache.
# You can manually update this version and Flyte caches the next execution
# instead of relying on the old cache.
@task(cache=True, cache_version="1.0") # noqa: F841
@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F841
def square(n: int) -> int:
"""
Parameters:
Expand All @@ -36,14 +41,14 @@ def square(n: int) -> int:
# Caching of Non-flyte Offloaded Objects
# The default behavior displayed by Flyte's memoization feature might not
# match the user intuition. For example, this code makes use of pandas dataframes:
@task
@task(container_image=image_spec)
def foo(a: int, b: str) -> pandas.DataFrame:
df = pandas.DataFrame(...)
...
return df


@task(cache=True, cache_version="1.0")
@task(cache=True, cache_version="1.0", container_image=image_spec)
def bar(df: pandas.DataFrame) -> int:
return 1

Expand All @@ -65,7 +70,7 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str:
return str(pandas.util.hash_pandas_object(df))


@task
@task(container_image=image_spec)
def foo_1( # noqa: F811
a: int,
b: str, # noqa: F821
Expand All @@ -75,7 +80,7 @@ def foo_1( # noqa: F811
return df


@task(cache=True, cache_version="1.0") # noqa: F811
@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F811
def bar_1(df: pandas.DataFrame) -> int: # noqa: F811
return 1

Expand All @@ -99,18 +104,18 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str:
return str(pandas.util.hash_pandas_object(df))


@task
@task(container_image=image_spec)
def uncached_data_reading_task() -> Annotated[pandas.DataFrame, HashMethod(hash_pandas_dataframe)]:
return pandas.DataFrame({"column_1": [1, 2, 3]})


@task(cache=True, cache_version="1.0")
@task(cache=True, cache_version="1.0", container_image=image_spec)
def cached_data_processing_task(df: pandas.DataFrame) -> pandas.DataFrame:
time.sleep(1)
return df * 2


@task
@task(container_image=image_spec)
def compare_dataframes(df1: pandas.DataFrame, df2: pandas.DataFrame):
assert df1.equals(df2)

Expand Down
Loading