From 91c6c80fa0f0664c914c9a6a2e34c3e8ba47d73c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 24 Jul 2024 12:19:19 +0800 Subject: [PATCH] Fix task_cache example (#1712) * Fix task_cache example Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- .../development_lifecycle/task_cache.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/development_lifecycle/development_lifecycle/task_cache.py b/examples/development_lifecycle/development_lifecycle/task_cache.py index d33b35124..712e06db3 100644 --- a/examples/development_lifecycle/development_lifecycle/task_cache.py +++ b/examples/development_lifecycle/development_lifecycle/task_cache.py @@ -5,10 +5,15 @@ # %% [markdown] # For any {py:func}`flytekit.task` in Flyte, there is always one required import, which is: # %% -from flytekit import HashMethod, task, workflow +from flytekit import HashMethod, ImageSpec, task, workflow from flytekit.core.node_creation import create_node from typing_extensions import Annotated +image_spec = ImageSpec( + registry="ghcr.io/flyteorg", + packages=["pandas"], +) + # Task caching is disabled by default to avoid unintended consequences of # caching tasks with side effects. To enable caching and control its behavior, @@ -19,7 +24,7 @@ # Bumping the `cache_version` is akin to invalidating the cache. # You can manually update this version and Flyte caches the next execution # instead of relying on the old cache. -@task(cache=True, cache_version="1.0") # noqa: F841 +@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F841 def square(n: int) -> int: """ Parameters: @@ -36,14 +41,14 @@ def square(n: int) -> int: # Caching of Non-flyte Offloaded Objects # The default behavior displayed by Flyte's memoization feature might not # match the user intuition. For example, this code makes use of pandas dataframes: -@task +@task(container_image=image_spec) def foo(a: int, b: str) -> pandas.DataFrame: df = pandas.DataFrame(...) ... return df -@task(cache=True, cache_version="1.0") +@task(cache=True, cache_version="1.0", container_image=image_spec) def bar(df: pandas.DataFrame) -> int: return 1 @@ -65,7 +70,7 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str: return str(pandas.util.hash_pandas_object(df)) -@task +@task(container_image=image_spec) def foo_1( # noqa: F811 a: int, b: str, # noqa: F821 @@ -75,7 +80,7 @@ def foo_1( # noqa: F811 return df -@task(cache=True, cache_version="1.0") # noqa: F811 +@task(cache=True, cache_version="1.0", container_image=image_spec) # noqa: F811 def bar_1(df: pandas.DataFrame) -> int: # noqa: F811 return 1 @@ -99,18 +104,18 @@ def hash_pandas_dataframe(df: pandas.DataFrame) -> str: return str(pandas.util.hash_pandas_object(df)) -@task +@task(container_image=image_spec) def uncached_data_reading_task() -> Annotated[pandas.DataFrame, HashMethod(hash_pandas_dataframe)]: return pandas.DataFrame({"column_1": [1, 2, 3]}) -@task(cache=True, cache_version="1.0") +@task(cache=True, cache_version="1.0", container_image=image_spec) def cached_data_processing_task(df: pandas.DataFrame) -> pandas.DataFrame: time.sleep(1) return df * 2 -@task +@task(container_image=image_spec) def compare_dataframes(df1: pandas.DataFrame, df2: pandas.DataFrame): assert df1.equals(df2)