Skip to content

Commit

Permalink
Fix task_cache example (#1712)
Browse files Browse the repository at this point in the history
* Fix task_cache example

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

---------

Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw authored Jul 24, 2024
1 parent d091fbd commit 91c6c80
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions examples/development_lifecycle/development_lifecycle/task_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
# %% [markdown]
# For any {py:func}`flytekit.task` in Flyte, there is always one required import, which is:
# %%
from flytekit import HashMethod, task, workflow
from flytekit import HashMethod, ImageSpec, task, workflow
from flytekit.core.node_creation import create_node
from typing_extensions import Annotated

image_spec = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["pandas"],
)


# Task caching is disabled by default to avoid unintended consequences of
# caching tasks with side effects. To enable caching and control its behavior,
Expand All @@ -19,7 +24,7 @@
# Bumping the `cache_version` is akin to invalidating the cache.
# You can manually update this version and Flyte caches the next execution
# instead of relying on the old cache.
@task(cache=True, cache_version="1.0") # noqa: F841
@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F841
def square(n: int) -> int:
"""
Parameters:
Expand All @@ -36,14 +41,14 @@ def square(n: int) -> int:
# Caching of Non-flyte Offloaded Objects
# The default behavior displayed by Flyte's memoization feature might not
# match the user intuition. For example, this code makes use of pandas dataframes:
@task
@task(container_image=image_spec)
def foo(a: int, b: str) -> pandas.DataFrame:
df = pandas.DataFrame(...)
...
return df


@task(cache=True, cache_version="1.0")
@task(cache=True, cache_version="1.0", container_image=image_spec)
def bar(df: pandas.DataFrame) -> int:
return 1

Expand All @@ -65,7 +70,7 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str:
return str(pandas.util.hash_pandas_object(df))


@task
@task(container_image=image_spec)
def foo_1( # noqa: F811
a: int,
b: str, # noqa: F821
Expand All @@ -75,7 +80,7 @@ def foo_1( # noqa: F811
return df


@task(cache=True, cache_version="1.0") # noqa: F811
@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F811
def bar_1(df: pandas.DataFrame) -> int: # noqa: F811
return 1

Expand All @@ -99,18 +104,18 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str:
return str(pandas.util.hash_pandas_object(df))


@task
@task(container_image=image_spec)
def uncached_data_reading_task() -> Annotated[pandas.DataFrame, HashMethod(hash_pandas_dataframe)]:
return pandas.DataFrame({"column_1": [1, 2, 3]})


@task(cache=True, cache_version="1.0")
@task(cache=True, cache_version="1.0", container_image=image_spec)
def cached_data_processing_task(df: pandas.DataFrame) -> pandas.DataFrame:
time.sleep(1)
return df * 2


@task
@task(container_image=image_spec)
def compare_dataframes(df1: pandas.DataFrame, df2: pandas.DataFrame):
assert df1.equals(df2)

Expand Down

0 comments on commit 91c6c80

Please sign in to comment.