Skip to content

Commit

Permalink
Support black
Browse files Browse the repository at this point in the history
  • Loading branch information
sbsekiguchi committed Aug 5, 2024
1 parent 33330cd commit 850181f
Show file tree
Hide file tree
Showing 391 changed files with 11,569 additions and 9,945 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: check python code format with autopep8
run: |
autopep8 --diff .
- name: check python code format with black
uses: psf/black@stable
with:
options: "--check --verbose"
src: "."
typing:
runs-on: ubuntu-latest
timeout-minutes: 3
Expand Down
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ $ pip install -e .

### Code format guidelines

We use [autopep8](https://github.com/hhatto/autopep8) and [isort](https://github.com/PyCQA/isort) to keep consistent coding style. After finishing developing the code, run autopep8 and isort to ensure that your code is correctly formatted.
We use [black](https://github.com/psf/black) and [isort](https://github.com/PyCQA/isort) to keep consistent coding style. After finishing developing the code, run black and isort to ensure that your code is correctly formatted.

You can run autopep8 and isort as follows.
You can run black and isort as follows.

```sh
cd <nnabla-rl root directory>
autopep8 .
black .
```

```sh
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
[![Build status](https://github.com/sony/nnabla-rl/workflows/Build%20nnabla-rl/badge.svg)](https://github.com/sony/nnabla-rl/actions)
[![Documentation Status](https://readthedocs.org/projects/nnabla-rl/badge/?version=latest)](https://nnabla-rl.readthedocs.io/en/latest/?badge=latest)
[![Doc style](https://img.shields.io/badge/%20style-google-3666d6.svg)](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

# Deep Reinforcement Learning Library built on top of Neural Network Libraries

Expand Down
32 changes: 16 additions & 16 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,14 +18,14 @@

import sphinx_rtd_theme

sys.path.insert(0, os.path.abspath('../../'))
sys.path.insert(0, os.path.abspath("../../"))

import nnabla_rl # noqa

# -- Project information -----------------------------------------------------
project = 'nnablaRL'
copyright = '2021, Sony Group Corporation'
author = 'Sony Group Corporation'
project = "nnablaRL"
copyright = "2021, Sony Group Corporation"
author = "Sony Group Corporation"
release = nnabla_rl.__version__

# -- General configuration ---------------------------------------------------
Expand All @@ -34,22 +34,22 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.doctest',
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.mathjax',
'sphinx.ext.autosummary',
'sphinx.ext.viewcode'
"sphinx.ext.doctest",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.mathjax",
"sphinx.ext.autosummary",
"sphinx.ext.viewcode",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['ntemplates']
templates_path = ["ntemplates"]

source_suffix = '.rst'
source_suffix = ".rst"
# The master toctree document.
master_doc = 'index'
master_doc = "index"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
8 changes: 4 additions & 4 deletions examples/evaluate_trained_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,16 +22,16 @@

def build_env():
try:
env = gym.make('Pendulum-v0')
env = gym.make("Pendulum-v0")
except gym.error.DeprecatedEnv:
env = gym.make('Pendulum-v1')
env = gym.make("Pendulum-v1")
env = NumpyFloat32Env(env)
env = ScreenRenderEnv(env)
return env


def main():
snapshot_dir = './pendulum_v0_snapshot/iteration-10000'
snapshot_dir = "./pendulum_v0_snapshot/iteration-10000"
env = build_env()

algorithm = serializers.load_snapshot(snapshot_dir, env)
Expand Down
6 changes: 3 additions & 3 deletions examples/hook_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,15 +24,15 @@ def __init__(self):
super().__init__(timing=1)

def on_hook_called(self, algorithm):
print('hello!!')
print("hello!!")


class PrintOnlyEvenIteraion(Hook):
def __init__(self):
super().__init__(timing=2)

def on_hook_called(self, algorithm):
print('even iteration -> {}'.format(algorithm.iteration_num))
print("even iteration -> {}".format(algorithm.iteration_num))


def main():
Expand Down
62 changes: 31 additions & 31 deletions examples/recurrent_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,17 +65,17 @@ def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]:
shapes: Dict[str, nn.Variable] = {}
# You can use arbitral (but distinguishable) key as name
# Use same key for same state
shapes['my_lstm_h'] = (self._lstm_state_size, )
shapes['my_lstm_c'] = (self._lstm_state_size, )
shapes["my_lstm_h"] = (self._lstm_state_size,)
shapes["my_lstm_c"] = (self._lstm_state_size,)
return shapes

def get_internal_states(self) -> Dict[str, nn.Variable]:
# Return current internal states
states: Dict[str, nn.Variable] = {}
# You can use arbitral (but distinguishable) key as name
# Use same key for same state.
states['my_lstm_h'] = self._h
states['my_lstm_c'] = self._c
states["my_lstm_h"] = self._h
states["my_lstm_c"] = self._c
return states

def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
Expand All @@ -88,8 +88,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None):
else:
# Otherwise, set given states
# Use the key defined in internal_state_shapes() for getting the states
self._h = states['my_lstm_h']
self._c = states['my_lstm_c']
self._h = states["my_lstm_h"]
self._c = states["my_lstm_c"]

def _create_internal_states(self, batch_size):
self._h = nn.Variable((batch_size, self._lstm_state_size))
Expand All @@ -103,11 +103,9 @@ def _is_internal_state_created(self) -> bool:


class QFunctionWithRNNBuilder(ModelBuilder[QFunction]):
def build_model(self,
scope_name: str,
env_info: EnvironmentInfo,
algorithm_config: AlgorithmConfig,
**kwargs) -> QFunction:
def build_model(
self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs
) -> QFunction:
action_num = env_info.action_dim
return QFunctionWithRNN(scope_name, action_num)

Expand All @@ -118,7 +116,7 @@ def build_solver(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmCon


def build_env(seed=None):
env = gym.make('MountainCar-v0')
env = gym.make("MountainCar-v0")
env = NumpyFloat32Env(env)
env = ScreenRenderEnv(env)
env.seed(seed)
Expand All @@ -133,30 +131,32 @@ def main():
# Here, we use DRQN algorithm
# For the list of algorithms that support RNN layers see
# https://github.com/sony/nnabla-rl/tree/master/nnabla_rl/algorithms
config = A.DRQNConfig(gpu_id=0,
learning_rate=1e-2,
gamma=0.9,
learner_update_frequency=1,
target_update_frequency=200,
start_timesteps=200,
replay_buffer_size=10000,
max_explore_steps=10000,
initial_epsilon=1.0,
final_epsilon=0.001,
test_epsilon=0.001,
grad_clip=None,
unroll_steps=2) # Unroll only for 2 timesteps for fast iteration. Because this is an example
drqn = A.DRQN(train_env,
config=config,
q_func_builder=QFunctionWithRNNBuilder(),
q_solver_builder=AdamSolverBuilder())
config = A.DRQNConfig(
gpu_id=0,
learning_rate=1e-2,
gamma=0.9,
learner_update_frequency=1,
target_update_frequency=200,
start_timesteps=200,
replay_buffer_size=10000,
max_explore_steps=10000,
initial_epsilon=1.0,
final_epsilon=0.001,
test_epsilon=0.001,
grad_clip=None,
unroll_steps=2,
) # Unroll only for 2 timesteps for fast iteration. Because this is an example
drqn = A.DRQN(
train_env, config=config, q_func_builder=QFunctionWithRNNBuilder(), q_solver_builder=AdamSolverBuilder()
)

# Optional: Add hooks to check the training progress
eval_env = build_env(seed=100)
evaluation_hook = H.EvaluationHook(
eval_env,
timing=1000,
writer=W.FileWriter(outdir='./mountain_car_v0_drqn_results', file_prefix='evaluation_result'))
writer=W.FileWriter(outdir="./mountain_car_v0_drqn_results", file_prefix="evaluation_result"),
)
iteration_num_hook = H.IterationNumHook(timing=100)
drqn.set_hooks(hooks=[iteration_num_hook, evaluation_hook])

Expand Down
8 changes: 4 additions & 4 deletions examples/rl_project_template/environment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sony Group Corporation.
# Copyright 2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,15 +20,15 @@ def __init__(self, max_episode_steps=100):
# max_episode_steps is the maximum possible steps that the rl agent interacts with this environment.
# You can set this value to None if the is no limits.
# The first argument is the name of this environment used when registering this environment to gym.
self.spec = EnvSpec('template-v0', max_episode_steps=max_episode_steps)
self.spec = EnvSpec("template-v0", max_episode_steps=max_episode_steps)
self._episode_steps = 0

# Use gym's spaces to define the shapes and ranges of states and actions.
# observation_space: definition of states's shape and its ranges
# action_space: definition of actions's shape and its ranges
observation_shape = (10, ) # Example 10 dimensional state with range of [0.0, 1.0] each.
observation_shape = (10,) # Example 10 dimensional state with range of [0.0, 1.0] each.
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=observation_shape)
action_shape = (1, ) # 1 dimensional continuous action with range of [0.0, 1.0].
action_shape = (1,) # 1 dimensional continuous action with range of [0.0, 1.0].
self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=action_shape)

def reset(self):
Expand Down
4 changes: 2 additions & 2 deletions examples/rl_project_template/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sony Group Corporation.
# Copyright 2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,7 +58,7 @@ def pi(self, state: nn.Variable) -> Distribution:
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=256, name="linear2")
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3")
h = NPF.affine(h, n_outmaps=self._action_dim * 2, name="linear3")
reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim))

# Split the output into mean and variance of the Gaussian distribution.
Expand Down
12 changes: 6 additions & 6 deletions examples/rl_project_template/training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Sony Group Corporation.
# Copyright 2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,7 +75,7 @@ def run_training(args):
# Save trained parameters every "timing" steps.
# Without this, the parameters will not be saved.
# We recommend saving parameters at every evaluation timing.
outdir = pathlib.Path(args.save_dir) / 'snapshots'
outdir = pathlib.Path(args.save_dir) / "snapshots"
save_snapshot_hook = H.SaveSnapshotHook(outdir=outdir, timing=1000)

# All instantiated hooks should be set at once.
Expand All @@ -87,13 +87,13 @@ def run_training(args):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--save-dir', type=str, default=str(pathlib.Path(__file__).parent))
parser.add_argument('--seed', type=int, default=0)
parser.add_argument("--gpu", type=int, default=-1)
parser.add_argument("--save-dir", type=str, default=str(pathlib.Path(__file__).parent))
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()

run_training(args)


if __name__ == '__main__':
if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/save_load_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,7 @@ def main():
env = E.DummyContinuous()
ddpg = A.DDPG(env, config=config)

outdir = './save_load_snapshot'
outdir = "./save_load_snapshot"

# This actually saves the model and solver state right after the algorithm construction
snapshot_dir = serializers.save_snapshot(outdir, ddpg)
Expand Down
13 changes: 6 additions & 7 deletions examples/training_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,9 +23,9 @@

def build_env(seed=None):
try:
env = gym.make('Pendulum-v0')
env = gym.make("Pendulum-v0")
except gym.error.DeprecatedEnv:
env = gym.make('Pendulum-v1')
env = gym.make("Pendulum-v1")
env = NumpyFloat32Env(env)
env = ScreenRenderEnv(env)
env.seed(seed)
Expand All @@ -36,15 +36,14 @@ def main():
# Evaluate the trained network (Optional)
eval_env = build_env(seed=100)
evaluation_hook = H.EvaluationHook(
eval_env,
timing=1000,
writer=W.FileWriter(outdir='./pendulum_v0_ddpg_results', file_prefix='evaluation_result'))
eval_env, timing=1000, writer=W.FileWriter(outdir="./pendulum_v0_ddpg_results", file_prefix="evaluation_result")
)

# Pring iteration number every 100 iterations.
iteration_num_hook = H.IterationNumHook(timing=100)

# Save trained algorithm snapshot (Optional)
save_snapshot_hook = H.SaveSnapshotHook('./pendulum_v0_ddpg_results', timing=1000)
save_snapshot_hook = H.SaveSnapshotHook("./pendulum_v0_ddpg_results", timing=1000)

train_env = build_env()
config = A.DDPGConfig(start_timesteps=200, gpu_id=0)
Expand Down
Loading

0 comments on commit 850181f

Please sign in to comment.